Skip to main content

hirn_exec/
extensions.rs

1//! Runtime state for DataFusion operators.
2//!
3//! [`HirnSessionExt`] carries shared references — graph store, config, and
4//! an embedder — that operators retrieve at execution time via DataFusion's
5//! `SessionContext` extension mechanism.
6
7use std::any::Any;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::sync::OnceLock;
11use std::sync::atomic::{AtomicU64, Ordering};
12
13use arrow_array::RecordBatch;
14use async_trait::async_trait;
15use datafusion::prelude::SessionContext;
16use datafusion_common::config::{ConfigEntry, ConfigExtension, ExtensionOptions};
17use hirn_core::HirnResult;
18use hirn_core::config::HirnConfig;
19use hirn_core::embed::Embedder;
20use hirn_core::id::MemoryId;
21use hirn_core::tokenizer::Tokenizer;
22use hirn_core::types::{EdgeRelation, Namespace};
23use hirn_graph::PprConfig;
24use hirn_query::compiler::plan_compiler::SemanticTargetKindRepr;
25use hirn_storage::PhysicalStore;
26use hirn_storage::store::DistanceMetric;
27use parking_lot::RwLock;
28
29use crate::operators::ActivationMode;
30use crate::operators::SearchNumericFilter;
31use crate::operators::nli_contradiction::NliClassifier;
32
33#[derive(Debug, Clone, PartialEq)]
34pub struct GraphActivationOutput {
35    pub ids: Vec<String>,
36    pub scores: Vec<f32>,
37    pub depths: Vec<u32>,
38}
39
40#[derive(Debug, Clone, PartialEq)]
41pub struct GraphCausalChainRow {
42    pub chain_id: String,
43    pub source_id: String,
44    pub target_id: String,
45    pub strength: f32,
46    pub confidence: f32,
47    pub evidence_count: u32,
48    pub mechanism: Option<String>,
49    pub depth: u32,
50    pub chain_score: f32,
51}
52
53#[derive(Debug, Clone, PartialEq)]
54pub struct GraphTraverseRow {
55    pub node_id: String,
56    pub depth: u32,
57}
58
59#[derive(Debug, Clone, PartialEq)]
60pub struct RecallSearchBinding {
61    pub query_vector: Vec<f32>,
62    pub filter: Option<String>,
63    pub limit: usize,
64    pub metric: DistanceMetric,
65    pub numeric_filters: Vec<SearchNumericFilter>,
66    pub temporal_start_ms: Option<i64>,
67    pub temporal_end_ms: Option<i64>,
68    /// Enable dual-pass temporal expansion in `LanceHybridSearchExec`.
69    pub temporal_expansion: bool,
70}
71
72#[async_trait]
73pub trait GraphReadRuntime: Send + Sync {
74    async fn activate_graph(
75        &self,
76        seeds: &[MemoryId],
77        mode: ActivationMode,
78        ppr_config: Option<&PprConfig>,
79        max_depth: u32,
80        epsilon: f32,
81        inhibition_mu: f32,
82        delegation_threshold: usize,
83        allowed_namespaces: Option<&[Namespace]>,
84    ) -> HirnResult<GraphActivationOutput>;
85
86    async fn causal_chain(
87        &self,
88        start_ids: &[MemoryId],
89        max_depth: u32,
90        confidence_threshold: f32,
91        delegation_threshold: usize,
92        relation: EdgeRelation,
93        allowed_namespaces: Option<&[Namespace]>,
94    ) -> HirnResult<Vec<GraphCausalChainRow>>;
95
96    async fn traverse_graph(
97        &self,
98        start_ids: &[MemoryId],
99        max_depth: u32,
100        delegation_threshold: usize,
101        relation_filter: Option<&[EdgeRelation]>,
102        allowed_namespaces: Option<&[Namespace]>,
103    ) -> HirnResult<Vec<GraphTraverseRow>>;
104}
105
106#[async_trait]
107pub trait QueryReadRuntime: Send + Sync {
108    async fn inspect_json(
109        &self,
110        target: &str,
111        target_kind: SemanticTargetKindRepr,
112        agent_id: &str,
113        allowed_namespaces: Option<&[String]>,
114    ) -> HirnResult<Vec<u8>>;
115
116    async fn trace_json(
117        &self,
118        target: &str,
119        target_kind: SemanticTargetKindRepr,
120        agent_id: &str,
121        allowed_namespaces: Option<&[String]>,
122    ) -> HirnResult<Vec<u8>>;
123
124    async fn explain_causes_json(
125        &self,
126        query: &str,
127        depth: u32,
128        namespace: Option<&str>,
129        allowed_namespaces: Option<&[String]>,
130    ) -> HirnResult<Vec<u8>>;
131
132    async fn what_if_json(
133        &self,
134        intervention: &str,
135        outcome: &str,
136        namespace: Option<&str>,
137        allowed_namespaces: Option<&[String]>,
138    ) -> HirnResult<Vec<u8>>;
139
140    async fn counterfactual_json(
141        &self,
142        antecedent: &str,
143        consequent: &str,
144        namespace: Option<&str>,
145        allowed_namespaces: Option<&[String]>,
146    ) -> HirnResult<Vec<u8>>;
147
148    async fn show_policies_json(
149        &self,
150        principal_kind: Option<&str>,
151        principal_name: Option<&str>,
152    ) -> HirnResult<Vec<u8>>;
153
154    async fn explain_policy_json(
155        &self,
156        principal_kind: &str,
157        principal_name: &str,
158        resource_type: &str,
159        resource_name: &str,
160        action: &str,
161    ) -> HirnResult<Vec<u8>>;
162}
163
164static QUERY_READ_RUNTIME_IDS: AtomicU64 = AtomicU64::new(1);
165
166fn query_read_runtime_registry() -> &'static RwLock<HashMap<u64, Arc<dyn QueryReadRuntime>>> {
167    static REGISTRY: OnceLock<RwLock<HashMap<u64, Arc<dyn QueryReadRuntime>>>> = OnceLock::new();
168    REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
169}
170
171#[derive(Debug)]
172pub struct RegisteredQueryReadRuntime {
173    id: u64,
174}
175
176impl RegisteredQueryReadRuntime {
177    pub fn key(&self) -> String {
178        self.id.to_string()
179    }
180}
181
182impl Drop for RegisteredQueryReadRuntime {
183    fn drop(&mut self) {
184        query_read_runtime_registry().write().remove(&self.id);
185    }
186}
187
188pub fn register_query_read_runtime(
189    runtime: Arc<dyn QueryReadRuntime>,
190) -> RegisteredQueryReadRuntime {
191    let id = QUERY_READ_RUNTIME_IDS.fetch_add(1, Ordering::Relaxed);
192    query_read_runtime_registry().write().insert(id, runtime);
193    RegisteredQueryReadRuntime { id }
194}
195
196fn lookup_query_read_runtime(key: &str) -> Option<Arc<dyn QueryReadRuntime>> {
197    let id = key.parse::<u64>().ok()?;
198    query_read_runtime_registry().read().get(&id).cloned()
199}
200
201// ── ContextAssemblyRuntime ─────────────────────────────────────────────
202
203/// Per-query runtime bridge for the THINK context assembly operator.
204///
205/// Registered once per THINK query execution (with actor identity, config,
206/// and recall context captured at registration time), then looked up by
207/// key inside `ContextAssemblyExec::execute()`.
208///
209/// The implementation in `hirn-engine` calls `assemble_think_context` and
210/// JSON-serialises the full `ThinkAssemblyOutput` (including decoded
211/// `ScoredMemory` records) so the operator can return a single opaque row.
212#[async_trait]
213pub trait ContextAssemblyRuntime: Send + Sync {
214    /// Assemble context from scored candidate batches.
215    ///
216    /// Receives the raw Arrow output from `ContextBudgetExec` (or `McfaDefenseExec`
217    /// if MCFA defense is enabled).  Returns opaque JSON bytes that the engine
218    /// decodes into a fully-hydrated `ThinkAssemblyOutput`.
219    async fn assemble_from_batches(
220        &self,
221        candidate_batches: Vec<RecordBatch>,
222    ) -> HirnResult<Vec<u8>>;
223}
224
225static CONTEXT_ASSEMBLY_RUNTIME_IDS: AtomicU64 = AtomicU64::new(1);
226
227fn context_assembly_runtime_registry()
228-> &'static RwLock<HashMap<u64, Arc<dyn ContextAssemblyRuntime>>> {
229    static REGISTRY: OnceLock<RwLock<HashMap<u64, Arc<dyn ContextAssemblyRuntime>>>> =
230        OnceLock::new();
231    REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
232}
233
234/// RAII handle for a registered [`ContextAssemblyRuntime`].
235///
236/// Removes the runtime from the global registry on drop so resources are freed
237/// as soon as the query scope exits.
238#[derive(Debug)]
239pub struct RegisteredContextAssemblyRuntime {
240    id: u64,
241}
242
243impl RegisteredContextAssemblyRuntime {
244    /// Opaque string key for injecting into [`HirnSessionExt`].
245    pub fn key(&self) -> String {
246        self.id.to_string()
247    }
248}
249
250impl Drop for RegisteredContextAssemblyRuntime {
251    fn drop(&mut self) {
252        context_assembly_runtime_registry().write().remove(&self.id);
253    }
254}
255
256/// Register a `ContextAssemblyRuntime` for the current query scope.
257///
258/// Returns a RAII guard; drop it after plan execution to clean up the registry.
259pub fn register_context_assembly_runtime(
260    runtime: Arc<dyn ContextAssemblyRuntime>,
261) -> RegisteredContextAssemblyRuntime {
262    let id = CONTEXT_ASSEMBLY_RUNTIME_IDS.fetch_add(1, Ordering::Relaxed);
263    context_assembly_runtime_registry()
264        .write()
265        .insert(id, runtime);
266    RegisteredContextAssemblyRuntime { id }
267}
268
269pub(crate) fn lookup_context_assembly_runtime(
270    key: &str,
271) -> Option<Arc<dyn ContextAssemblyRuntime>> {
272    let id = key.parse::<u64>().ok()?;
273    context_assembly_runtime_registry().read().get(&id).cloned()
274}
275
276/// Shared runtime state accessible by all hirn DataFusion operators.
277///
278/// Registered in [`SessionContext`] at database open time. Operators retrieve
279/// it via [`HirnSessionExt::get`] — never through constructor injection.
280#[derive(Clone)]
281pub struct HirnSessionExt {
282    /// Hot+cold two-tier graph (CachedGraphStore lives in hirn-engine;
283    /// we store as `Arc<dyn Any + Send + Sync>` to avoid depending on
284    /// hirn-engine from hirn-exec).
285    graph: Arc<dyn Any + Send + Sync>,
286
287    /// Authoritative graph read runtime.
288    ///
289    /// When present, graph-aware operators should prefer this contract over
290    /// downcasting the raw hot graph handle so the engine can enforce a single
291    /// hot-vs-cold delegation rule.
292    graph_read_runtime: Option<Arc<dyn GraphReadRuntime>>,
293
294    /// Scoring weights and database configuration.
295    pub config: Arc<HirnConfig>,
296
297    /// Embedding provider (optional — not all operators need it).
298    embedder: Option<Arc<dyn Embedder>>,
299
300    /// Storage backend (optional — operators needing vector search use it).
301    storage: Option<Arc<dyn PhysicalStore>>,
302
303    /// Authoritative tokenizer for token-aware budgeted operators.
304    tokenizer: Option<Arc<dyn Tokenizer>>,
305
306    /// Authenticated agent identity for the current session.
307    /// Used by `PolicyPushdownRule` to identify the requesting agent.
308    agent_id: Option<String>,
309
310    /// Pre-resolved namespace access list from the policy engine.
311    ///
312    /// - `None` — open mode: no namespace filtering applied.
313    /// - `Some(vec)` — restrict scans to the listed namespaces.
314    ///   An empty vec means deny all access.
315    ///
316    /// Resolved once at session setup by evaluating Cedar policies.
317    allowed_namespaces: Option<Vec<String>>,
318
319    /// Query-scoped runtime handle for compiled terminal read commands.
320    query_read_runtime_key: Option<String>,
321
322    /// Query-scoped runtime handle for THINK context assembly.
323    ///
324    /// Set once per THINK query, looked up by `ContextAssemblyExec` at
325    /// execution time to retrieve the per-query `ContextAssemblyRuntime`.
326    context_assembly_runtime_key: Option<String>,
327
328    /// Query-scoped recall/search bindings used by compiled search operators.
329    recall_search_binding: Option<RecallSearchBinding>,
330
331    /// Shared historical RPE population statistics (Welford's online algorithm).
332    ///
333    /// Seeded from `WriteRuntime` at session setup; updated by `RpeScoreExec`
334    /// after each batch so that z-scores compare against the full historical
335    /// distribution, not just the current write batch (N-H08).
336    pub rpe_population_stats: Arc<RwLock<hirn_core::WelfordStats>>,
337
338    /// NLI classifier for `InterferenceDetectorExec` Check 3.
339    ///
340    /// `None` — operator uses its own default (`HeuristicNliClassifier`).
341    /// `Some(clf)` — operator uses the injected classifier, enabling ONNX upgrade
342    /// without recompiling or changing `InterferenceConfig`.
343    nli_classifier: Option<Arc<dyn NliClassifier>>,
344}
345
346// SAFETY: Arc fields are Send + Sync.
347const _: () = {
348    const fn assert_send_sync<T: Send + Sync>() {}
349    assert_send_sync::<HirnSessionExt>();
350};
351
352#[allow(clippy::missing_fields_in_debug)] // rpe_population_stats and recall_search_binding intentionally omitted (lock + query-scoped)
353impl std::fmt::Debug for HirnSessionExt {
354    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355        f.debug_struct("HirnSessionExt")
356            .field("graph", &"<type-erased>")
357            .field("has_graph_read_runtime", &self.graph_read_runtime.is_some())
358            .field("config", &self.config)
359            .field("has_embedder", &self.embedder.is_some())
360            .field("has_storage", &self.storage.is_some())
361            .field("has_tokenizer", &self.tokenizer.is_some())
362            .field("agent_id", &self.agent_id)
363            .field("allowed_namespaces", &self.allowed_namespaces)
364            .field(
365                "has_query_read_runtime",
366                &self.query_read_runtime_key.is_some(),
367            )
368            .field(
369                "has_context_assembly_runtime",
370                &self.context_assembly_runtime_key.is_some(),
371            )
372            .field(
373                "nli_classifier_backend",
374                &self
375                    .nli_classifier
376                    .as_ref()
377                    .map(|c| c.backend_name())
378                    .unwrap_or("default"),
379            )
380            // rpe_population_stats and recall_search_binding omitted from Debug
381            // (locking during format is undesirable; binding is query-scoped).
382            .finish_non_exhaustive()
383    }
384}
385
386impl HirnSessionExt {
387    /// Create a new extension bundle.
388    /// Create a new extension bundle.
389    pub fn new(
390        graph: Arc<dyn Any + Send + Sync>,
391        config: Arc<HirnConfig>,
392        embedder: Option<Arc<dyn Embedder>>,
393    ) -> Self {
394        Self {
395            graph,
396            graph_read_runtime: None,
397            config,
398            embedder,
399            storage: None,
400            tokenizer: None,
401            agent_id: None,
402            allowed_namespaces: None,
403            query_read_runtime_key: None,
404            context_assembly_runtime_key: None,
405            recall_search_binding: None,
406            rpe_population_stats: Arc::new(RwLock::new(hirn_core::WelfordStats::new())),
407            nli_classifier: None,
408        }
409    }
410
411    /// Seed historical RPE population statistics (from `WriteRuntime`).
412    ///
413    /// Allows `RpeScoreExec` to z-score against the full historical
414    /// distribution instead of only the current write batch (N-H08).
415    pub fn with_rpe_population_stats(
416        mut self,
417        stats: Arc<RwLock<hirn_core::WelfordStats>>,
418    ) -> Self {
419        self.rpe_population_stats = stats;
420        self
421    }
422
423    /// Set the authenticated agent identity.
424    pub fn with_agent_id(mut self, agent_id: impl Into<String>) -> Self {
425        self.agent_id = Some(agent_id.into());
426        self
427    }
428
429    /// Inject an NLI classifier for `InterferenceDetectorExec` Check 3.
430    ///
431    /// Use this to upgrade from the default `HeuristicNliClassifier` to a
432    /// DeBERTa-MNLI ONNX model at database open time, without changing any
433    /// `InterferenceConfig` fields.
434    pub fn with_nli_classifier(mut self, clf: Arc<dyn NliClassifier>) -> Self {
435        self.nli_classifier = Some(clf);
436        self
437    }
438
439    /// Returns the injected NLI classifier, if any.
440    ///
441    /// `None` — `InterferenceDetectorExec` will use its own default.
442    pub fn nli_classifier(&self) -> Option<Arc<dyn NliClassifier>> {
443        self.nli_classifier.clone()
444    }
445
446    /// Set the storage backend for operators needing vector search.
447    pub fn with_storage(mut self, storage: Arc<dyn PhysicalStore>) -> Self {
448        self.storage = Some(storage);
449        self
450    }
451
452    /// Set the authoritative graph read runtime.
453    pub fn with_graph_read_runtime(
454        mut self,
455        graph_read_runtime: Arc<dyn GraphReadRuntime>,
456    ) -> Self {
457        self.graph_read_runtime = Some(graph_read_runtime);
458        self
459    }
460
461    /// Set the tokenizer for operators needing authoritative token counts.
462    pub fn with_tokenizer(mut self, tokenizer: Arc<dyn Tokenizer>) -> Self {
463        self.tokenizer = Some(tokenizer);
464        self
465    }
466
467    /// Set pre-resolved allowed namespaces.
468    ///
469    /// `None` means open mode (no filtering). `Some(vec)` restricts to those namespaces.
470    pub fn with_allowed_namespaces(mut self, namespaces: Option<Vec<String>>) -> Self {
471        self.allowed_namespaces = namespaces;
472        self
473    }
474
475    /// Set a query-scoped runtime handle for compiled terminal read operators.
476    pub fn with_query_read_runtime_key(mut self, key: Option<String>) -> Self {
477        self.query_read_runtime_key = key;
478        self
479    }
480
481    /// Set a query-scoped runtime handle for THINK context assembly.
482    pub fn with_context_assembly_runtime_key(mut self, key: Option<String>) -> Self {
483        self.context_assembly_runtime_key = key;
484        self
485    }
486
487    /// Set query-scoped compiled recall/search bindings.
488    pub fn with_recall_search_binding(mut self, binding: Option<RecallSearchBinding>) -> Self {
489        self.recall_search_binding = binding;
490        self
491    }
492
493    /// Returns the agent ID, if set.
494    pub fn agent_id(&self) -> Option<&str> {
495        self.agent_id.as_deref()
496    }
497
498    /// Returns the pre-resolved allowed namespaces.
499    ///
500    /// `None` = open mode (no filtering), `Some(&[])` = deny all.
501    pub fn allowed_namespaces(&self) -> Option<&[String]> {
502        self.allowed_namespaces.as_deref()
503    }
504
505    /// Retrieve `HirnSessionExt` from a [`SessionContext`].
506    ///
507    /// Returns a clone since `SessionContext::state()` returns by value.
508    ///
509    /// # Errors
510    /// Returns an error if the extension was never registered.
511    pub fn get(ctx: &SessionContext) -> datafusion_common::Result<Self> {
512        let state = ctx.state();
513        let ext = state
514            .config()
515            .options()
516            .extensions
517            .get::<Self>()
518            .ok_or_else(|| {
519                datafusion_common::DataFusionError::Configuration(
520                    "HirnSessionExt not registered in SessionContext — \
521                     was the database opened correctly?"
522                        .into(),
523                )
524            })?;
525        Ok(ext.clone())
526    }
527
528    /// Register this extension in a [`SessionContext`].
529    ///
530    /// # Errors
531    /// Returns an error if the `SessionState` has already been dropped.
532    pub fn register(self, ctx: &SessionContext) -> datafusion_common::Result<()> {
533        let state = ctx.state_weak_ref().upgrade().ok_or_else(|| {
534            datafusion_common::DataFusionError::Internal(
535                "Cannot register HirnSessionExt: SessionState already dropped".into(),
536            )
537        })?;
538        state
539            .write()
540            .config_mut()
541            .options_mut()
542            .extensions
543            .insert(self);
544        Ok(())
545    }
546
547    /// Downcast the type-erased graph handle to the concrete type `T`.
548    ///
549    /// Returns `None` if the stored graph is not of type `T`.
550    pub fn graph_as<T: Send + Sync + 'static>(&self) -> Option<&T> {
551        self.graph.downcast_ref::<T>()
552    }
553
554    /// Clone the graph `Arc` and downcast to `Arc<T>`.
555    ///
556    /// Returns `None` if the stored graph is not of type `T`.
557    pub fn graph_arc<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
558        self.graph.clone().downcast::<T>().ok()
559    }
560
561    /// Raw `Arc<dyn Any>` graph handle.
562    pub fn graph_any(&self) -> &Arc<dyn Any + Send + Sync> {
563        &self.graph
564    }
565
566    /// Optional graph read runtime.
567    pub fn graph_read_runtime(&self) -> Option<Arc<dyn GraphReadRuntime>> {
568        self.graph_read_runtime.clone()
569    }
570
571    /// Optional terminal-read runtime resolved from the query-scoped registry.
572    pub fn query_read_runtime(&self) -> Option<Arc<dyn QueryReadRuntime>> {
573        self.query_read_runtime_key
574            .as_deref()
575            .and_then(lookup_query_read_runtime)
576    }
577
578    /// Optional context-assembly runtime resolved from the query-scoped registry.
579    pub fn context_assembly_runtime(&self) -> Option<Arc<dyn ContextAssemblyRuntime>> {
580        self.context_assembly_runtime_key
581            .as_deref()
582            .and_then(lookup_context_assembly_runtime)
583    }
584
585    /// Optional query-scoped compiled recall/search binding.
586    pub fn recall_search_binding(&self) -> Option<&RecallSearchBinding> {
587        self.recall_search_binding.as_ref()
588    }
589
590    /// Optional embedder reference.
591    pub fn embedder(&self) -> Option<&dyn Embedder> {
592        self.embedder.as_deref()
593    }
594
595    /// Optional embedder Arc clone.
596    pub fn embedder_arc(&self) -> Option<Arc<dyn Embedder>> {
597        self.embedder.clone()
598    }
599
600    /// Optional storage reference.
601    pub fn storage(&self) -> Option<&dyn PhysicalStore> {
602        self.storage.as_deref()
603    }
604
605    /// Optional storage Arc clone.
606    pub fn storage_arc(&self) -> Option<Arc<dyn PhysicalStore>> {
607        self.storage.clone()
608    }
609
610    /// Optional tokenizer reference.
611    pub fn tokenizer(&self) -> Option<&dyn Tokenizer> {
612        self.tokenizer.as_deref()
613    }
614
615    /// Optional tokenizer Arc clone.
616    pub fn tokenizer_arc(&self) -> Option<Arc<dyn Tokenizer>> {
617        self.tokenizer.clone()
618    }
619}
620
621impl ExtensionOptions for HirnSessionExt {
622    fn as_any(&self) -> &dyn Any {
623        self
624    }
625
626    fn as_any_mut(&mut self) -> &mut dyn Any {
627        self
628    }
629
630    fn cloned(&self) -> Box<dyn ExtensionOptions> {
631        Box::new(self.clone())
632    }
633
634    fn set(&mut self, _key: &str, _value: &str) -> datafusion_common::Result<()> {
635        Ok(())
636    }
637
638    fn entries(&self) -> Vec<ConfigEntry> {
639        vec![]
640    }
641}
642
643impl ConfigExtension for HirnSessionExt {
644    const PREFIX: &'static str = "hirn";
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650
651    /// Compile-time assertion: `HirnSessionExt` is `Send + Sync`.
652    const _: fn() = || {
653        fn assert_send_sync<T: Send + Sync>() {}
654        assert_send_sync::<HirnSessionExt>();
655    };
656
657    #[test]
658    fn register_and_retrieve() {
659        let ctx = SessionContext::new();
660        let config = Arc::new(HirnConfig::default());
661        let ext = HirnSessionExt::new(Arc::new(42_u32), config.clone(), None);
662        ext.register(&ctx).expect("register should succeed");
663
664        let retrieved = HirnSessionExt::get(&ctx).expect("extension should be present");
665        // Same Arc — pointer equality proves we got the same config back.
666        assert!(Arc::ptr_eq(&retrieved.config, &config));
667        assert!(retrieved.embedder().is_none());
668        assert!(retrieved.tokenizer().is_none());
669    }
670
671    #[test]
672    fn missing_extension_gives_clear_error() {
673        let ctx = SessionContext::new();
674        let err = HirnSessionExt::get(&ctx).unwrap_err();
675        assert!(
676            err.to_string().contains("HirnSessionExt not registered"),
677            "unexpected error: {err}"
678        );
679    }
680
681    #[test]
682    fn graph_downcast() {
683        let ctx = SessionContext::new();
684        let ext = HirnSessionExt::new(
685            Arc::new(String::from("test_graph")),
686            Arc::new(HirnConfig::default()),
687            None,
688        );
689        ext.register(&ctx).expect("register should succeed");
690
691        let retrieved = HirnSessionExt::get(&ctx).unwrap();
692        let graph = retrieved.graph_as::<String>().unwrap();
693        assert_eq!(graph, "test_graph");
694    }
695}