1use std::any::Any;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::sync::OnceLock;
11use std::sync::atomic::{AtomicU64, Ordering};
12
13use arrow_array::RecordBatch;
14use async_trait::async_trait;
15use datafusion::prelude::SessionContext;
16use datafusion_common::config::{ConfigEntry, ConfigExtension, ExtensionOptions};
17use hirn_core::HirnResult;
18use hirn_core::config::HirnConfig;
19use hirn_core::embed::Embedder;
20use hirn_core::id::MemoryId;
21use hirn_core::tokenizer::Tokenizer;
22use hirn_core::types::{EdgeRelation, Namespace};
23use hirn_graph::PprConfig;
24use hirn_query::compiler::plan_compiler::SemanticTargetKindRepr;
25use hirn_storage::PhysicalStore;
26use hirn_storage::store::DistanceMetric;
27use parking_lot::RwLock;
28
29use crate::operators::ActivationMode;
30use crate::operators::SearchNumericFilter;
31use crate::operators::nli_contradiction::NliClassifier;
32
33#[derive(Debug, Clone, PartialEq)]
34pub struct GraphActivationOutput {
35 pub ids: Vec<String>,
36 pub scores: Vec<f32>,
37 pub depths: Vec<u32>,
38}
39
40#[derive(Debug, Clone, PartialEq)]
41pub struct GraphCausalChainRow {
42 pub chain_id: String,
43 pub source_id: String,
44 pub target_id: String,
45 pub strength: f32,
46 pub confidence: f32,
47 pub evidence_count: u32,
48 pub mechanism: Option<String>,
49 pub depth: u32,
50 pub chain_score: f32,
51}
52
53#[derive(Debug, Clone, PartialEq)]
54pub struct GraphTraverseRow {
55 pub node_id: String,
56 pub depth: u32,
57}
58
59#[derive(Debug, Clone, PartialEq)]
60pub struct RecallSearchBinding {
61 pub query_vector: Vec<f32>,
62 pub filter: Option<String>,
63 pub limit: usize,
64 pub metric: DistanceMetric,
65 pub numeric_filters: Vec<SearchNumericFilter>,
66 pub temporal_start_ms: Option<i64>,
67 pub temporal_end_ms: Option<i64>,
68 pub temporal_expansion: bool,
70}
71
72#[async_trait]
73pub trait GraphReadRuntime: Send + Sync {
74 async fn activate_graph(
75 &self,
76 seeds: &[MemoryId],
77 mode: ActivationMode,
78 ppr_config: Option<&PprConfig>,
79 max_depth: u32,
80 epsilon: f32,
81 inhibition_mu: f32,
82 delegation_threshold: usize,
83 allowed_namespaces: Option<&[Namespace]>,
84 ) -> HirnResult<GraphActivationOutput>;
85
86 async fn causal_chain(
87 &self,
88 start_ids: &[MemoryId],
89 max_depth: u32,
90 confidence_threshold: f32,
91 delegation_threshold: usize,
92 relation: EdgeRelation,
93 allowed_namespaces: Option<&[Namespace]>,
94 ) -> HirnResult<Vec<GraphCausalChainRow>>;
95
96 async fn traverse_graph(
97 &self,
98 start_ids: &[MemoryId],
99 max_depth: u32,
100 delegation_threshold: usize,
101 relation_filter: Option<&[EdgeRelation]>,
102 allowed_namespaces: Option<&[Namespace]>,
103 ) -> HirnResult<Vec<GraphTraverseRow>>;
104}
105
106#[async_trait]
107pub trait QueryReadRuntime: Send + Sync {
108 async fn inspect_json(
109 &self,
110 target: &str,
111 target_kind: SemanticTargetKindRepr,
112 agent_id: &str,
113 allowed_namespaces: Option<&[String]>,
114 ) -> HirnResult<Vec<u8>>;
115
116 async fn trace_json(
117 &self,
118 target: &str,
119 target_kind: SemanticTargetKindRepr,
120 agent_id: &str,
121 allowed_namespaces: Option<&[String]>,
122 ) -> HirnResult<Vec<u8>>;
123
124 async fn explain_causes_json(
125 &self,
126 query: &str,
127 depth: u32,
128 namespace: Option<&str>,
129 allowed_namespaces: Option<&[String]>,
130 ) -> HirnResult<Vec<u8>>;
131
132 async fn what_if_json(
133 &self,
134 intervention: &str,
135 outcome: &str,
136 namespace: Option<&str>,
137 allowed_namespaces: Option<&[String]>,
138 ) -> HirnResult<Vec<u8>>;
139
140 async fn counterfactual_json(
141 &self,
142 antecedent: &str,
143 consequent: &str,
144 namespace: Option<&str>,
145 allowed_namespaces: Option<&[String]>,
146 ) -> HirnResult<Vec<u8>>;
147
148 async fn show_policies_json(
149 &self,
150 principal_kind: Option<&str>,
151 principal_name: Option<&str>,
152 ) -> HirnResult<Vec<u8>>;
153
154 async fn explain_policy_json(
155 &self,
156 principal_kind: &str,
157 principal_name: &str,
158 resource_type: &str,
159 resource_name: &str,
160 action: &str,
161 ) -> HirnResult<Vec<u8>>;
162}
163
164static QUERY_READ_RUNTIME_IDS: AtomicU64 = AtomicU64::new(1);
165
166fn query_read_runtime_registry() -> &'static RwLock<HashMap<u64, Arc<dyn QueryReadRuntime>>> {
167 static REGISTRY: OnceLock<RwLock<HashMap<u64, Arc<dyn QueryReadRuntime>>>> = OnceLock::new();
168 REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
169}
170
171#[derive(Debug)]
172pub struct RegisteredQueryReadRuntime {
173 id: u64,
174}
175
176impl RegisteredQueryReadRuntime {
177 pub fn key(&self) -> String {
178 self.id.to_string()
179 }
180}
181
182impl Drop for RegisteredQueryReadRuntime {
183 fn drop(&mut self) {
184 query_read_runtime_registry().write().remove(&self.id);
185 }
186}
187
188pub fn register_query_read_runtime(
189 runtime: Arc<dyn QueryReadRuntime>,
190) -> RegisteredQueryReadRuntime {
191 let id = QUERY_READ_RUNTIME_IDS.fetch_add(1, Ordering::Relaxed);
192 query_read_runtime_registry().write().insert(id, runtime);
193 RegisteredQueryReadRuntime { id }
194}
195
196fn lookup_query_read_runtime(key: &str) -> Option<Arc<dyn QueryReadRuntime>> {
197 let id = key.parse::<u64>().ok()?;
198 query_read_runtime_registry().read().get(&id).cloned()
199}
200
201#[async_trait]
213pub trait ContextAssemblyRuntime: Send + Sync {
214 async fn assemble_from_batches(
220 &self,
221 candidate_batches: Vec<RecordBatch>,
222 ) -> HirnResult<Vec<u8>>;
223}
224
225static CONTEXT_ASSEMBLY_RUNTIME_IDS: AtomicU64 = AtomicU64::new(1);
226
227fn context_assembly_runtime_registry()
228-> &'static RwLock<HashMap<u64, Arc<dyn ContextAssemblyRuntime>>> {
229 static REGISTRY: OnceLock<RwLock<HashMap<u64, Arc<dyn ContextAssemblyRuntime>>>> =
230 OnceLock::new();
231 REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
232}
233
234#[derive(Debug)]
239pub struct RegisteredContextAssemblyRuntime {
240 id: u64,
241}
242
243impl RegisteredContextAssemblyRuntime {
244 pub fn key(&self) -> String {
246 self.id.to_string()
247 }
248}
249
250impl Drop for RegisteredContextAssemblyRuntime {
251 fn drop(&mut self) {
252 context_assembly_runtime_registry().write().remove(&self.id);
253 }
254}
255
256pub fn register_context_assembly_runtime(
260 runtime: Arc<dyn ContextAssemblyRuntime>,
261) -> RegisteredContextAssemblyRuntime {
262 let id = CONTEXT_ASSEMBLY_RUNTIME_IDS.fetch_add(1, Ordering::Relaxed);
263 context_assembly_runtime_registry()
264 .write()
265 .insert(id, runtime);
266 RegisteredContextAssemblyRuntime { id }
267}
268
269pub(crate) fn lookup_context_assembly_runtime(
270 key: &str,
271) -> Option<Arc<dyn ContextAssemblyRuntime>> {
272 let id = key.parse::<u64>().ok()?;
273 context_assembly_runtime_registry().read().get(&id).cloned()
274}
275
276#[derive(Clone)]
281pub struct HirnSessionExt {
282 graph: Arc<dyn Any + Send + Sync>,
286
287 graph_read_runtime: Option<Arc<dyn GraphReadRuntime>>,
293
294 pub config: Arc<HirnConfig>,
296
297 embedder: Option<Arc<dyn Embedder>>,
299
300 storage: Option<Arc<dyn PhysicalStore>>,
302
303 tokenizer: Option<Arc<dyn Tokenizer>>,
305
306 agent_id: Option<String>,
309
310 allowed_namespaces: Option<Vec<String>>,
318
319 query_read_runtime_key: Option<String>,
321
322 context_assembly_runtime_key: Option<String>,
327
328 recall_search_binding: Option<RecallSearchBinding>,
330
331 pub rpe_population_stats: Arc<RwLock<hirn_core::WelfordStats>>,
337
338 nli_classifier: Option<Arc<dyn NliClassifier>>,
344}
345
346const _: () = {
348 const fn assert_send_sync<T: Send + Sync>() {}
349 assert_send_sync::<HirnSessionExt>();
350};
351
352#[allow(clippy::missing_fields_in_debug)] impl std::fmt::Debug for HirnSessionExt {
354 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 f.debug_struct("HirnSessionExt")
356 .field("graph", &"<type-erased>")
357 .field("has_graph_read_runtime", &self.graph_read_runtime.is_some())
358 .field("config", &self.config)
359 .field("has_embedder", &self.embedder.is_some())
360 .field("has_storage", &self.storage.is_some())
361 .field("has_tokenizer", &self.tokenizer.is_some())
362 .field("agent_id", &self.agent_id)
363 .field("allowed_namespaces", &self.allowed_namespaces)
364 .field(
365 "has_query_read_runtime",
366 &self.query_read_runtime_key.is_some(),
367 )
368 .field(
369 "has_context_assembly_runtime",
370 &self.context_assembly_runtime_key.is_some(),
371 )
372 .field(
373 "nli_classifier_backend",
374 &self
375 .nli_classifier
376 .as_ref()
377 .map(|c| c.backend_name())
378 .unwrap_or("default"),
379 )
380 .finish_non_exhaustive()
383 }
384}
385
386impl HirnSessionExt {
387 pub fn new(
390 graph: Arc<dyn Any + Send + Sync>,
391 config: Arc<HirnConfig>,
392 embedder: Option<Arc<dyn Embedder>>,
393 ) -> Self {
394 Self {
395 graph,
396 graph_read_runtime: None,
397 config,
398 embedder,
399 storage: None,
400 tokenizer: None,
401 agent_id: None,
402 allowed_namespaces: None,
403 query_read_runtime_key: None,
404 context_assembly_runtime_key: None,
405 recall_search_binding: None,
406 rpe_population_stats: Arc::new(RwLock::new(hirn_core::WelfordStats::new())),
407 nli_classifier: None,
408 }
409 }
410
411 pub fn with_rpe_population_stats(
416 mut self,
417 stats: Arc<RwLock<hirn_core::WelfordStats>>,
418 ) -> Self {
419 self.rpe_population_stats = stats;
420 self
421 }
422
423 pub fn with_agent_id(mut self, agent_id: impl Into<String>) -> Self {
425 self.agent_id = Some(agent_id.into());
426 self
427 }
428
429 pub fn with_nli_classifier(mut self, clf: Arc<dyn NliClassifier>) -> Self {
435 self.nli_classifier = Some(clf);
436 self
437 }
438
439 pub fn nli_classifier(&self) -> Option<Arc<dyn NliClassifier>> {
443 self.nli_classifier.clone()
444 }
445
446 pub fn with_storage(mut self, storage: Arc<dyn PhysicalStore>) -> Self {
448 self.storage = Some(storage);
449 self
450 }
451
452 pub fn with_graph_read_runtime(
454 mut self,
455 graph_read_runtime: Arc<dyn GraphReadRuntime>,
456 ) -> Self {
457 self.graph_read_runtime = Some(graph_read_runtime);
458 self
459 }
460
461 pub fn with_tokenizer(mut self, tokenizer: Arc<dyn Tokenizer>) -> Self {
463 self.tokenizer = Some(tokenizer);
464 self
465 }
466
467 pub fn with_allowed_namespaces(mut self, namespaces: Option<Vec<String>>) -> Self {
471 self.allowed_namespaces = namespaces;
472 self
473 }
474
475 pub fn with_query_read_runtime_key(mut self, key: Option<String>) -> Self {
477 self.query_read_runtime_key = key;
478 self
479 }
480
481 pub fn with_context_assembly_runtime_key(mut self, key: Option<String>) -> Self {
483 self.context_assembly_runtime_key = key;
484 self
485 }
486
487 pub fn with_recall_search_binding(mut self, binding: Option<RecallSearchBinding>) -> Self {
489 self.recall_search_binding = binding;
490 self
491 }
492
493 pub fn agent_id(&self) -> Option<&str> {
495 self.agent_id.as_deref()
496 }
497
498 pub fn allowed_namespaces(&self) -> Option<&[String]> {
502 self.allowed_namespaces.as_deref()
503 }
504
505 pub fn get(ctx: &SessionContext) -> datafusion_common::Result<Self> {
512 let state = ctx.state();
513 let ext = state
514 .config()
515 .options()
516 .extensions
517 .get::<Self>()
518 .ok_or_else(|| {
519 datafusion_common::DataFusionError::Configuration(
520 "HirnSessionExt not registered in SessionContext — \
521 was the database opened correctly?"
522 .into(),
523 )
524 })?;
525 Ok(ext.clone())
526 }
527
528 pub fn register(self, ctx: &SessionContext) -> datafusion_common::Result<()> {
533 let state = ctx.state_weak_ref().upgrade().ok_or_else(|| {
534 datafusion_common::DataFusionError::Internal(
535 "Cannot register HirnSessionExt: SessionState already dropped".into(),
536 )
537 })?;
538 state
539 .write()
540 .config_mut()
541 .options_mut()
542 .extensions
543 .insert(self);
544 Ok(())
545 }
546
547 pub fn graph_as<T: Send + Sync + 'static>(&self) -> Option<&T> {
551 self.graph.downcast_ref::<T>()
552 }
553
554 pub fn graph_arc<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
558 self.graph.clone().downcast::<T>().ok()
559 }
560
561 pub fn graph_any(&self) -> &Arc<dyn Any + Send + Sync> {
563 &self.graph
564 }
565
566 pub fn graph_read_runtime(&self) -> Option<Arc<dyn GraphReadRuntime>> {
568 self.graph_read_runtime.clone()
569 }
570
571 pub fn query_read_runtime(&self) -> Option<Arc<dyn QueryReadRuntime>> {
573 self.query_read_runtime_key
574 .as_deref()
575 .and_then(lookup_query_read_runtime)
576 }
577
578 pub fn context_assembly_runtime(&self) -> Option<Arc<dyn ContextAssemblyRuntime>> {
580 self.context_assembly_runtime_key
581 .as_deref()
582 .and_then(lookup_context_assembly_runtime)
583 }
584
585 pub fn recall_search_binding(&self) -> Option<&RecallSearchBinding> {
587 self.recall_search_binding.as_ref()
588 }
589
590 pub fn embedder(&self) -> Option<&dyn Embedder> {
592 self.embedder.as_deref()
593 }
594
595 pub fn embedder_arc(&self) -> Option<Arc<dyn Embedder>> {
597 self.embedder.clone()
598 }
599
600 pub fn storage(&self) -> Option<&dyn PhysicalStore> {
602 self.storage.as_deref()
603 }
604
605 pub fn storage_arc(&self) -> Option<Arc<dyn PhysicalStore>> {
607 self.storage.clone()
608 }
609
610 pub fn tokenizer(&self) -> Option<&dyn Tokenizer> {
612 self.tokenizer.as_deref()
613 }
614
615 pub fn tokenizer_arc(&self) -> Option<Arc<dyn Tokenizer>> {
617 self.tokenizer.clone()
618 }
619}
620
621impl ExtensionOptions for HirnSessionExt {
622 fn as_any(&self) -> &dyn Any {
623 self
624 }
625
626 fn as_any_mut(&mut self) -> &mut dyn Any {
627 self
628 }
629
630 fn cloned(&self) -> Box<dyn ExtensionOptions> {
631 Box::new(self.clone())
632 }
633
634 fn set(&mut self, _key: &str, _value: &str) -> datafusion_common::Result<()> {
635 Ok(())
636 }
637
638 fn entries(&self) -> Vec<ConfigEntry> {
639 vec![]
640 }
641}
642
643impl ConfigExtension for HirnSessionExt {
644 const PREFIX: &'static str = "hirn";
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650
651 const _: fn() = || {
653 fn assert_send_sync<T: Send + Sync>() {}
654 assert_send_sync::<HirnSessionExt>();
655 };
656
657 #[test]
658 fn register_and_retrieve() {
659 let ctx = SessionContext::new();
660 let config = Arc::new(HirnConfig::default());
661 let ext = HirnSessionExt::new(Arc::new(42_u32), config.clone(), None);
662 ext.register(&ctx).expect("register should succeed");
663
664 let retrieved = HirnSessionExt::get(&ctx).expect("extension should be present");
665 assert!(Arc::ptr_eq(&retrieved.config, &config));
667 assert!(retrieved.embedder().is_none());
668 assert!(retrieved.tokenizer().is_none());
669 }
670
671 #[test]
672 fn missing_extension_gives_clear_error() {
673 let ctx = SessionContext::new();
674 let err = HirnSessionExt::get(&ctx).unwrap_err();
675 assert!(
676 err.to_string().contains("HirnSessionExt not registered"),
677 "unexpected error: {err}"
678 );
679 }
680
681 #[test]
682 fn graph_downcast() {
683 let ctx = SessionContext::new();
684 let ext = HirnSessionExt::new(
685 Arc::new(String::from("test_graph")),
686 Arc::new(HirnConfig::default()),
687 None,
688 );
689 ext.register(&ctx).expect("register should succeed");
690
691 let retrieved = HirnSessionExt::get(&ctx).unwrap();
692 let graph = retrieved.graph_as::<String>().unwrap();
693 assert_eq!(graph, "test_graph");
694 }
695}