use serde::{Deserialize, Serialize};
use super::merge::{PartialAgg, PartialAggMerger};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScatterGatherPlan {
pub collection: String,
pub start_ms: i64,
pub end_ms: i64,
pub value_column: String,
pub bucket_interval_ms: i64,
pub shard_ids: Vec<u16>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardResult {
pub shard_id: u16,
pub partials: Vec<PartialAgg>,
}
pub fn merge_shard_results(shard_results: &[ShardResult]) -> Vec<PartialAgg> {
let mut merger = PartialAggMerger::new();
for result in shard_results {
merger.add_shard_results(&result.partials);
}
merger.finalize()
}
pub fn shards_for_collection(_collection: &str, total_shards: u16) -> Vec<u16> {
if total_shards == 0 {
return Vec::new();
}
(0..total_shards).collect()
}
pub fn consolidate_aggregate_results(shard_results: &[ShardResult]) -> Vec<PartialAgg> {
merge_shard_results(shard_results)
}
pub fn estimate_scatter_cost(
shard_row_counts: &[u64], selectivity: f64, ) -> u64 {
let total_rows: u64 = shard_row_counts.iter().sum();
(total_rows as f64 * selectivity) as u64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn merge_two_shards() {
let shard_a = ShardResult {
shard_id: 0,
partials: vec![
PartialAgg {
count: 100,
sum: 5000.0,
..PartialAgg::from_single(0, 1, 50.0)
},
PartialAgg {
count: 100,
sum: 6000.0,
..PartialAgg::from_single(60_000, 60001, 60.0)
},
],
};
let shard_b = ShardResult {
shard_id: 1,
partials: vec![
PartialAgg {
count: 80,
sum: 4000.0,
..PartialAgg::from_single(0, 2, 50.0)
},
PartialAgg {
count: 80,
sum: 5600.0,
..PartialAgg::from_single(60_000, 60002, 70.0)
},
],
};
let merged = merge_shard_results(&[shard_a, shard_b]);
assert_eq!(merged.len(), 2);
assert_eq!(merged[0].count, 180);
assert_eq!(merged[0].sum, 9000.0);
assert_eq!(merged[1].count, 180);
}
#[test]
fn shards_for_collection_all() {
let shards = shards_for_collection("metrics", 10);
assert_eq!(shards.len(), 10);
}
#[test]
fn estimate_cost() {
let row_counts = vec![1_000_000u64; 10]; let cost = estimate_scatter_cost(&row_counts, 0.01); assert_eq!(cost, 100_000); }
}