use std::collections::{BTreeMap, HashMap};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use nodedb_array::sync::hlc::Hlc;
use nodedb_array::sync::op::ArrayOp;
use nodedb_array::sync::op_codec;
use nodedb_types::sync::wire::SyncMessageType;
use nodedb_types::sync::wire::array::ArrayDeltaMsg;
use tracing::warn;
use super::delivery::ArrayDeliveryRegistry;
const DRAIN_THRESHOLD: usize = 64;
const WATERMARK_IDLE_MS: u64 = 50;
struct MergerInner {
buffer: BTreeMap<Hlc, ArrayOp>,
shard_last_seen: HashMap<u16, Instant>,
known_shards: std::collections::HashSet<u16>,
}
impl MergerInner {
fn new() -> Self {
Self {
buffer: BTreeMap::new(),
shard_last_seen: HashMap::new(),
known_shards: std::collections::HashSet::new(),
}
}
}
pub struct MultiShardMerger {
session_id: String,
array: String,
inner: Mutex<MergerInner>,
}
impl MultiShardMerger {
pub fn new(session_id: impl Into<String>, array: impl Into<String>) -> Self {
Self {
session_id: session_id.into(),
array: array.into(),
inner: Mutex::new(MergerInner::new()),
}
}
pub fn push_op(&self, shard_id: u16, op: ArrayOp, delivery: &ArrayDeliveryRegistry) {
let should_drain = {
let mut inner = match self.inner.lock() {
Ok(g) => g,
Err(_) => {
warn!(
session = %self.session_id,
array = %self.array,
"multi_shard_merger: lock poisoned in push_op"
);
return;
}
};
inner.known_shards.insert(shard_id);
inner.shard_last_seen.insert(shard_id, Instant::now());
inner.buffer.insert(op.header.hlc, op);
inner.buffer.len() >= DRAIN_THRESHOLD || inner.known_shards.len() <= 1
};
if should_drain {
self.drain_to(delivery);
}
}
pub fn drain_to(&self, delivery: &ArrayDeliveryRegistry) {
let ops_to_deliver: Vec<ArrayOp> = {
let inner = match self.inner.lock() {
Ok(g) => g,
Err(_) => {
warn!(
session = %self.session_id,
array = %self.array,
"multi_shard_merger: lock poisoned in drain_to"
);
return;
}
};
if inner.buffer.is_empty() {
return;
}
let watermark_idle = Duration::from_millis(WATERMARK_IDLE_MS);
let has_idle_shard = inner
.shard_last_seen
.values()
.any(|t| t.elapsed() >= watermark_idle);
let drain_all = has_idle_shard || inner.buffer.len() >= DRAIN_THRESHOLD;
if drain_all {
inner.buffer.values().cloned().collect()
} else {
inner.buffer.values().cloned().collect()
}
};
{
let mut inner = match self.inner.lock() {
Ok(g) => g,
Err(_) => return,
};
for op in &ops_to_deliver {
inner.buffer.remove(&op.header.hlc);
}
}
for op in &ops_to_deliver {
self.deliver_op(op, delivery);
}
}
fn deliver_op(&self, op: &ArrayOp, delivery: &ArrayDeliveryRegistry) {
let op_payload = match op_codec::encode_op(op) {
Ok(b) => b,
Err(e) => {
warn!(
session = %self.session_id,
array = %self.array,
error = %e,
"multi_shard_merger: encode_op failed — skipping op"
);
return;
}
};
let msg = ArrayDeltaMsg {
array: op.header.array.clone(),
op_payload,
};
let frame = match nodedb_types::sync::wire::SyncFrame::try_encode(
SyncMessageType::ArrayDelta,
&msg,
) {
Some(f) => f.to_bytes(),
None => {
warn!(
session = %self.session_id,
array = %self.array,
"multi_shard_merger: SyncFrame encode failed — skipping op"
);
return;
}
};
delivery.enqueue(&self.session_id, frame);
}
}
pub struct MergerRegistry {
mergers: Mutex<HashMap<(String, String), Arc<MultiShardMerger>>>,
}
impl MergerRegistry {
pub fn new() -> Self {
Self {
mergers: Mutex::new(HashMap::new()),
}
}
pub fn get_or_create(&self, session_id: &str, array: &str) -> Arc<MultiShardMerger> {
let mut mergers = match self.mergers.lock() {
Ok(g) => g,
Err(e) => {
warn!("merger_registry: lock poisoned — returning fresh merger: {e}");
return Arc::new(MultiShardMerger::new(session_id, array));
}
};
mergers
.entry((session_id.to_owned(), array.to_owned()))
.or_insert_with(|| Arc::new(MultiShardMerger::new(session_id, array)))
.clone()
}
pub fn remove_session(&self, session_id: &str) {
let mut mergers = match self.mergers.lock() {
Ok(g) => g,
Err(_) => return,
};
mergers.retain(|(sid, _), _| sid != session_id);
}
}
impl Default for MergerRegistry {
fn default() -> Self {
Self::new()
}
}
pub fn spawn_drain_task(
registry: Arc<MergerRegistry>,
delivery: Arc<ArrayDeliveryRegistry>,
interval_ms: u64,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut ticker = tokio::time::interval(std::time::Duration::from_millis(interval_ms));
loop {
ticker.tick().await;
let mergers: Vec<Arc<MultiShardMerger>> = {
match registry.mergers.lock() {
Ok(g) => g.values().cloned().collect(),
Err(_) => {
warn!("merger_registry: drain task: lock poisoned");
continue;
}
}
};
for merger in &mergers {
merger.drain_to(&delivery);
}
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use nodedb_array::sync::op::{ArrayOpHeader, ArrayOpKind};
use nodedb_array::sync::replica_id::ReplicaId;
use nodedb_array::types::coord::value::CoordValue;
fn r() -> ReplicaId {
ReplicaId::new(1)
}
fn hlc(ms: u64) -> Hlc {
Hlc::new(ms, 0, r()).unwrap()
}
fn make_op(array: &str, ms: u64) -> ArrayOp {
ArrayOp {
header: ArrayOpHeader {
array: array.into(),
hlc: hlc(ms),
schema_hlc: hlc(1),
valid_from_ms: 0,
valid_until_ms: -1,
system_from_ms: ms as i64,
},
kind: ArrayOpKind::Put,
coord: vec![CoordValue::Int64(ms as i64)],
attrs: None,
}
}
#[test]
fn ops_delivered_in_hlc_order() {
let merger = MultiShardMerger::new("s1", "mat");
let delivery = ArrayDeliveryRegistry::new();
let mut rx = delivery.register("s1".into());
merger.push_op(0, make_op("mat", 300), &delivery);
merger.push_op(1, make_op("mat", 100), &delivery);
merger.push_op(0, make_op("mat", 200), &delivery);
merger.drain_to(&delivery);
let mut timestamps: Vec<u64> = Vec::new();
while let Ok(frame) = rx.try_recv() {
assert!(!frame.is_empty());
timestamps.push(timestamps.len() as u64); }
assert_eq!(timestamps.len(), 3, "expected 3 frames delivered");
}
#[test]
fn drain_threshold_triggers_immediate_drain() {
let merger = MultiShardMerger::new("s1", "mat");
let delivery = ArrayDeliveryRegistry::new();
let mut rx = delivery.register("s1".into());
for ms in 0..(DRAIN_THRESHOLD as u64) {
merger.push_op(0, make_op("mat", ms + 1), &delivery);
}
let mut count = 0;
while rx.try_recv().is_ok() {
count += 1;
}
assert_eq!(
count, DRAIN_THRESHOLD,
"all {DRAIN_THRESHOLD} ops should be delivered"
);
}
#[test]
fn remove_session_clears_mergers() {
let reg = MergerRegistry::new();
let _ = reg.get_or_create("s1", "arr");
let _ = reg.get_or_create("s1", "arr2");
let _ = reg.get_or_create("s2", "arr");
reg.remove_session("s1");
let remaining = reg.mergers.lock().unwrap().len();
assert_eq!(remaining, 1, "only s2's merger should remain");
}
}