nodedb_cluster/
distributed_join.rs1use std::collections::HashMap;
14
15use serde::{Deserialize, Serialize};
16use tracing::debug;
17
18use crate::routing::RoutingTable;
19
20pub const DEFAULT_BROADCAST_THRESHOLD_BYTES: usize = 8 * 1024 * 1024; #[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct BroadcastJoinRequest {
29 pub broadcast_data: Vec<u8>,
31 pub large_collection: String,
33 pub on_keys: Vec<(String, String)>,
35 pub join_type: String,
37 pub limit: usize,
39 pub tenant_id: u32,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ShufflePartition {
46 pub data: Vec<u8>,
48 pub side: JoinSide,
50 pub target_node: u64,
52 pub partition_id: u32,
54}
55
56#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
57pub enum JoinSide {
58 Left,
59 Right,
60}
61
62pub fn select_strategy(
68 left_estimated_bytes: usize,
69 right_estimated_bytes: usize,
70 broadcast_threshold_bytes: usize,
71) -> JoinStrategy {
72 let (smaller, _larger) = if left_estimated_bytes <= right_estimated_bytes {
73 (left_estimated_bytes, right_estimated_bytes)
74 } else {
75 (right_estimated_bytes, left_estimated_bytes)
76 };
77
78 if smaller <= broadcast_threshold_bytes {
79 JoinStrategy::Broadcast {
80 broadcast_side: if left_estimated_bytes <= right_estimated_bytes {
81 JoinSide::Left
82 } else {
83 JoinSide::Right
84 },
85 }
86 } else {
87 JoinStrategy::Shuffle
88 }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq)]
93pub enum JoinStrategy {
94 Broadcast { broadcast_side: JoinSide },
96 Shuffle,
98}
99
100pub fn partition_for_key(key: &str, num_partitions: usize) -> u32 {
105 (crate::routing::fnv1a_hash(key) % num_partitions as u64) as u32
106}
107
108pub fn plan_shuffle_partitions(routing: &RoutingTable, num_partitions: usize) -> HashMap<u32, u64> {
112 let group_ids = routing.group_ids();
113 let mut partition_map = HashMap::new();
114
115 for p in 0..num_partitions {
116 let group_idx = p % group_ids.len();
117 let group_id = group_ids[group_idx];
118 let leader = routing.group_info(group_id).map(|g| g.leader).unwrap_or(0);
119 partition_map.insert(p as u32, leader);
120 }
121
122 debug!(
123 num_partitions,
124 num_groups = group_ids.len(),
125 "shuffle partition plan computed"
126 );
127 partition_map
128}
129
130pub fn estimate_collection_bytes(doc_count: usize, avg_doc_bytes: usize) -> usize {
135 doc_count * avg_doc_bytes
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn broadcast_selected_for_small_side() {
144 let strategy = select_strategy(1_000, 100_000_000, DEFAULT_BROADCAST_THRESHOLD_BYTES);
145 assert!(matches!(
146 strategy,
147 JoinStrategy::Broadcast {
148 broadcast_side: JoinSide::Left
149 }
150 ));
151 }
152
153 #[test]
154 fn shuffle_selected_for_large_sides() {
155 let strategy = select_strategy(100_000_000, 200_000_000, DEFAULT_BROADCAST_THRESHOLD_BYTES);
156 assert_eq!(strategy, JoinStrategy::Shuffle);
157 }
158
159 #[test]
160 fn partition_deterministic() {
161 let p1 = partition_for_key("alice", 16);
162 let p2 = partition_for_key("alice", 16);
163 assert_eq!(p1, p2);
164
165 let p3 = partition_for_key("bob", 16);
167 let _ = p3;
169 }
170
171 #[test]
172 fn shuffle_plan_covers_all_partitions() {
173 let routing = RoutingTable::uniform(4, &[1, 2, 3], 2);
174 let plan = plan_shuffle_partitions(&routing, 8);
175 assert_eq!(plan.len(), 8);
176 for p in 0..8u32 {
178 assert!(plan.contains_key(&p));
179 }
180 }
181
182 #[test]
183 fn broadcast_threshold() {
184 let strategy = select_strategy(
186 DEFAULT_BROADCAST_THRESHOLD_BYTES,
187 100_000_000,
188 DEFAULT_BROADCAST_THRESHOLD_BYTES,
189 );
190 assert!(matches!(strategy, JoinStrategy::Broadcast { .. }));
191
192 let strategy = select_strategy(
194 DEFAULT_BROADCAST_THRESHOLD_BYTES + 1,
195 100_000_000,
196 DEFAULT_BROADCAST_THRESHOLD_BYTES,
197 );
198 assert_eq!(strategy, JoinStrategy::Shuffle);
199 }
200}