use std::collections::HashMap;
use crate::types::Memory;
#[derive(Debug, Clone)]
pub struct MemoryGroup {
pub topic: String,
pub summary: String,
pub count: usize,
pub memory_ids: Vec<i64>,
}
pub struct ContextGrouper {
pub max_group_size: usize,
pub min_similarity: f32,
}
impl Default for ContextGrouper {
fn default() -> Self {
Self::new()
}
}
impl ContextGrouper {
pub fn new() -> Self {
Self {
max_group_size: 5,
min_similarity: 0.7,
}
}
pub fn group_for_context(&self, memories: &[Memory]) -> Vec<MemoryGroup> {
let mut groups: HashMap<String, Vec<Memory>> = HashMap::new();
for memory in memories {
let topic = self.extract_topic(&memory.content);
groups.entry(topic).or_default().push(memory.clone());
}
let mut result: Vec<MemoryGroup> = groups
.into_iter()
.map(|(topic, mems)| {
let summary = self.summarize_group(&mems);
MemoryGroup {
topic,
summary,
count: mems.len(),
memory_ids: mems.iter().map(|m| m.id).collect(),
}
})
.collect();
result.sort_by_key(|g| std::cmp::Reverse(g.count));
result
}
fn extract_topic(&self, content: &str) -> String {
const STOPWORDS: &[&str] = &[
"the", "and", "for", "with", "that", "this", "user", "from", "into", "have", "has",
"was", "were", "are", "but", "not", "you", "your", "our", "their", "his", "her", "its",
"about",
];
let cleaned: String = content
.chars()
.map(|c| if c.is_alphanumeric() { c } else { ' ' })
.collect();
cleaned
.split_whitespace()
.map(|w| w.to_lowercase())
.find(|w| w.len() > 3 && !STOPWORDS.contains(&w.as_str()))
.unwrap_or_else(|| "general".to_string())
}
fn summarize_group(&self, memories: &[Memory]) -> String {
if memories.len() > self.max_group_size {
format!(
"Summary of {} memories about related topics",
memories.len()
)
} else {
let combined = memories
.iter()
.map(|m| m.content.as_str())
.collect::<Vec<&str>>()
.join(" ");
if combined.len() > 500 {
format!("{}... (truncated)", &combined[..500])
} else {
combined
}
}
}
pub fn find_similar_by_topic(&self, topic: &str, memories: &[Memory]) -> Vec<Memory> {
memories
.iter()
.filter(|m| self.extract_topic(&m.content) == topic)
.cloned()
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::MemoryType;
use chrono::Utc;
fn create_test_memory(id: i64, content: &str) -> Memory {
Memory {
id,
content: content.to_string(),
memory_type: MemoryType::Note,
tags: vec!["test".to_string()],
metadata: HashMap::new(),
importance: 0.5,
access_count: 0,
created_at: Utc::now(),
updated_at: Utc::now(),
last_accessed_at: None,
owner_id: None,
visibility: Default::default(),
scope: Default::default(),
workspace: "default".to_string(),
tier: Default::default(),
version: 1,
has_embedding: false,
expires_at: None,
content_hash: None,
event_time: None,
event_duration_seconds: None,
trigger_pattern: None,
procedure_success_count: 0,
procedure_failure_count: 0,
summary_of_id: None,
lifecycle_state: Default::default(),
media_url: None,
}
}
#[test]
fn test_group_for_context() {
let grouper = ContextGrouper::new();
let memories = vec![
create_test_memory(1, "User prefers dark mode in UI settings"),
create_test_memory(2, "User likes coffee with oat milk"),
create_test_memory(3, "UI color scheme set to dark"),
create_test_memory(4, "Coffee preference: oat milk only"),
create_test_memory(5, "User works on Engram project"),
];
let groups = grouper.group_for_context(&memories);
assert!(!groups.is_empty());
}
#[test]
fn test_extract_topic() {
let grouper = ContextGrouper::new();
let topic = grouper.extract_topic("User prefers dark mode in UI settings");
assert_eq!(topic, "prefers");
assert_eq!(
grouper.extract_topic("Rust is a systems programming language"),
grouper.extract_topic("Rust has memory safety"),
);
assert_eq!(grouper.extract_topic("a b c"), "general");
}
#[test]
fn test_find_similar_by_topic() {
let grouper = ContextGrouper::new();
let memories = vec![
create_test_memory(1, "Rust is a systems programming language"),
create_test_memory(2, "Python is great for AI"),
create_test_memory(3, "Rust has memory safety"),
];
let rust_memories = grouper.find_similar_by_topic("rust", &memories);
assert_eq!(rust_memories.len(), 2);
assert_eq!(rust_memories[0].id, 1);
assert_eq!(rust_memories[1].id, 3);
}
#[test]
fn groups_share_keys_for_similar_content() {
let grouper = ContextGrouper::new();
let memories = vec![
create_test_memory(1, "Rust ownership and borrowing"),
create_test_memory(2, "Rust traits and generics"),
create_test_memory(3, "Python decorators are cool"),
];
let groups = grouper.group_for_context(&memories);
assert_eq!(groups.len(), 2);
assert_eq!(groups[0].count, 2); assert_eq!(groups[0].topic, "rust");
}
}