use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardRow {
pub payload: Vec<u8>,
pub sort_key: Vec<u8>,
pub shard_id: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SortDirection {
Ascending,
Descending,
}
pub struct OrderByMerger {
rows: Vec<ShardRow>,
direction: SortDirection,
}
impl OrderByMerger {
pub fn new(direction: SortDirection) -> Self {
Self {
rows: Vec::new(),
direction,
}
}
pub fn add_shard_rows(&mut self, rows: Vec<ShardRow>) {
self.rows.extend(rows);
}
pub fn merge(&mut self, global_limit: usize) -> Vec<ShardRow> {
match self.direction {
SortDirection::Ascending => {
self.rows.sort_by(|a, b| a.sort_key.cmp(&b.sort_key));
}
SortDirection::Descending => {
self.rows.sort_by(|a, b| b.sort_key.cmp(&a.sort_key));
}
}
self.rows.truncate(global_limit);
self.rows.clone()
}
pub fn total_rows(&self) -> usize {
self.rows.len()
}
}
pub fn encode_sort_key_i64(value: i64) -> Vec<u8> {
let unsigned = (value as u64) ^ (1u64 << 63);
unsigned.to_be_bytes().to_vec()
}
pub fn encode_sort_key_f64(value: f64) -> Vec<u8> {
let bits = value.to_bits();
let ordered = if bits >> 63 == 1 {
!bits } else {
bits | (1u64 << 63) };
ordered.to_be_bytes().to_vec()
}
pub fn encode_sort_key_string(value: &str) -> Vec<u8> {
value.as_bytes().to_vec()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn merge_sort_ascending() {
let mut merger = OrderByMerger::new(SortDirection::Ascending);
merger.add_shard_rows(vec![
ShardRow {
payload: b"alice".to_vec(),
sort_key: encode_sort_key_i64(20),
shard_id: 0,
},
ShardRow {
payload: b"bob".to_vec(),
sort_key: encode_sort_key_i64(30),
shard_id: 0,
},
ShardRow {
payload: b"carol".to_vec(),
sort_key: encode_sort_key_i64(40),
shard_id: 0,
},
]);
merger.add_shard_rows(vec![
ShardRow {
payload: b"dave".to_vec(),
sort_key: encode_sort_key_i64(15),
shard_id: 1,
},
ShardRow {
payload: b"eve".to_vec(),
sort_key: encode_sort_key_i64(25),
shard_id: 1,
},
ShardRow {
payload: b"frank".to_vec(),
sort_key: encode_sort_key_i64(35),
shard_id: 1,
},
]);
let result = merger.merge(3); assert_eq!(result.len(), 3);
assert_eq!(result[0].payload, b"dave");
assert_eq!(result[1].payload, b"alice");
assert_eq!(result[2].payload, b"eve");
}
#[test]
fn merge_sort_descending() {
let mut merger = OrderByMerger::new(SortDirection::Descending);
merger.add_shard_rows(vec![
ShardRow {
payload: b"a".to_vec(),
sort_key: encode_sort_key_i64(100),
shard_id: 0,
},
ShardRow {
payload: b"b".to_vec(),
sort_key: encode_sort_key_i64(50),
shard_id: 0,
},
]);
merger.add_shard_rows(vec![
ShardRow {
payload: b"c".to_vec(),
sort_key: encode_sort_key_i64(90),
shard_id: 1,
},
ShardRow {
payload: b"d".to_vec(),
sort_key: encode_sort_key_i64(10),
shard_id: 1,
},
]);
let result = merger.merge(2);
assert_eq!(result.len(), 2);
assert_eq!(result[0].payload, b"a"); assert_eq!(result[1].payload, b"c"); }
#[test]
fn sort_key_i64_ordering() {
let neg = encode_sort_key_i64(-100);
let zero = encode_sort_key_i64(0);
let pos = encode_sort_key_i64(100);
assert!(neg < zero);
assert!(zero < pos);
}
#[test]
fn sort_key_f64_ordering() {
let neg = encode_sort_key_f64(-1.5);
let zero = encode_sort_key_f64(0.0);
let pos = encode_sort_key_f64(1.5);
assert!(neg < zero);
assert!(zero < pos);
}
#[test]
fn sort_key_string_ordering() {
let a = encode_sort_key_string("alice");
let b = encode_sort_key_string("bob");
assert!(a < b);
}
}