use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::sync::Arc;
use dashmap::DashMap;
use datafusion_expr::LogicalPlan;
use datafusion_expr::logical_plan::Extension;
use parking_lot::Mutex;
use hirn_core::error::{HirnError, HirnResult};
use super::plan_compiler;
use super::typed_ast::{self, AnalyzeContext, TypedStatement};
use crate::parser;
use crate::parser::ast::Statement;
#[derive(Debug, Clone)]
pub struct CompiledPlan {
pub source: String,
pub ast: Statement,
pub typed: TypedStatement,
pub plan: LogicalPlan,
}
pub struct PlanCache {
entries: DashMap<u64, CacheEntry>,
eviction_heap: Mutex<BinaryHeap<(Reverse<u64>, u64)>>,
max_entries: usize,
}
#[derive(Clone)]
struct CacheEntry {
normalized_source: Arc<str>,
plan: Arc<CompiledPlan>,
access_count: u64,
}
impl PlanCache {
pub fn new(max_entries: usize) -> Self {
Self {
entries: DashMap::with_capacity(max_entries.min(256)),
eviction_heap: Mutex::new(BinaryHeap::with_capacity(max_entries.min(256))),
max_entries,
}
}
pub fn get(&self, key: u64, normalized_source: &str) -> Option<Arc<CompiledPlan>> {
self.entries.get_mut(&key).and_then(|mut entry| {
if entry.normalized_source.as_ref() != normalized_source {
tracing::warn!(
key,
cached_source = %entry.normalized_source,
incoming_source = %normalized_source,
"plan cache: 64-bit hash collision — skipping cached plan"
);
return None;
}
entry.access_count += 1;
self.eviction_heap
.lock()
.push((Reverse(entry.access_count), key));
Some(Arc::clone(&entry.plan))
})
}
pub fn put(&self, key: u64, normalized_source: Arc<str>, plan: Arc<CompiledPlan>) {
if self.entries.len() >= self.max_entries {
let evicted = self.try_evict_one();
if !evicted && self.entries.len() >= self.max_entries {
let arbitrary_key = self.entries.iter().next().map(|e| *e.key());
if let Some(entry) = arbitrary_key {
self.entries.remove(&entry);
}
}
}
self.eviction_heap.lock().push((Reverse(1), key));
self.entries.insert(
key,
CacheEntry {
normalized_source,
plan,
access_count: 1,
},
);
}
fn try_evict_one(&self) -> bool {
let mut heap = self.eviction_heap.lock();
loop {
match heap.pop() {
None => return false,
Some((Reverse(snapshot_count), evict_key)) => {
if self
.entries
.remove_if(&evict_key, |_, v| v.access_count == snapshot_count)
.is_some()
{
return true;
}
}
}
}
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn clear(&self) {
self.entries.clear();
}
}
impl std::fmt::Debug for PlanCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PlanCache")
.field("len", &self.entries.len())
.field("max_entries", &self.max_entries)
.finish_non_exhaustive()
}
}
pub struct QueryPipeline {
ctx: AnalyzeContext,
cache: Option<Arc<PlanCache>>,
}
impl QueryPipeline {
pub fn new(ctx: AnalyzeContext) -> Self {
Self { ctx, cache: None }
}
pub fn with_cache(mut self, cache: Arc<PlanCache>) -> Self {
self.cache = Some(cache);
self
}
pub fn compile(&self, query: &str) -> HirnResult<Arc<CompiledPlan>> {
self.compile_with_ctx(query, &self.ctx)
}
pub fn compile_with_ctx(&self, query: &str, ctx: &AnalyzeContext) -> HirnResult<Arc<CompiledPlan>> {
let (normalized, base_key) = plan_compiler::query_normalize_and_hash(query);
let ns_id = ctx.default_namespace.as_interned_id();
let key = base_key
.wrapping_mul(0x9e37_79b9_7f4a_7c15_u64)
.wrapping_add(ns_id as u64);
if let Some(ref cache) = self.cache {
if let Some(plan) = cache.get(key, &normalized) {
return Ok(plan);
}
}
let ast = parser::parse(query)
.map_err(|e| HirnError::InvalidInput(format!("parse error: {e}")))?;
let typed = typed_ast::analyze(&ast, ctx)?;
let typed = self.rewrite(typed)?;
let plan = plan_compiler::compile(&typed)?;
let compiled = Arc::new(CompiledPlan {
source: query.to_string(),
ast,
typed,
plan,
});
if let Some(ref cache) = self.cache {
cache.put(key, normalized.into(), Arc::clone(&compiled));
}
Ok(compiled)
}
fn rewrite(&self, typed: TypedStatement) -> HirnResult<TypedStatement> {
Ok(typed)
}
pub fn explain(&self, query: &str) -> HirnResult<String> {
let compiled = self.compile(query)?;
Ok(format_plan_tree(&compiled.plan))
}
pub fn context(&self) -> &AnalyzeContext {
&self.ctx
}
}
pub fn format_plan_tree(plan: &LogicalPlan) -> String {
let mut lines = Vec::new();
format_plan_node(plan, 0, &mut lines);
lines.join("\n")
}
fn format_plan_node(plan: &LogicalPlan, depth: usize, lines: &mut Vec<String>) {
let indent = " ".repeat(depth);
lines.push(format!("{}{}", indent, plan_node_label(plan)));
for child in plan.inputs() {
format_plan_node(child, depth + 1, lines);
}
}
fn plan_node_label(plan: &LogicalPlan) -> String {
match plan {
LogicalPlan::Extension(Extension { node }) => node.name().to_string(),
_ => plan.display().to_string(),
}
}
impl std::fmt::Debug for QueryPipeline {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("QueryPipeline")
.field("ctx", &self.ctx)
.field("cache", &self.cache.is_some())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
fn pipeline() -> QueryPipeline {
QueryPipeline::new(AnalyzeContext::default())
}
#[test]
fn compile_recall_produces_plan() {
let p = pipeline();
let result = p.compile(r#"RECALL episodic ABOUT "test" LIMIT 5"#);
assert!(result.is_ok());
let compiled = result.unwrap();
assert!(matches!(compiled.ast, Statement::Recall(_)));
assert!(matches!(compiled.typed, TypedStatement::Recall(_)));
let display = format!("{}", compiled.plan);
assert!(display.contains("HybridSearch"), "plan: {display}");
}
#[test]
fn compile_think_produces_plan() {
let p = pipeline();
let compiled = p.compile(r#"THINK ABOUT "test" BUDGET 4096"#).unwrap();
assert!(matches!(compiled.typed, TypedStatement::Think(_)));
let display = format!("{}", compiled.plan);
assert!(display.contains("QualityGate"), "plan: {display}");
}
#[test]
fn compile_rejects_removed_embedded_mutation_verbs() {
let p = pipeline();
for query in [
r#"REMEMBER episode CONTENT "event happened""#,
r#"FORGET "01J000000000000000000000""#,
"WATCH ALL FORMAT json",
"CONSOLIDATE WHERE episodic.access_count > 5",
] {
let err = p.compile(query).unwrap_err();
assert!(
err.to_string().contains("not supported"),
"unexpected error for `{query}`: {err}"
);
}
}
#[test]
fn compile_parse_error() {
let p = pipeline();
let err = p.compile("NOT_A_QUERY").unwrap_err();
assert!(matches!(err, HirnError::InvalidInput(_)));
}
#[test]
fn cache_hit() {
let cache = Arc::new(PlanCache::new(100));
let p = pipeline().with_cache(cache.clone());
let q = r#"RECALL episodic ABOUT "test" LIMIT 10"#;
p.compile(q).unwrap();
assert_eq!(cache.len(), 1);
p.compile(q).unwrap();
assert_eq!(cache.len(), 1);
}
#[test]
fn cache_different_queries() {
let cache = Arc::new(PlanCache::new(100));
let p = pipeline().with_cache(cache.clone());
p.compile(r#"RECALL episodic ABOUT "a""#).unwrap();
p.compile(r#"RECALL episodic ABOUT "b""#).unwrap();
assert_eq!(cache.len(), 2);
}
#[test]
fn cache_eviction() {
let cache = Arc::new(PlanCache::new(2));
let p = pipeline().with_cache(cache.clone());
p.compile(r#"RECALL episodic ABOUT "a""#).unwrap();
p.compile(r#"RECALL episodic ABOUT "b""#).unwrap();
assert_eq!(cache.len(), 2);
p.compile(r#"RECALL episodic ABOUT "c""#).unwrap();
assert_eq!(cache.len(), 2); }
#[test]
fn cache_clear() {
let cache = Arc::new(PlanCache::new(100));
let p = pipeline().with_cache(cache.clone());
p.compile(r#"RECALL episodic ABOUT "a""#).unwrap();
p.compile(r#"RECALL episodic ABOUT "b""#).unwrap();
assert_eq!(cache.len(), 2);
cache.clear();
assert!(cache.is_empty());
}
#[test]
fn pipeline_without_cache() {
let p = pipeline();
let compiled = p.compile(r#"RECALL episodic ABOUT "test""#).unwrap();
assert!(!compiled.source.is_empty());
}
#[test]
fn stages_independently_callable() {
let ast = parser::parse(r#"RECALL episodic ABOUT "test" LIMIT 5"#).unwrap();
let ctx = AnalyzeContext::default();
let typed = typed_ast::analyze(&ast, &ctx).unwrap();
let plan = plan_compiler::compile(&typed).unwrap();
let display = format!("{plan}");
assert!(display.contains("HybridSearch"), "plan: {display}");
}
#[test]
fn explain_returns_plan_tree() {
let p = pipeline();
let tree = p
.explain(r#"RECALL episodic ABOUT "test" LIMIT 5"#)
.unwrap();
assert!(
tree.contains("HybridSearch") || tree.contains("Limit"),
"plan tree: {tree}"
);
assert!(tree.lines().count() > 1, "plan tree: {tree}");
}
#[test]
fn explain_correct_shows_extension_name() {
let p = pipeline();
let tree = p
.explain(r#"EXPLAIN CORRECT "01ARZ3NDEKTSV4RRFFQ69G5FAV" SET description = "updated""#)
.unwrap();
assert!(tree.contains("HirnDirectCorrect"), "plan tree: {tree}");
}
#[test]
fn explain_supersede_shows_extension_name() {
let p = pipeline();
let tree = p
.explain(
r#"EXPLAIN SUPERSEDE "01ARZ3NDEKTSV4RRFFQ69G5FAV" SET description = "replacement""#,
)
.unwrap();
assert!(tree.contains("HirnDirectSupersede"), "plan tree: {tree}");
}
#[test]
fn explain_merge_memory_shows_extension_name() {
let p = pipeline();
let tree = p
.explain(
r#"EXPLAIN MERGE MEMORY "01ARZ3NDEKTSV4RRFFQ69G5FAA" INTO "01ARZ3NDEKTSV4RRFFQ69G5FAV""#,
)
.unwrap();
assert!(tree.contains("HirnDirectMergeMemory"), "plan tree: {tree}");
}
#[test]
fn explain_history_shows_extension_name() {
let p = pipeline();
let tree = p
.explain(r#"EXPLAIN HISTORY "01ARZ3NDEKTSV4RRFFQ69G5FAV" NAMESPACE custom"#)
.unwrap();
assert!(
tree.contains("HirnSemanticHistoryScan"),
"plan tree: {tree}"
);
}
#[test]
fn explain_retract_shows_extension_name() {
let p = pipeline();
let tree = p
.explain(r#"EXPLAIN RETRACT "01ARZ3NDEKTSV4RRFFQ69G5FAV" REASON "obsolete""#)
.unwrap();
assert!(tree.contains("HirnDirectRetract"), "plan tree: {tree}");
}
#[test]
fn explain_of_cached_query_still_shows_plan() {
let cache = Arc::new(PlanCache::new(10));
let p = pipeline().with_cache(cache);
let tree1 = p.explain(r#"RECALL episodic ABOUT "test""#).unwrap();
let tree2 = p.explain(r#"RECALL episodic ABOUT "test""#).unwrap();
assert_eq!(tree1, tree2);
assert!(!tree1.is_empty());
}
#[test]
fn format_plan_tree_indents_children() {
let p = pipeline();
let compiled = p
.compile(r#"RECALL episodic ABOUT "test" EXPAND GRAPH DEPTH 2 LIMIT 5"#)
.unwrap();
let tree = super::format_plan_tree(&compiled.plan);
let lines: Vec<&str> = tree.lines().collect();
assert!(!lines.is_empty());
assert!(!lines[0].starts_with(' '), "root: {}", lines[0]);
if lines.len() > 1 {
assert!(lines[1].starts_with(" "), "child: {}", lines[1]);
}
}
#[test]
fn cached_query_executes_under_5us() {
let cache = Arc::new(PlanCache::new(100));
let p = pipeline().with_cache(cache.clone());
p.compile(r#"RECALL episodic ABOUT "test""#).unwrap();
let start = std::time::Instant::now();
let iterations = 1000;
for _ in 0..iterations {
let _ = p.compile(r#"RECALL episodic ABOUT "test""#).unwrap();
}
let elapsed = start.elapsed();
let per_op = elapsed / iterations;
assert!(
per_op.as_micros() < 5,
"cached query took {per_op:?} per op, expected < 5µs"
);
}
}