use std::sync::Arc;
use datafusion::arrow::array::{Array, BinaryArray, RecordBatch};
use datafusion::logical_expr::ScalarUDF;
use datafusion::prelude::*;
use parking_lot::RwLock;
use tracing::{info, instrument};
use crate::execution::fao_udf::FaoScalarUdf;
use crate::core::error::{AnamError, Result};
use crate::core::provenance::{PolynomialSemiring, ProvenanceMode, ProvenanceToken, Semiring};
use crate::execution::dispatcher::DevicePool;
use crate::execution::optimizer::ParetoOptimizer;
use crate::hitl::monitor::SemanticMonitor;
use crate::hitl::triage::Anomaly;
use crate::logic::engine::LogicEngine;
use crate::logic::nl_compiler::NlCompiler;
use crate::model::registry::ModelRegistry;
use crate::storage::lance_provider::LanceTableManager;
#[derive(Debug, Clone)]
pub struct SessionConfig {
pub provenance_mode: ProvenanceMode,
pub enable_hardware_accel: bool,
pub llm_api_key: Option<String>,
pub llm_endpoint: Option<String>,
pub llm_model: Option<String>,
pub anomaly_threshold: f64,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
provenance_mode: ProvenanceMode::Polynomial,
enable_hardware_accel: false,
llm_api_key: None,
llm_endpoint: None,
llm_model: None,
anomaly_threshold: 0.5,
}
}
}
impl SessionConfig {
pub fn load_from_toml(path: &str) -> std::result::Result<Self, String> {
let content = std::fs::read_to_string(path).map_err(|e| e.to_string())?;
let table: toml::Table = toml::from_str(&content).map_err(|e| e.to_string())?;
let mut config = SessionConfig::default();
if let Some(engine) = table.get("engine").and_then(|v| v.as_table()) {
if let Some(prov) = engine.get("provenance_mode").and_then(|v| v.as_str()) {
config.provenance_mode = match prov.to_lowercase().as_str() {
"boolean" | "bool" => ProvenanceMode::Boolean,
"probability" | "prob" => ProvenanceMode::Probability,
_ => ProvenanceMode::Polynomial,
};
}
if let Some(gpu) = engine.get("gpu").and_then(|v| v.as_bool()) {
config.enable_hardware_accel = gpu;
}
if let Some(threshold) = engine.get("anomaly_threshold").and_then(|v| v.as_float()) {
config.anomaly_threshold = threshold;
}
}
Ok(config)
}
}
#[derive(Debug)]
pub struct QueryResult {
pub batches: Vec<RecordBatch>,
pub anomalies: Vec<Anomaly>,
pub reasoning_tree: Option<String>,
}
impl QueryResult {
pub fn requires_clarification(&self) -> bool {
!self.anomalies.is_empty()
}
pub async fn explain_reasoning(&self) -> Result<()> {
match &self.reasoning_tree {
Some(tree) => {
println!("═══ Reasoning Tree ═══\n{tree}");
Ok(())
}
None => {
println!("(no reasoning tree attached — provenance mode may be Boolean)");
Ok(())
}
}
}
}
pub struct Session {
pub(crate) df_ctx: SessionContext,
pub(crate) logic_engine: Arc<RwLock<LogicEngine>>,
pub(crate) model_registry: Arc<ModelRegistry>,
#[allow(dead_code)]
pub(crate) nl_compiler: Arc<NlCompiler>,
pub(crate) optimizer: Arc<ParetoOptimizer>,
pub(crate) device_pool: Arc<DevicePool>,
pub(crate) monitor: Arc<SemanticMonitor>,
#[allow(dead_code)]
pub(crate) lance_mgr: Arc<LanceTableManager>,
pub config: SessionConfig,
}
impl Session {
#[instrument(name = "Session::new")]
pub async fn new() -> Result<Self> {
Self::with_config(SessionConfig::default()).await
}
#[instrument(name = "Session::new_with_npu")]
pub async fn new_with_npu() -> Result<Self> {
let config = SessionConfig {
enable_hardware_accel: true,
..Default::default()
};
Self::with_config(config).await
}
pub async fn with_config(config: SessionConfig) -> Result<Self> {
info!(
provenance = ?config.provenance_mode,
hw_accel = config.enable_hardware_accel,
"initializing AnamDB session"
);
let df_ctx = SessionContext::new();
let logic_engine = Arc::new(RwLock::new(LogicEngine::new(config.provenance_mode)?));
let model_registry = Arc::new(ModelRegistry::new());
let nl_compiler = Arc::new(NlCompiler::new(
config.llm_api_key.clone(),
config.llm_endpoint.clone(),
config.llm_model.clone(),
));
let device_pool = Arc::new(if config.enable_hardware_accel {
DevicePool::detect_hardware().await?
} else {
DevicePool::cpu_only()
});
let optimizer = Arc::new(ParetoOptimizer::new(
Arc::clone(&model_registry),
Arc::clone(&device_pool),
));
let monitor = Arc::new(SemanticMonitor::new(config.anomaly_threshold));
let lance_mgr = Arc::new(LanceTableManager::new());
Ok(Self {
df_ctx,
logic_engine,
model_registry,
nl_compiler,
optimizer,
device_pool,
monitor,
lance_mgr,
config,
})
}
#[instrument(skip(self))]
pub async fn register_table(&self, name: &str, path: &str) -> Result<()> {
info!(table = name, path, "registering Lance table (streaming)");
let provider =
crate::storage::streaming_provider::LanceStreamingProvider::open(path).await?;
self.df_ctx
.register_table(name, Arc::new(provider))
.map_err(AnamError::DataFusion)?;
Ok(())
}
#[instrument(skip(self, batches))]
pub async fn insert(
&self,
_table_name: &str,
lance_path: &str,
batches: Vec<RecordBatch>,
schema: Arc<datafusion::arrow::datatypes::Schema>,
) -> Result<crate::storage::write_path::WriteResult> {
info!(table = _table_name, "INSERT into table");
crate::storage::write_path::insert_rows(lance_path, batches, schema).await
}
#[instrument(skip(self))]
pub async fn delete(
&self,
_table_name: &str,
lance_path: &str,
predicate: &str,
) -> Result<crate::storage::write_path::WriteResult> {
info!(table = _table_name, predicate, "DELETE from table");
crate::storage::write_path::delete_rows(lance_path, predicate).await
}
#[instrument(skip(self))]
pub async fn register_logic_from_nl(
&self,
name: &str,
table: &str,
nl_description: &str,
) -> Result<()> {
info!(rule = name, table, "compiling NL → Datalog");
let datalog_source = self.nl_compiler.compile(nl_description, table).await?;
info!(datalog = %datalog_source, "generated Datalog");
self.logic_engine
.write()
.register_rule(name, &datalog_source)?;
Ok(())
}
pub fn register_logic(&self, name: &str, datalog: &str) -> Result<()> {
self.logic_engine.write().register_rule(name, datalog)
}
#[instrument(skip(self))]
pub async fn sql(&self, query: &str) -> Result<QueryResult> {
info!(query, "executing neurosymbolic query");
let (clean_sql, constraints) = self.optimizer.parse_constraints(query)?;
let df = self
.df_ctx
.sql(&clean_sql)
.await
.map_err(AnamError::DataFusion)?;
let batches = if let Some(c) = constraints {
self.optimizer.execute_with_constraints(df, c).await?
} else {
df.collect().await.map_err(AnamError::DataFusion)?
};
let batches = self.logic_engine.read().filter_batches(&batches)?;
let batches = self.attach_provenance(&batches)?;
let anomalies = self.monitor.inspect_batches(&batches)?;
let reasoning_tree = self.build_reasoning_tree(&batches)?;
Ok(QueryResult {
batches,
anomalies,
reasoning_tree,
})
}
#[instrument(skip(self))]
pub async fn refine_query(&self, correction: &str) -> Result<QueryResult> {
info!(correction, "refining query with human feedback");
let patch = self
.nl_compiler
.compile(correction, "__refinement__")
.await?;
self.logic_engine
.write()
.register_rule("__refinement_patch__", &patch)?;
let batches = self.logic_engine.read().evaluate_all()?;
let anomalies = self.monitor.inspect_batches(&batches)?;
let reasoning_tree = self.build_reasoning_tree(&batches)?;
Ok(QueryResult {
batches,
anomalies,
reasoning_tree,
})
}
pub fn models(&self) -> &ModelRegistry {
&self.model_registry
}
pub fn logic_engine(&self) -> &Arc<RwLock<LogicEngine>> {
&self.logic_engine
}
pub fn device_pool(&self) -> &DevicePool {
&self.device_pool
}
#[instrument(skip(self))]
pub fn load_onnx_model(
&self,
name: &str,
model_path: &str,
function_id: &str,
num_input_features: usize,
) -> Result<String> {
use crate::model::ai_tables::{AiModelEntry, DeviceAffinity, ModelFormat};
use crate::model::onnx_adapter::OnnxFaoOperator;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
info!(name, model_path, function_id, "loading ONNX model");
let input_fields: Vec<Field> = (0..num_input_features)
.map(|i| Field::new(format!("feature_{i}"), DataType::Float32, false))
.collect();
let input_schema = Arc::new(Schema::new(input_fields));
let output_schema = Arc::new(Schema::new(vec![Field::new(
"score",
DataType::Float64,
false,
)]));
let file_size = std::fs::metadata(model_path).map(|m| m.len()).unwrap_or(0);
let entry = AiModelEntry::builder(name, "1.0.0")
.format(ModelFormat::Onnx)
.artifact_path(model_path)
.avg_latency_ms(1.0)
.accuracy(0.95)
.size_bytes(file_size)
.device_affinity(DeviceAffinity::Any)
.build();
let model_id = entry.model_id.clone();
self.model_registry.register_model(entry)?;
let operator = OnnxFaoOperator::load(
model_path,
function_id,
"1.0.0",
&model_id,
input_schema,
output_schema,
1.0,
0.95,
)?;
let operator: Arc<dyn crate::model::fao::FaoOperator> = Arc::new(operator);
self.model_registry
.register_operator(Arc::clone(&operator))?;
self.register_fao_udf(Arc::clone(&operator));
info!(model_id = %model_id, "ONNX model registered");
Ok(model_id)
}
#[allow(clippy::too_many_arguments)]
#[instrument(skip(self))]
pub fn load_onnx_model_with_metrics(
&self,
name: &str,
version: &str,
model_path: &str,
function_id: &str,
num_input_features: usize,
avg_latency_ms: f64,
accuracy: f64,
) -> Result<String> {
use crate::model::ai_tables::{AiModelEntry, DeviceAffinity, ModelFormat};
use crate::model::onnx_adapter::OnnxFaoOperator;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
info!(
name,
version,
model_path,
function_id,
avg_latency_ms,
accuracy,
"loading ONNX model variant"
);
let input_fields: Vec<Field> = (0..num_input_features)
.map(|i| Field::new(format!("feature_{i}"), DataType::Float32, false))
.collect();
let input_schema = Arc::new(Schema::new(input_fields));
let output_schema = Arc::new(Schema::new(vec![Field::new(
"score",
DataType::Float64,
false,
)]));
let file_size = std::fs::metadata(model_path).map(|m| m.len()).unwrap_or(0);
let entry = AiModelEntry::builder(name, version)
.format(ModelFormat::Onnx)
.artifact_path(model_path)
.avg_latency_ms(avg_latency_ms)
.accuracy(accuracy)
.size_bytes(file_size)
.device_affinity(DeviceAffinity::Any)
.build();
let model_id = entry.model_id.clone();
self.model_registry.register_model(entry)?;
let operator = OnnxFaoOperator::load(
model_path,
function_id,
version,
&model_id,
input_schema,
output_schema,
avg_latency_ms,
accuracy,
)?;
let operator: Arc<dyn crate::model::fao::FaoOperator> = Arc::new(operator);
self.model_registry
.register_operator(Arc::clone(&operator))?;
self.register_fao_udf(Arc::clone(&operator));
info!(model_id = %model_id, "ONNX model variant registered");
Ok(model_id)
}
#[instrument(skip(self, pack), fields(pack_name = %pack.name, pack_version = %pack.version))]
pub fn load_logic_pack(&self, pack: &crate::sdk::LogicPack) -> Result<String> {
info!(
name = %pack.name,
version = %pack.version,
rules = pack.rules.len(),
models = pack.models.len(),
"loading Logic Pack"
);
let mut engine = self.logic_engine.write();
for rule in &pack.rules {
engine.register_rule(&rule.name, &rule.datalog)?;
}
drop(engine);
for model in &pack.models {
self.load_onnx_model_with_metrics(
&model.name,
&pack.version,
&model.artifact_path,
&model.name,
model.num_features,
model.avg_latency_ms,
model.accuracy,
)?;
}
let summary = pack.summary();
info!("Logic Pack loaded successfully");
Ok(summary)
}
pub fn explain_query(
&self,
batches: &[RecordBatch],
level: crate::hitl::explainer::ExplainLevel,
) -> Result<crate::hitl::explainer::QueryExplanation> {
let engine = self.logic_engine.read();
let rules: Vec<(String, String)> = engine
.list_rules()
.into_iter()
.map(|name| {
let body = engine
.get_rule_body(&name)
.unwrap_or_else(|| "<unknown>".to_string());
(name, body)
})
.collect();
let models: Vec<(String, String)> = self
.model_registry
.list_models()
.into_iter()
.map(|e| (e.name.clone(), e.version.clone()))
.collect();
let context = crate::hitl::explainer::ExplainContext {
rules,
models,
provenance_mode: format!("{:?}", self.config.provenance_mode),
device_summary: self.device_pool.summary(),
};
crate::hitl::explainer::Explainer::explain(level, batches, &context)
}
pub fn self_repair(
&self,
error_msg: &str,
operator_name: &str,
context: &str,
) -> Result<crate::hitl::self_repair::RepairReport> {
let mut agent = crate::hitl::self_repair::SelfRepairAgent::new();
let model_names: Vec<String> = self
.model_registry
.list_models()
.into_iter()
.map(|e| e.name.clone())
.collect();
agent.register_available_models(model_names);
agent.diagnose_and_repair(error_msg, operator_name, context)
}
fn register_fao_udf(&self, operator: Arc<dyn crate::model::fao::FaoOperator>) {
let udf_impl = FaoScalarUdf::new(operator);
let udf = ScalarUDF::from(udf_impl);
self.df_ctx.register_udf(udf);
info!("registered FAO as DataFusion scalar UDF");
}
fn attach_provenance(&self, batches: &[RecordBatch]) -> Result<Vec<RecordBatch>> {
use datafusion::arrow::array::{ArrayRef, BinaryArray};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
if batches.is_empty() {
return Ok(batches.to_vec());
}
let mode = self.config.provenance_mode;
let mut result = Vec::with_capacity(batches.len());
for batch in batches {
if batch.schema().column_with_name("provenance").is_some() {
result.push(batch.clone());
continue;
}
let num_rows = batch.num_rows();
let prov_bytes: Vec<Vec<u8>> = (0..num_rows)
.map(|row_idx| match mode {
ProvenanceMode::Boolean => vec![1u8],
ProvenanceMode::Probability => 1.0_f64.to_le_bytes().to_vec(),
ProvenanceMode::Polynomial => {
let token = ProvenanceToken {
model_ver_id: "query_pipeline".to_string(),
func_id: "sql".to_string(),
source_record_ids: vec![format!("row_{row_idx}")],
};
let poly = PolynomialSemiring::singleton(token);
poly.to_bytes().unwrap_or_default()
}
})
.collect();
let prov_refs: Vec<&[u8]> = prov_bytes.iter().map(|b| b.as_slice()).collect();
let prov_array: ArrayRef = Arc::new(BinaryArray::from(prov_refs));
let mut fields: Vec<Arc<Field>> = batch.schema().fields().to_vec();
fields.push(Arc::new(Field::new("provenance", DataType::Binary, true)));
let new_schema = Arc::new(Schema::new(fields));
let mut columns: Vec<ArrayRef> = (0..batch.num_columns())
.map(|i| batch.column(i).clone())
.collect();
columns.push(prov_array);
let new_batch = RecordBatch::try_new(new_schema, columns).map_err(AnamError::Arrow)?;
result.push(new_batch);
}
Ok(result)
}
fn build_reasoning_tree(&self, batches: &[RecordBatch]) -> Result<Option<String>> {
if self.config.provenance_mode == ProvenanceMode::Boolean {
return Ok(None);
}
let mut tree = String::new();
for (i, batch) in batches.iter().enumerate() {
if let Some(col_idx) = batch.schema().column_with_name("provenance") {
tree.push_str(&format!("── Batch {i} ({} rows) ──\n", batch.num_rows()));
let col = batch.column(col_idx.0);
if let Some(binary_arr) = col.as_any().downcast_ref::<BinaryArray>() {
for row in 0..binary_arr.len() {
let nulls = binary_arr.nulls();
let valid = nulls.is_none_or(|n| n.is_valid(row));
if valid {
let bytes = binary_arr.value(row);
match PolynomialSemiring::from_bytes(bytes) {
Ok(poly) => {
tree.push_str(&format!(" row {row}: {}\n", poly.explain()));
}
Err(_) => {
tree.push_str(&format!(
" row {row}: <raw {} bytes>\n",
bytes.len()
));
}
}
}
}
}
}
}
if tree.is_empty() {
Ok(None)
} else {
Ok(Some(tree))
}
}
}