use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use super::{
compressor::IndexCompressor,
index_builder::{build_heuristic_index, build_raw_summary},
types::{CompressedIndex, ContextTier, IndexConfig, LlmBackend, RankedMemory, RecallContext},
};
use crate::{
application::services::projection::Projection,
prime::{
projections::{
AdjacencyListProjection, CrossDomainProjection, DomainIndexProjection,
GraphStatsProjection, NodeStateProjection,
},
types::{Node, PrimeStats},
},
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecallContextQuery {
pub query: String,
pub agent_id: Option<String>,
pub top_k: usize,
pub as_of: Option<DateTime<Utc>>,
pub include_index: bool,
pub max_tokens: Option<usize>,
pub tier: ContextTier,
pub conversation_id: Option<String>,
}
impl Default for RecallContextQuery {
fn default() -> Self {
Self {
query: String::new(),
agent_id: None,
top_k: 5,
as_of: None,
include_index: true,
max_tokens: None,
tier: ContextTier::default(),
conversation_id: None,
}
}
}
pub struct RecallDeps {
pub domain_index: Arc<DomainIndexProjection>,
pub cross_domain: Arc<CrossDomainProjection>,
pub node_state: Option<Arc<NodeStateProjection>>,
pub adjacency: Option<Arc<AdjacencyListProjection>>,
pub graph_stats: Option<Arc<GraphStatsProjection>>,
}
pub struct RecallEngine {
domain_index: Arc<DomainIndexProjection>,
cross_domain: Arc<CrossDomainProjection>,
node_state: Option<Arc<NodeStateProjection>>,
adjacency: Option<Arc<AdjacencyListProjection>>,
graph_stats: Option<Arc<GraphStatsProjection>>,
compressor: IndexCompressor,
}
impl RecallEngine {
pub fn new(config: &IndexConfig) -> Self {
let llm_backend: Option<Box<dyn LlmBackend>> =
if let Some(ref endpoint) = config.llm_endpoint {
let model = config
.llm_model
.clone()
.unwrap_or_else(|| "mistral".to_string());
Some(Box::new(super::ollama::OllamaBackend::new(
endpoint.clone(),
model,
)))
} else {
None
};
Self {
domain_index: Arc::new(DomainIndexProjection::new()),
cross_domain: Arc::new(CrossDomainProjection::new()),
node_state: None,
adjacency: None,
graph_stats: None,
compressor: IndexCompressor::new(
llm_backend,
config.refresh_interval_events,
config.refresh_interval_seconds,
),
}
}
pub fn with_deps(deps: RecallDeps, config: &IndexConfig) -> Self {
let llm_backend: Option<Box<dyn LlmBackend>> =
if let Some(ref endpoint) = config.llm_endpoint {
let model = config
.llm_model
.clone()
.unwrap_or_else(|| "mistral".to_string());
Some(Box::new(super::ollama::OllamaBackend::new(
endpoint.clone(),
model,
)))
} else {
None
};
Self {
domain_index: deps.domain_index,
cross_domain: deps.cross_domain,
node_state: deps.node_state,
adjacency: deps.adjacency,
graph_stats: deps.graph_stats,
compressor: IndexCompressor::new(
llm_backend,
config.refresh_interval_events,
config.refresh_interval_seconds,
),
}
}
pub fn with_dependencies(
domain_index: Arc<DomainIndexProjection>,
cross_domain: Arc<CrossDomainProjection>,
compressor: IndexCompressor,
) -> Self {
Self {
domain_index,
cross_domain,
node_state: None,
adjacency: None,
graph_stats: None,
compressor,
}
}
pub async fn index(&self) -> CompressedIndex {
let summary = build_raw_summary(&self.domain_index, &self.cross_domain);
let heuristic = build_heuristic_index(&summary);
let event_count = summary.total_nodes as u64 + summary.total_edges as u64;
self.compressor
.compress(&summary, event_count, &heuristic)
.await
}
pub async fn context(&self, query: RecallContextQuery) -> RecallContext {
match query.tier {
ContextTier::L0 => self.context_l0(&query),
ContextTier::L1 => self.context_l1(&query),
ContextTier::L2 => self.context_l2(&query).await,
}
}
fn context_l0(&self, _query: &RecallContextQuery) -> RecallContext {
let stats = self.build_stats();
let stats_json = serde_json::to_string(&stats).unwrap_or_default();
let token_count = super::types::estimate_tokens(&stats_json);
RecallContext {
index: String::new(),
vectors: Vec::new(),
nodes: Vec::new(),
edges: Vec::new(),
stats: Some(stats),
tier: ContextTier::L0,
token_count,
}
}
fn context_l1(&self, query: &RecallContextQuery) -> RecallContext {
let stats = self.build_stats();
let mut nodes: Vec<Node> = Vec::new();
let mut edges: Vec<crate::prime::types::Edge> = Vec::new();
let mut token_count = 0usize;
if let Some(ref node_state) = self.node_state {
let all_nodes = node_state.all_nodes();
let mut candidate_nodes: Vec<Node> = if let Some(ref conv_id) = query.conversation_id {
all_nodes
.into_iter()
.filter(|n| {
n.properties
.get("conversation_id")
.and_then(|v| v.as_str())
.is_some_and(|id| id == conv_id)
})
.collect()
} else {
all_nodes
};
candidate_nodes.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
candidate_nodes.truncate(20);
if let Some(ref adjacency) = self.adjacency {
for node in &candidate_nodes {
let adj_entries = adjacency.outgoing(node.id.as_str());
for adj in &adj_entries {
let peer_node = node_state.get_node(&adj.peer);
if let Some(target) = peer_node
&& !nodes.iter().any(|n| n.id == target.id)
&& !candidate_nodes.iter().any(|n| n.id == target.id)
{
nodes.push(target);
}
edges.push(crate::prime::types::Edge {
id: crate::prime::types::EdgeId::new(&adj.edge_id),
source: node.id.clone(),
target: crate::prime::types::NodeId::new(&adj.peer),
relation: adj.relation.clone(),
properties: None,
weight: adj.weight,
deleted: false,
created_at: Utc::now(),
});
}
}
}
nodes.splice(0..0, candidate_nodes);
}
let stats_json = serde_json::to_string(&stats).unwrap_or_default();
token_count += super::types::estimate_tokens(&stats_json);
for node in &nodes {
let node_json = serde_json::to_string(&node).unwrap_or_default();
token_count += super::types::estimate_tokens(&node_json);
}
if let Some(max) = query.max_tokens {
while token_count > max && nodes.len() > 1 {
nodes.pop(); token_count = super::types::estimate_tokens(&stats_json);
for node in &nodes {
let nj = serde_json::to_string(&node).unwrap_or_default();
token_count += super::types::estimate_tokens(&nj);
}
}
}
RecallContext {
index: String::new(),
vectors: Vec::new(),
nodes,
edges,
stats: Some(stats),
tier: ContextTier::L1,
token_count,
}
}
async fn context_l2(&self, query: &RecallContextQuery) -> RecallContext {
let mut index_text = String::new();
let mut token_count = 0usize;
if query.include_index {
let idx = self.index().await;
index_text = idx.markdown;
token_count += idx.token_count;
}
if let Some(max) = query.max_tokens
&& token_count > max
{
let target_words = max * 10 / 13; let truncated: String = index_text
.split_whitespace()
.take(target_words)
.collect::<Vec<_>>()
.join(" ");
index_text = truncated + "\n...(truncated)";
token_count = max;
}
let vectors: Vec<RankedMemory> = Vec::new();
let nodes: Vec<Node> = Vec::new();
RecallContext {
index: index_text,
vectors,
nodes,
edges: Vec::new(),
stats: None,
tier: ContextTier::L2,
token_count,
}
}
fn build_stats(&self) -> PrimeStats {
if let Some(ref gs) = self.graph_stats {
gs.stats()
} else {
let domain_counts = self.domain_index.domain_counts();
let total_nodes: usize = domain_counts.iter().map(|(_, c)| c).sum();
PrimeStats {
total_nodes,
total_edges: 0,
nodes_by_type: std::collections::HashMap::new(),
edges_by_relation: std::collections::HashMap::new(),
deleted_nodes: 0,
deleted_edges: 0,
event_count: 0,
}
}
}
pub fn domains(&self) -> Vec<String> {
self.domain_index.domains()
}
pub fn cross_domain_links(
&self,
) -> Vec<crate::prime::projections::cross_domain::CrossDomainLink> {
self.cross_domain.cross_domain_links()
}
pub async fn refresh_index(&self) -> CompressedIndex {
self.compressor.invalidate_cache();
self.index().await
}
pub fn projections(&self) -> Vec<Arc<dyn Projection>> {
vec![
Arc::clone(&self.domain_index) as Arc<dyn Projection>,
Arc::clone(&self.cross_domain) as Arc<dyn Projection>,
]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
application::services::projection::Projection,
domain::entities::Event,
prime::{projections::GraphStatsProjection, types::event_types},
};
use uuid::Uuid;
fn make_node(node_id: &str, node_type: &str, domain: &str, name: &str) -> Event {
Event::reconstruct_from_strings(
Uuid::new_v4(),
event_types::NODE_CREATED.to_string(),
format!("node:{node_type}:{node_id}"),
"default".to_string(),
serde_json::json!({
"node_id": node_id,
"node_type": node_type,
"domain": domain,
"properties": {"name": name}
}),
Utc::now(),
None,
1,
)
}
fn make_edge(edge_id: &str, source: &str, target: &str, relation: &str) -> Event {
Event::reconstruct_from_strings(
Uuid::new_v4(),
event_types::EDGE_CREATED.to_string(),
format!("edge:{edge_id}"),
"default".to_string(),
serde_json::json!({
"edge_id": edge_id,
"source": source,
"target": target,
"relation": relation,
}),
Utc::now(),
None,
1,
)
}
fn seed_engine() -> RecallEngine {
let engine = RecallEngine::new(&IndexConfig::default());
let events = vec![
make_node("n1", "metric", "revenue", "Q3 Revenue"),
make_node("n2", "metric", "revenue", "Churn Rate"),
make_node("n3", "service", "engineering", "Core API"),
make_node("n4", "feature", "product", "Dark Mode"),
make_edge("e1", "n1", "n3", "impacts"),
make_edge("e2", "n4", "n3", "depends_on"),
];
let projections = engine.projections();
for event in &events {
for proj in &projections {
proj.process(event).unwrap();
}
}
engine
}
#[tokio::test]
async fn test_index_returns_compressed_index() {
let engine = seed_engine();
let index = engine.index().await;
assert!(index.markdown.contains("revenue"));
assert!(index.markdown.contains("engineering"));
assert!(index.token_count > 0);
assert!(!index.domains.is_empty());
}
#[tokio::test]
async fn test_context_with_include_index() {
let engine = seed_engine();
let query = RecallContextQuery {
query: "How does revenue relate to engineering?".to_string(),
include_index: true,
..Default::default()
};
let ctx = engine.context(query).await;
assert!(!ctx.index.is_empty());
assert!(ctx.token_count > 0);
assert_eq!(ctx.tier, ContextTier::L2);
}
#[tokio::test]
async fn test_context_without_index() {
let engine = seed_engine();
let query = RecallContextQuery {
query: "test".to_string(),
include_index: false,
..Default::default()
};
let ctx = engine.context(query).await;
assert!(ctx.index.is_empty());
assert_eq!(ctx.token_count, 0);
assert_eq!(ctx.tier, ContextTier::L2);
}
#[tokio::test]
async fn test_context_with_max_tokens_truncates() {
let engine = seed_engine();
let query = RecallContextQuery {
query: "test".to_string(),
include_index: true,
max_tokens: Some(10),
..Default::default()
};
let ctx = engine.context(query).await;
assert!(ctx.token_count <= 10);
assert!(ctx.index.contains("truncated"));
}
#[tokio::test]
async fn test_l0_returns_stats_only() {
let engine = seed_engine();
let query = RecallContextQuery {
query: String::new(),
tier: ContextTier::L0,
..Default::default()
};
let ctx = engine.context(query).await;
assert_eq!(ctx.tier, ContextTier::L0);
assert!(ctx.index.is_empty(), "L0 should not generate index");
assert!(ctx.vectors.is_empty(), "L0 should not run vector search");
assert!(ctx.nodes.is_empty(), "L0 should not return graph nodes");
assert!(ctx.stats.is_some(), "L0 should return stats");
assert!(ctx.token_count > 0, "L0 should have non-zero token count");
assert!(ctx.token_count < 300, "L0 should be under 300 tokens");
}
#[tokio::test]
async fn test_l1_returns_recent_nodes() {
let node_state = Arc::new(NodeStateProjection::new("prime.node_state"));
let adjacency = Arc::new(AdjacencyListProjection::forward("prime.adjacency"));
let graph_stats = Arc::new(GraphStatsProjection::new("prime.graph_stats"));
let domain_index = Arc::new(DomainIndexProjection::new());
let cross_domain = Arc::new(CrossDomainProjection::new());
let compressor = IndexCompressor::new(None, 100, 300);
let engine = RecallEngine {
domain_index: domain_index.clone(),
cross_domain: cross_domain.clone(),
node_state: Some(node_state.clone()),
adjacency: Some(adjacency.clone()),
graph_stats: Some(graph_stats.clone()),
compressor,
};
let events = vec![
make_node("n1", "metric", "revenue", "Q3 Revenue"),
make_node("n2", "metric", "revenue", "Churn Rate"),
make_node("n3", "service", "engineering", "Core API"),
];
let all_projections: Vec<Arc<dyn Projection>> = vec![
node_state.clone(),
adjacency.clone(),
graph_stats.clone(),
domain_index.clone(),
cross_domain.clone(),
];
for event in &events {
for proj in &all_projections {
proj.process(event).unwrap();
}
}
let query = RecallContextQuery {
query: "test".to_string(),
tier: ContextTier::L1,
..Default::default()
};
let ctx = engine.context(query).await;
assert_eq!(ctx.tier, ContextTier::L1);
assert!(ctx.index.is_empty(), "L1 should not generate index");
assert!(ctx.vectors.is_empty(), "L1 should not run vector search");
assert!(ctx.stats.is_some(), "L1 should return stats");
assert!(!ctx.nodes.is_empty(), "L1 should return recent nodes");
}
#[tokio::test]
async fn test_l0_cheaper_than_l2() {
let engine = seed_engine();
let l0_ctx = engine
.context(RecallContextQuery {
tier: ContextTier::L0,
..Default::default()
})
.await;
let l2_ctx = engine
.context(RecallContextQuery {
include_index: true,
..Default::default()
})
.await;
assert!(
l0_ctx.token_count < l2_ctx.token_count,
"L0 ({}) should cost fewer tokens than L2 ({})",
l0_ctx.token_count,
l2_ctx.token_count
);
}
#[tokio::test]
async fn test_refresh_index_regenerates() {
let engine = seed_engine();
let idx1 = engine.index().await;
let idx2 = engine.index().await;
assert_eq!(idx1.markdown, idx2.markdown);
let idx3 = engine.refresh_index().await;
assert!(idx3.last_updated >= idx1.last_updated);
}
#[tokio::test]
async fn test_domains_returns_known_domains() {
let engine = seed_engine();
let domains = engine.domains();
assert!(domains.contains(&"revenue".to_string()));
assert!(domains.contains(&"engineering".to_string()));
assert!(domains.contains(&"product".to_string()));
}
#[tokio::test]
async fn test_cross_domain_links_detected() {
let engine = seed_engine();
let links = engine.cross_domain_links();
assert_eq!(links.len(), 2);
}
#[tokio::test]
async fn test_projections_returns_arc_projections() {
let engine = RecallEngine::new(&IndexConfig::default());
let projections = engine.projections();
assert_eq!(projections.len(), 2);
let names: Vec<&str> = projections.iter().map(|p| p.name()).collect();
assert!(names.contains(&"prime.domain_index"));
assert!(names.contains(&"prime.cross_domain"));
}
#[tokio::test]
async fn test_with_dependencies_constructor() {
let domain_index = Arc::new(DomainIndexProjection::new());
let cross_domain = Arc::new(CrossDomainProjection::new());
let compressor = IndexCompressor::new(None, 100, 300);
let engine =
RecallEngine::with_dependencies(domain_index.clone(), cross_domain.clone(), compressor);
let event = make_node("n1", "metric", "revenue", "Q3 Revenue");
domain_index.process(&event).unwrap();
cross_domain.process(&event).unwrap();
let domains = engine.domains();
assert!(domains.contains(&"revenue".to_string()));
}
#[tokio::test]
async fn test_with_dependencies_custom_llm_backend() {
use std::{future::Future, pin::Pin};
struct MockBackend;
impl LlmBackend for MockBackend {
fn generate(
&self,
_prompt: &str,
) -> Pin<Box<dyn Future<Output = std::result::Result<String, String>> + Send + '_>>
{
Box::pin(async { Ok("# Mock Index".to_string()) })
}
}
let domain_index = Arc::new(DomainIndexProjection::new());
let cross_domain = Arc::new(CrossDomainProjection::new());
let compressor = IndexCompressor::new(Some(Box::new(MockBackend)), 100, 300);
let engine =
RecallEngine::with_dependencies(domain_index.clone(), cross_domain.clone(), compressor);
let event = make_node("n1", "metric", "revenue", "Q3 Revenue");
domain_index.process(&event).unwrap();
cross_domain.process(&event).unwrap();
let index = engine.index().await;
assert_eq!(index.markdown, "# Mock Index");
}
}