use anyhow::Result;
use chrono::Utc;
use uuid::Uuid;
use crate::schema::{Memory, RelationType};
use crate::store::Store;
#[derive(Debug, Clone)]
pub struct QueryRequest {
pub text: String,
pub embedding: Vec<f32>,
pub limit: usize,
pub filters: QueryFilters,
}
#[derive(Debug, Clone, Default)]
pub struct QueryFilters {
pub source: Option<String>,
pub memory_type: Option<crate::schema::MemoryType>,
pub min_confidence: Option<f32>,
pub entity_names: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct QueryResult {
pub memory: Memory,
pub score: f32,
pub path: Vec<Uuid>,
}
const MIN_RELEVANCE_SCORE: f32 = 0.59;
pub struct QueryEngine<'a, S: Store> {
store: &'a S,
vector_weight: f32,
graph_weight: f32,
recency_weight: f32,
}
impl<'a, S: Store> QueryEngine<'a, S> {
pub fn new(store: &'a S) -> Self {
Self {
store,
vector_weight: 0.5,
graph_weight: 0.3,
recency_weight: 0.2,
}
}
pub fn with_weights(mut self, vector: f32, graph: f32, recency: f32) -> Self {
self.vector_weight = vector;
self.graph_weight = graph;
self.recency_weight = recency;
self
}
pub fn recall(&self, request: &QueryRequest) -> Result<Vec<QueryResult>> {
let vector_results = self
.store
.vector_search(&request.embedding, request.limit * 3)?;
let mut scored: Vec<QueryResult> = Vec::new();
for (memory, similarity) in vector_results {
if let Some(min_conf) = request.filters.min_confidence
&& memory.confidence < min_conf
{
continue;
}
if let Some(ref source) = request.filters.source
&& &memory.source != source
{
continue;
}
if let Some(ref mt) = request.filters.memory_type
&& &memory.memory_type != mt
{
continue;
}
let recency_score = self.compute_recency(&memory);
let graph_score = self
.compute_graph_relevance(&memory, &request.filters.entity_names)
.unwrap_or(0.0);
let final_score = (similarity * self.vector_weight)
+ (graph_score * self.graph_weight)
+ (recency_score * self.recency_weight);
scored.push(QueryResult {
memory,
score: final_score,
path: Vec::new(),
});
}
scored.retain(|r| r.score >= MIN_RELEVANCE_SCORE);
scored.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
scored.truncate(request.limit);
Ok(scored)
}
fn compute_recency(&self, memory: &Memory) -> f32 {
let hours_since_access = Utc::now()
.signed_duration_since(memory.last_accessed)
.num_hours() as f32;
let decay = (-hours_since_access / (24.0 * 30.0)).exp();
let access_boost = (memory.access_count as f32).ln_1p() / 10.0;
(decay + access_boost).min(1.0)
}
fn compute_graph_relevance(&self, memory: &Memory, _entity_names: &[String]) -> Result<f32> {
let scored_types = [
(RelationType::Reinforces, 0.3_f32),
(RelationType::RelatesTo, 0.2),
(RelationType::DistilledFrom, 0.15),
(RelationType::Mentions, 0.1),
(RelationType::DerivedFrom, 0.05),
(RelationType::Contradicts, -0.1),
(RelationType::Supersedes, -0.2),
];
let mut relevance = 0.0_f32;
for (rt, boost) in &scored_types {
if let Ok(rels) = self.store.get_relations(memory.id, Some(*rt)) {
for rel in &rels {
let b = if *rt == RelationType::RelatesTo {
boost * rel.strength
} else {
*boost
};
relevance += b;
}
}
}
Ok(relevance.clamp(0.0, 1.0))
}
pub fn store(&self) -> &S {
self.store
}
}