use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
use rusqlite::{Connection, params};
use serde::Serialize;
use crate::error::MemoryError;
use crate::memory::Memory;
use crate::relationship::get_relationships_for_memory;
use crate::{ENERGY_THRESHOLD, MAX_SPREADING_ITERATIONS, SearchParams};
#[derive(Debug, Serialize)]
pub struct ActivatedMemory {
#[serde(flatten)]
pub memory: Memory,
pub energy: f64,
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub is_context: bool,
}
#[derive(Debug, Serialize)]
pub struct SearchResult {
pub query: String,
pub seed_count: usize,
pub total_activated: usize,
pub iterations: usize,
pub memories: Vec<ActivatedMemory>,
}
#[derive(Debug, Clone)]
struct ActivationItem {
energy: f64,
mem_id: i64,
}
impl PartialEq for ActivationItem {
fn eq(&self, other: &Self) -> bool {
self.mem_id == other.mem_id
}
}
impl Eq for ActivationItem {}
impl PartialOrd for ActivationItem {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ActivationItem {
fn cmp(&self, other: &Self) -> Ordering {
self.energy
.partial_cmp(&other.energy)
.unwrap_or(Ordering::Equal)
}
}
fn fts_search(
conn: &Connection,
query: &str,
limit: usize,
) -> Result<Vec<(i64, f64)>, MemoryError> {
if query.trim().is_empty() {
return Err(MemoryError::InvalidInput(
"Search query cannot be empty".to_string(),
));
}
let mut stmt = conn.prepare(
"SELECT rowid, bm25(memories_fts) FROM memories_fts
WHERE memories_fts MATCH ?1
ORDER BY bm25(memories_fts)
LIMIT ?2",
)?;
let raw_results: Vec<(i64, f64)> = stmt
.query_map(params![query, limit as i64], |row| {
Ok((row.get(0)?, row.get(1)?))
})?
.collect::<Result<Vec<_>, _>>()?;
if raw_results.is_empty() {
return Ok(vec![]);
}
let min_score = raw_results
.iter()
.map(|(_, s)| *s)
.fold(f64::INFINITY, f64::min);
let max_score = raw_results
.iter()
.map(|(_, s)| *s)
.fold(f64::NEG_INFINITY, f64::max);
let results: Vec<(i64, f64)> = if (max_score - min_score).abs() < 1e-9 {
raw_results.into_iter().map(|(id, _)| (id, 1.0)).collect()
} else {
raw_results
.into_iter()
.map(|(id, score)| {
let normalized = (max_score - score) / (max_score - min_score);
let scaled = 0.1 + 0.9 * normalized; (id, scaled)
})
.collect()
};
Ok(results)
}
fn get_memories_by_ids(conn: &Connection, ids: &[i64]) -> Result<Vec<Memory>, MemoryError> {
use chrono::{DateTime, Utc};
if ids.is_empty() {
return Ok(vec![]);
}
let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
let query = format!(
"SELECT id, datetime, text, source FROM memories WHERE id IN ({})",
placeholders
);
let mut stmt = conn.prepare(&query)?;
let params: Vec<&dyn rusqlite::ToSql> =
ids.iter().map(|id| id as &dyn rusqlite::ToSql).collect();
let rows = stmt.query_map(params.as_slice(), |row| {
let datetime_str: String = row.get(1)?;
let datetime = DateTime::parse_from_rfc3339(&datetime_str)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
Ok(Memory {
id: row.get(0)?,
datetime,
text: row.get(2)?,
source: row.get(3)?,
})
})?;
let memories: Result<Vec<_>, _> = rows.collect();
Ok(memories?)
}
pub(crate) fn surface_candidates(
conn: &Connection,
query: &str,
params: &SearchParams,
current_max_mem: i64,
) -> Result<SearchResult, MemoryError> {
let limit = params.limit;
let seeds = fts_search(conn, query, limit)?;
let seed_count = seeds.len();
let mut energy_map: HashMap<i64, f64> = HashMap::new();
let mut heap: BinaryHeap<ActivationItem> = BinaryHeap::new();
let mut propagated_energy: HashMap<i64, f64> = HashMap::new();
for (seed_id, bm25_score) in &seeds {
heap.push(ActivationItem {
energy: *bm25_score,
mem_id: *seed_id,
});
}
let mut iterations = 0;
while let Some(item) = heap.pop() {
iterations += 1;
if iterations > MAX_SPREADING_ITERATIONS {
break;
}
let total = energy_map.entry(item.mem_id).or_insert(0.0);
*total += item.energy;
let already_propagated = propagated_energy.get(&item.mem_id).copied().unwrap_or(0.0);
let to_propagate = *total - already_propagated;
if to_propagate > ENERGY_THRESHOLD {
propagated_energy.insert(item.mem_id, *total);
let neighbors =
get_relationships_for_memory(conn, item.mem_id, current_max_mem, params)?;
for (neighbor_id, effective_strength) in neighbors {
let new_energy = to_propagate * params.energy_decay * effective_strength;
if new_energy > ENERGY_THRESHOLD {
heap.push(ActivationItem {
energy: new_energy,
mem_id: neighbor_id,
});
}
}
}
}
let mut activated: Vec<(i64, f64)> = energy_map.into_iter().collect();
activated.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
activated.truncate(limit);
let total_activated = activated.len();
let activated_ids: HashSet<i64> = activated.iter().map(|(id, _)| *id).collect();
let mut context_ids: HashSet<i64> = HashSet::new();
if params.context > 0 {
for &(id, _) in &activated {
let start = (id - params.context as i64).max(1);
let end = id + params.context as i64;
for ctx_id in start..=end {
if ctx_id != id && !activated_ids.contains(&ctx_id) {
context_ids.insert(ctx_id);
}
}
}
}
let mut all_ids: Vec<i64> = activated.iter().map(|(id, _)| *id).collect();
all_ids.extend(context_ids.iter());
let memories = get_memories_by_ids(conn, &all_ids)?;
let memory_map: HashMap<i64, Memory> = memories.into_iter().map(|m| (m.id, m)).collect();
let energy_map: HashMap<i64, f64> = activated.into_iter().collect();
let mut results: Vec<ActivatedMemory> = Vec::new();
for id in all_ids {
if let Some(mem) = memory_map.get(&id) {
let (energy, is_context) = if let Some(&e) = energy_map.get(&id) {
(e, false)
} else {
(0.0, true)
};
results.push(ActivatedMemory {
memory: mem.clone(),
energy,
is_context,
});
}
}
results.sort_by_key(|m| m.memory.id);
Ok(SearchResult {
query: query.to_string(),
seed_count,
total_activated,
iterations,
memories: results,
})
}