use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicU64, Ordering};
use arrow_array::RecordBatch;
use async_trait::async_trait;
use datafusion::prelude::SessionContext;
use datafusion_common::config::{ConfigEntry, ConfigExtension, ExtensionOptions};
use hirn_core::HirnResult;
use hirn_core::config::HirnConfig;
use hirn_core::embed::Embedder;
use hirn_core::id::MemoryId;
use hirn_core::tokenizer::Tokenizer;
use hirn_core::types::{EdgeRelation, Namespace};
use hirn_graph::PprConfig;
use hirn_query::compiler::plan_compiler::SemanticTargetKindRepr;
use hirn_storage::PhysicalStore;
use hirn_storage::store::DistanceMetric;
use parking_lot::RwLock;
use crate::operators::ActivationMode;
use crate::operators::SearchNumericFilter;
use crate::operators::nli_contradiction::NliClassifier;
#[derive(Debug, Clone, PartialEq)]
pub struct GraphActivationOutput {
pub ids: Vec<String>,
pub scores: Vec<f32>,
pub depths: Vec<u32>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct GraphCausalChainRow {
pub chain_id: String,
pub source_id: String,
pub target_id: String,
pub strength: f32,
pub confidence: f32,
pub evidence_count: u32,
pub mechanism: Option<String>,
pub depth: u32,
pub chain_score: f32,
}
#[derive(Debug, Clone, PartialEq)]
pub struct GraphTraverseRow {
pub node_id: String,
pub depth: u32,
}
#[derive(Debug, Clone, PartialEq)]
pub struct RecallSearchBinding {
pub query_vector: Vec<f32>,
pub filter: Option<String>,
pub limit: usize,
pub metric: DistanceMetric,
pub numeric_filters: Vec<SearchNumericFilter>,
pub temporal_start_ms: Option<i64>,
pub temporal_end_ms: Option<i64>,
pub temporal_expansion: bool,
}
#[async_trait]
pub trait GraphReadRuntime: Send + Sync {
async fn activate_graph(
&self,
seeds: &[MemoryId],
mode: ActivationMode,
ppr_config: Option<&PprConfig>,
max_depth: u32,
epsilon: f32,
inhibition_mu: f32,
delegation_threshold: usize,
allowed_namespaces: Option<&[Namespace]>,
) -> HirnResult<GraphActivationOutput>;
async fn causal_chain(
&self,
start_ids: &[MemoryId],
max_depth: u32,
confidence_threshold: f32,
delegation_threshold: usize,
relation: EdgeRelation,
allowed_namespaces: Option<&[Namespace]>,
) -> HirnResult<Vec<GraphCausalChainRow>>;
async fn traverse_graph(
&self,
start_ids: &[MemoryId],
max_depth: u32,
delegation_threshold: usize,
relation_filter: Option<&[EdgeRelation]>,
allowed_namespaces: Option<&[Namespace]>,
) -> HirnResult<Vec<GraphTraverseRow>>;
}
#[async_trait]
pub trait QueryReadRuntime: Send + Sync {
async fn inspect_json(
&self,
target: &str,
target_kind: SemanticTargetKindRepr,
agent_id: &str,
allowed_namespaces: Option<&[String]>,
) -> HirnResult<Vec<u8>>;
async fn trace_json(
&self,
target: &str,
target_kind: SemanticTargetKindRepr,
agent_id: &str,
allowed_namespaces: Option<&[String]>,
) -> HirnResult<Vec<u8>>;
async fn explain_causes_json(
&self,
query: &str,
depth: u32,
namespace: Option<&str>,
allowed_namespaces: Option<&[String]>,
) -> HirnResult<Vec<u8>>;
async fn what_if_json(
&self,
intervention: &str,
outcome: &str,
namespace: Option<&str>,
allowed_namespaces: Option<&[String]>,
) -> HirnResult<Vec<u8>>;
async fn counterfactual_json(
&self,
antecedent: &str,
consequent: &str,
namespace: Option<&str>,
allowed_namespaces: Option<&[String]>,
) -> HirnResult<Vec<u8>>;
async fn show_policies_json(
&self,
principal_kind: Option<&str>,
principal_name: Option<&str>,
) -> HirnResult<Vec<u8>>;
async fn explain_policy_json(
&self,
principal_kind: &str,
principal_name: &str,
resource_type: &str,
resource_name: &str,
action: &str,
) -> HirnResult<Vec<u8>>;
}
static QUERY_READ_RUNTIME_IDS: AtomicU64 = AtomicU64::new(1);
fn query_read_runtime_registry() -> &'static RwLock<HashMap<u64, Arc<dyn QueryReadRuntime>>> {
static REGISTRY: OnceLock<RwLock<HashMap<u64, Arc<dyn QueryReadRuntime>>>> = OnceLock::new();
REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
}
#[derive(Debug)]
pub struct RegisteredQueryReadRuntime {
id: u64,
}
impl RegisteredQueryReadRuntime {
pub fn key(&self) -> String {
self.id.to_string()
}
}
impl Drop for RegisteredQueryReadRuntime {
fn drop(&mut self) {
query_read_runtime_registry().write().remove(&self.id);
}
}
pub fn register_query_read_runtime(
runtime: Arc<dyn QueryReadRuntime>,
) -> RegisteredQueryReadRuntime {
let id = QUERY_READ_RUNTIME_IDS.fetch_add(1, Ordering::Relaxed);
query_read_runtime_registry().write().insert(id, runtime);
RegisteredQueryReadRuntime { id }
}
fn lookup_query_read_runtime(key: &str) -> Option<Arc<dyn QueryReadRuntime>> {
let id = key.parse::<u64>().ok()?;
query_read_runtime_registry().read().get(&id).cloned()
}
#[async_trait]
pub trait ContextAssemblyRuntime: Send + Sync {
async fn assemble_from_batches(
&self,
candidate_batches: Vec<RecordBatch>,
) -> HirnResult<Vec<u8>>;
}
static CONTEXT_ASSEMBLY_RUNTIME_IDS: AtomicU64 = AtomicU64::new(1);
fn context_assembly_runtime_registry()
-> &'static RwLock<HashMap<u64, Arc<dyn ContextAssemblyRuntime>>> {
static REGISTRY: OnceLock<RwLock<HashMap<u64, Arc<dyn ContextAssemblyRuntime>>>> =
OnceLock::new();
REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
}
#[derive(Debug)]
pub struct RegisteredContextAssemblyRuntime {
id: u64,
}
impl RegisteredContextAssemblyRuntime {
pub fn key(&self) -> String {
self.id.to_string()
}
}
impl Drop for RegisteredContextAssemblyRuntime {
fn drop(&mut self) {
context_assembly_runtime_registry().write().remove(&self.id);
}
}
pub fn register_context_assembly_runtime(
runtime: Arc<dyn ContextAssemblyRuntime>,
) -> RegisteredContextAssemblyRuntime {
let id = CONTEXT_ASSEMBLY_RUNTIME_IDS.fetch_add(1, Ordering::Relaxed);
context_assembly_runtime_registry()
.write()
.insert(id, runtime);
RegisteredContextAssemblyRuntime { id }
}
pub(crate) fn lookup_context_assembly_runtime(
key: &str,
) -> Option<Arc<dyn ContextAssemblyRuntime>> {
let id = key.parse::<u64>().ok()?;
context_assembly_runtime_registry().read().get(&id).cloned()
}
#[derive(Clone)]
pub struct HirnSessionExt {
graph: Arc<dyn Any + Send + Sync>,
graph_read_runtime: Option<Arc<dyn GraphReadRuntime>>,
pub config: Arc<HirnConfig>,
embedder: Option<Arc<dyn Embedder>>,
storage: Option<Arc<dyn PhysicalStore>>,
tokenizer: Option<Arc<dyn Tokenizer>>,
agent_id: Option<String>,
allowed_namespaces: Option<Vec<String>>,
query_read_runtime_key: Option<String>,
context_assembly_runtime_key: Option<String>,
recall_search_binding: Option<RecallSearchBinding>,
pub rpe_population_stats: Arc<RwLock<hirn_core::WelfordStats>>,
nli_classifier: Option<Arc<dyn NliClassifier>>,
}
const _: () = {
const fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<HirnSessionExt>();
};
#[allow(clippy::missing_fields_in_debug)] impl std::fmt::Debug for HirnSessionExt {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HirnSessionExt")
.field("graph", &"<type-erased>")
.field("has_graph_read_runtime", &self.graph_read_runtime.is_some())
.field("config", &self.config)
.field("has_embedder", &self.embedder.is_some())
.field("has_storage", &self.storage.is_some())
.field("has_tokenizer", &self.tokenizer.is_some())
.field("agent_id", &self.agent_id)
.field("allowed_namespaces", &self.allowed_namespaces)
.field(
"has_query_read_runtime",
&self.query_read_runtime_key.is_some(),
)
.field(
"has_context_assembly_runtime",
&self.context_assembly_runtime_key.is_some(),
)
.field(
"nli_classifier_backend",
&self
.nli_classifier
.as_ref()
.map(|c| c.backend_name())
.unwrap_or("default"),
)
.finish_non_exhaustive()
}
}
impl HirnSessionExt {
pub fn new(
graph: Arc<dyn Any + Send + Sync>,
config: Arc<HirnConfig>,
embedder: Option<Arc<dyn Embedder>>,
) -> Self {
Self {
graph,
graph_read_runtime: None,
config,
embedder,
storage: None,
tokenizer: None,
agent_id: None,
allowed_namespaces: None,
query_read_runtime_key: None,
context_assembly_runtime_key: None,
recall_search_binding: None,
rpe_population_stats: Arc::new(RwLock::new(hirn_core::WelfordStats::new())),
nli_classifier: None,
}
}
pub fn with_rpe_population_stats(
mut self,
stats: Arc<RwLock<hirn_core::WelfordStats>>,
) -> Self {
self.rpe_population_stats = stats;
self
}
pub fn with_agent_id(mut self, agent_id: impl Into<String>) -> Self {
self.agent_id = Some(agent_id.into());
self
}
pub fn with_nli_classifier(mut self, clf: Arc<dyn NliClassifier>) -> Self {
self.nli_classifier = Some(clf);
self
}
pub fn nli_classifier(&self) -> Option<Arc<dyn NliClassifier>> {
self.nli_classifier.clone()
}
pub fn with_storage(mut self, storage: Arc<dyn PhysicalStore>) -> Self {
self.storage = Some(storage);
self
}
pub fn with_graph_read_runtime(
mut self,
graph_read_runtime: Arc<dyn GraphReadRuntime>,
) -> Self {
self.graph_read_runtime = Some(graph_read_runtime);
self
}
pub fn with_tokenizer(mut self, tokenizer: Arc<dyn Tokenizer>) -> Self {
self.tokenizer = Some(tokenizer);
self
}
pub fn with_allowed_namespaces(mut self, namespaces: Option<Vec<String>>) -> Self {
self.allowed_namespaces = namespaces;
self
}
pub fn with_query_read_runtime_key(mut self, key: Option<String>) -> Self {
self.query_read_runtime_key = key;
self
}
pub fn with_context_assembly_runtime_key(mut self, key: Option<String>) -> Self {
self.context_assembly_runtime_key = key;
self
}
pub fn with_recall_search_binding(mut self, binding: Option<RecallSearchBinding>) -> Self {
self.recall_search_binding = binding;
self
}
pub fn agent_id(&self) -> Option<&str> {
self.agent_id.as_deref()
}
pub fn allowed_namespaces(&self) -> Option<&[String]> {
self.allowed_namespaces.as_deref()
}
pub fn get(ctx: &SessionContext) -> datafusion_common::Result<Self> {
let state = ctx.state();
let ext = state
.config()
.options()
.extensions
.get::<Self>()
.ok_or_else(|| {
datafusion_common::DataFusionError::Configuration(
"HirnSessionExt not registered in SessionContext — \
was the database opened correctly?"
.into(),
)
})?;
Ok(ext.clone())
}
pub fn register(self, ctx: &SessionContext) -> datafusion_common::Result<()> {
let state = ctx.state_weak_ref().upgrade().ok_or_else(|| {
datafusion_common::DataFusionError::Internal(
"Cannot register HirnSessionExt: SessionState already dropped".into(),
)
})?;
state
.write()
.config_mut()
.options_mut()
.extensions
.insert(self);
Ok(())
}
pub fn graph_as<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.graph.downcast_ref::<T>()
}
pub fn graph_arc<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
self.graph.clone().downcast::<T>().ok()
}
pub fn graph_any(&self) -> &Arc<dyn Any + Send + Sync> {
&self.graph
}
pub fn graph_read_runtime(&self) -> Option<Arc<dyn GraphReadRuntime>> {
self.graph_read_runtime.clone()
}
pub fn query_read_runtime(&self) -> Option<Arc<dyn QueryReadRuntime>> {
self.query_read_runtime_key
.as_deref()
.and_then(lookup_query_read_runtime)
}
pub fn context_assembly_runtime(&self) -> Option<Arc<dyn ContextAssemblyRuntime>> {
self.context_assembly_runtime_key
.as_deref()
.and_then(lookup_context_assembly_runtime)
}
pub fn recall_search_binding(&self) -> Option<&RecallSearchBinding> {
self.recall_search_binding.as_ref()
}
pub fn embedder(&self) -> Option<&dyn Embedder> {
self.embedder.as_deref()
}
pub fn embedder_arc(&self) -> Option<Arc<dyn Embedder>> {
self.embedder.clone()
}
pub fn storage(&self) -> Option<&dyn PhysicalStore> {
self.storage.as_deref()
}
pub fn storage_arc(&self) -> Option<Arc<dyn PhysicalStore>> {
self.storage.clone()
}
pub fn tokenizer(&self) -> Option<&dyn Tokenizer> {
self.tokenizer.as_deref()
}
pub fn tokenizer_arc(&self) -> Option<Arc<dyn Tokenizer>> {
self.tokenizer.clone()
}
}
impl ExtensionOptions for HirnSessionExt {
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn cloned(&self) -> Box<dyn ExtensionOptions> {
Box::new(self.clone())
}
fn set(&mut self, _key: &str, _value: &str) -> datafusion_common::Result<()> {
Ok(())
}
fn entries(&self) -> Vec<ConfigEntry> {
vec![]
}
}
impl ConfigExtension for HirnSessionExt {
const PREFIX: &'static str = "hirn";
}
#[cfg(test)]
mod tests {
use super::*;
const _: fn() = || {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<HirnSessionExt>();
};
#[test]
fn register_and_retrieve() {
let ctx = SessionContext::new();
let config = Arc::new(HirnConfig::default());
let ext = HirnSessionExt::new(Arc::new(42_u32), config.clone(), None);
ext.register(&ctx).expect("register should succeed");
let retrieved = HirnSessionExt::get(&ctx).expect("extension should be present");
assert!(Arc::ptr_eq(&retrieved.config, &config));
assert!(retrieved.embedder().is_none());
assert!(retrieved.tokenizer().is_none());
}
#[test]
fn missing_extension_gives_clear_error() {
let ctx = SessionContext::new();
let err = HirnSessionExt::get(&ctx).unwrap_err();
assert!(
err.to_string().contains("HirnSessionExt not registered"),
"unexpected error: {err}"
);
}
#[test]
fn graph_downcast() {
let ctx = SessionContext::new();
let ext = HirnSessionExt::new(
Arc::new(String::from("test_graph")),
Arc::new(HirnConfig::default()),
None,
);
ext.register(&ctx).expect("register should succeed");
let retrieved = HirnSessionExt::get(&ctx).unwrap();
let graph = retrieved.graph_as::<String>().unwrap();
assert_eq!(graph, "test_graph");
}
}