Skip to main content

hirn_engine/
watch.rs

1//! Watch subscriptions for real-time reactive memory.
2//!
3//! Builds on the existing `EventLog` broadcast channel by adding
4//! filter-based subscriptions. Subscribers receive only events
5//! matching their `WatchFilter`.
6
7use hirn_core::error::{HirnError, HirnResult};
8use hirn_core::types::{Layer, Namespace};
9use tokio::sync::broadcast;
10
11use crate::event::{EventEnvelope, MemoryEvent};
12
13// ═══════════════════════════════════════════════════════════════════════════
14// Watch Filter
15// ═══════════════════════════════════════════════════════════════════════════
16
17/// Filter criteria for watch subscriptions.
18#[derive(Debug, Clone, PartialEq)]
19pub enum WatchFilter {
20    /// Receive all events.
21    All,
22    /// Events from a specific realm.
23    Realm(String),
24    /// Events affecting any of the provided memory layers.
25    Layers(Vec<Layer>),
26    /// Events from a specific namespace.
27    Namespace(String),
28    /// Events from any of the provided namespaces.
29    Namespaces(Vec<String>),
30    /// Events triggered by a specific agent.
31    AgentId(String),
32    /// Events mentioning any of these entities (checked in content previews).
33    Entities(Vec<String>),
34    /// Only events with importance updates above this threshold.
35    ImportanceAbove(f32),
36    /// Only contradiction-related events.
37    Contradictions,
38    /// Only specific event types.
39    EventTypes(Vec<String>),
40    /// All child filters must match.
41    AllOf(Vec<WatchFilter>),
42}
43
44impl WatchFilter {
45    /// Combine filters conjunctively, flattening nested `AllOf` nodes.
46    #[must_use]
47    pub fn all_of(filters: Vec<Self>) -> Self {
48        let mut flattened = Vec::new();
49        for filter in filters {
50            match filter {
51                Self::All => {}
52                Self::AllOf(children) => flattened.extend(children),
53                other => flattened.push(other),
54            }
55        }
56
57        match flattened.len() {
58            0 => Self::All,
59            1 => flattened.into_iter().next().unwrap_or(Self::All),
60            _ => Self::AllOf(flattened),
61        }
62    }
63
64    /// Restrict this filter to a namespace set.
65    #[must_use]
66    pub fn scoped_to_namespaces(self, allowed_namespaces: &[Namespace]) -> Self {
67        let namespaces = allowed_namespaces
68            .iter()
69            .map(|namespace| namespace.as_str().to_string())
70            .collect();
71        Self::all_of(vec![Self::Namespaces(namespaces), self])
72    }
73
74    /// Reject filters that explicitly name namespaces outside the allowed set.
75    pub fn validate_allowed_namespaces(&self, allowed_namespaces: &[Namespace]) -> HirnResult<()> {
76        let mut referenced_namespaces = Vec::new();
77        self.collect_referenced_namespaces(&mut referenced_namespaces);
78
79        for namespace in referenced_namespaces {
80            let allowed = allowed_namespaces
81                .iter()
82                .any(|allowed_namespace| allowed_namespace.as_str() == namespace);
83            if !allowed {
84                return Err(HirnError::AccessDenied(format!(
85                    "watch cannot access namespace '{}'",
86                    namespace
87                )));
88            }
89        }
90
91        Ok(())
92    }
93
94    fn collect_referenced_namespaces(&self, namespaces: &mut Vec<String>) {
95        match self {
96            Self::Namespace(namespace) => namespaces.push(namespace.clone()),
97            Self::Namespaces(items) => namespaces.extend(items.iter().cloned()),
98            Self::AllOf(filters) => {
99                for filter in filters {
100                    filter.collect_referenced_namespaces(namespaces);
101                }
102            }
103            Self::All
104            | Self::Realm(_)
105            | Self::Layers(_)
106            | Self::AgentId(_)
107            | Self::Entities(_)
108            | Self::ImportanceAbove(_)
109            | Self::Contradictions
110            | Self::EventTypes(_) => {}
111        }
112    }
113
114    /// Check whether an event envelope matches this filter.
115    pub fn matches(&self, envelope: &EventEnvelope) -> bool {
116        match self {
117            WatchFilter::All => true,
118            WatchFilter::Realm(realm) => envelope.realm == *realm,
119            WatchFilter::Layers(layers) => envelope
120                .event
121                .layer()
122                .is_some_and(|layer| layers.contains(&layer)),
123            WatchFilter::Namespace(ns) => envelope.namespace == *ns,
124            WatchFilter::Namespaces(namespaces) => namespaces.contains(&envelope.namespace),
125            WatchFilter::AgentId(agent_id) => envelope.agent_id == *agent_id,
126            WatchFilter::Entities(entities) => {
127                let text = match &envelope.event {
128                    MemoryEvent::EpisodeCreated {
129                        content_preview, ..
130                    } => content_preview.as_str(),
131                    MemoryEvent::SemanticCreated { concept_name, .. } => concept_name.as_str(),
132                    MemoryEvent::ProceduralCreated { procedure_name, .. } => {
133                        procedure_name.as_str()
134                    }
135                    MemoryEvent::Reconsolidated { reason, .. } => reason.as_str(),
136                    _ => "",
137                };
138                let lower = text.to_lowercase();
139                entities.iter().any(|e| lower.contains(&e.to_lowercase()))
140            }
141            WatchFilter::ImportanceAbove(threshold) => {
142                matches!(
143                    &envelope.event,
144                    MemoryEvent::ImportanceUpdated { new_value, .. }
145                        if *new_value > *threshold
146                )
147            }
148            WatchFilter::Contradictions => match &envelope.event {
149                MemoryEvent::ContradictionDetected { .. } => true,
150                MemoryEvent::Reconsolidated { reason, .. } => reason.contains("contradict"),
151                _ => false,
152            },
153            WatchFilter::EventTypes(types) => {
154                let event_type = envelope.event.event_type();
155                types.iter().any(|t| t == event_type)
156            }
157            WatchFilter::AllOf(filters) => filters.iter().all(|filter| filter.matches(envelope)),
158        }
159    }
160}
161
162// ═══════════════════════════════════════════════════════════════════════════
163// Watch Subscription
164// ═══════════════════════════════════════════════════════════════════════════
165
166/// A filtered watch subscription over the event stream.
167pub struct WatchSubscription {
168    filter: WatchFilter,
169    rx: broadcast::Receiver<EventEnvelope>,
170}
171
172impl WatchSubscription {
173    /// Create a new subscription from a broadcast receiver and filter.
174    pub fn new(rx: broadcast::Receiver<EventEnvelope>, filter: WatchFilter) -> Self {
175        Self { filter, rx }
176    }
177
178    /// Receive the next matching event, blocking until one arrives.
179    ///
180    /// Returns `Err` if the channel is closed or the subscriber lagged
181    /// (missed events due to slow consumption).
182    pub async fn next(&mut self) -> HirnResult<EventEnvelope> {
183        loop {
184            match self.rx.recv().await {
185                Ok(envelope) => {
186                    if self.filter.matches(&envelope) {
187                        return Ok(envelope);
188                    }
189                    // Skip non-matching events.
190                }
191                Err(broadcast::error::RecvError::Lagged(n)) => {
192                    return Err(HirnError::LimitExceeded(format!(
193                        "watch subscriber lagged, missed {n} events"
194                    )));
195                }
196                Err(broadcast::error::RecvError::Closed) => {
197                    return Err(HirnError::InvalidInput("event channel closed".to_string()));
198                }
199            }
200        }
201    }
202
203    /// Try to receive the next matching event without blocking indefinitely.
204    ///
205    /// Returns `Ok(None)` if the channel is closed.
206    /// Returns `Err` if the subscriber lagged (missed events).
207    pub fn try_next(&mut self) -> HirnResult<Option<EventEnvelope>> {
208        loop {
209            match self.rx.try_recv() {
210                Ok(envelope) => {
211                    if self.filter.matches(&envelope) {
212                        return Ok(Some(envelope));
213                    }
214                }
215                Err(broadcast::error::TryRecvError::Empty) => return Ok(None),
216                Err(broadcast::error::TryRecvError::Lagged(n)) => {
217                    return Err(HirnError::LimitExceeded(format!(
218                        "watch subscriber lagged, missed {n} events"
219                    )));
220                }
221                Err(broadcast::error::TryRecvError::Closed) => return Ok(None),
222            }
223        }
224    }
225}
226
227// ═══════════════════════════════════════════════════════════════════════════
228// HirnDB::watch integration
229// ═══════════════════════════════════════════════════════════════════════════
230
231use crate::db::HirnDB;
232
233impl HirnDB {
234    /// Create a watch subscription with the given filter.
235    ///
236    /// Requires an active `EventLog` (returns error otherwise).
237    pub fn watch(&self, filter: WatchFilter) -> HirnResult<WatchSubscription> {
238        let event_log = self
239            .event_log()
240            .ok_or_else(|| HirnError::InvalidInput("event log not configured".to_string()))?;
241        let rx = event_log.subscribe();
242        Ok(WatchSubscription::new(rx, filter))
243    }
244}
245
246// ═══════════════════════════════════════════════════════════════════════════
247// Tests
248// ═══════════════════════════════════════════════════════════════════════════
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use hirn_core::id::MemoryId;
254    use hirn_core::types::Layer;
255
256    fn make_envelope(event: MemoryEvent, namespace: &str) -> EventEnvelope {
257        EventEnvelope::new(1, "default", namespace, "test-agent", event)
258    }
259
260    #[test]
261    fn filter_all_matches_everything() {
262        let filter = WatchFilter::All;
263        let env = make_envelope(
264            MemoryEvent::Forgotten {
265                id: MemoryId::new(),
266            },
267            "ns1",
268        );
269        assert!(filter.matches(&env));
270    }
271
272    #[test]
273    fn filter_namespace_matches_correct_ns() {
274        let filter = WatchFilter::Namespace("shared".to_string());
275
276        let matching = make_envelope(
277            MemoryEvent::Forgotten {
278                id: MemoryId::new(),
279            },
280            "shared",
281        );
282        let non_matching = make_envelope(
283            MemoryEvent::Forgotten {
284                id: MemoryId::new(),
285            },
286            "private",
287        );
288
289        assert!(filter.matches(&matching));
290        assert!(!filter.matches(&non_matching));
291    }
292
293    #[test]
294    fn filter_namespaces_matches_any_allowed_ns() {
295        let filter = WatchFilter::Namespaces(vec!["shared".to_string(), "team".to_string()]);
296
297        let matching = make_envelope(
298            MemoryEvent::Forgotten {
299                id: MemoryId::new(),
300            },
301            "team",
302        );
303        let non_matching = make_envelope(
304            MemoryEvent::Forgotten {
305                id: MemoryId::new(),
306            },
307            "private",
308        );
309
310        assert!(filter.matches(&matching));
311        assert!(!filter.matches(&non_matching));
312    }
313
314    #[test]
315    fn filter_entities_case_insensitive() {
316        let filter = WatchFilter::Entities(vec!["auth".to_string()]);
317
318        let matching = make_envelope(
319            MemoryEvent::EpisodeCreated {
320                id: MemoryId::new(),
321                content_preview: "Discussed Auth flow with OAuth2".to_string(),
322            },
323            "ns",
324        );
325        let non_matching = make_envelope(
326            MemoryEvent::EpisodeCreated {
327                id: MemoryId::new(),
328                content_preview: "Talked about recipes".to_string(),
329            },
330            "ns",
331        );
332
333        assert!(filter.matches(&matching));
334        assert!(!filter.matches(&non_matching));
335    }
336
337    #[test]
338    fn filter_importance_above_threshold() {
339        let filter = WatchFilter::ImportanceAbove(0.8);
340
341        let above = make_envelope(
342            MemoryEvent::ImportanceUpdated {
343                id: MemoryId::new(),
344                old_value: 0.5,
345                new_value: 0.9,
346            },
347            "ns",
348        );
349        let below = make_envelope(
350            MemoryEvent::ImportanceUpdated {
351                id: MemoryId::new(),
352                old_value: 0.5,
353                new_value: 0.7,
354            },
355            "ns",
356        );
357        let other = make_envelope(
358            MemoryEvent::Forgotten {
359                id: MemoryId::new(),
360            },
361            "ns",
362        );
363
364        assert!(filter.matches(&above));
365        assert!(!filter.matches(&below));
366        assert!(!filter.matches(&other));
367    }
368
369    #[test]
370    fn filter_layers_match_actual_event_layer() {
371        let filter = WatchFilter::Layers(vec![Layer::Procedural]);
372
373        let matching = make_envelope(
374            MemoryEvent::ProceduralCreated {
375                id: MemoryId::new(),
376                procedure_name: "deploy-to-staging".to_string(),
377            },
378            "ns",
379        );
380        let non_matching = make_envelope(
381            MemoryEvent::EpisodeCreated {
382                id: MemoryId::new(),
383                content_preview: "deploy-to-staging".to_string(),
384            },
385            "ns",
386        );
387
388        assert!(filter.matches(&matching));
389        assert!(!filter.matches(&non_matching));
390    }
391
392    #[test]
393    fn filter_contradictions_matches_detected_events() {
394        let filter = WatchFilter::Contradictions;
395
396        let contradiction = make_envelope(
397            MemoryEvent::ContradictionDetected {
398                memory_a: MemoryId::new(),
399                memory_b: MemoryId::new(),
400                confidence: 0.92,
401            },
402            "ns",
403        );
404        let other = make_envelope(
405            MemoryEvent::Forgotten {
406                id: MemoryId::new(),
407            },
408            "ns",
409        );
410
411        assert!(filter.matches(&contradiction));
412        assert!(!filter.matches(&other));
413    }
414
415    #[test]
416    fn filter_event_types() {
417        let filter = WatchFilter::EventTypes(vec![
418            "episode_created".to_string(),
419            "semantic_created".to_string(),
420        ]);
421
422        let ep = make_envelope(
423            MemoryEvent::EpisodeCreated {
424                id: MemoryId::new(),
425                content_preview: "test".to_string(),
426            },
427            "ns",
428        );
429        let sem = make_envelope(
430            MemoryEvent::SemanticCreated {
431                id: MemoryId::new(),
432                concept_name: "test".to_string(),
433            },
434            "ns",
435        );
436        let other = make_envelope(
437            MemoryEvent::Forgotten {
438                id: MemoryId::new(),
439            },
440            "ns",
441        );
442
443        assert!(filter.matches(&ep));
444        assert!(filter.matches(&sem));
445        assert!(!filter.matches(&other));
446    }
447
448    #[test]
449    fn filter_all_of_requires_every_child_to_match() {
450        let filter = WatchFilter::all_of(vec![
451            WatchFilter::Namespace("shared".to_string()),
452            WatchFilter::Entities(vec!["auth".to_string()]),
453        ]);
454
455        let matching = make_envelope(
456            MemoryEvent::EpisodeCreated {
457                id: MemoryId::new(),
458                content_preview: "auth rollout completed".to_string(),
459            },
460            "shared",
461        );
462        let wrong_namespace = make_envelope(
463            MemoryEvent::EpisodeCreated {
464                id: MemoryId::new(),
465                content_preview: "auth rollout completed".to_string(),
466            },
467            "private",
468        );
469        let wrong_entity = make_envelope(
470            MemoryEvent::EpisodeCreated {
471                id: MemoryId::new(),
472                content_preview: "recipe rollout completed".to_string(),
473            },
474            "shared",
475        );
476
477        assert!(filter.matches(&matching));
478        assert!(!filter.matches(&wrong_namespace));
479        assert!(!filter.matches(&wrong_entity));
480    }
481
482    #[test]
483    fn filter_validate_allowed_namespaces_rejects_unauthorized_reference() {
484        let filter = WatchFilter::Namespace("private:agent_a".to_string());
485        let agent_b = hirn_core::types::AgentId::new("agent_b").unwrap();
486        let allowed_namespaces = [Namespace::shared(), Namespace::private_for(&agent_b)];
487
488        let result = filter.validate_allowed_namespaces(&allowed_namespaces);
489        assert!(result.is_err());
490    }
491
492    #[test]
493    fn multiple_subscribers_independent() {
494        let (tx, _) = broadcast::channel::<EventEnvelope>(16);
495
496        let sub1 = WatchSubscription::new(tx.subscribe(), WatchFilter::All);
497        let sub2 =
498            WatchSubscription::new(tx.subscribe(), WatchFilter::Namespace("shared".to_string()));
499
500        // Both created — dropping one doesn't affect the other.
501        drop(sub1);
502        assert!(matches!(sub2.filter, WatchFilter::Namespace(_)));
503    }
504
505    #[tokio::test]
506    async fn subscription_receives_filtered_events() {
507        let (tx, _) = broadcast::channel::<EventEnvelope>(16);
508
509        let mut sub =
510            WatchSubscription::new(tx.subscribe(), WatchFilter::Namespace("target".to_string()));
511
512        // Send matching and non-matching events.
513        let matching = make_envelope(
514            MemoryEvent::EpisodeCreated {
515                id: MemoryId::new(),
516                content_preview: "test".to_string(),
517            },
518            "target",
519        );
520        let non_matching = make_envelope(
521            MemoryEvent::Forgotten {
522                id: MemoryId::new(),
523            },
524            "other",
525        );
526
527        tx.send(non_matching).unwrap();
528        tx.send(matching.clone()).unwrap();
529
530        let received = sub.next().await.unwrap();
531        assert_eq!(received.namespace, "target");
532    }
533
534    #[tokio::test]
535    async fn subscriber_drop_no_error_on_others() {
536        let (tx, _rx) = broadcast::channel::<EventEnvelope>(16);
537
538        let sub1 = WatchSubscription::new(tx.subscribe(), WatchFilter::All);
539        let mut sub2 = WatchSubscription::new(tx.subscribe(), WatchFilter::All);
540
541        drop(sub1);
542
543        // sub2 should still work fine.
544        let env = make_envelope(
545            MemoryEvent::Forgotten {
546                id: MemoryId::new(),
547            },
548            "ns",
549        );
550        tx.send(env).unwrap();
551
552        let received = sub2.next().await.unwrap();
553        assert_eq!(received.event.event_type(), "forgotten");
554    }
555}