use std::collections::HashSet;
use anyhow::Result;
use super::{MemoryEntry, MemoryManager};
#[derive(Debug, Clone, Default)]
pub struct RecallTiming {
pub last_recall_topic: Option<String>,
pub message_count_since_recall: usize,
}
impl RecallTiming {
pub fn new() -> Self {
Self {
last_recall_topic: None,
message_count_since_recall: 0,
}
}
pub fn should_recall(&mut self, query: &str) -> bool {
let topic_changed = self
.last_recall_topic
.as_ref()
.is_none_or(|prev| !topics_similar(prev, query));
let should = self.message_count_since_recall == 0 || (topic_changed && self.message_count_since_recall >= 3) || self.message_count_since_recall >= 10;
if should {
self.last_recall_topic = Some(query.to_string());
self.message_count_since_recall = 0;
} else {
self.message_count_since_recall += 1;
}
should
}
}
fn topics_similar(a: &str, b: &str) -> bool {
let a_words: HashSet<String> = a
.to_lowercase()
.split_whitespace()
.filter(|w| w.len() > 3)
.map(|w| w.to_string())
.collect();
let b_words: HashSet<String> = b
.to_lowercase()
.split_whitespace()
.filter(|w| w.len() > 3)
.map(|w| w.to_string())
.collect();
if a_words.is_empty() || b_words.is_empty() {
return false;
}
let overlap = a_words.intersection(&b_words).count();
let union = a_words.union(&b_words).count();
overlap as f32 / union as f32 > 0.3
}
pub struct ProactiveRecall {
pub limit: usize,
pub threshold: f32,
}
impl ProactiveRecall {
pub fn new(limit: usize, threshold: f32) -> Self {
Self { limit, threshold }
}
pub async fn recall(
&self,
mgr: &MemoryManager,
query: &str,
current_context: &[MemoryEntry],
) -> Result<Vec<MemoryEntry>> {
let mut results = Vec::new();
let mut seen_ids: HashSet<String> = current_context.iter().map(|e| e.id.clone()).collect();
if let Ok(hot_entries) = mgr
.list_by_tier(crate::memory::MemoryTier::Hot, self.limit)
.await
{
for entry in hot_entries {
if !seen_ids.contains(&entry.id) {
seen_ids.insert(entry.id.clone());
results.push(entry);
}
}
}
if results.len() < self.limit {
let remaining = self.limit - results.len();
let search_results = mgr
.search(query, None, remaining * 2)
.await
.unwrap_or_default();
for entry in search_results {
if !seen_ids.contains(&entry.id) {
seen_ids.insert(entry.id.clone());
results.push(entry);
}
if results.len() >= self.limit {
break;
}
}
}
results.retain(|e| {
crate::memory::decay::DecayEngine::effective_importance(e) >= self.threshold
});
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_recall_timing_first_message() {
let mut timing = RecallTiming::new();
assert!(timing.should_recall("hello"));
}
#[test]
fn test_recall_timing_topic_change() {
let mut timing = RecallTiming::new();
timing.should_recall("rust programming");
timing.message_count_since_recall = 5;
assert!(timing.should_recall("python deployment"));
}
#[test]
fn test_recall_timing_same_topic() {
let mut timing = RecallTiming::new();
timing.should_recall("rust async runtime");
timing.message_count_since_recall = 1;
assert!(!timing.should_recall("rust async tokio"));
}
#[test]
fn test_recall_timing_periodic() {
let mut timing = RecallTiming::new();
timing.should_recall("rust");
timing.message_count_since_recall = 10;
assert!(timing.should_recall("rust continued"));
}
#[test]
fn test_topics_similar_same() {
assert!(topics_similar("rust async runtime", "rust async runtime"));
}
#[test]
fn test_topics_similar_overlap() {
assert!(topics_similar(
"rust async runtime tokio",
"rust async runtime futures"
));
}
#[test]
fn test_topics_different() {
assert!(!topics_similar("rust async runtime", "python data science"));
}
}