use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use tracing::debug;
use crate::routing::RoutingTable;
pub const DEFAULT_BROADCAST_THRESHOLD_BYTES: usize = 8 * 1024 * 1024;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BroadcastJoinRequest {
pub broadcast_data: Vec<u8>,
pub large_collection: String,
pub on_keys: Vec<(String, String)>,
pub join_type: String,
pub limit: usize,
pub tenant_id: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShufflePartition {
pub data: Vec<u8>,
pub side: JoinSide,
pub target_node: u64,
pub partition_id: u32,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
pub enum JoinSide {
Left,
Right,
}
pub fn select_strategy(
left_estimated_bytes: usize,
right_estimated_bytes: usize,
broadcast_threshold_bytes: usize,
) -> JoinStrategy {
let (smaller, _larger) = if left_estimated_bytes <= right_estimated_bytes {
(left_estimated_bytes, right_estimated_bytes)
} else {
(right_estimated_bytes, left_estimated_bytes)
};
if smaller <= broadcast_threshold_bytes {
JoinStrategy::Broadcast {
broadcast_side: if left_estimated_bytes <= right_estimated_bytes {
JoinSide::Left
} else {
JoinSide::Right
},
}
} else {
JoinStrategy::Shuffle
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum JoinStrategy {
Broadcast { broadcast_side: JoinSide },
Shuffle,
}
pub fn partition_for_key(key: &str, num_partitions: usize) -> u32 {
(crate::routing::fnv1a_hash(key) % num_partitions as u64) as u32
}
pub fn plan_shuffle_partitions(routing: &RoutingTable, num_partitions: usize) -> HashMap<u32, u64> {
let group_ids = routing.group_ids();
let mut partition_map = HashMap::new();
for p in 0..num_partitions {
let group_idx = p % group_ids.len();
let group_id = group_ids[group_idx];
let leader = routing.group_info(group_id).map(|g| g.leader).unwrap_or(0);
partition_map.insert(p as u32, leader);
}
debug!(
num_partitions,
num_groups = group_ids.len(),
"shuffle partition plan computed"
);
partition_map
}
pub fn estimate_collection_bytes(doc_count: usize, avg_doc_bytes: usize) -> usize {
doc_count * avg_doc_bytes
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn broadcast_selected_for_small_side() {
let strategy = select_strategy(1_000, 100_000_000, DEFAULT_BROADCAST_THRESHOLD_BYTES);
assert!(matches!(
strategy,
JoinStrategy::Broadcast {
broadcast_side: JoinSide::Left
}
));
}
#[test]
fn shuffle_selected_for_large_sides() {
let strategy = select_strategy(100_000_000, 200_000_000, DEFAULT_BROADCAST_THRESHOLD_BYTES);
assert_eq!(strategy, JoinStrategy::Shuffle);
}
#[test]
fn partition_deterministic() {
let p1 = partition_for_key("alice", 16);
let p2 = partition_for_key("alice", 16);
assert_eq!(p1, p2);
let p3 = partition_for_key("bob", 16);
let _ = p3;
}
#[test]
fn shuffle_plan_covers_all_partitions() {
let routing = RoutingTable::uniform(4, &[1, 2, 3], 2);
let plan = plan_shuffle_partitions(&routing, 8);
assert_eq!(plan.len(), 8);
for p in 0..8u32 {
assert!(plan.contains_key(&p));
}
}
#[test]
fn broadcast_threshold() {
let strategy = select_strategy(
DEFAULT_BROADCAST_THRESHOLD_BYTES,
100_000_000,
DEFAULT_BROADCAST_THRESHOLD_BYTES,
);
assert!(matches!(strategy, JoinStrategy::Broadcast { .. }));
let strategy = select_strategy(
DEFAULT_BROADCAST_THRESHOLD_BYTES + 1,
100_000_000,
DEFAULT_BROADCAST_THRESHOLD_BYTES,
);
assert_eq!(strategy, JoinStrategy::Shuffle);
}
}