use std::collections::HashMap;
use std::sync::RwLock;
use std::time::Instant;
use tracing::{debug, info, warn};
use nodedb_query::metadata_filter::matches_metadata_filter;
use nodedb_types::filter::MetadataFilter;
use super::definition::{ShapeDefinition, ShapeId, ShapeType};
struct ClientShapes {
shapes: HashMap<ShapeId, ShapeDefinition>,
tenant_id: u64,
last_modified: Instant,
}
pub struct ShapeRegistry {
sessions: RwLock<HashMap<String, ClientShapes>>,
}
impl Default for ShapeRegistry {
fn default() -> Self {
Self::new()
}
}
impl ShapeRegistry {
pub fn new() -> Self {
Self {
sessions: RwLock::new(HashMap::new()),
}
}
pub fn subscribe(&self, session_id: &str, tenant_id: u64, shape: ShapeDefinition) {
let mut sessions =
crate::control::lock_utils::write_or_recover(self.sessions.write(), "shape_sessions");
let client = sessions
.entry(session_id.to_string())
.or_insert_with(|| ClientShapes {
shapes: HashMap::new(),
tenant_id,
last_modified: Instant::now(),
});
info!(
session = session_id,
shape_id = %shape.shape_id,
"shape subscribed"
);
client
.shapes
.insert(ShapeId::from_validated(shape.shape_id.clone()), shape);
client.last_modified = Instant::now();
}
pub fn unsubscribe(&self, session_id: &str, shape_id: &str) -> bool {
let mut sessions =
crate::control::lock_utils::write_or_recover(self.sessions.write(), "shape_sessions");
if let Some(client) = sessions.get_mut(session_id) {
let removed = client.shapes.remove(shape_id).is_some();
if removed {
debug!(session = session_id, shape_id, "shape unsubscribed");
client.last_modified = Instant::now();
}
removed
} else {
false
}
}
pub fn remove_session(&self, session_id: &str) {
let mut sessions =
crate::control::lock_utils::write_or_recover(self.sessions.write(), "shape_sessions");
if let Some(client) = sessions.remove(session_id) {
info!(
session = session_id,
shapes = client.shapes.len(),
"session shapes removed"
);
}
}
pub fn shapes_for_session(&self, session_id: &str) -> Vec<ShapeId> {
let sessions =
crate::control::lock_utils::read_or_recover(self.sessions.read(), "shape_sessions");
sessions
.get(session_id)
.map(|c| c.shapes.keys().cloned().collect())
.unwrap_or_default()
}
pub fn evaluate_mutation(
&self,
tenant_id: u64,
collection: &str,
doc_id: &str,
doc_json: &serde_json::Value,
) -> Vec<(String, ShapeId)> {
let sessions =
crate::control::lock_utils::read_or_recover(self.sessions.read(), "shape_sessions");
let mut matches = Vec::new();
for (session_id, client) in sessions.iter() {
if client.tenant_id != tenant_id {
continue;
}
for (shape_id, shape) in &client.shapes {
if !shape.could_match(collection, doc_id) {
continue;
}
if let ShapeType::Document { predicate, .. } = &shape.shape_type
&& !predicate.is_empty()
{
match zerompk::from_msgpack::<MetadataFilter>(predicate) {
Ok(filter) => {
if !matches_metadata_filter(doc_json, &filter) {
continue;
}
}
Err(err) => {
warn!(
shape_id = %shape_id,
error = %err,
"failed to decode shape predicate; treating shape as non-matching"
);
continue;
}
}
}
matches.push((session_id.clone(), shape_id.clone()));
}
}
matches
}
pub fn session_info(&self, session_id: &str) -> Option<(u64, usize)> {
let sessions =
crate::control::lock_utils::read_or_recover(self.sessions.read(), "shape_sessions");
sessions
.get(session_id)
.map(|c| (c.tenant_id, c.shapes.len()))
}
pub fn total_shapes(&self) -> usize {
let sessions =
crate::control::lock_utils::read_or_recover(self.sessions.read(), "shape_sessions");
sessions.values().map(|c| c.shapes.len()).sum()
}
pub fn active_sessions(&self) -> usize {
let sessions =
crate::control::lock_utils::read_or_recover(self.sessions.read(), "shape_sessions");
sessions.len()
}
pub fn evaluate_array_mutation(
&self,
tenant_id: u64,
array_name: &str,
coord: &[u64],
) -> Vec<(String, ShapeId)> {
let sessions =
crate::control::lock_utils::read_or_recover(self.sessions.read(), "shape_sessions");
let mut matches = Vec::new();
for (session_id, client) in sessions.iter() {
if client.tenant_id != tenant_id {
continue;
}
for (shape_id, shape) in &client.shapes {
if shape.matches_array_op(array_name, coord) {
matches.push((session_id.clone(), shape_id.clone()));
}
}
}
matches
}
pub fn get_shape(&self, session_id: &str, shape_id: &str) -> Option<ShapeDefinition> {
let sessions =
crate::control::lock_utils::read_or_recover(self.sessions.read(), "shape_sessions");
sessions
.get(session_id)
.and_then(|c| c.shapes.get(shape_id).cloned())
}
pub fn export_all(&self) -> Vec<(String, u64, ShapeDefinition)> {
let sessions =
crate::control::lock_utils::read_or_recover(self.sessions.read(), "shape_sessions");
let mut result = Vec::new();
for (session_id, client) in sessions.iter() {
for shape in client.shapes.values() {
result.push((session_id.clone(), client.tenant_id, shape.clone()));
}
}
result
}
pub fn import(&self, shapes: Vec<(String, u64, ShapeDefinition)>) {
let mut sessions =
crate::control::lock_utils::write_or_recover(self.sessions.write(), "shape_sessions");
for (session_id, tenant_id, shape) in shapes {
let client = sessions.entry(session_id).or_insert_with(|| ClientShapes {
shapes: HashMap::new(),
tenant_id,
last_modified: Instant::now(),
});
client
.shapes
.insert(ShapeId::from_validated(shape.shape_id.clone()), shape);
}
}
pub fn compact(&self, max_idle: std::time::Duration) -> usize {
let mut sessions =
crate::control::lock_utils::write_or_recover(self.sessions.write(), "shape_sessions");
let before = sessions.len();
sessions.retain(|_, client| {
!client.shapes.is_empty() && client.last_modified.elapsed() < max_idle
});
before - sessions.len()
}
}
#[cfg(test)]
mod tests {
use super::super::definition::ShapeType;
use super::*;
fn make_doc_shape(id: &str, collection: &str) -> ShapeDefinition {
ShapeDefinition {
shape_id: id.into(),
tenant_id: 1,
shape_type: ShapeType::Document {
collection: collection.into(),
predicate: Vec::new(),
},
description: format!("all {collection}"),
field_filter: vec![],
}
}
#[test]
fn subscribe_and_query() {
let reg = ShapeRegistry::new();
reg.subscribe("s1", 1, make_doc_shape("sh1", "orders"));
reg.subscribe("s1", 1, make_doc_shape("sh2", "users"));
assert_eq!(reg.total_shapes(), 2);
assert_eq!(reg.active_sessions(), 1);
assert_eq!(reg.shapes_for_session("s1").len(), 2);
}
#[test]
fn evaluate_mutation_matches() {
let reg = ShapeRegistry::new();
reg.subscribe("s1", 1, make_doc_shape("sh1", "orders"));
reg.subscribe("s2", 1, make_doc_shape("sh2", "orders"));
reg.subscribe("s3", 2, make_doc_shape("sh3", "orders"));
let doc = serde_json::json!({"status": "open"});
let matches = reg.evaluate_mutation(1, "orders", "o1", &doc);
assert_eq!(matches.len(), 2); }
#[test]
fn unsubscribe() {
let reg = ShapeRegistry::new();
reg.subscribe("s1", 1, make_doc_shape("sh1", "orders"));
assert_eq!(reg.total_shapes(), 1);
assert!(reg.unsubscribe("s1", "sh1"));
assert_eq!(reg.total_shapes(), 0);
assert!(!reg.unsubscribe("s1", "sh1")); }
#[test]
fn remove_session() {
let reg = ShapeRegistry::new();
reg.subscribe("s1", 1, make_doc_shape("sh1", "orders"));
reg.subscribe("s1", 1, make_doc_shape("sh2", "users"));
reg.remove_session("s1");
assert_eq!(reg.total_shapes(), 0);
assert_eq!(reg.active_sessions(), 0);
}
#[test]
fn no_match_wrong_collection() {
let reg = ShapeRegistry::new();
reg.subscribe("s1", 1, make_doc_shape("sh1", "orders"));
let doc = serde_json::json!({});
let matches = reg.evaluate_mutation(1, "users", "u1", &doc);
assert!(matches.is_empty());
}
fn make_doc_shape_with_predicate(
id: &str,
collection: &str,
predicate: Vec<u8>,
) -> ShapeDefinition {
ShapeDefinition {
shape_id: id.into(),
tenant_id: 1,
shape_type: ShapeType::Document {
collection: collection.into(),
predicate,
},
description: format!("filtered {collection}"),
field_filter: vec![],
}
}
#[test]
fn predicate_eq_routes_matching_doc() {
use nodedb_types::filter::MetadataFilter;
let reg = ShapeRegistry::new();
let filter = MetadataFilter::eq("kind", "Rule");
let predicate = zerompk::to_msgpack_vec(&filter).expect("encode");
reg.subscribe(
"s1",
1,
make_doc_shape_with_predicate("sh1", "entries", predicate),
);
let matching = serde_json::json!({"kind": "Rule", "name": "test"});
let non_matching = serde_json::json!({"kind": "Note", "name": "test"});
let m1 = reg.evaluate_mutation(1, "entries", "doc1", &matching);
assert_eq!(m1.len(), 1, "matching doc must route to shape");
let m2 = reg.evaluate_mutation(1, "entries", "doc2", &non_matching);
assert!(m2.is_empty(), "non-matching doc must not route to shape");
}
#[test]
fn predicate_and_routes_correctly() {
use nodedb_types::filter::MetadataFilter;
let reg = ShapeRegistry::new();
let filter = MetadataFilter::and(vec![
MetadataFilter::eq("kind", "Rule"),
MetadataFilter::eq("share", "team"),
]);
let predicate = zerompk::to_msgpack_vec(&filter).expect("encode");
reg.subscribe(
"s1",
1,
make_doc_shape_with_predicate("sh1", "entries", predicate),
);
let both_match = serde_json::json!({"kind": "Rule", "share": "team"});
let partial = serde_json::json!({"kind": "Rule", "share": "personal"});
assert_eq!(
reg.evaluate_mutation(1, "entries", "d1", &both_match).len(),
1
);
assert!(
reg.evaluate_mutation(1, "entries", "d2", &partial)
.is_empty()
);
}
#[test]
fn empty_predicate_matches_all() {
let reg = ShapeRegistry::new();
reg.subscribe("s1", 1, make_doc_shape("sh1", "entries"));
let any_doc = serde_json::json!({"anything": "goes"});
let matches = reg.evaluate_mutation(1, "entries", "d1", &any_doc);
assert_eq!(matches.len(), 1);
}
#[test]
fn invalid_predicate_bytes_excluded() {
let reg = ShapeRegistry::new();
let bad_predicate = vec![0xFF, 0xFE, 0x01, 0x02];
reg.subscribe(
"s1",
1,
make_doc_shape_with_predicate("sh1", "entries", bad_predicate),
);
let doc = serde_json::json!({"kind": "Rule"});
let matches = reg.evaluate_mutation(1, "entries", "d1", &doc);
assert!(
matches.is_empty(),
"bad predicate must be non-matching (fail-closed)"
);
}
fn make_array_shape(id: &str, array: &str) -> ShapeDefinition {
ShapeDefinition {
shape_id: id.into(),
tenant_id: 1,
shape_type: ShapeType::Array {
array_name: array.into(),
coord_range: None,
},
description: format!("all {array}"),
field_filter: vec![],
}
}
#[test]
fn evaluate_array_mutation_matches() {
let reg = ShapeRegistry::new();
reg.subscribe("s1", 1, make_array_shape("ah1", "prices"));
reg.subscribe("s2", 1, make_array_shape("ah2", "prices"));
reg.subscribe("s3", 2, make_array_shape("ah3", "prices")); reg.subscribe("s4", 1, make_array_shape("ah4", "other"));
let matches = reg.evaluate_array_mutation(1, "prices", &[5, 5]);
assert_eq!(matches.len(), 2);
let session_ids: Vec<&str> = matches.iter().map(|(s, _)| s.as_str()).collect();
assert!(session_ids.contains(&"s1"));
assert!(session_ids.contains(&"s2"));
}
#[test]
fn evaluate_array_mutation_coord_range_filter() {
use nodedb_types::sync::shape::ArrayCoordRange;
let reg = ShapeRegistry::new();
let in_range = ShapeDefinition {
shape_id: "ar1".into(),
tenant_id: 1,
shape_type: ShapeType::Array {
array_name: "mat".into(),
coord_range: Some(ArrayCoordRange {
start: vec![0],
end: Some(vec![10]),
}),
},
description: "narrow".into(),
field_filter: vec![],
};
let all = make_array_shape("ar2", "mat");
reg.subscribe("s1", 1, in_range);
reg.subscribe("s2", 1, all);
let matches = reg.evaluate_array_mutation(1, "mat", &[5]);
assert_eq!(matches.len(), 2);
let matches = reg.evaluate_array_mutation(1, "mat", &[50]);
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].0, "s2");
}
}