use std::collections::HashSet;
use crate::config::AssociationConfig;
use crate::embeddings::EmbeddingProvider;
use crate::storage::Storage;
pub struct CandidateSelector<'a> {
storage: &'a Storage,
}
impl<'a> CandidateSelector<'a> {
pub fn new(storage: &'a Storage) -> Self {
Self { storage }
}
pub fn select_candidates(
&self,
new_memory_id: &str,
new_memory_created_at: f64,
entities: &[String],
embedding: Option<&[f32]>,
config: &AssociationConfig,
) -> Result<Vec<String>, rusqlite::Error> {
let mut candidate_ids: Vec<String> = Vec::new();
let window_secs = config.temporal_window_days as f64 * 86400.0;
let since = new_memory_created_at - window_secs;
let temporal_ids = self.storage.get_memory_ids_since(since, "default")?;
candidate_ids.extend(temporal_ids);
if !entities.is_empty() {
let entity_query = entities.join(" OR ");
let fts_results = self.storage.search_fts_ns(&entity_query, 20, Some("default"))?;
for record in fts_results {
candidate_ids.push(record.id);
}
}
if let Some(emb) = embedding {
if let Ok(all_embeddings) = self.storage.get_embeddings_in_namespace(Some("default"), "*") {
let mut scored: Vec<(String, f64)> = all_embeddings
.iter()
.map(|(id, stored_emb)| {
let sim = EmbeddingProvider::cosine_similarity(emb, stored_emb) as f64;
(id.clone(), sim)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(20);
for (id, _) in scored {
candidate_ids.push(id);
}
}
}
let mut seen = HashSet::new();
candidate_ids.retain(|id| {
if id == new_memory_id {
return false; }
seen.insert(id.clone())
});
candidate_ids.truncate(config.candidate_limit);
Ok(candidate_ids)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::AssociationConfig;
use crate::types::{MemoryLayer, MemoryRecord, MemoryType};
use chrono::Utc;
fn test_storage() -> Storage {
Storage::new(":memory:").expect("in-memory storage")
}
fn make_record(id: &str, content: &str, created_at: chrono::DateTime<Utc>) -> MemoryRecord {
MemoryRecord {
id: id.to_string(),
content: content.to_string(),
memory_type: MemoryType::Factual,
layer: MemoryLayer::Working,
created_at,
access_times: vec![created_at],
working_strength: 1.0,
core_strength: 0.0,
importance: 0.5,
pinned: false,
consolidation_count: 0,
last_consolidated: None,
source: String::new(),
contradicts: None,
contradicted_by: None,
superseded_by: None,
metadata: None,
}
}
#[test]
fn test_candidate_selection_temporal() {
let mut storage = test_storage();
let now = Utc::now();
let now_ts = now.timestamp() as f64;
let recent = make_record("recent1", "recent memory about cats", now);
storage.add(&recent, "default").unwrap();
let two_days_ago = now - chrono::Duration::days(2);
let older = make_record("older1", "older memory about dogs", two_days_ago);
storage.add(&older, "default").unwrap();
let thirty_days_ago = now - chrono::Duration::days(30);
let ancient = make_record("ancient1", "ancient memory about fish", thirty_days_ago);
storage.add(&ancient, "default").unwrap();
let selector = CandidateSelector::new(&storage);
let config = AssociationConfig::default();
let candidates = selector
.select_candidates("new_mem", now_ts, &[], None, &config)
.unwrap();
assert!(candidates.contains(&"recent1".to_string()));
assert!(candidates.contains(&"older1".to_string()));
assert!(!candidates.contains(&"ancient1".to_string()));
}
#[test]
fn test_candidate_selection_excludes_self() {
let mut storage = test_storage();
let now = Utc::now();
let now_ts = now.timestamp() as f64;
let mem = make_record("self_mem", "test memory content", now);
storage.add(&mem, "default").unwrap();
let selector = CandidateSelector::new(&storage);
let config = AssociationConfig::default();
let candidates = selector
.select_candidates("self_mem", now_ts, &[], None, &config)
.unwrap();
assert!(!candidates.contains(&"self_mem".to_string()));
}
}