Skip to main content

nodedb_cluster/distributed_document/
merge_sort.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3//! Distributed ORDER BY + LIMIT merge for document scans.
4//!
5//! Each shard applies ORDER BY + LIMIT locally and returns its top-N rows.
6//! The coordinator performs an N-way merge sort on the (shards × N) rows
7//! and returns the global top-N.
8//!
9//! This is NOT simple concatenation — Shard A's top-10 might all rank
10//! below Shard B's top-10 globally.
11
12use serde::{Deserialize, Serialize};
13
14/// A row from a shard, with a sort key for merge-sorting.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ShardRow {
17    /// The row payload (JSON bytes or MessagePack).
18    pub payload: Vec<u8>,
19    /// Sort key extracted from the ORDER BY column(s).
20    /// Encoded as comparable bytes (big-endian for numbers, UTF-8 for strings).
21    pub sort_key: Vec<u8>,
22    /// Which shard produced this row.
23    pub shard_id: u32,
24}
25
26/// Sort direction for ORDER BY.
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28pub enum SortDirection {
29    Ascending,
30    Descending,
31}
32
33/// N-way merge-sort merger for distributed ORDER BY + LIMIT.
34pub struct OrderByMerger {
35    /// All rows from all shards, unsorted.
36    rows: Vec<ShardRow>,
37    /// Sort direction.
38    direction: SortDirection,
39}
40
41impl OrderByMerger {
42    pub fn new(direction: SortDirection) -> Self {
43        Self {
44            rows: Vec::new(),
45            direction,
46        }
47    }
48
49    /// Add a shard's locally-sorted, locally-limited rows.
50    pub fn add_shard_rows(&mut self, rows: Vec<ShardRow>) {
51        self.rows.extend(rows);
52    }
53
54    /// Perform the global merge sort and apply the global LIMIT.
55    ///
56    /// Each shard already applied `ORDER BY + LIMIT` locally, so we have
57    /// at most `num_shards × limit` rows. The global sort + limit produces
58    /// the correct result.
59    pub fn merge(&mut self, global_limit: usize) -> Vec<ShardRow> {
60        match self.direction {
61            SortDirection::Ascending => {
62                self.rows.sort_by(|a, b| a.sort_key.cmp(&b.sort_key));
63            }
64            SortDirection::Descending => {
65                self.rows.sort_by(|a, b| b.sort_key.cmp(&a.sort_key));
66            }
67        }
68        self.rows.truncate(global_limit);
69        self.rows.clone()
70    }
71
72    /// Total rows collected before merge.
73    pub fn total_rows(&self) -> usize {
74        self.rows.len()
75    }
76}
77
78/// Encode a sort key from a typed value for byte-comparable ordering.
79///
80/// Numbers are encoded big-endian with sign flip for correct ordering.
81/// Strings are encoded as UTF-8 (natural lexicographic order).
82pub fn encode_sort_key_i64(value: i64) -> Vec<u8> {
83    // Flip sign bit so negative < positive in unsigned byte ordering.
84    let unsigned = (value as u64) ^ (1u64 << 63);
85    unsigned.to_be_bytes().to_vec()
86}
87
88pub fn encode_sort_key_f64(value: f64) -> Vec<u8> {
89    let bits = value.to_bits();
90    // IEEE 754 float ordering trick: flip all bits if negative, flip sign bit if positive.
91    let ordered = if bits >> 63 == 1 {
92        !bits // Negative: flip all bits.
93    } else {
94        bits | (1u64 << 63) // Positive: flip sign bit.
95    };
96    ordered.to_be_bytes().to_vec()
97}
98
99pub fn encode_sort_key_string(value: &str) -> Vec<u8> {
100    value.as_bytes().to_vec()
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn merge_sort_ascending() {
109        let mut merger = OrderByMerger::new(SortDirection::Ascending);
110
111        // Shard 0: ages [20, 30, 40] (locally sorted, limit 3).
112        merger.add_shard_rows(vec![
113            ShardRow {
114                payload: b"alice".to_vec(),
115                sort_key: encode_sort_key_i64(20),
116                shard_id: 0,
117            },
118            ShardRow {
119                payload: b"bob".to_vec(),
120                sort_key: encode_sort_key_i64(30),
121                shard_id: 0,
122            },
123            ShardRow {
124                payload: b"carol".to_vec(),
125                sort_key: encode_sort_key_i64(40),
126                shard_id: 0,
127            },
128        ]);
129
130        // Shard 1: ages [15, 25, 35] (locally sorted, limit 3).
131        merger.add_shard_rows(vec![
132            ShardRow {
133                payload: b"dave".to_vec(),
134                sort_key: encode_sort_key_i64(15),
135                shard_id: 1,
136            },
137            ShardRow {
138                payload: b"eve".to_vec(),
139                sort_key: encode_sort_key_i64(25),
140                shard_id: 1,
141            },
142            ShardRow {
143                payload: b"frank".to_vec(),
144                sort_key: encode_sort_key_i64(35),
145                shard_id: 1,
146            },
147        ]);
148
149        let result = merger.merge(3); // Global LIMIT 3.
150        assert_eq!(result.len(), 3);
151        // Youngest 3: dave(15), alice(20), eve(25).
152        assert_eq!(result[0].payload, b"dave");
153        assert_eq!(result[1].payload, b"alice");
154        assert_eq!(result[2].payload, b"eve");
155    }
156
157    #[test]
158    fn merge_sort_descending() {
159        let mut merger = OrderByMerger::new(SortDirection::Descending);
160
161        merger.add_shard_rows(vec![
162            ShardRow {
163                payload: b"a".to_vec(),
164                sort_key: encode_sort_key_i64(100),
165                shard_id: 0,
166            },
167            ShardRow {
168                payload: b"b".to_vec(),
169                sort_key: encode_sort_key_i64(50),
170                shard_id: 0,
171            },
172        ]);
173        merger.add_shard_rows(vec![
174            ShardRow {
175                payload: b"c".to_vec(),
176                sort_key: encode_sort_key_i64(90),
177                shard_id: 1,
178            },
179            ShardRow {
180                payload: b"d".to_vec(),
181                sort_key: encode_sort_key_i64(10),
182                shard_id: 1,
183            },
184        ]);
185
186        let result = merger.merge(2);
187        assert_eq!(result.len(), 2);
188        assert_eq!(result[0].payload, b"a"); // 100 (highest)
189        assert_eq!(result[1].payload, b"c"); // 90
190    }
191
192    #[test]
193    fn sort_key_i64_ordering() {
194        let neg = encode_sort_key_i64(-100);
195        let zero = encode_sort_key_i64(0);
196        let pos = encode_sort_key_i64(100);
197        assert!(neg < zero);
198        assert!(zero < pos);
199    }
200
201    #[test]
202    fn sort_key_f64_ordering() {
203        let neg = encode_sort_key_f64(-1.5);
204        let zero = encode_sort_key_f64(0.0);
205        let pos = encode_sort_key_f64(1.5);
206        assert!(neg < zero);
207        assert!(zero < pos);
208    }
209
210    #[test]
211    fn sort_key_string_ordering() {
212        let a = encode_sort_key_string("alice");
213        let b = encode_sort_key_string("bob");
214        assert!(a < b);
215    }
216}