use std::collections::{HashMap, HashSet};
use std::time::{SystemTime, UNIX_EPOCH};
#[allow(unused_imports)]
use zeph_db::sql;
use crate::embedding_store::EmbeddingStore;
use crate::error::MemoryError;
use super::activation::{ActivatedFact, SpreadingActivation, SpreadingActivationParams};
use super::store::GraphStore;
use super::types::{EdgeType, GraphFact};
#[allow(clippy::too_many_arguments)]
pub async fn graph_recall(
store: &GraphStore,
embeddings: Option<&crate::embedding_store::EmbeddingStore>,
provider: &zeph_llm::any::AnyProvider,
query: &str,
limit: usize,
max_hops: u32,
at_timestamp: Option<&str>,
temporal_decay_rate: f64,
edge_types: &[EdgeType],
) -> Result<Vec<GraphFact>, MemoryError> {
const DEFAULT_STRUCTURAL_WEIGHT: f32 = 0.4;
const DEFAULT_COMMUNITY_CAP: usize = 3;
if limit == 0 {
return Ok(Vec::new());
}
let entity_scores = find_seed_entities(
store,
embeddings,
provider,
query,
limit,
DEFAULT_STRUCTURAL_WEIGHT,
DEFAULT_COMMUNITY_CAP,
)
.await?;
if entity_scores.is_empty() {
return Ok(Vec::new());
}
let now_secs: i64 = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs().cast_signed())
.unwrap_or(0);
let mut all_facts: Vec<GraphFact> = Vec::new();
for (seed_id, seed_score) in &entity_scores {
let (entities, edges, depth_map) = if let Some(ts) = at_timestamp {
store.bfs_at_timestamp(*seed_id, max_hops, ts).await?
} else if !edge_types.is_empty() {
store.bfs_typed(*seed_id, max_hops, edge_types).await?
} else {
store.bfs_with_depth(*seed_id, max_hops).await?
};
let name_map: HashMap<i64, &str> = entities
.iter()
.map(|e| (e.id, e.canonical_name.as_str()))
.collect();
let traversed_edge_ids: Vec<i64> = edges.iter().map(|e| e.id).collect();
for edge in &edges {
let Some(&hop_distance) = depth_map
.get(&edge.source_entity_id)
.or_else(|| depth_map.get(&edge.target_entity_id))
else {
continue;
};
let entity_name = name_map
.get(&edge.source_entity_id)
.copied()
.unwrap_or_default();
let target_name = name_map
.get(&edge.target_entity_id)
.copied()
.unwrap_or_default();
if entity_name.is_empty() || target_name.is_empty() {
continue;
}
all_facts.push(GraphFact {
entity_name: entity_name.to_owned(),
relation: edge.relation.clone(),
target_name: target_name.to_owned(),
fact: edge.fact.clone(),
entity_match_score: *seed_score,
hop_distance,
confidence: edge.confidence,
valid_from: Some(edge.valid_from.clone()),
edge_type: edge.edge_type,
retrieval_count: edge.retrieval_count,
});
}
if !traversed_edge_ids.is_empty()
&& let Err(e) = store.record_edge_retrieval(&traversed_edge_ids).await
{
tracing::warn!(error = %e, "graph_recall: failed to record edge retrieval");
}
}
let mut scored: Vec<(f32, GraphFact)> = all_facts
.into_iter()
.map(|f| {
let s = f.score_with_decay(temporal_decay_rate, now_secs);
(s, f)
})
.collect();
scored.sort_by(|(sa, _), (sb, _)| sb.total_cmp(sa));
let mut all_facts: Vec<GraphFact> = scored.into_iter().map(|(_, f)| f).collect();
let mut seen: HashSet<(String, String, String, EdgeType)> = HashSet::new();
all_facts.retain(|f| {
seen.insert((
f.entity_name.clone(),
f.relation.clone(),
f.target_name.clone(),
f.edge_type,
))
});
all_facts.truncate(limit);
Ok(all_facts)
}
async fn seed_embedding_fallback(
store: &GraphStore,
emb_store: &EmbeddingStore,
provider: &zeph_llm::any::AnyProvider,
query: &str,
limit: usize,
fts_map: &mut HashMap<i64, (super::types::Entity, f32)>,
) -> bool {
use zeph_llm::LlmProvider as _;
const ENTITY_COLLECTION: &str = "zeph_graph_entities";
let embedding = match provider.embed(query).await {
Ok(v) => v,
Err(e) => {
tracing::warn!(error = %e, "seed fallback: embed() failed, returning empty seeds");
return false;
}
};
match emb_store
.search_collection(ENTITY_COLLECTION, &embedding, limit, None)
.await
{
Ok(results) => {
for result in results {
if let Some(entity_id) = result
.payload
.get("entity_id")
.and_then(serde_json::Value::as_i64)
&& let Ok(Some(entity)) = store.find_entity_by_id(entity_id).await
{
fts_map.insert(entity_id, (entity, result.score));
}
}
}
Err(e) => {
tracing::warn!(error = %e, "seed fallback: embedding search failed");
}
}
true
}
async fn find_seed_entities(
store: &GraphStore,
embeddings: Option<&EmbeddingStore>,
provider: &zeph_llm::any::AnyProvider,
query: &str,
limit: usize,
structural_weight: f32,
community_cap: usize,
) -> Result<HashMap<i64, f32>, MemoryError> {
use crate::graph::types::ScoredEntity;
const MAX_WORDS: usize = 5;
let filtered: Vec<&str> = query
.split_whitespace()
.filter(|w| w.len() >= 3)
.take(MAX_WORDS)
.collect();
let words: Vec<&str> = if filtered.is_empty() && !query.is_empty() {
vec![query]
} else {
filtered
};
let mut fts_map: HashMap<i64, (super::types::Entity, f32)> = HashMap::new();
for word in &words {
let ranked = store.find_entities_ranked(word, limit * 2).await?;
for (entity, fts_score) in ranked {
fts_map
.entry(entity.id)
.and_modify(|(_, s)| *s = s.max(fts_score))
.or_insert((entity, fts_score));
}
}
if fts_map.is_empty()
&& let Some(emb_store) = embeddings
&& !seed_embedding_fallback(store, emb_store, provider, query, limit, &mut fts_map).await
{
return Ok(HashMap::new());
}
if fts_map.is_empty() {
return Ok(HashMap::new());
}
let entity_ids: Vec<i64> = fts_map.keys().copied().collect();
let structural_scores = store.entity_structural_scores(&entity_ids).await?;
let community_ids = store.entity_community_ids(&entity_ids).await?;
let fts_weight = 1.0 - structural_weight;
let mut scored: Vec<ScoredEntity> = fts_map
.into_values()
.map(|(entity, fts_score)| {
let struct_score = structural_scores.get(&entity.id).copied().unwrap_or(0.0);
let community_id = community_ids.get(&entity.id).copied();
ScoredEntity {
entity,
fts_score,
structural_score: struct_score,
community_id,
}
})
.collect();
scored.sort_by(|a, b| {
let score_a = a.fts_score * fts_weight + a.structural_score * structural_weight;
let score_b = b.fts_score * fts_weight + b.structural_score * structural_weight;
score_b.total_cmp(&score_a)
});
let capped: Vec<&ScoredEntity> = if community_cap == 0 {
scored.iter().collect()
} else {
let mut community_counts: HashMap<i64, usize> = HashMap::new();
let mut result: Vec<&ScoredEntity> = Vec::new();
for se in &scored {
match se.community_id {
Some(cid) => {
let count = community_counts.entry(cid).or_insert(0);
if *count < community_cap {
*count += 1;
result.push(se);
}
}
None => {
result.push(se);
}
}
}
result
};
let selected: Vec<&ScoredEntity> = if capped.is_empty() && !scored.is_empty() {
scored.iter().take(limit).collect()
} else {
capped.into_iter().take(limit).collect()
};
let entity_scores: HashMap<i64, f32> = selected
.into_iter()
.map(|se| {
let hybrid = se.fts_score * fts_weight + se.structural_score * structural_weight;
(se.entity.id, hybrid.clamp(0.1, 1.0))
})
.collect();
Ok(entity_scores)
}
pub async fn graph_recall_activated(
store: &GraphStore,
embeddings: Option<&EmbeddingStore>,
provider: &zeph_llm::any::AnyProvider,
query: &str,
limit: usize,
params: SpreadingActivationParams,
edge_types: &[EdgeType],
) -> Result<Vec<ActivatedFact>, MemoryError> {
if limit == 0 {
return Ok(Vec::new());
}
let entity_scores = find_seed_entities(
store,
embeddings,
provider,
query,
limit,
params.seed_structural_weight,
params.seed_community_cap,
)
.await?;
if entity_scores.is_empty() {
return Ok(Vec::new());
}
tracing::debug!(
seeds = entity_scores.len(),
"spreading activation: starting recall"
);
let sa = SpreadingActivation::new(params);
let (_, mut facts) = sa.spread(store, entity_scores, edge_types).await?;
let edge_ids: Vec<i64> = facts.iter().map(|f| f.edge.id).collect();
if !edge_ids.is_empty()
&& let Err(e) = store.record_edge_retrieval(&edge_ids).await
{
tracing::warn!(error = %e, "graph_recall_activated: failed to record edge retrieval");
}
facts.sort_by(|a, b| b.activation_score.total_cmp(&a.activation_score));
let mut seen: HashSet<(i64, String, i64, EdgeType)> = HashSet::new();
facts.retain(|f| {
seen.insert((
f.edge.source_entity_id,
f.edge.relation.clone(),
f.edge.target_entity_id,
f.edge.edge_type,
))
});
facts.truncate(limit);
tracing::debug!(
result_count = facts.len(),
"spreading activation: recall complete"
);
Ok(facts)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::store::GraphStore;
use crate::graph::types::EntityType;
use crate::store::SqliteStore;
use zeph_llm::any::AnyProvider;
use zeph_llm::mock::MockProvider;
async fn setup_store() -> GraphStore {
let store = SqliteStore::new(":memory:").await.unwrap();
GraphStore::new(store.pool().clone())
}
fn mock_provider() -> AnyProvider {
AnyProvider::Mock(MockProvider::default())
}
#[tokio::test]
async fn graph_recall_empty_graph_returns_empty() {
let store = setup_store().await;
let provider = mock_provider();
let result = graph_recall(&store, None, &provider, "anything", 10, 2, None, 0.0, &[])
.await
.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn graph_recall_zero_limit_returns_empty() {
let store = setup_store().await;
let provider = mock_provider();
let result = graph_recall(&store, None, &provider, "user", 0, 2, None, 0.0, &[])
.await
.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn graph_recall_fuzzy_match_returns_facts() {
let store = setup_store().await;
let user_id = store
.upsert_entity("Alice", "Alice", EntityType::Person, None)
.await
.unwrap();
let tool_id = store
.upsert_entity("neovim", "neovim", EntityType::Tool, None)
.await
.unwrap();
store
.insert_edge(user_id, tool_id, "uses", "Alice uses neovim", 0.9, None)
.await
.unwrap();
let provider = mock_provider();
let result = graph_recall(&store, None, &provider, "Ali neovim", 10, 2, None, 0.0, &[])
.await
.unwrap();
assert!(!result.is_empty());
assert_eq!(result[0].relation, "uses");
}
#[tokio::test]
async fn graph_recall_respects_max_hops() {
let store = setup_store().await;
let a = store
.upsert_entity("Alpha", "Alpha", EntityType::Person, None)
.await
.unwrap();
let b = store
.upsert_entity("Beta", "Beta", EntityType::Person, None)
.await
.unwrap();
let c = store
.upsert_entity("Gamma", "Gamma", EntityType::Person, None)
.await
.unwrap();
store
.insert_edge(a, b, "knows", "Alpha knows Beta", 0.8, None)
.await
.unwrap();
store
.insert_edge(b, c, "knows", "Beta knows Gamma", 0.8, None)
.await
.unwrap();
let provider = mock_provider();
let result = graph_recall(&store, None, &provider, "Alp", 10, 1, None, 0.0, &[])
.await
.unwrap();
assert!(result.iter().all(|f| f.hop_distance <= 1));
}
#[tokio::test]
async fn graph_recall_deduplicates_facts() {
let store = setup_store().await;
let alice = store
.upsert_entity("Alice", "Alice", EntityType::Person, None)
.await
.unwrap();
let bob = store
.upsert_entity("Bob", "Bob", EntityType::Person, None)
.await
.unwrap();
store
.insert_edge(alice, bob, "knows", "Alice knows Bob", 0.9, None)
.await
.unwrap();
let provider = mock_provider();
let result = graph_recall(&store, None, &provider, "Ali Bob", 10, 2, None, 0.0, &[])
.await
.unwrap();
let mut seen = std::collections::HashSet::new();
for f in &result {
let key = (&f.entity_name, &f.relation, &f.target_name);
assert!(seen.insert(key), "duplicate fact found: {f:?}");
}
}
#[tokio::test]
async fn graph_recall_sorts_by_composite_score() {
let store = setup_store().await;
let a = store
.upsert_entity("Alpha", "Alpha", EntityType::Person, None)
.await
.unwrap();
let b = store
.upsert_entity("Beta", "Beta", EntityType::Tool, None)
.await
.unwrap();
let c = store
.upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
.await
.unwrap();
store
.insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
.await
.unwrap();
store
.insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
.await
.unwrap();
let provider = mock_provider();
let result = graph_recall(&store, None, &provider, "Alp", 10, 2, None, 0.0, &[])
.await
.unwrap();
assert!(result.len() >= 2);
let s0 = result[0].composite_score();
let s1 = result[1].composite_score();
assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
}
#[tokio::test]
async fn graph_recall_limit_truncates() {
let store = setup_store().await;
let root = store
.upsert_entity("Root", "Root", EntityType::Person, None)
.await
.unwrap();
for i in 0..10 {
let target = store
.upsert_entity(
&format!("Target{i}"),
&format!("Target{i}"),
EntityType::Tool,
None,
)
.await
.unwrap();
store
.insert_edge(
root,
target,
"has",
&format!("Root has Target{i}"),
0.8,
None,
)
.await
.unwrap();
}
let provider = mock_provider();
let result = graph_recall(&store, None, &provider, "Roo", 3, 2, None, 0.0, &[])
.await
.unwrap();
assert!(result.len() <= 3);
}
#[tokio::test]
async fn graph_recall_at_timestamp_excludes_future_edges() {
let store = setup_store().await;
let alice = store
.upsert_entity("Alice", "Alice", EntityType::Person, None)
.await
.unwrap();
let bob = store
.upsert_entity("Bob", "Bob", EntityType::Person, None)
.await
.unwrap();
zeph_db::query(
sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
VALUES (?1, ?2, 'knows', 'Alice knows Bob', 0.9, '2100-01-01 00:00:00')"),
)
.bind(alice)
.bind(bob)
.execute(store.pool())
.await
.unwrap();
let provider = mock_provider();
let result = graph_recall(
&store,
None,
&provider,
"Ali",
10,
2,
Some("2026-01-01 00:00:00"),
0.0,
&[],
)
.await
.unwrap();
assert!(result.is_empty(), "future edge should be excluded");
}
#[tokio::test]
async fn graph_recall_at_timestamp_excludes_invalidated_edges() {
let store = setup_store().await;
let alice = store
.upsert_entity("Alice", "Alice", EntityType::Person, None)
.await
.unwrap();
let carol = store
.upsert_entity("Carol", "Carol", EntityType::Person, None)
.await
.unwrap();
zeph_db::query(
sql!("INSERT INTO graph_edges
(source_entity_id, target_entity_id, relation, fact, confidence, valid_from, valid_to, expired_at)
VALUES (?1, ?2, 'manages', 'Alice manages Carol', 0.8,
'2020-01-01 00:00:00', '2021-01-01 00:00:00', '2021-01-01 00:00:00')"),
)
.bind(alice)
.bind(carol)
.execute(store.pool())
.await
.unwrap();
let provider = mock_provider();
let result_current = graph_recall(&store, None, &provider, "Ali", 10, 2, None, 0.0, &[])
.await
.unwrap();
assert!(
result_current.is_empty(),
"expired edge should be invisible at current time"
);
let result_historical = graph_recall(
&store,
None,
&provider,
"Ali",
10,
2,
Some("2020-06-01 00:00:00"),
0.0,
&[],
)
.await
.unwrap();
assert!(
!result_historical.is_empty(),
"edge should be visible within its validity window"
);
}
#[tokio::test]
async fn graph_recall_community_cap_guard_non_empty() {
let store = setup_store().await;
let mut entity_ids = Vec::new();
for i in 0..5usize {
let id = store
.upsert_entity(
&format!("Entity{i}"),
&format!("entity{i}"),
crate::graph::types::EntityType::Concept,
None,
)
.await
.unwrap();
entity_ids.push(id);
}
let community_id = store
.upsert_community("TestComm", "test", &entity_ids, Some("fp"))
.await
.unwrap();
let _ = community_id;
let hub = store
.upsert_entity("Hub", "hub", crate::graph::types::EntityType::Concept, None)
.await
.unwrap();
for &target in &entity_ids {
store
.insert_edge(hub, target, "has", "Hub has entity", 0.9, None)
.await
.unwrap();
}
let provider = mock_provider();
let result = graph_recall(&store, None, &provider, "entity", 10, 2, None, 0.0, &[])
.await
.unwrap();
assert!(
!result.is_empty(),
"SA-INV-10: community cap must not zero out all seeds"
);
}
#[tokio::test]
async fn graph_recall_no_fts_match_no_embeddings_returns_empty() {
let store = setup_store().await;
let a = store
.upsert_entity(
"Zephyr",
"zephyr",
crate::graph::types::EntityType::Concept,
None,
)
.await
.unwrap();
let b = store
.upsert_entity(
"Concept",
"concept",
crate::graph::types::EntityType::Concept,
None,
)
.await
.unwrap();
store
.insert_edge(a, b, "rel", "Zephyr rel Concept", 0.9, None)
.await
.unwrap();
let provider = mock_provider();
let result = graph_recall(
&store,
None,
&provider,
"xyzzyquuxfrob",
10,
2,
None,
0.0,
&[],
)
.await
.unwrap();
assert!(
result.is_empty(),
"must return empty (not error) when FTS5 returns 0 and no embeddings available"
);
}
#[tokio::test]
async fn graph_recall_temporal_decay_preserves_order_with_zero_rate() {
let store = setup_store().await;
let a = store
.upsert_entity("Alpha", "Alpha", EntityType::Person, None)
.await
.unwrap();
let b = store
.upsert_entity("Beta", "Beta", EntityType::Tool, None)
.await
.unwrap();
let c = store
.upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
.await
.unwrap();
store
.insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
.await
.unwrap();
store
.insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
.await
.unwrap();
let provider = mock_provider();
let result = graph_recall(&store, None, &provider, "Alp", 10, 2, None, 0.0, &[])
.await
.unwrap();
assert!(result.len() >= 2);
let s0 = result[0].composite_score();
let s1 = result[1].composite_score();
assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
}
}