use anyhow::Result;
use chrono::Utc;
use uuid::Uuid;
use crate::schema::{Memory, RelationType};
use crate::store::Store;
#[derive(Debug, Clone)]
pub struct QueryRequest {
pub text: String,
pub embedding: Vec<f32>,
pub limit: usize,
pub filters: QueryFilters,
}
#[derive(Debug, Clone, Default)]
pub struct QueryFilters {
pub source: Option<String>,
pub memory_type: Option<crate::schema::MemoryType>,
pub min_confidence: Option<f32>,
pub entity_names: Vec<String>,
pub project_path: Option<String>,
}
#[derive(Debug, Clone)]
pub struct QueryResult {
pub memory: Memory,
pub score: f32,
pub path: Vec<Uuid>,
}
const MIN_SIMILARITY: f32 = 0.59;
const PROJECT_AFFINITY_BOOST: f32 = 1.15;
pub struct QueryEngine<'a, S: Store> {
store: &'a S,
vector_weight: f32,
graph_weight: f32,
recency_weight: f32,
}
impl<'a, S: Store> QueryEngine<'a, S> {
pub fn new(store: &'a S) -> Self {
Self {
store,
vector_weight: 0.5,
graph_weight: 0.3,
recency_weight: 0.2,
}
}
pub fn with_weights(mut self, vector: f32, graph: f32, recency: f32) -> Self {
self.vector_weight = vector;
self.graph_weight = graph;
self.recency_weight = recency;
self
}
pub fn recall(&self, request: &QueryRequest) -> Result<Vec<QueryResult>> {
let vector_results = self
.store
.vector_search(&request.embedding, request.limit * 3)?;
let mut scored: Vec<QueryResult> = Vec::new();
for (memory, similarity) in vector_results {
if similarity < MIN_SIMILARITY {
continue;
}
if let Some(min_conf) = request.filters.min_confidence
&& memory.confidence < min_conf
{
continue;
}
if let Some(ref source) = request.filters.source
&& &memory.source != source
{
continue;
}
if let Some(ref mt) = request.filters.memory_type
&& &memory.memory_type != mt
{
continue;
}
let recency_score = self.compute_recency(&memory);
let graph_score = self
.compute_graph_relevance(&memory, &request.filters.entity_names)
.unwrap_or(0.0);
let mut rank_score = (similarity * self.vector_weight)
+ (graph_score * self.graph_weight)
+ (recency_score * self.recency_weight);
if let Some(ref qp) = request.filters.project_path {
if memory.project_path.as_deref() == Some(qp.as_str()) {
rank_score *= PROJECT_AFFINITY_BOOST;
}
}
scored.push(QueryResult {
memory,
score: rank_score,
path: Vec::new(),
});
}
scored.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
scored.truncate(request.limit);
Ok(scored)
}
fn compute_recency(&self, memory: &Memory) -> f32 {
let hours_since_access = Utc::now()
.signed_duration_since(memory.last_accessed)
.num_hours() as f32;
let decay = (-hours_since_access / (24.0 * 30.0)).exp();
let access_boost = (memory.access_count as f32).ln_1p() / 10.0;
(decay + access_boost).min(1.0)
}
fn compute_graph_relevance(&self, memory: &Memory, _entity_names: &[String]) -> Result<f32> {
let scored_types = [
(RelationType::Reinforces, 0.3_f32),
(RelationType::RelatesTo, 0.2),
(RelationType::DistilledFrom, 0.15),
(RelationType::Mentions, 0.1),
(RelationType::DerivedFrom, 0.05),
(RelationType::Contradicts, -0.1),
(RelationType::Supersedes, -0.2),
];
let mut relevance = 0.0_f32;
for (rt, boost) in &scored_types {
if let Ok(rels) = self.store.get_relations(memory.id, Some(*rt)) {
for rel in &rels {
let b = if *rt == RelationType::RelatesTo {
boost * rel.strength
} else {
*boost
};
relevance += b;
}
}
}
Ok(relevance.clamp(0.0, 1.0))
}
pub fn store(&self) -> &S {
self.store
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::{Conversation, Entity, MemoryType, Relation};
use chrono::Duration;
struct StubStore {
vector_results: Vec<(Memory, f32)>,
relations: Vec<Relation>,
}
impl Store for StubStore {
fn vector_search(&self, _embedding: &[f32], _limit: usize) -> Result<Vec<(Memory, f32)>> {
Ok(self.vector_results.clone())
}
fn get_relations(
&self,
node_id: Uuid,
relation_type: Option<RelationType>,
) -> Result<Vec<Relation>> {
Ok(self
.relations
.iter()
.filter(|r| r.from_id == node_id)
.filter(|r| relation_type.map(|rt| rt == r.relation_type).unwrap_or(true))
.cloned()
.collect())
}
fn store_memory(&self, _m: &Memory) -> Result<()> {
unimplemented!()
}
fn get_memory(&self, _id: Uuid) -> Result<Option<Memory>> {
unimplemented!()
}
fn delete_memory(&self, _id: Uuid) -> Result<()> {
unimplemented!()
}
fn store_entity(&self, _e: &Entity) -> Result<()> {
unimplemented!()
}
fn get_entity(&self, _id: Uuid) -> Result<Option<Entity>> {
unimplemented!()
}
fn find_entity_by_name(&self, _name: &str) -> Result<Option<Entity>> {
unimplemented!()
}
fn store_conversation(&self, _c: &Conversation) -> Result<()> {
unimplemented!()
}
fn store_relation(&self, _r: &Relation) -> Result<()> {
unimplemented!()
}
fn traverse(&self, _id: Uuid, _depth: u32) -> Result<Vec<(Memory, Vec<Relation>)>> {
unimplemented!()
}
fn memories_by_source(&self, _s: &str) -> Result<Vec<Memory>> {
unimplemented!()
}
fn memories_by_type(&self, _mt: MemoryType) -> Result<Vec<Memory>> {
unimplemented!()
}
fn memories_needing_decay(&self, _days: u32) -> Result<Vec<Memory>> {
unimplemented!()
}
fn update_memory(&self, _m: &Memory) -> Result<()> {
unimplemented!()
}
fn record_access(&self, _memory: &Memory) -> Result<()> {
unimplemented!()
}
fn text_search(&self, _q: &str, _limit: usize) -> Result<Vec<Memory>> {
unimplemented!()
}
fn memory_count(&self) -> Result<usize> {
unimplemented!()
}
fn all_memory_ids(&self) -> Result<Vec<Uuid>> {
unimplemented!()
}
fn all_relations(&self) -> Result<Vec<Relation>> {
unimplemented!()
}
}
fn memory_aged(content: &str, days_old: i64) -> Memory {
let when = Utc::now() - Duration::days(days_old);
let mut m = Memory::new(
content.to_string(),
MemoryType::Semantic,
"test".to_string(),
String::new(),
);
m.created_at = when;
m.last_accessed = when;
m
}
fn query() -> QueryRequest {
QueryRequest {
text: "query text".to_string(),
embedding: vec![0.0_f32; 384],
limit: 10,
filters: QueryFilters::default(),
}
}
#[test]
fn aged_unconnected_memory_survives_on_strong_similarity() {
let store = StubStore {
vector_results: vec![(
memory_aged("kuzu was chosen as the embedded graph store", 180),
0.92,
)],
relations: Vec::new(),
};
let results = QueryEngine::new(&store).recall(&query()).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn well_connected_low_similarity_memory_is_excluded() {
let weak = memory_aged("an entirely unrelated topic", 0);
let weak_id = weak.id;
let reinforces = |to| Relation {
from_id: weak_id,
to_id: to,
relation_type: RelationType::Reinforces,
strength: 1.0,
context: None,
};
let store = StubStore {
vector_results: vec![(weak, 0.45)],
relations: vec![
reinforces(Uuid::new_v4()),
reinforces(Uuid::new_v4()),
reinforces(Uuid::new_v4()),
],
};
let results = QueryEngine::new(&store).recall(&query()).unwrap();
assert!(results.is_empty());
}
#[test]
fn results_are_ranked_by_blended_score() {
let store = StubStore {
vector_results: vec![
(memory_aged("older relevant note", 200), 0.80),
(memory_aged("fresher relevant note", 0), 0.80),
],
relations: Vec::new(),
};
let results = QueryEngine::new(&store).recall(&query()).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].memory.content, "fresher relevant note");
}
}