use std::collections::{HashMap, HashSet};
use std::time::{SystemTime, UNIX_EPOCH};
use crate::embedding_store::EmbeddingStore;
use crate::error::MemoryError;
use crate::graph::retrieval::find_seed_entities;
use crate::graph::store::GraphStore;
use crate::graph::types::{EdgeType, GraphFact};
const DEFAULT_STRUCTURAL_WEIGHT: f32 = 0.4;
const DEFAULT_COMMUNITY_CAP: usize = 3;
#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
pub async fn graph_recall_beam(
store: &GraphStore,
embeddings: Option<&EmbeddingStore>,
provider: &zeph_llm::any::AnyProvider,
query: &str,
limit: usize,
beam_width: usize,
max_hops: u32,
edge_types: &[EdgeType],
temporal_decay_rate: f64,
hebbian_enabled: bool,
hebbian_lr: f32,
) -> Result<Vec<GraphFact>, MemoryError> {
let _span = tracing::info_span!("memory.graph.beam", query_len = query.len()).entered();
if limit == 0 {
return Ok(Vec::new());
}
let entity_scores = find_seed_entities(
store,
embeddings,
provider,
query,
limit,
DEFAULT_STRUCTURAL_WEIGHT,
DEFAULT_COMMUNITY_CAP,
)
.await?;
if entity_scores.is_empty() {
return Ok(Vec::new());
}
let now_secs: i64 = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| d.as_secs().cast_signed());
let mut beam_scores: Vec<(i64, f32)> = entity_scores.into_iter().collect();
beam_scores.sort_by(|(_, sa), (_, sb)| sb.total_cmp(sa));
beam_scores.truncate(beam_width);
let mut beam_ids: Vec<i64> = beam_scores.iter().map(|(id, _)| *id).collect();
let mut beam_score_map: HashMap<i64, f32> = beam_scores.into_iter().collect();
let mut all_db_edges: Vec<crate::graph::types::Edge> = Vec::new();
let mut entity_name_map: HashMap<i64, String> = HashMap::new();
for _hop in 0..max_hops {
if beam_ids.is_empty() {
break;
}
let edges = store.edges_for_entities(&beam_ids, edge_types).await?;
if edges.is_empty() {
break;
}
let new_entity_ids: Vec<i64> = edges
.iter()
.flat_map(|e| [e.source_entity_id, e.target_entity_id])
.filter(|id| !entity_name_map.contains_key(id))
.collect::<HashSet<_>>()
.into_iter()
.collect();
for id in new_entity_ids {
if let Ok(Some(entity)) = store.find_entity_by_id(id).await {
entity_name_map.insert(id, entity.canonical_name.clone());
}
}
let mut neighbour_scores: HashMap<i64, f32> = HashMap::new();
for edge in &edges {
let edge_conf = edge.confidence;
neighbour_scores
.entry(edge.target_entity_id)
.and_modify(|s| *s = s.max(edge_conf))
.or_insert(edge_conf);
neighbour_scores
.entry(edge.source_entity_id)
.and_modify(|s| *s = s.max(edge_conf))
.or_insert(edge_conf);
}
let mut candidates: Vec<(i64, f32)> = neighbour_scores
.into_iter()
.filter(|(id, _)| !beam_score_map.contains_key(id))
.collect();
candidates.sort_by(|(_, sa), (_, sb)| sb.total_cmp(sa));
candidates.truncate(beam_width);
beam_ids = candidates.iter().map(|(id, _)| *id).collect();
for (id, cand_score) in candidates {
beam_score_map.insert(id, cand_score);
}
all_db_edges.extend(edges);
}
if all_db_edges.is_empty() {
return Ok(Vec::new());
}
let edge_ids: Vec<i64> = all_db_edges.iter().map(|e| e.id).collect();
if let Err(e) = store.record_edge_retrieval(&edge_ids).await {
tracing::warn!(error = %e, "graph_recall_beam: failed to record edge retrieval");
}
if hebbian_enabled
&& !edge_ids.is_empty()
&& let Err(e) = store.apply_hebbian_increment(&edge_ids, hebbian_lr).await
{
tracing::warn!(error = %e, "graph_recall_beam: hebbian increment failed");
}
let mut facts: Vec<GraphFact> = Vec::new();
let mut seen: HashSet<(String, String, String, EdgeType)> = HashSet::new();
for edge in &all_db_edges {
let entity_name = entity_name_map
.get(&edge.source_entity_id)
.cloned()
.unwrap_or_default();
let target_name = entity_name_map
.get(&edge.target_entity_id)
.cloned()
.unwrap_or_default();
if entity_name.is_empty() || target_name.is_empty() {
continue;
}
let key = (
entity_name.clone(),
edge.relation.clone(),
target_name.clone(),
edge.edge_type,
);
if seen.insert(key) {
let seed_score = beam_score_map
.get(&edge.source_entity_id)
.copied()
.unwrap_or(0.5);
facts.push(GraphFact {
entity_name,
relation: edge.relation.clone(),
target_name,
fact: edge.fact.clone(),
entity_match_score: seed_score,
hop_distance: 1,
confidence: edge.confidence,
valid_from: Some(edge.valid_from.clone()),
edge_type: edge.edge_type,
retrieval_count: edge.retrieval_count,
});
}
}
facts.sort_by(|a, b| {
let sa = a.score_with_decay(temporal_decay_rate, now_secs);
let sb = b.score_with_decay(temporal_decay_rate, now_secs);
sb.total_cmp(&sa)
});
facts.truncate(limit);
Ok(facts)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::store::GraphStore;
use crate::graph::types::EntityType;
use crate::store::SqliteStore;
use zeph_llm::any::AnyProvider;
use zeph_llm::mock::MockProvider;
async fn setup_store() -> GraphStore {
let store = SqliteStore::new(":memory:").await.unwrap();
GraphStore::new(store.pool().clone())
}
fn mock_provider() -> AnyProvider {
AnyProvider::Mock(MockProvider::default())
}
#[tokio::test]
async fn beam_empty_graph_returns_empty() {
let store = setup_store().await;
let provider = mock_provider();
let result = graph_recall_beam(
&store,
None,
&provider,
"anything",
10,
5,
2,
&[],
0.0,
false,
0.0,
)
.await
.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn beam_zero_limit_returns_empty() {
let store = setup_store().await;
let provider = mock_provider();
let result = graph_recall_beam(
&store,
None,
&provider,
"anything",
0,
5,
2,
&[],
0.0,
false,
0.0,
)
.await
.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn beam_finds_direct_edge() {
let store = setup_store().await;
let a = store
.upsert_entity("Alice", "alice", EntityType::Person, None)
.await
.unwrap();
let b = store
.upsert_entity("Bob", "bob", EntityType::Person, None)
.await
.unwrap();
store
.insert_edge(a, b, "knows", "Alice knows Bob", 0.9, None)
.await
.unwrap();
let provider = mock_provider();
let result = graph_recall_beam(
&store,
None,
&provider,
"Alice",
10,
5,
2,
&[],
0.0,
false,
0.0,
)
.await
.unwrap();
assert!(!result.is_empty());
}
}