use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
use crate::retrieval::ScoredCandidate;
use crate::types::Memory;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum OrderBy {
#[default]
Chronological,
Relevance,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecallGroupedConfig {
pub limit: usize,
pub order: OrderBy,
pub max_groups: Option<usize>,
}
impl Default for RecallGroupedConfig {
fn default() -> Self {
Self {
limit: 50,
order: OrderBy::Chronological,
max_groups: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoredMemory {
pub memory: Memory,
pub score: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionGroup {
pub session_id: Option<Uuid>,
pub session_time: DateTime<Utc>,
pub memories: Vec<ScoredMemory>,
pub group_score: f32,
}
fn memory_time(memory: &Memory) -> DateTime<Utc> {
match memory {
Memory::Episodic(e) => e.event_time.unwrap_or(e.timestamp),
Memory::Semantic(s) => s.valid_at,
Memory::Procedural(p) => p.created_at,
Memory::Observation(o) => o.event_time.unwrap_or(o.created_at),
}
}
fn memory_episode_id(memory: &Memory) -> Option<Uuid> {
match memory {
Memory::Episodic(e) => Some(e.episode_id),
Memory::Observation(o) => Some(o.episode_id),
Memory::Semantic(_) | Memory::Procedural(_) => None,
}
}
pub fn group_by_session(
candidates: Vec<ScoredCandidate>,
order: OrderBy,
max_groups: Option<usize>,
) -> Vec<SessionGroup> {
if candidates.is_empty() {
return Vec::new();
}
let mut bucket_order: Vec<Uuid> = Vec::new();
let mut buckets: HashMap<Uuid, Vec<ScoredCandidate>> = HashMap::new();
for candidate in candidates {
let key = memory_episode_id(&candidate.memory)
.unwrap_or_else(Uuid::new_v4);
if !buckets.contains_key(&key) {
bucket_order.push(key);
}
buckets.entry(key).or_default().push(candidate);
}
let mut groups: Vec<SessionGroup> = bucket_order
.into_iter()
.map(|key| {
let mut members = buckets.remove(&key).expect("bucket populated above");
members.sort_by_key(|c| memory_time(&c.memory));
let session_id = memory_episode_id(&members[0].memory);
let session_time = members
.iter()
.map(|c| memory_time(&c.memory))
.min()
.expect("non-empty bucket");
let group_score = members
.iter()
.map(|c| c.final_score)
.fold(f32::NEG_INFINITY, f32::max);
let memories: Vec<ScoredMemory> = members
.into_iter()
.map(|c| ScoredMemory {
memory: c.memory,
score: c.final_score,
})
.collect();
SessionGroup {
session_id,
session_time,
memories,
group_score,
}
})
.collect();
match order {
OrderBy::Chronological => {
groups.sort_by_key(|g| g.session_time);
}
OrderBy::Relevance => {
groups.sort_by(|a, b| {
b.group_score
.partial_cmp(&a.group_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
}
if let Some(cap) = max_groups {
groups.truncate(cap);
}
groups
}
pub fn attach_observations_to_groups(
storage: &dyn crate::storage::StorageTrait,
groups: Vec<SessionGroup>,
) -> Vec<SessionGroup> {
let episode_ids: Vec<Uuid> = groups.iter().filter_map(|g| g.session_id).collect();
if episode_ids.is_empty() {
return groups;
}
let observations = match storage.list_observations_by_episode_ids(&episode_ids, 1024) {
Ok(obs) => obs,
Err(e) => {
tracing::warn!(
target: "pensyve::observation",
error = %e,
"failed to load observations for session groups — returning groups unchanged"
);
return groups;
}
};
if observations.is_empty() {
return groups;
}
let mut by_episode: HashMap<Uuid, Vec<crate::types::ObservationMemory>> = HashMap::new();
for obs in observations {
by_episode.entry(obs.episode_id).or_default().push(obs);
}
groups
.into_iter()
.map(|mut g| {
if let Some(sid) = g.session_id
&& let Some(obs_for_group) = by_episode.remove(&sid)
{
for obs in obs_for_group {
g.memories.push(ScoredMemory {
memory: Memory::Observation(obs),
score: g.group_score,
});
}
}
g
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{EpisodicMemory, Outcome, ProceduralMemory, SemanticMemory};
use chrono::TimeZone;
use std::collections::HashMap as StdHashMap;
fn scored(memory: Memory, final_score: f32) -> ScoredCandidate {
ScoredCandidate {
memory_id: memory.id(),
memory,
vector_score: 0.0,
bm25_score: 0.0,
graph_score: 0.0,
intent_score: 0.0,
recency_score: 0.0,
access_score: 0.0,
confidence_score: 0.0,
entity_score: 0.0,
type_boost: 1.0,
final_score,
}
}
fn ep_at(episode_id: Uuid, event_time: DateTime<Utc>, content: &str) -> Memory {
let ns = Uuid::nil();
let mut m = EpisodicMemory::new(ns, episode_id, Uuid::new_v4(), Uuid::new_v4(), content);
m.event_time = Some(event_time);
m.timestamp = Utc.with_ymd_and_hms(2099, 1, 1, 0, 0, 0).unwrap();
Memory::Episodic(m)
}
fn ep_no_event_time(episode_id: Uuid, timestamp: DateTime<Utc>, content: &str) -> Memory {
let ns = Uuid::nil();
let mut m = EpisodicMemory::new(ns, episode_id, Uuid::new_v4(), Uuid::new_v4(), content);
m.event_time = None;
m.timestamp = timestamp;
Memory::Episodic(m)
}
fn sem(subject: Uuid, predicate: &str, object: &str) -> Memory {
Memory::Semantic(SemanticMemory::new(
Uuid::nil(),
subject,
predicate,
object,
0.9,
))
}
fn proc(trigger: &str, action: &str) -> Memory {
Memory::Procedural(ProceduralMemory::new(
Uuid::nil(),
trigger,
action,
Outcome::Success,
StdHashMap::new(),
))
}
fn t(y: i32, mo: u32, d: u32) -> DateTime<Utc> {
Utc.with_ymd_and_hms(y, mo, d, 12, 0, 0).unwrap()
}
#[test]
fn empty_input_yields_empty_output() {
let out = group_by_session(Vec::new(), OrderBy::Chronological, None);
assert!(out.is_empty());
}
#[test]
fn single_episode_collapses_to_one_group_sorted_by_event_time() {
let ep = Uuid::new_v4();
let candidates = vec![
scored(ep_at(ep, t(2026, 1, 3), "third"), 0.5),
scored(ep_at(ep, t(2026, 1, 1), "first"), 0.9),
scored(ep_at(ep, t(2026, 1, 2), "second"), 0.7),
];
let groups = group_by_session(candidates, OrderBy::Chronological, None);
assert_eq!(groups.len(), 1);
let g = &groups[0];
assert_eq!(g.session_id, Some(ep));
assert_eq!(g.session_time, t(2026, 1, 1));
assert!((g.group_score - 0.9).abs() < f32::EPSILON);
let contents: Vec<_> = g
.memories
.iter()
.map(|m| memory_content(&m.memory))
.collect();
assert_eq!(contents, vec!["first", "second", "third"]);
}
#[test]
fn per_member_scores_survive_grouping() {
let ep = Uuid::new_v4();
let candidates = vec![
scored(ep_at(ep, t(2026, 1, 1), "first"), 0.92),
scored(ep_at(ep, t(2026, 1, 2), "second"), 0.11),
scored(ep_at(ep, t(2026, 1, 3), "third"), 0.45),
];
let groups = group_by_session(candidates, OrderBy::Chronological, None);
assert_eq!(groups.len(), 1);
let g = &groups[0];
assert_eq!(g.memories.len(), 3);
assert!((g.group_score - 0.92).abs() < f32::EPSILON);
assert!((g.memories[0].score - 0.92).abs() < f32::EPSILON);
assert!((g.memories[1].score - 0.11).abs() < f32::EPSILON);
assert!((g.memories[2].score - 0.45).abs() < f32::EPSILON);
let contents: Vec<_> = g
.memories
.iter()
.map(|m| memory_content(&m.memory))
.collect();
assert_eq!(contents, vec!["first", "second", "third"]);
}
#[test]
fn chronological_ordering_sorts_groups_by_earliest_event_time() {
let a = Uuid::new_v4();
let b = Uuid::new_v4();
let c = Uuid::new_v4();
let candidates = vec![
scored(ep_at(c, t(2026, 3, 1), "c1"), 0.5),
scored(ep_at(a, t(2026, 1, 1), "a1"), 0.5),
scored(ep_at(b, t(2026, 2, 1), "b1"), 0.5),
];
let groups = group_by_session(candidates, OrderBy::Chronological, None);
assert_eq!(groups.len(), 3);
assert_eq!(groups[0].session_id, Some(a));
assert_eq!(groups[1].session_id, Some(b));
assert_eq!(groups[2].session_id, Some(c));
}
#[test]
fn relevance_ordering_sorts_groups_by_max_score_descending() {
let a = Uuid::new_v4();
let b = Uuid::new_v4();
let c = Uuid::new_v4();
let candidates = vec![
scored(ep_at(a, t(2026, 1, 1), "a"), 0.2),
scored(ep_at(b, t(2026, 2, 1), "b"), 0.9), scored(ep_at(c, t(2026, 3, 1), "c"), 0.5),
];
let groups = group_by_session(candidates, OrderBy::Relevance, None);
assert_eq!(groups.len(), 3);
assert_eq!(groups[0].session_id, Some(b));
assert_eq!(groups[1].session_id, Some(c));
assert_eq!(groups[2].session_id, Some(a));
}
#[test]
fn group_score_is_max_across_members() {
let ep = Uuid::new_v4();
let candidates = vec![
scored(ep_at(ep, t(2026, 1, 1), "low"), 0.2),
scored(ep_at(ep, t(2026, 1, 2), "high"), 0.8),
scored(ep_at(ep, t(2026, 1, 3), "mid"), 0.5),
];
let groups = group_by_session(candidates, OrderBy::Chronological, None);
assert_eq!(groups.len(), 1);
assert!((groups[0].group_score - 0.8).abs() < f32::EPSILON);
}
#[test]
fn semantic_memories_become_singleton_groups_with_no_session() {
let subj = Uuid::new_v4();
let candidates = vec![
scored(sem(subj, "knows", "Rust"), 0.9),
scored(sem(subj, "likes", "Python"), 0.8),
];
let groups = group_by_session(candidates, OrderBy::Relevance, None);
assert_eq!(groups.len(), 2);
for g in &groups {
assert_eq!(g.session_id, None);
assert_eq!(g.memories.len(), 1);
}
assert!((groups[0].group_score - 0.9).abs() < f32::EPSILON);
}
#[test]
fn procedural_memories_become_singleton_groups() {
let candidates = vec![scored(proc("on_error", "retry"), 0.5)];
let groups = group_by_session(candidates, OrderBy::Chronological, None);
assert_eq!(groups.len(), 1);
assert_eq!(groups[0].session_id, None);
assert_eq!(groups[0].memories.len(), 1);
}
#[test]
fn mixed_episodic_and_semantic_clusters_episodes_and_keeps_semantics_singleton() {
let ep = Uuid::new_v4();
let subj = Uuid::new_v4();
let candidates = vec![
scored(ep_at(ep, t(2026, 1, 1), "a"), 0.5),
scored(sem(subj, "is", "cool"), 0.3),
scored(ep_at(ep, t(2026, 1, 2), "b"), 0.6),
];
let groups = group_by_session(candidates, OrderBy::Relevance, None);
assert_eq!(groups.len(), 2);
assert_eq!(groups[0].session_id, Some(ep));
assert_eq!(groups[0].memories.len(), 2);
assert_eq!(groups[1].session_id, None);
assert_eq!(groups[1].memories.len(), 1);
}
#[test]
fn max_groups_caps_result_preserving_order() {
let eps: Vec<Uuid> = (0..5).map(|_| Uuid::new_v4()).collect();
let candidates: Vec<_> = eps
.iter()
.enumerate()
.map(|(i, ep)| scored(ep_at(*ep, t(2026, 1, (i + 1) as u32), "x"), 0.1 * i as f32))
.collect();
let groups = group_by_session(candidates, OrderBy::Chronological, Some(3));
assert_eq!(groups.len(), 3);
assert_eq!(groups[0].session_id, Some(eps[0]));
assert_eq!(groups[1].session_id, Some(eps[1]));
assert_eq!(groups[2].session_id, Some(eps[2]));
}
#[test]
fn max_groups_zero_returns_empty() {
let ep = Uuid::new_v4();
let candidates = vec![scored(ep_at(ep, t(2026, 1, 1), "a"), 0.5)];
let groups = group_by_session(candidates, OrderBy::Chronological, Some(0));
assert!(groups.is_empty());
}
#[test]
fn null_event_time_falls_back_to_encoding_timestamp() {
let ep = Uuid::new_v4();
let ts = t(2025, 6, 15);
let candidates = vec![scored(ep_no_event_time(ep, ts, "a"), 0.5)];
let groups = group_by_session(candidates, OrderBy::Chronological, None);
assert_eq!(groups.len(), 1);
assert_eq!(groups[0].session_time, ts);
}
#[test]
fn default_config_matches_spec() {
let cfg = RecallGroupedConfig::default();
assert_eq!(cfg.limit, 50);
assert_eq!(cfg.order, OrderBy::Chronological);
assert_eq!(cfg.max_groups, None);
}
fn memory_content(memory: &Memory) -> String {
match memory {
Memory::Episodic(e) => e.content.clone(),
Memory::Semantic(s) => format!("{} {}", s.predicate, s.object),
Memory::Procedural(p) => format!("{}:{}", p.trigger, p.action),
Memory::Observation(o) => o.content.clone(),
}
}
use crate::storage::StorageTrait;
use crate::storage::sqlite::SqliteBackend;
use crate::types::{Namespace, ObservationMemory};
use tempfile::TempDir;
fn setup_storage_with_namespace() -> (TempDir, SqliteBackend, Namespace) {
let dir = TempDir::new().unwrap();
let db = SqliteBackend::open(dir.path()).unwrap();
let ns = Namespace::new("test-attach");
db.save_namespace(&ns).unwrap();
(dir, db, ns)
}
fn save_obs(db: &SqliteBackend, ns: Uuid, episode_id: Uuid, instance: &str) -> Uuid {
let obs = ObservationMemory::new(
ns,
episode_id,
"game_played",
instance,
"played",
format!("played {instance}"),
);
let id = obs.id;
db.save_observation(&obs).unwrap();
id
}
#[test]
fn attach_appends_observations_after_episodic_memories() {
let (_dir, db, ns) = setup_storage_with_namespace();
let ep = Uuid::new_v4();
let obs_id = save_obs(&db, ns.id, ep, "AC Odyssey");
let candidates = vec![
scored(ep_at(ep, t(2026, 1, 1), "turn 1"), 0.9),
scored(ep_at(ep, t(2026, 1, 2), "turn 2"), 0.8),
];
let groups = group_by_session(candidates, OrderBy::Chronological, None);
let attached = attach_observations_to_groups(&db, groups);
assert_eq!(attached.len(), 1);
let g = &attached[0];
assert_eq!(g.memories.len(), 3);
match &g.memories[0].memory {
Memory::Episodic(e) => assert_eq!(e.content, "turn 1"),
_ => panic!("expected episodic first"),
}
match &g.memories[1].memory {
Memory::Episodic(e) => assert_eq!(e.content, "turn 2"),
_ => panic!("expected episodic second"),
}
match &g.memories[2].memory {
Memory::Observation(o) => {
assert_eq!(o.id, obs_id);
assert_eq!(o.instance, "AC Odyssey");
}
_ => panic!("expected observation last"),
}
assert!((g.memories[2].score - g.group_score).abs() < f32::EPSILON);
}
#[test]
fn attach_scopes_observations_to_their_own_episode() {
let (_dir, db, ns) = setup_storage_with_namespace();
let ep_a = Uuid::new_v4();
let ep_b = Uuid::new_v4();
save_obs(&db, ns.id, ep_a, "game A");
save_obs(&db, ns.id, ep_b, "game B");
let candidates = vec![
scored(ep_at(ep_a, t(2026, 1, 1), "a1"), 0.5),
scored(ep_at(ep_b, t(2026, 1, 2), "b1"), 0.5),
];
let groups = group_by_session(candidates, OrderBy::Chronological, None);
let attached = attach_observations_to_groups(&db, groups);
assert_eq!(attached.len(), 2);
for g in &attached {
let obs_count = g
.memories
.iter()
.filter(|m| matches!(m.memory, Memory::Observation(_)))
.count();
assert_eq!(obs_count, 1, "each group gets exactly its own obs");
let matching_instance = g.memories.iter().find_map(|m| match &m.memory {
Memory::Observation(o) => Some(o.instance.clone()),
_ => None,
});
let expected = if g.session_id == Some(ep_a) {
"game A"
} else {
"game B"
};
assert_eq!(matching_instance.as_deref(), Some(expected));
}
}
#[test]
fn attach_leaks_no_observations_from_non_topk_episodes() {
let (_dir, db, ns) = setup_storage_with_namespace();
let topk_ep = Uuid::new_v4();
let unseen_ep = Uuid::new_v4();
save_obs(&db, ns.id, topk_ep, "visible");
save_obs(&db, ns.id, unseen_ep, "LEAKED");
let candidates = vec![scored(ep_at(topk_ep, t(2026, 1, 1), "x"), 0.5)];
let groups = group_by_session(candidates, OrderBy::Chronological, None);
let attached = attach_observations_to_groups(&db, groups);
assert_eq!(attached.len(), 1);
let leaked = attached[0].memories.iter().any(|m| match &m.memory {
Memory::Observation(o) => o.instance == "LEAKED",
_ => false,
});
assert!(
!leaked,
"observations from non-top-k episodes leaked through"
);
}
#[test]
fn attach_is_noop_when_no_observations_stored() {
let (_dir, db, ns) = setup_storage_with_namespace();
let _ = ns;
let ep = Uuid::new_v4();
let candidates = vec![scored(ep_at(ep, t(2026, 1, 1), "x"), 0.5)];
let groups = group_by_session(candidates, OrderBy::Chronological, None);
let attached = attach_observations_to_groups(&db, groups);
assert_eq!(attached.len(), 1);
assert_eq!(attached[0].memories.len(), 1);
}
#[test]
fn attach_skips_singleton_semantic_groups() {
let (_dir, db, ns) = setup_storage_with_namespace();
let _ = ns;
let subj = Uuid::new_v4();
let candidates = vec![scored(sem(subj, "knows", "Rust"), 0.9)];
let groups = group_by_session(candidates, OrderBy::Chronological, None);
let attached = attach_observations_to_groups(&db, groups);
assert_eq!(attached.len(), 1);
assert_eq!(attached[0].session_id, None);
assert_eq!(attached[0].memories.len(), 1);
}
}