use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use super::stream_def::ChangeStreamDef;
const WILDCARD: &str = "*";
struct Inner {
by_name: HashMap<(u64, String), Arc<ChangeStreamDef>>,
by_collection: HashMap<(u64, String), Vec<Arc<ChangeStreamDef>>>,
}
impl Inner {
fn new() -> Self {
Self {
by_name: HashMap::new(),
by_collection: HashMap::new(),
}
}
fn insert(&mut self, def: ChangeStreamDef) {
let key_name = (def.tenant_id, def.name.clone());
let key_coll = (def.tenant_id, def.collection.clone());
let def_arc = Arc::new(def);
if let Some(prev) = self.by_name.remove(&key_name) {
let prev_coll_key = (prev.tenant_id, prev.collection.clone());
if let Some(bucket) = self.by_collection.get_mut(&prev_coll_key) {
bucket.retain(|d| d.name != prev.name);
if bucket.is_empty() {
self.by_collection.remove(&prev_coll_key);
}
}
}
self.by_collection
.entry(key_coll)
.or_default()
.push(Arc::clone(&def_arc));
self.by_name.insert(key_name, def_arc);
}
fn remove(&mut self, tenant_id: u64, name: &str) -> bool {
let key_name = (tenant_id, name.to_string());
let Some(prev) = self.by_name.remove(&key_name) else {
return false;
};
let key_coll = (prev.tenant_id, prev.collection.clone());
if let Some(bucket) = self.by_collection.get_mut(&key_coll) {
bucket.retain(|d| d.name != prev.name);
if bucket.is_empty() {
self.by_collection.remove(&key_coll);
}
}
true
}
}
pub struct StreamRegistry {
inner: RwLock<Inner>,
}
impl StreamRegistry {
pub fn new() -> Self {
Self {
inner: RwLock::new(Inner::new()),
}
}
fn read(&self) -> std::sync::RwLockReadGuard<'_, Inner> {
self.inner.read().unwrap_or_else(|p| p.into_inner())
}
fn write(&self) -> std::sync::RwLockWriteGuard<'_, Inner> {
self.inner.write().unwrap_or_else(|p| p.into_inner())
}
pub fn register(&self, def: ChangeStreamDef) {
self.write().insert(def);
}
pub fn unregister(&self, tenant_id: u64, name: &str) -> bool {
self.write().remove(tenant_id, name)
}
pub fn get(&self, tenant_id: u64, name: &str) -> Option<ChangeStreamDef> {
let key = (tenant_id, name.to_string());
self.read().by_name.get(&key).map(|a| (**a).clone())
}
pub fn find_matching(&self, tenant_id: u64, collection: &str) -> Vec<Arc<ChangeStreamDef>> {
let inner = self.read();
let mut out: Vec<Arc<ChangeStreamDef>> = Vec::new();
if let Some(bucket) = inner
.by_collection
.get(&(tenant_id, collection.to_string()))
{
out.extend(bucket.iter().cloned());
}
if collection != WILDCARD
&& let Some(bucket) = inner.by_collection.get(&(tenant_id, WILDCARD.to_string()))
{
out.extend(bucket.iter().cloned());
}
out
}
pub fn list_all(&self) -> Vec<ChangeStreamDef> {
self.read()
.by_name
.values()
.map(|a| (**a).clone())
.collect()
}
pub fn clear_and_reload(
&self,
catalog: &crate::control::security::catalog::types::SystemCatalog,
) -> crate::Result<()> {
let fresh = catalog.load_all_change_streams()?;
let mut inner = self.write();
*inner = Inner::new();
for stream in fresh {
inner.insert(stream);
}
Ok(())
}
pub fn list_for_tenant(&self, tenant_id: u64) -> Vec<ChangeStreamDef> {
self.read()
.by_name
.iter()
.filter(|((tid, _), _)| *tid == tenant_id)
.map(|(_, a)| (**a).clone())
.collect()
}
pub fn load_from_catalog(
&self,
catalog: &crate::control::security::catalog::types::SystemCatalog,
) {
let streams = match catalog.load_all_change_streams() {
Ok(s) => s,
Err(e) => {
tracing::warn!(error = %e, "failed to load change streams from catalog");
return;
}
};
if streams.is_empty() {
return;
}
let mut inner = self.write();
let count_before = inner.by_name.len();
for stream in streams {
inner.insert(stream);
}
tracing::info!(
count = inner.by_name.len() - count_before,
"loaded change streams from catalog"
);
}
pub fn len(&self) -> usize {
self.read().by_name.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for StreamRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::event::cdc::stream_def::*;
fn sample_def(name: &str, collection: &str) -> ChangeStreamDef {
ChangeStreamDef {
tenant_id: 1,
name: name.into(),
collection: collection.into(),
op_filter: OpFilter::all(),
format: StreamFormat::Json,
retention: RetentionConfig::default(),
compaction: CompactionConfig::default(),
webhook: crate::event::webhook::WebhookConfig::default(),
late_data: LateDataPolicy::default(),
kafka: crate::event::kafka::KafkaDeliveryConfig::default(),
owner: "admin".into(),
created_at: 0,
}
}
#[test]
fn register_and_lookup() {
let reg = StreamRegistry::new();
reg.register(sample_def("orders_stream", "orders"));
assert!(reg.get(1, "orders_stream").is_some());
assert!(reg.get(1, "nonexistent").is_none());
assert!(reg.get(2, "orders_stream").is_none());
}
#[test]
fn find_matching_collection() {
let reg = StreamRegistry::new();
reg.register(sample_def("s1", "orders"));
reg.register(sample_def("s2", "orders"));
reg.register(sample_def("s3", "users"));
let matches = reg.find_matching(1, "orders");
assert_eq!(matches.len(), 2);
let matches = reg.find_matching(1, "users");
assert_eq!(matches.len(), 1);
}
#[test]
fn wildcard_matches_everything() {
let reg = StreamRegistry::new();
reg.register(sample_def("all", "*"));
assert_eq!(reg.find_matching(1, "orders").len(), 1);
assert_eq!(reg.find_matching(1, "users").len(), 1);
assert_eq!(reg.find_matching(1, "anything").len(), 1);
}
#[test]
fn unregister() {
let reg = StreamRegistry::new();
reg.register(sample_def("s1", "orders"));
assert!(reg.unregister(1, "s1"));
assert!(!reg.unregister(1, "s1"));
assert!(reg.is_empty());
}
}