1use hirn_core::error::{HirnError, HirnResult};
8use hirn_core::types::{Layer, Namespace};
9use tokio::sync::broadcast;
10
11use crate::event::{EventEnvelope, MemoryEvent};
12
13#[derive(Debug, Clone, PartialEq)]
19pub enum WatchFilter {
20 All,
22 Realm(String),
24 Layers(Vec<Layer>),
26 Namespace(String),
28 Namespaces(Vec<String>),
30 AgentId(String),
32 Entities(Vec<String>),
34 ImportanceAbove(f32),
36 Contradictions,
38 EventTypes(Vec<String>),
40 AllOf(Vec<WatchFilter>),
42}
43
44impl WatchFilter {
45 #[must_use]
47 pub fn all_of(filters: Vec<Self>) -> Self {
48 let mut flattened = Vec::new();
49 for filter in filters {
50 match filter {
51 Self::All => {}
52 Self::AllOf(children) => flattened.extend(children),
53 other => flattened.push(other),
54 }
55 }
56
57 match flattened.len() {
58 0 => Self::All,
59 1 => flattened.into_iter().next().unwrap_or(Self::All),
60 _ => Self::AllOf(flattened),
61 }
62 }
63
64 #[must_use]
66 pub fn scoped_to_namespaces(self, allowed_namespaces: &[Namespace]) -> Self {
67 let namespaces = allowed_namespaces
68 .iter()
69 .map(|namespace| namespace.as_str().to_string())
70 .collect();
71 Self::all_of(vec![Self::Namespaces(namespaces), self])
72 }
73
74 pub fn validate_allowed_namespaces(&self, allowed_namespaces: &[Namespace]) -> HirnResult<()> {
76 let mut referenced_namespaces = Vec::new();
77 self.collect_referenced_namespaces(&mut referenced_namespaces);
78
79 for namespace in referenced_namespaces {
80 let allowed = allowed_namespaces
81 .iter()
82 .any(|allowed_namespace| allowed_namespace.as_str() == namespace);
83 if !allowed {
84 return Err(HirnError::AccessDenied(format!(
85 "watch cannot access namespace '{}'",
86 namespace
87 )));
88 }
89 }
90
91 Ok(())
92 }
93
94 fn collect_referenced_namespaces(&self, namespaces: &mut Vec<String>) {
95 match self {
96 Self::Namespace(namespace) => namespaces.push(namespace.clone()),
97 Self::Namespaces(items) => namespaces.extend(items.iter().cloned()),
98 Self::AllOf(filters) => {
99 for filter in filters {
100 filter.collect_referenced_namespaces(namespaces);
101 }
102 }
103 Self::All
104 | Self::Realm(_)
105 | Self::Layers(_)
106 | Self::AgentId(_)
107 | Self::Entities(_)
108 | Self::ImportanceAbove(_)
109 | Self::Contradictions
110 | Self::EventTypes(_) => {}
111 }
112 }
113
114 pub fn matches(&self, envelope: &EventEnvelope) -> bool {
116 match self {
117 WatchFilter::All => true,
118 WatchFilter::Realm(realm) => envelope.realm == *realm,
119 WatchFilter::Layers(layers) => envelope
120 .event
121 .layer()
122 .is_some_and(|layer| layers.contains(&layer)),
123 WatchFilter::Namespace(ns) => envelope.namespace == *ns,
124 WatchFilter::Namespaces(namespaces) => namespaces.contains(&envelope.namespace),
125 WatchFilter::AgentId(agent_id) => envelope.agent_id == *agent_id,
126 WatchFilter::Entities(entities) => {
127 let text = match &envelope.event {
128 MemoryEvent::EpisodeCreated {
129 content_preview, ..
130 } => content_preview.as_str(),
131 MemoryEvent::SemanticCreated { concept_name, .. } => concept_name.as_str(),
132 MemoryEvent::ProceduralCreated { procedure_name, .. } => {
133 procedure_name.as_str()
134 }
135 MemoryEvent::Reconsolidated { reason, .. } => reason.as_str(),
136 _ => "",
137 };
138 let lower = text.to_lowercase();
139 entities.iter().any(|e| lower.contains(&e.to_lowercase()))
140 }
141 WatchFilter::ImportanceAbove(threshold) => {
142 matches!(
143 &envelope.event,
144 MemoryEvent::ImportanceUpdated { new_value, .. }
145 if *new_value > *threshold
146 )
147 }
148 WatchFilter::Contradictions => match &envelope.event {
149 MemoryEvent::ContradictionDetected { .. } => true,
150 MemoryEvent::Reconsolidated { reason, .. } => reason.contains("contradict"),
151 _ => false,
152 },
153 WatchFilter::EventTypes(types) => {
154 let event_type = envelope.event.event_type();
155 types.iter().any(|t| t == event_type)
156 }
157 WatchFilter::AllOf(filters) => filters.iter().all(|filter| filter.matches(envelope)),
158 }
159 }
160}
161
162pub struct WatchSubscription {
168 filter: WatchFilter,
169 rx: broadcast::Receiver<EventEnvelope>,
170}
171
172impl WatchSubscription {
173 pub fn new(rx: broadcast::Receiver<EventEnvelope>, filter: WatchFilter) -> Self {
175 Self { filter, rx }
176 }
177
178 pub async fn next(&mut self) -> HirnResult<EventEnvelope> {
183 loop {
184 match self.rx.recv().await {
185 Ok(envelope) => {
186 if self.filter.matches(&envelope) {
187 return Ok(envelope);
188 }
189 }
191 Err(broadcast::error::RecvError::Lagged(n)) => {
192 return Err(HirnError::LimitExceeded(format!(
193 "watch subscriber lagged, missed {n} events"
194 )));
195 }
196 Err(broadcast::error::RecvError::Closed) => {
197 return Err(HirnError::InvalidInput("event channel closed".to_string()));
198 }
199 }
200 }
201 }
202
203 pub fn try_next(&mut self) -> HirnResult<Option<EventEnvelope>> {
208 loop {
209 match self.rx.try_recv() {
210 Ok(envelope) => {
211 if self.filter.matches(&envelope) {
212 return Ok(Some(envelope));
213 }
214 }
215 Err(broadcast::error::TryRecvError::Empty) => return Ok(None),
216 Err(broadcast::error::TryRecvError::Lagged(n)) => {
217 return Err(HirnError::LimitExceeded(format!(
218 "watch subscriber lagged, missed {n} events"
219 )));
220 }
221 Err(broadcast::error::TryRecvError::Closed) => return Ok(None),
222 }
223 }
224 }
225}
226
227use crate::db::HirnDB;
232
233impl HirnDB {
234 pub fn watch(&self, filter: WatchFilter) -> HirnResult<WatchSubscription> {
238 let event_log = self
239 .event_log()
240 .ok_or_else(|| HirnError::InvalidInput("event log not configured".to_string()))?;
241 let rx = event_log.subscribe();
242 Ok(WatchSubscription::new(rx, filter))
243 }
244}
245
246#[cfg(test)]
251mod tests {
252 use super::*;
253 use hirn_core::id::MemoryId;
254 use hirn_core::types::Layer;
255
256 fn make_envelope(event: MemoryEvent, namespace: &str) -> EventEnvelope {
257 EventEnvelope::new(1, "default", namespace, "test-agent", event)
258 }
259
260 #[test]
261 fn filter_all_matches_everything() {
262 let filter = WatchFilter::All;
263 let env = make_envelope(
264 MemoryEvent::Forgotten {
265 id: MemoryId::new(),
266 },
267 "ns1",
268 );
269 assert!(filter.matches(&env));
270 }
271
272 #[test]
273 fn filter_namespace_matches_correct_ns() {
274 let filter = WatchFilter::Namespace("shared".to_string());
275
276 let matching = make_envelope(
277 MemoryEvent::Forgotten {
278 id: MemoryId::new(),
279 },
280 "shared",
281 );
282 let non_matching = make_envelope(
283 MemoryEvent::Forgotten {
284 id: MemoryId::new(),
285 },
286 "private",
287 );
288
289 assert!(filter.matches(&matching));
290 assert!(!filter.matches(&non_matching));
291 }
292
293 #[test]
294 fn filter_namespaces_matches_any_allowed_ns() {
295 let filter = WatchFilter::Namespaces(vec!["shared".to_string(), "team".to_string()]);
296
297 let matching = make_envelope(
298 MemoryEvent::Forgotten {
299 id: MemoryId::new(),
300 },
301 "team",
302 );
303 let non_matching = make_envelope(
304 MemoryEvent::Forgotten {
305 id: MemoryId::new(),
306 },
307 "private",
308 );
309
310 assert!(filter.matches(&matching));
311 assert!(!filter.matches(&non_matching));
312 }
313
314 #[test]
315 fn filter_entities_case_insensitive() {
316 let filter = WatchFilter::Entities(vec!["auth".to_string()]);
317
318 let matching = make_envelope(
319 MemoryEvent::EpisodeCreated {
320 id: MemoryId::new(),
321 content_preview: "Discussed Auth flow with OAuth2".to_string(),
322 },
323 "ns",
324 );
325 let non_matching = make_envelope(
326 MemoryEvent::EpisodeCreated {
327 id: MemoryId::new(),
328 content_preview: "Talked about recipes".to_string(),
329 },
330 "ns",
331 );
332
333 assert!(filter.matches(&matching));
334 assert!(!filter.matches(&non_matching));
335 }
336
337 #[test]
338 fn filter_importance_above_threshold() {
339 let filter = WatchFilter::ImportanceAbove(0.8);
340
341 let above = make_envelope(
342 MemoryEvent::ImportanceUpdated {
343 id: MemoryId::new(),
344 old_value: 0.5,
345 new_value: 0.9,
346 },
347 "ns",
348 );
349 let below = make_envelope(
350 MemoryEvent::ImportanceUpdated {
351 id: MemoryId::new(),
352 old_value: 0.5,
353 new_value: 0.7,
354 },
355 "ns",
356 );
357 let other = make_envelope(
358 MemoryEvent::Forgotten {
359 id: MemoryId::new(),
360 },
361 "ns",
362 );
363
364 assert!(filter.matches(&above));
365 assert!(!filter.matches(&below));
366 assert!(!filter.matches(&other));
367 }
368
369 #[test]
370 fn filter_layers_match_actual_event_layer() {
371 let filter = WatchFilter::Layers(vec![Layer::Procedural]);
372
373 let matching = make_envelope(
374 MemoryEvent::ProceduralCreated {
375 id: MemoryId::new(),
376 procedure_name: "deploy-to-staging".to_string(),
377 },
378 "ns",
379 );
380 let non_matching = make_envelope(
381 MemoryEvent::EpisodeCreated {
382 id: MemoryId::new(),
383 content_preview: "deploy-to-staging".to_string(),
384 },
385 "ns",
386 );
387
388 assert!(filter.matches(&matching));
389 assert!(!filter.matches(&non_matching));
390 }
391
392 #[test]
393 fn filter_contradictions_matches_detected_events() {
394 let filter = WatchFilter::Contradictions;
395
396 let contradiction = make_envelope(
397 MemoryEvent::ContradictionDetected {
398 memory_a: MemoryId::new(),
399 memory_b: MemoryId::new(),
400 confidence: 0.92,
401 },
402 "ns",
403 );
404 let other = make_envelope(
405 MemoryEvent::Forgotten {
406 id: MemoryId::new(),
407 },
408 "ns",
409 );
410
411 assert!(filter.matches(&contradiction));
412 assert!(!filter.matches(&other));
413 }
414
415 #[test]
416 fn filter_event_types() {
417 let filter = WatchFilter::EventTypes(vec![
418 "episode_created".to_string(),
419 "semantic_created".to_string(),
420 ]);
421
422 let ep = make_envelope(
423 MemoryEvent::EpisodeCreated {
424 id: MemoryId::new(),
425 content_preview: "test".to_string(),
426 },
427 "ns",
428 );
429 let sem = make_envelope(
430 MemoryEvent::SemanticCreated {
431 id: MemoryId::new(),
432 concept_name: "test".to_string(),
433 },
434 "ns",
435 );
436 let other = make_envelope(
437 MemoryEvent::Forgotten {
438 id: MemoryId::new(),
439 },
440 "ns",
441 );
442
443 assert!(filter.matches(&ep));
444 assert!(filter.matches(&sem));
445 assert!(!filter.matches(&other));
446 }
447
448 #[test]
449 fn filter_all_of_requires_every_child_to_match() {
450 let filter = WatchFilter::all_of(vec![
451 WatchFilter::Namespace("shared".to_string()),
452 WatchFilter::Entities(vec!["auth".to_string()]),
453 ]);
454
455 let matching = make_envelope(
456 MemoryEvent::EpisodeCreated {
457 id: MemoryId::new(),
458 content_preview: "auth rollout completed".to_string(),
459 },
460 "shared",
461 );
462 let wrong_namespace = make_envelope(
463 MemoryEvent::EpisodeCreated {
464 id: MemoryId::new(),
465 content_preview: "auth rollout completed".to_string(),
466 },
467 "private",
468 );
469 let wrong_entity = make_envelope(
470 MemoryEvent::EpisodeCreated {
471 id: MemoryId::new(),
472 content_preview: "recipe rollout completed".to_string(),
473 },
474 "shared",
475 );
476
477 assert!(filter.matches(&matching));
478 assert!(!filter.matches(&wrong_namespace));
479 assert!(!filter.matches(&wrong_entity));
480 }
481
482 #[test]
483 fn filter_validate_allowed_namespaces_rejects_unauthorized_reference() {
484 let filter = WatchFilter::Namespace("private:agent_a".to_string());
485 let agent_b = hirn_core::types::AgentId::new("agent_b").unwrap();
486 let allowed_namespaces = [Namespace::shared(), Namespace::private_for(&agent_b)];
487
488 let result = filter.validate_allowed_namespaces(&allowed_namespaces);
489 assert!(result.is_err());
490 }
491
492 #[test]
493 fn multiple_subscribers_independent() {
494 let (tx, _) = broadcast::channel::<EventEnvelope>(16);
495
496 let sub1 = WatchSubscription::new(tx.subscribe(), WatchFilter::All);
497 let sub2 =
498 WatchSubscription::new(tx.subscribe(), WatchFilter::Namespace("shared".to_string()));
499
500 drop(sub1);
502 assert!(matches!(sub2.filter, WatchFilter::Namespace(_)));
503 }
504
505 #[tokio::test]
506 async fn subscription_receives_filtered_events() {
507 let (tx, _) = broadcast::channel::<EventEnvelope>(16);
508
509 let mut sub =
510 WatchSubscription::new(tx.subscribe(), WatchFilter::Namespace("target".to_string()));
511
512 let matching = make_envelope(
514 MemoryEvent::EpisodeCreated {
515 id: MemoryId::new(),
516 content_preview: "test".to_string(),
517 },
518 "target",
519 );
520 let non_matching = make_envelope(
521 MemoryEvent::Forgotten {
522 id: MemoryId::new(),
523 },
524 "other",
525 );
526
527 tx.send(non_matching).unwrap();
528 tx.send(matching.clone()).unwrap();
529
530 let received = sub.next().await.unwrap();
531 assert_eq!(received.namespace, "target");
532 }
533
534 #[tokio::test]
535 async fn subscriber_drop_no_error_on_others() {
536 let (tx, _rx) = broadcast::channel::<EventEnvelope>(16);
537
538 let sub1 = WatchSubscription::new(tx.subscribe(), WatchFilter::All);
539 let mut sub2 = WatchSubscription::new(tx.subscribe(), WatchFilter::All);
540
541 drop(sub1);
542
543 let env = make_envelope(
545 MemoryEvent::Forgotten {
546 id: MemoryId::new(),
547 },
548 "ns",
549 );
550 tx.send(env).unwrap();
551
552 let received = sub2.next().await.unwrap();
553 assert_eq!(received.event.event_type(), "forgotten");
554 }
555}