use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use tracing::{debug, info};
use super::definition::{ShapeDefinition, ShapeId};
use super::registry::ShapeRegistry;
use crate::control::server::sync::wire::*;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShapeSubscribeMsg {
pub shape: ShapeDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShapeSnapshotMsg {
pub shape_id: ShapeId,
pub data: Vec<u8>,
pub snapshot_lsn: u64,
pub doc_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShapeDeltaMsg {
pub shape_id: ShapeId,
pub collection: String,
pub document_id: String,
pub operation: String,
pub delta: Vec<u8>,
pub lsn: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShapeUnsubscribeMsg {
pub shape_id: ShapeId,
}
#[derive(Debug, Clone)]
pub struct ShapeSnapshotData {
pub data: Vec<u8>,
pub doc_count: usize,
}
impl ShapeSnapshotData {
pub fn empty() -> Self {
Self {
data: Vec::new(),
doc_count: 0,
}
}
}
pub fn handle_subscribe<F>(
session_id: &str,
tenant_id: u32,
msg: &ShapeSubscribeMsg,
registry: &ShapeRegistry,
current_lsn: u64,
snapshot_provider: F,
) -> SyncFrame
where
F: FnOnce(&ShapeDefinition, u64) -> ShapeSnapshotData,
{
let shape = msg.shape.clone();
let shape_id = shape.shape_id.clone();
registry.subscribe(session_id, tenant_id, shape.clone());
let snapshot_data = snapshot_provider(&shape, current_lsn);
let snapshot = ShapeSnapshotMsg {
shape_id,
data: snapshot_data.data,
snapshot_lsn: current_lsn,
doc_count: snapshot_data.doc_count,
};
info!(
session = session_id,
shape_id = %msg.shape.shape_id,
lsn = current_lsn,
doc_count = snapshot.doc_count,
"shape subscribed, snapshot sent"
);
SyncFrame::encode_or_empty(SyncMessageType::ShapeSnapshot, &snapshot)
}
pub fn handle_unsubscribe(session_id: &str, msg: &ShapeUnsubscribeMsg, registry: &ShapeRegistry) {
registry.unsubscribe(session_id, &msg.shape_id);
debug!(session = session_id, shape_id = %msg.shape_id, "shape unsubscribed");
}
pub fn evaluate_and_generate_deltas(
tenant_id: u32,
collection: &str,
doc_id: &str,
operation: &str,
delta: &[u8],
lsn: u64,
registry: &ShapeRegistry,
) -> Vec<(String, SyncFrame)> {
let matches = registry.evaluate_mutation(tenant_id, collection, doc_id);
matches
.into_iter()
.filter_map(|(session_id, shape_id)| {
let msg = ShapeDeltaMsg {
shape_id,
collection: collection.to_string(),
document_id: doc_id.to_string(),
operation: operation.to_string(),
delta: delta.to_vec(),
lsn,
};
let frame = SyncFrame::new_msgpack(SyncMessageType::ShapeDelta, &msg)?;
Some((session_id, frame))
})
.collect()
}
pub struct ShapeCompactor {
delta_counts: HashMap<(String, ShapeId), usize>,
compact_threshold: usize,
max_delta_age_secs: u64,
}
impl ShapeCompactor {
pub fn new(compact_threshold: usize, max_delta_age_secs: u64) -> Self {
Self {
delta_counts: HashMap::new(),
compact_threshold,
max_delta_age_secs,
}
}
pub fn record_delta(&mut self, session_id: &str, shape_id: &str) -> bool {
let key = (session_id.to_string(), shape_id.to_string());
let count = self.delta_counts.entry(key).or_insert(0);
*count += 1;
*count >= self.compact_threshold
}
pub fn compaction_done(&mut self, session_id: &str, shape_id: &str) {
let key = (session_id.to_string(), shape_id.to_string());
self.delta_counts.remove(&key);
}
pub fn max_delta_age_secs(&self) -> u64 {
self.max_delta_age_secs
}
pub fn shapes_needing_compaction(&self) -> Vec<(String, ShapeId)> {
self.delta_counts
.iter()
.filter(|(_, count)| **count >= self.compact_threshold)
.map(|((s, sh), _)| (s.clone(), sh.clone()))
.collect()
}
}
pub fn resolve_historical_delta(
client_lsn: u64,
oldest_wal_lsn: u64,
loro_history_lsn: u64,
) -> HistoricalResolution {
if client_lsn >= oldest_wal_lsn {
HistoricalResolution::WalDelta
} else if client_lsn >= loro_history_lsn {
HistoricalResolution::LoroDelta
} else {
HistoricalResolution::FullSnapshot
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HistoricalResolution {
WalDelta,
LoroDelta,
FullSnapshot,
}
#[cfg(test)]
mod tests {
use super::super::definition::ShapeType;
use super::*;
fn make_shape(id: &str) -> ShapeDefinition {
ShapeDefinition {
shape_id: id.into(),
tenant_id: 1,
shape_type: ShapeType::Document {
collection: "orders".into(),
predicate: Vec::new(),
},
description: "test shape".into(),
field_filter: vec![],
}
}
#[test]
fn subscribe_sends_snapshot() {
let registry = ShapeRegistry::new();
let msg = ShapeSubscribeMsg {
shape: make_shape("sh1"),
};
let frame = handle_subscribe("s1", 1, &msg, ®istry, 100, |_shape, _lsn| {
ShapeSnapshotData {
data: rmp_serde::to_vec_named(&vec!["doc1", "doc2"]).unwrap_or_default(),
doc_count: 2,
}
});
assert_eq!(frame.msg_type, SyncMessageType::ShapeSnapshot);
let snapshot: ShapeSnapshotMsg = frame.decode_body().unwrap();
assert_eq!(snapshot.shape_id, "sh1");
assert_eq!(snapshot.snapshot_lsn, 100);
assert_eq!(snapshot.doc_count, 2);
assert!(!snapshot.data.is_empty());
assert_eq!(registry.total_shapes(), 1);
}
#[test]
fn unsubscribe_removes() {
let registry = ShapeRegistry::new();
registry.subscribe("s1", 1, make_shape("sh1"));
assert_eq!(registry.total_shapes(), 1);
handle_unsubscribe(
"s1",
&ShapeUnsubscribeMsg {
shape_id: "sh1".into(),
},
®istry,
);
assert_eq!(registry.total_shapes(), 0);
}
#[test]
fn evaluate_generates_deltas() {
let registry = ShapeRegistry::new();
registry.subscribe("s1", 1, make_shape("sh1"));
let deltas = evaluate_and_generate_deltas(
1,
"orders",
"o42",
"INSERT",
b"delta_bytes",
200,
®istry,
);
assert_eq!(deltas.len(), 1);
assert_eq!(deltas[0].0, "s1");
assert_eq!(deltas[0].1.msg_type, SyncMessageType::ShapeDelta);
}
#[test]
fn compactor_triggers() {
let mut compactor = ShapeCompactor::new(3, 3600);
assert!(!compactor.record_delta("s1", "sh1"));
assert!(!compactor.record_delta("s1", "sh1"));
assert!(compactor.record_delta("s1", "sh1"));
assert_eq!(compactor.shapes_needing_compaction().len(), 1);
compactor.compaction_done("s1", "sh1");
assert!(compactor.shapes_needing_compaction().is_empty());
}
#[test]
fn historical_resolution() {
assert_eq!(
resolve_historical_delta(100, 50, 10),
HistoricalResolution::WalDelta
);
assert_eq!(
resolve_historical_delta(30, 50, 10),
HistoricalResolution::LoroDelta
);
assert_eq!(
resolve_historical_delta(5, 50, 10),
HistoricalResolution::FullSnapshot
);
}
}