use std::collections::HashMap;
use std::sync::RwLock;
use std::time::Instant;
use tracing::{debug, info};
use super::definition::{ShapeDefinition, ShapeId};
struct ClientShapes {
shapes: HashMap<ShapeId, ShapeDefinition>,
tenant_id: u32,
last_modified: Instant,
}
pub struct ShapeRegistry {
sessions: RwLock<HashMap<String, ClientShapes>>,
}
impl Default for ShapeRegistry {
fn default() -> Self {
Self::new()
}
}
impl ShapeRegistry {
pub fn new() -> Self {
Self {
sessions: RwLock::new(HashMap::new()),
}
}
pub fn subscribe(&self, session_id: &str, tenant_id: u32, shape: ShapeDefinition) {
let mut sessions =
crate::control::lock_utils::write_or_recover(self.sessions.write(), "shape_sessions");
let client = sessions
.entry(session_id.to_string())
.or_insert_with(|| ClientShapes {
shapes: HashMap::new(),
tenant_id,
last_modified: Instant::now(),
});
info!(
session = session_id,
shape_id = %shape.shape_id,
"shape subscribed"
);
client.shapes.insert(shape.shape_id.clone(), shape);
client.last_modified = Instant::now();
}
pub fn unsubscribe(&self, session_id: &str, shape_id: &str) -> bool {
let mut sessions =
crate::control::lock_utils::write_or_recover(self.sessions.write(), "shape_sessions");
if let Some(client) = sessions.get_mut(session_id) {
let removed = client.shapes.remove(shape_id).is_some();
if removed {
debug!(session = session_id, shape_id, "shape unsubscribed");
client.last_modified = Instant::now();
}
removed
} else {
false
}
}
pub fn remove_session(&self, session_id: &str) {
let mut sessions =
crate::control::lock_utils::write_or_recover(self.sessions.write(), "shape_sessions");
if let Some(client) = sessions.remove(session_id) {
info!(
session = session_id,
shapes = client.shapes.len(),
"session shapes removed"
);
}
}
pub fn shapes_for_session(&self, session_id: &str) -> Vec<ShapeId> {
let sessions =
crate::control::lock_utils::read_or_recover(self.sessions.read(), "shape_sessions");
sessions
.get(session_id)
.map(|c| c.shapes.keys().cloned().collect())
.unwrap_or_default()
}
pub fn evaluate_mutation(
&self,
tenant_id: u32,
collection: &str,
doc_id: &str,
) -> Vec<(String, ShapeId)> {
let sessions =
crate::control::lock_utils::read_or_recover(self.sessions.read(), "shape_sessions");
let mut matches = Vec::new();
for (session_id, client) in sessions.iter() {
if client.tenant_id != tenant_id {
continue;
}
for (shape_id, shape) in &client.shapes {
if shape.could_match(collection, doc_id) {
matches.push((session_id.clone(), shape_id.clone()));
}
}
}
matches
}
pub fn session_info(&self, session_id: &str) -> Option<(u32, usize)> {
let sessions =
crate::control::lock_utils::read_or_recover(self.sessions.read(), "shape_sessions");
sessions
.get(session_id)
.map(|c| (c.tenant_id, c.shapes.len()))
}
pub fn total_shapes(&self) -> usize {
let sessions =
crate::control::lock_utils::read_or_recover(self.sessions.read(), "shape_sessions");
sessions.values().map(|c| c.shapes.len()).sum()
}
pub fn active_sessions(&self) -> usize {
let sessions =
crate::control::lock_utils::read_or_recover(self.sessions.read(), "shape_sessions");
sessions.len()
}
pub fn get_shape(&self, session_id: &str, shape_id: &str) -> Option<ShapeDefinition> {
let sessions =
crate::control::lock_utils::read_or_recover(self.sessions.read(), "shape_sessions");
sessions
.get(session_id)
.and_then(|c| c.shapes.get(shape_id).cloned())
}
pub fn export_all(&self) -> Vec<(String, u32, ShapeDefinition)> {
let sessions =
crate::control::lock_utils::read_or_recover(self.sessions.read(), "shape_sessions");
let mut result = Vec::new();
for (session_id, client) in sessions.iter() {
for shape in client.shapes.values() {
result.push((session_id.clone(), client.tenant_id, shape.clone()));
}
}
result
}
pub fn import(&self, shapes: Vec<(String, u32, ShapeDefinition)>) {
let mut sessions =
crate::control::lock_utils::write_or_recover(self.sessions.write(), "shape_sessions");
for (session_id, tenant_id, shape) in shapes {
let client = sessions.entry(session_id).or_insert_with(|| ClientShapes {
shapes: HashMap::new(),
tenant_id,
last_modified: Instant::now(),
});
client.shapes.insert(shape.shape_id.clone(), shape);
}
}
pub fn compact(&self, max_idle: std::time::Duration) -> usize {
let mut sessions =
crate::control::lock_utils::write_or_recover(self.sessions.write(), "shape_sessions");
let before = sessions.len();
sessions.retain(|_, client| {
!client.shapes.is_empty() && client.last_modified.elapsed() < max_idle
});
before - sessions.len()
}
}
#[cfg(test)]
mod tests {
use super::super::definition::ShapeType;
use super::*;
fn make_doc_shape(id: &str, collection: &str) -> ShapeDefinition {
ShapeDefinition {
shape_id: id.into(),
tenant_id: 1,
shape_type: ShapeType::Document {
collection: collection.into(),
predicate: Vec::new(),
},
description: format!("all {collection}"),
field_filter: vec![],
}
}
#[test]
fn subscribe_and_query() {
let reg = ShapeRegistry::new();
reg.subscribe("s1", 1, make_doc_shape("sh1", "orders"));
reg.subscribe("s1", 1, make_doc_shape("sh2", "users"));
assert_eq!(reg.total_shapes(), 2);
assert_eq!(reg.active_sessions(), 1);
assert_eq!(reg.shapes_for_session("s1").len(), 2);
}
#[test]
fn evaluate_mutation_matches() {
let reg = ShapeRegistry::new();
reg.subscribe("s1", 1, make_doc_shape("sh1", "orders"));
reg.subscribe("s2", 1, make_doc_shape("sh2", "orders"));
reg.subscribe("s3", 2, make_doc_shape("sh3", "orders"));
let matches = reg.evaluate_mutation(1, "orders", "o1");
assert_eq!(matches.len(), 2); }
#[test]
fn unsubscribe() {
let reg = ShapeRegistry::new();
reg.subscribe("s1", 1, make_doc_shape("sh1", "orders"));
assert_eq!(reg.total_shapes(), 1);
assert!(reg.unsubscribe("s1", "sh1"));
assert_eq!(reg.total_shapes(), 0);
assert!(!reg.unsubscribe("s1", "sh1")); }
#[test]
fn remove_session() {
let reg = ShapeRegistry::new();
reg.subscribe("s1", 1, make_doc_shape("sh1", "orders"));
reg.subscribe("s1", 1, make_doc_shape("sh2", "users"));
reg.remove_session("s1");
assert_eq!(reg.total_shapes(), 0);
assert_eq!(reg.active_sessions(), 0);
}
#[test]
fn no_match_wrong_collection() {
let reg = ShapeRegistry::new();
reg.subscribe("s1", 1, make_doc_shape("sh1", "orders"));
let matches = reg.evaluate_mutation(1, "users", "u1");
assert!(matches.is_empty());
}
}