use roboticus_core::config::MemoryConfig;
use roboticus_db::Database;
use serde::Serialize;
use std::collections::HashSet;
use crate::context::{ComplexityLevel, token_budget};
use crate::memory::MemoryBudgetManager;
#[derive(Debug, Clone, Default, Serialize)]
pub struct RetrievalMetrics {
pub retrieval_count: usize,
pub retrieval_hit: bool,
pub avg_similarity: f64,
pub budget_utilization: f64,
pub tiers: MemoryTierBreakdown,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct MemoryTierBreakdown {
pub working: usize,
pub episodic: usize,
pub semantic: usize,
pub procedural: usize,
pub relationship: usize,
}
pub struct RetrievalOutput {
pub text: String,
pub metrics: RetrievalMetrics,
}
pub struct MemoryRetriever {
budget_manager: MemoryBudgetManager,
hybrid_weight: f64,
similarity_threshold: f64,
decay_half_life_days: f64,
}
impl MemoryRetriever {
pub fn new(config: MemoryConfig) -> Self {
let hybrid_weight = config.hybrid_weight;
let similarity_threshold = config.similarity_threshold;
let decay_half_life_days = config.decay_half_life_days;
Self {
budget_manager: MemoryBudgetManager::new(config),
hybrid_weight,
similarity_threshold,
decay_half_life_days,
}
}
pub fn with_decay_half_life(mut self, days: f64) -> Self {
self.decay_half_life_days = days;
self
}
pub fn retrieve(
&self,
db: &Database,
session_id: &str,
query: &str,
query_embedding: Option<&[f32]>,
complexity: ComplexityLevel,
) -> String {
self.retrieve_with_ann(db, session_id, query, query_embedding, complexity, None)
}
pub fn retrieve_with_ann(
&self,
db: &Database,
session_id: &str,
query: &str,
query_embedding: Option<&[f32]>,
complexity: ComplexityLevel,
ann_index: Option<&roboticus_db::ann::AnnIndex>,
) -> String {
self.retrieve_with_metrics(
db,
session_id,
query,
query_embedding,
complexity,
ann_index,
)
.text
}
pub fn retrieve_with_metrics(
&self,
db: &Database,
session_id: &str,
query: &str,
query_embedding: Option<&[f32]>,
complexity: ComplexityLevel,
ann_index: Option<&roboticus_db::ann::AnnIndex>,
) -> RetrievalOutput {
let total_budget = token_budget(complexity);
let budgets = self.budget_manager.allocate_budgets(total_budget);
let mut sections = Vec::new();
let mut tiers = MemoryTierBreakdown::default();
let working_count = if let Some(s) = self.retrieve_working(db, session_id, budgets.working)
{
let count = s.lines().filter(|l| l.starts_with("- ")).count();
sections.push(s);
count
} else {
0
};
tiers.working = working_count;
let ambient_count = if let Some(s) = self.retrieve_recent_ambient(db, budgets.episodic / 3)
{
let count = s.lines().filter(|l| l.starts_with("- ")).count();
sections.push(s);
count
} else {
0
};
tiers.episodic += ambient_count;
let relevant = if let (Some(ann), Some(emb)) = (ann_index, query_embedding) {
ann.search(emb, 10).map(|results| {
results
.into_iter()
.map(|r| roboticus_db::embeddings::SearchResult {
source_table: r.source_table,
source_id: r.source_id,
content_preview: r.content_preview,
similarity: r.similarity,
})
.collect::<Vec<_>>()
})
} else {
None
};
let mut relevant = relevant.unwrap_or_else(|| {
roboticus_db::embeddings::hybrid_search(
db,
query,
query_embedding,
10,
self.hybrid_weight,
)
.unwrap_or_default()
});
if self.similarity_threshold > 0.0 {
relevant.retain(|r| r.similarity >= self.similarity_threshold);
}
if !query_requests_inactive_memories(query) {
self.filter_inactive_memories(db, &mut relevant);
}
if self.decay_half_life_days > 0.0 {
self.rerank_episodic_by_decay(db, &mut relevant);
}
let avg_similarity = if relevant.is_empty() {
0.0
} else {
let sum: f64 = relevant.iter().map(|r| r.similarity).sum();
sum / relevant.len() as f64
};
for r in &relevant {
match r.source_table.as_str() {
"episodic_memory" => tiers.episodic += 1,
"semantic_memory" => tiers.semantic += 1,
_ => {} }
}
if let Some(s) = self.format_relevant(&relevant, budgets.episodic + budgets.semantic) {
sections.push(s);
}
let procedural_count = if let Some(s) = self.retrieve_procedural(db, budgets.procedural) {
let count = s.lines().filter(|l| l.starts_with("- ")).count();
sections.push(s);
count
} else {
0
};
tiers.procedural = procedural_count;
let relationship_count =
if let Some(s) = self.retrieve_relationships(db, query, budgets.relationship) {
let count = s.lines().filter(|l| l.starts_with("- ")).count();
sections.push(s);
count
} else {
0
};
tiers.relationship = relationship_count;
let index_entries = roboticus_db::memory_index::top_entries(db, 20).unwrap_or_default();
let index_text = roboticus_db::memory_index::format_index_for_injection(&index_entries);
let direct_sections: Vec<&String> = sections
.iter()
.filter(|s| s.starts_with("[Working Memory]") || s.starts_with("[Recent Activity]"))
.collect();
let text = if direct_sections.is_empty() && index_text.is_empty() {
String::new()
} else {
let mut block = String::new();
for section in &direct_sections {
block.push_str(section);
block.push_str("\n\n");
}
if !index_text.is_empty() {
block.push_str(&index_text);
}
block.trim_end().to_string()
};
let memory_tokens = estimate_tokens(&text);
let retrieval_count =
tiers.working + tiers.episodic + tiers.semantic + tiers.procedural + tiers.relationship;
let metrics = RetrievalMetrics {
retrieval_count,
retrieval_hit: retrieval_count > 0,
avg_similarity,
budget_utilization: if total_budget > 0 {
memory_tokens as f64 / total_budget as f64
} else {
0.0
},
tiers,
};
RetrievalOutput { text, metrics }
}
fn retrieve_working(
&self,
db: &Database,
session_id: &str,
budget_tokens: usize,
) -> Option<String> {
if budget_tokens == 0 {
return None;
}
let entries = roboticus_db::memory::retrieve_working(db, session_id)
.inspect_err(
|e| tracing::warn!(error = %e, session_id, "working memory retrieval failed"),
)
.ok()?;
if entries.is_empty() {
return None;
}
let mut text = String::from("[Working Memory]\n");
let mut used = estimate_tokens(&text);
for entry in &entries {
if entry.entry_type.eq_ignore_ascii_case("turn_summary") {
continue;
}
let line = format!("- [{}] {}\n", entry.entry_type, entry.content);
let line_tokens = estimate_tokens(&line);
if used + line_tokens > budget_tokens {
break;
}
text.push_str(&line);
used += line_tokens;
}
if text.len() > "[Working Memory]\n".len() {
Some(text)
} else {
None
}
}
fn retrieve_recent_ambient(&self, db: &Database, budget_tokens: usize) -> Option<String> {
if budget_tokens == 0 {
return None;
}
let entries = roboticus_db::memory::retrieve_recent_episodic(db, 2, 10)
.inspect_err(|e| tracing::warn!(error = %e, "recent ambient memory retrieval failed"))
.ok()?;
if entries.is_empty() {
return None;
}
let mut text = String::from("[Recent Activity]\n");
let mut used = estimate_tokens(&text);
for entry in &entries {
let time_label = entry.created_at.get(11..16).unwrap_or("??:??");
let classification = if entry.classification.is_empty() {
"note"
} else {
&entry.classification
};
let line = format!(
"- [{}] ({}) {}\n",
time_label, classification, entry.content,
);
let line_tokens = estimate_tokens(&line);
if used + line_tokens > budget_tokens {
break;
}
text.push_str(&line);
used += line_tokens;
}
if text.len() > "[Recent Activity]\n".len() {
Some(text)
} else {
None
}
}
fn format_relevant(
&self,
results: &[roboticus_db::embeddings::SearchResult],
budget_tokens: usize,
) -> Option<String> {
if budget_tokens == 0 || results.is_empty() {
return None;
}
let mut text = String::from("[Relevant Memories]\n");
let mut used = estimate_tokens(&text);
for result in results {
let line = format!(
"- [{} | sim={:.2}] {}\n",
result.source_table, result.similarity, result.content_preview,
);
let line_tokens = estimate_tokens(&line);
if used + line_tokens > budget_tokens {
break;
}
text.push_str(&line);
used += line_tokens;
}
if text.len() > "[Relevant Memories]\n".len() {
Some(text)
} else {
None
}
}
fn rerank_episodic_by_decay(
&self,
db: &Database,
results: &mut [roboticus_db::embeddings::SearchResult],
) {
let now = chrono::Utc::now();
let episodic_ids: Vec<&str> = results
.iter()
.filter(|r| r.source_table == "episodic_memory")
.map(|r| r.source_id.as_str())
.collect();
if episodic_ids.is_empty() {
return;
}
let age_map: std::collections::HashMap<String, f64> = {
let conn = db.conn();
let placeholders: Vec<String> =
(1..=episodic_ids.len()).map(|i| format!("?{i}")).collect();
let sql = format!(
"SELECT id, created_at FROM episodic_memory WHERE id IN ({})",
placeholders.join(", ")
);
let mut stmt = match conn.prepare(&sql) {
Ok(s) => s,
Err(_) => return,
};
let rows = match stmt
.query_map(roboticus_db::params_from_iter(episodic_ids.iter()), |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
}) {
Ok(r) => r,
Err(_) => return,
};
rows.filter_map(|r| {
r.inspect_err(|e| tracing::warn!("skipping corrupted episodic row: {e}"))
.ok()
})
.filter_map(|(id, ts)| {
chrono::DateTime::parse_from_rfc3339(&ts)
.ok()
.map(|created| {
let age = (now - created.with_timezone(&chrono::Utc))
.to_std()
.map(|d| d.as_secs_f64() / 86_400.0)
.unwrap_or(0.0);
(id, age)
})
})
.collect()
};
for result in results.iter_mut() {
if result.source_table != "episodic_memory" {
continue;
}
if result.source_id.is_empty() {
result.similarity *= 0.5;
continue;
}
if let Some(&age) = age_map.get(&result.source_id) {
let decay_factor = (0.5_f64).powf(age / self.decay_half_life_days);
let clamped = decay_factor.max(0.05);
result.similarity *= clamped;
}
}
results.sort_by(|a, b| {
b.similarity
.partial_cmp(&a.similarity)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
fn filter_inactive_memories(
&self,
db: &Database,
results: &mut Vec<roboticus_db::embeddings::SearchResult>,
) {
let episodic_ids: Vec<&str> = results
.iter()
.filter(|r| r.source_table == "episodic_memory" && !r.source_id.is_empty())
.map(|r| r.source_id.as_str())
.collect();
let semantic_ids: Vec<&str> = results
.iter()
.filter(|r| r.source_table == "semantic_memory" && !r.source_id.is_empty())
.map(|r| r.source_id.as_str())
.collect();
let episodic_inactive = self.load_inactive_ids(db, "episodic_memory", &episodic_ids);
let semantic_inactive = self.load_inactive_ids(db, "semantic_memory", &semantic_ids);
results.retain(|r| match r.source_table.as_str() {
"episodic_memory" => !episodic_inactive.contains(r.source_id.as_str()),
"semantic_memory" => !semantic_inactive.contains(r.source_id.as_str()),
_ => true,
});
}
fn load_inactive_ids(&self, db: &Database, table: &str, ids: &[&str]) -> HashSet<String> {
if ids.is_empty() {
return HashSet::new();
}
let conn = db.conn();
let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("?{i}")).collect();
let sql = format!(
"SELECT id, memory_state FROM {table} WHERE id IN ({})",
placeholders.join(", ")
);
let mut stmt = match conn.prepare(&sql) {
Ok(stmt) => stmt,
Err(e) => {
tracing::warn!(error = %e, table, "failed to prepare inactive-memory query");
return HashSet::new();
}
};
let rows = match stmt.query_map(roboticus_db::params_from_iter(ids.iter()), |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
}) {
Ok(rows) => rows,
Err(e) => {
tracing::warn!(error = %e, table, "failed to query inactive memories");
return HashSet::new();
}
};
let mut inactive = HashSet::new();
for row in rows {
match row {
Ok((id, state)) if !state.eq_ignore_ascii_case("active") => {
inactive.insert(id);
}
Ok(_) => {}
Err(e) => tracing::warn!(error = %e, table, "skipping invalid memory-state row"),
}
}
inactive
}
fn retrieve_procedural(&self, db: &Database, budget_tokens: usize) -> Option<String> {
if budget_tokens == 0 {
return None;
}
let conn = db.conn();
let mut stmt = conn
.prepare(
"SELECT name, steps, success_count, failure_count FROM procedural_memory \
WHERE success_count > 0 OR failure_count > 0 \
ORDER BY success_count + failure_count DESC LIMIT 5",
)
.ok()?;
let rows: Vec<(String, String, i64, i64)> = stmt
.query_map([], |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, i64>(2)?,
row.get::<_, i64>(3)?,
))
})
.inspect_err(|e| tracing::warn!("failed to query tool experience: {e}"))
.ok()?
.filter_map(|r| {
r.inspect_err(|e| tracing::warn!("skipping corrupted tool experience row: {e}"))
.ok()
})
.collect();
if rows.is_empty() {
return None;
}
let mut text = String::from("[Tool Experience]\n");
let mut used = estimate_tokens(&text);
for (name, _steps, successes, failures) in &rows {
let total = *successes + *failures;
let rate = if total > 0 {
(*successes as f64 / total as f64 * 100.0) as u32
} else {
0
};
let line = format!("- {name}: {successes}/{total} success ({rate}%)\n");
let line_tokens = estimate_tokens(&line);
if used + line_tokens > budget_tokens {
break;
}
text.push_str(&line);
used += line_tokens;
}
if text.len() > "[Tool Experience]\n".len() {
Some(text)
} else {
None
}
}
fn retrieve_relationships(
&self,
db: &Database,
query: &str,
budget_tokens: usize,
) -> Option<String> {
if budget_tokens == 0 {
return None;
}
let conn = db.conn();
let mut stmt = conn
.prepare(
"SELECT entity_id, entity_name, trust_score, interaction_count \
FROM relationship_memory ORDER BY interaction_count DESC LIMIT 5",
)
.ok()?;
let rows: Vec<(String, Option<String>, f64, i64)> = stmt
.query_map([], |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, Option<String>>(1)?,
row.get::<_, f64>(2)?,
row.get::<_, i64>(3)?,
))
})
.inspect_err(|e| tracing::warn!("failed to query relationship memory: {e}"))
.ok()?
.filter_map(|r| {
r.inspect_err(|e| tracing::warn!("skipping corrupted relationship row: {e}"))
.ok()
})
.collect();
if rows.is_empty() {
return None;
}
let query_lower = query.to_lowercase();
let relevant: Vec<_> = rows
.into_iter()
.filter(|(id, name, _, count)| {
*count > 2
|| query_lower.contains(&id.to_lowercase())
|| name
.as_ref()
.is_some_and(|n| query_lower.contains(&n.to_lowercase()))
})
.collect();
if relevant.is_empty() {
return None;
}
let mut text = String::from("[Known Entities]\n");
let mut used = estimate_tokens(&text);
for (entity_id, name, trust, count) in &relevant {
let display = name.as_deref().unwrap_or(entity_id);
let line = format!("- {display}: trust={trust:.1}, interactions={count}\n");
let line_tokens = estimate_tokens(&line);
if used + line_tokens > budget_tokens {
break;
}
text.push_str(&line);
used += line_tokens;
}
if text.len() > "[Known Entities]\n".len() {
Some(text)
} else {
None
}
}
}
fn query_requests_inactive_memories(query: &str) -> bool {
let lower = query.to_ascii_lowercase();
[
"history",
"historical",
"previous",
"previously",
"earlier",
"before",
"past",
"old",
"resolved",
"stale",
"archive",
"archived",
]
.iter()
.any(|term| lower.contains(term))
}
fn estimate_tokens(text: &str) -> usize {
text.len().div_ceil(4)
}
pub struct ChunkConfig {
pub max_tokens: usize,
pub overlap_tokens: usize,
}
impl Default for ChunkConfig {
fn default() -> Self {
Self {
max_tokens: 512,
overlap_tokens: 64,
}
}
}
pub struct Chunk {
pub text: String,
pub index: usize,
pub start_char: usize,
pub end_char: usize,
}
fn floor_char_boundary(text: &str, pos: usize) -> usize {
if pos >= text.len() {
return text.len();
}
let mut p = pos;
while p > 0 && !text.is_char_boundary(p) {
p -= 1;
}
p
}
pub fn chunk_text(text: &str, config: &ChunkConfig) -> Vec<Chunk> {
if text.is_empty() || config.max_tokens == 0 {
return Vec::new();
}
let max_bytes = config.max_tokens * 4;
let overlap_bytes = config.overlap_tokens * 4;
if text.len() <= max_bytes {
return vec![Chunk {
text: text.to_string(),
index: 0,
start_char: 0,
end_char: text.len(),
}];
}
let step = max_bytes.saturating_sub(overlap_bytes).max(1);
let mut chunks = Vec::new();
let mut start = 0;
while start < text.len() {
let raw_end = floor_char_boundary(text, (start + max_bytes).min(text.len()));
let end = find_break_point(text, start, raw_end);
chunks.push(Chunk {
text: text[start..end].to_string(),
index: chunks.len(),
start_char: start,
end_char: end,
});
if end >= text.len() {
break;
}
let advance = step.min(end - start).max(1);
start = floor_char_boundary(text, start + advance);
}
chunks
}
fn find_break_point(text: &str, start: usize, raw_end: usize) -> usize {
if raw_end >= text.len() {
return text.len();
}
let search_start = floor_char_boundary(text, start + (raw_end - start) / 2);
let window = &text[search_start..raw_end];
if let Some(pos) = window.rfind("\n\n") {
return search_start + pos + 2;
}
for delim in [". ", ".\n", "? ", "! "] {
if let Some(pos) = window.rfind(delim) {
return search_start + pos + delim.len();
}
}
if let Some(pos) = window.rfind(' ') {
return search_start + pos + 1;
}
raw_end
}
#[cfg(test)]
mod tests {
use super::*;
fn test_db() -> Database {
Database::new(":memory:").unwrap()
}
fn default_config() -> MemoryConfig {
MemoryConfig::default()
}
#[test]
fn retriever_empty_db_returns_empty() {
let db = test_db();
let retriever = MemoryRetriever::new(default_config());
let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
let result = retriever.retrieve(&db, &session_id, "hello", None, ComplexityLevel::L1);
assert!(result.is_empty());
}
#[test]
fn retriever_returns_working_memory() {
let db = test_db();
let retriever = MemoryRetriever::new(default_config());
let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
roboticus_db::memory::store_working(&db, &session_id, "goal", "find documentation", 8)
.unwrap();
let result = retriever.retrieve(&db, &session_id, "hello", None, ComplexityLevel::L2);
assert!(result.contains("Working Memory"));
assert!(result.contains("find documentation"));
}
#[test]
fn retriever_skips_turn_summary_working_entries() {
let db = test_db();
let retriever = MemoryRetriever::new(default_config());
let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
roboticus_db::memory::store_working(
&db,
&session_id,
"turn_summary",
"Good to be back on familiar ground.",
9,
)
.unwrap();
roboticus_db::memory::store_working(&db, &session_id, "goal", "fix Telegram loop", 8)
.unwrap();
let result = retriever.retrieve(&db, &session_id, "telegram", None, ComplexityLevel::L2);
assert!(result.contains("Working Memory"));
assert!(result.contains("fix Telegram loop"));
assert!(!result.contains("Good to be back on familiar ground."));
}
#[test]
fn retriever_returns_relevant_memories() {
let db = test_db();
let retriever = MemoryRetriever::new(default_config());
let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
let id = roboticus_db::memory::store_semantic(&db, "facts", "sky", "the sky is blue", 0.9)
.unwrap();
roboticus_db::memory_index::upsert_index_entry(
&db,
"semantic_memory",
&id,
"the sky is blue",
Some("facts"),
)
.unwrap();
let result = retriever.retrieve(&db, &session_id, "sky", None, ComplexityLevel::L2);
assert!(
result.contains("[Memory Index"),
"index-only injection should contain the memory index header"
);
assert!(
result.contains("the sky is blue"),
"index entry summary should appear in output"
);
}
#[test]
fn retriever_returns_procedural_experience() {
let db = test_db();
let retriever = MemoryRetriever::new(default_config());
let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
roboticus_db::memory::store_procedural(&db, "web_search", "search the web").unwrap();
roboticus_db::memory::record_procedural_success(&db, "web_search").unwrap();
roboticus_db::memory::record_procedural_success(&db, "web_search").unwrap();
let output = retriever.retrieve_with_metrics(
&db,
&session_id,
"search",
None,
ComplexityLevel::L2,
None,
);
assert!(
output.metrics.tiers.procedural >= 1,
"procedural tier should count the stored tool experience"
);
}
#[test]
fn retriever_returns_relationships() {
let db = test_db();
let retriever = MemoryRetriever::new(default_config());
let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
roboticus_db::memory::store_relationship(&db, "user-1", "Jon", 0.9).unwrap();
let output = retriever.retrieve_with_metrics(
&db,
&session_id,
"Jon",
None,
ComplexityLevel::L2,
None,
);
assert!(
output.metrics.tiers.relationship >= 1,
"relationship tier should count the stored entity"
);
}
#[test]
fn retriever_respects_zero_budget() {
let config = MemoryConfig {
working_budget_pct: 0.0,
episodic_budget_pct: 0.0,
semantic_budget_pct: 0.0,
procedural_budget_pct: 0.0,
relationship_budget_pct: 100.0,
..default_config()
};
let db = test_db();
let retriever = MemoryRetriever::new(config);
let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
roboticus_db::memory::store_working(&db, &session_id, "goal", "test", 5).unwrap();
let result = retriever.retrieve(&db, &session_id, "test", None, ComplexityLevel::L0);
assert!(!result.contains("Working Memory"));
}
#[test]
fn retriever_similarity_threshold_filters_low_similarity_results() {
let config = MemoryConfig {
similarity_threshold: 0.4,
..default_config()
};
let db = test_db();
let retriever = MemoryRetriever::new(config);
let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
let active_id = roboticus_db::memory::store_semantic(
&db,
"facts",
"high-match",
"deployment rollback stabilizes the incident",
0.9,
)
.unwrap();
let low_id = roboticus_db::memory::store_semantic(
&db,
"facts",
"low-match",
"botanical orchids in alpine valleys",
0.9,
)
.unwrap();
roboticus_db::embeddings::store_embedding(
&db,
"emb-high",
"semantic_memory",
&active_id,
"deployment rollback stabilizes the incident",
&[1.0, 0.0],
)
.unwrap();
roboticus_db::embeddings::store_embedding(
&db,
"emb-low",
"semantic_memory",
&low_id,
"botanical orchids in alpine valleys",
&[-1.0, 0.0],
)
.unwrap();
roboticus_db::memory_index::upsert_index_entry(
&db,
"semantic_memory",
&active_id,
"deployment rollback stabilizes the incident",
Some("facts"),
)
.unwrap();
roboticus_db::memory_index::upsert_index_entry(
&db,
"semantic_memory",
&low_id,
"botanical orchids in alpine valleys",
Some("facts"),
)
.unwrap();
let output = retriever.retrieve_with_metrics(
&db,
&session_id,
"deployment rollback stabilizes the incident",
Some(&[1.0, 0.0]),
ComplexityLevel::L2,
None,
);
assert!(
output.metrics.avg_similarity >= 0.4,
"avg similarity should be above the configured threshold"
);
assert!(
output.metrics.tiers.semantic >= 1,
"at least the high-similarity match should be counted"
);
}
#[test]
fn chunk_empty_text() {
let chunks = chunk_text("", &ChunkConfig::default());
assert!(chunks.is_empty());
}
#[test]
fn chunk_short_text() {
let text = "This is a short sentence.";
let chunks = chunk_text(text, &ChunkConfig::default());
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].text, text);
assert_eq!(chunks[0].index, 0);
}
#[test]
fn chunk_long_text_produces_overlapping_chunks() {
let text = "word ".repeat(1000);
let config = ChunkConfig {
max_tokens: 50,
overlap_tokens: 10,
};
let chunks = chunk_text(&text, &config);
assert!(chunks.len() > 1);
for (i, chunk) in chunks.iter().enumerate() {
assert_eq!(chunk.index, i);
assert!(!chunk.text.is_empty());
}
for i in 1..chunks.len() {
assert!(chunks[i].start_char < chunks[i - 1].end_char);
}
}
#[test]
fn chunk_respects_sentence_boundaries() {
let text = "First sentence. Second sentence. Third sentence. Fourth sentence. Fifth sentence. \
Sixth sentence. Seventh sentence. Eighth sentence. Ninth sentence. Tenth sentence.";
let config = ChunkConfig {
max_tokens: 20,
overlap_tokens: 5,
};
let chunks = chunk_text(text, &config);
for chunk in &chunks {
if chunk.end_char < text.len() {
let ends_at_boundary = chunk.text.ends_with(". ")
|| chunk.text.ends_with('.')
|| chunk.text.ends_with(' ');
assert!(
ends_at_boundary,
"chunk should end at a boundary: {:?}",
&chunk.text[chunk.text.len().saturating_sub(10)..]
);
}
}
}
#[test]
fn chunk_covers_full_text() {
let text = "a ".repeat(500);
let config = ChunkConfig {
max_tokens: 25,
overlap_tokens: 5,
};
let chunks = chunk_text(&text, &config);
assert_eq!(chunks.first().unwrap().start_char, 0);
assert_eq!(chunks.last().unwrap().end_char, text.len());
}
#[test]
fn chunk_zero_max_tokens() {
let chunks = chunk_text(
"some text",
&ChunkConfig {
max_tokens: 0,
overlap_tokens: 0,
},
);
assert!(chunks.is_empty());
}
#[test]
fn estimate_tokens_basic() {
assert_eq!(estimate_tokens(""), 0);
assert_eq!(estimate_tokens("abcd"), 1);
assert_eq!(estimate_tokens("hello world!"), 3);
}
#[test]
fn chunk_multibyte_does_not_panic() {
let text = "Hello \u{1F600} world. ".repeat(200);
let config = ChunkConfig {
max_tokens: 20,
overlap_tokens: 5,
};
let chunks = chunk_text(&text, &config);
assert!(chunks.len() > 1);
for chunk in &chunks {
assert!(!chunk.text.is_empty());
let _ = chunk.text.as_bytes();
}
}
#[test]
fn chunk_cjk_text() {
let text = "\u{4F60}\u{597D}\u{4E16}\u{754C} ".repeat(300);
let config = ChunkConfig {
max_tokens: 15,
overlap_tokens: 3,
};
let chunks = chunk_text(&text, &config);
assert!(chunks.len() > 1);
assert_eq!(chunks.first().unwrap().start_char, 0);
assert_eq!(chunks.last().unwrap().end_char, text.len());
}
#[test]
fn floor_char_boundary_ascii() {
let text = "hello world";
assert_eq!(floor_char_boundary(text, 5), 5);
assert_eq!(floor_char_boundary(text, 0), 0);
assert_eq!(floor_char_boundary(text, 100), text.len());
}
#[test]
fn floor_char_boundary_multibyte() {
let text = "caf\u{00E9}";
assert_eq!(text.len(), 5);
assert_eq!(floor_char_boundary(text, 4), 3);
assert_eq!(floor_char_boundary(text, 3), 3);
assert_eq!(floor_char_boundary(text, 5), 5);
}
#[test]
fn floor_char_boundary_emoji() {
let text = "a\u{1F600}b"; assert_eq!(text.len(), 6);
assert_eq!(floor_char_boundary(text, 2), 1);
assert_eq!(floor_char_boundary(text, 5), 5);
}
#[test]
fn estimate_tokens_rounding() {
assert_eq!(estimate_tokens("a"), 1);
assert_eq!(estimate_tokens("abcde"), 2);
assert_eq!(estimate_tokens("abcdefgh"), 2);
}
#[test]
fn retriever_with_procedural_no_history() {
let db = test_db();
let retriever = MemoryRetriever::new(default_config());
let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
roboticus_db::memory::store_procedural(&db, "unused_tool", "a tool").unwrap();
let result = retriever.retrieve(&db, &session_id, "test", None, ComplexityLevel::L2);
assert!(
!result.contains("Tool Experience"),
"tools with no success/failure should not appear"
);
}
#[test]
fn chunk_with_paragraph_breaks() {
let text = "Paragraph one content.\n\nParagraph two content.\n\nParagraph three content.\n\n\
Paragraph four content.\n\nParagraph five content.";
let config = ChunkConfig {
max_tokens: 15,
overlap_tokens: 3,
};
let chunks = chunk_text(text, &config);
for chunk in &chunks {
if chunk.end_char < text.len() {
let last_few = &chunk.text[chunk.text.len().saturating_sub(5)..];
let has_good_break =
last_few.contains('\n') || last_few.contains(". ") || last_few.ends_with(' ');
assert!(has_good_break, "chunk should end at a reasonable boundary");
}
}
}
#[test]
fn chunk_config_default() {
let config = ChunkConfig::default();
assert_eq!(config.max_tokens, 512);
assert_eq!(config.overlap_tokens, 64);
}
#[test]
fn find_break_point_at_end_of_text() {
let text = "Hello world.";
assert_eq!(find_break_point(text, 0, text.len()), text.len());
}
#[test]
fn retriever_relationships_high_interaction_count() {
let db = test_db();
let retriever = MemoryRetriever::new(default_config());
let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
for _ in 0..4 {
roboticus_db::memory::store_relationship(&db, "alice", "Alice Smith", 0.8).unwrap();
}
let output = retriever.retrieve_with_metrics(
&db,
&session_id,
"some random query",
None,
ComplexityLevel::L2,
None,
);
assert!(
output.metrics.tiers.relationship >= 1,
"high interaction count entity should be retrieved into relationship tier"
);
}
#[test]
fn retriever_suppresses_stale_digests_by_default() {
let db = test_db();
let retriever = MemoryRetriever::new(default_config());
let session_id = roboticus_db::sessions::find_or_create(&db, "agent-1", None).unwrap();
let stale_id = roboticus_db::memory::store_episodic_with_meta(
&db,
"digest",
"[Session Digest] alpha rollout incident resolved",
9,
Some("agent-1"),
"active",
None,
)
.unwrap();
roboticus_db::memory::mark_episodic_digests_stale_for_owner(
&db,
"agent-1",
"newer-digest",
"superseded",
)
.unwrap();
let conn = db.conn();
conn.execute(
"UPDATE episodic_memory SET memory_state = 'stale' WHERE id = ?1",
[stale_id],
)
.unwrap();
drop(conn);
roboticus_db::memory::store_episodic_with_meta(
&db,
"digest",
"[Session Digest] beta stabilization plan active",
9,
Some("agent-1"),
"active",
None,
)
.unwrap();
let output = retriever.retrieve_with_metrics(
&db,
&session_id,
"alpha beta digest",
None,
ComplexityLevel::L2,
None,
);
assert!(
output.metrics.tiers.episodic >= 1,
"active digest should be retrieved"
);
}
#[test]
fn retriever_includes_stale_digests_when_history_requested() {
let db = test_db();
let retriever = MemoryRetriever::new(default_config());
let session_id = roboticus_db::sessions::find_or_create(&db, "agent-1", None).unwrap();
roboticus_db::memory::store_episodic_with_meta(
&db,
"digest",
"[Session Digest] alpha rollout incident resolved",
9,
Some("agent-1"),
"stale",
Some("superseded"),
)
.unwrap();
roboticus_db::memory::store_episodic_with_meta(
&db,
"digest",
"[Session Digest] beta stabilization plan active",
9,
Some("agent-1"),
"active",
None,
)
.unwrap();
let output = retriever.retrieve_with_metrics(
&db,
&session_id,
"show previous history for the alpha beta digest",
None,
ComplexityLevel::L2,
None,
);
assert!(
output.metrics.tiers.episodic >= 2,
"history query should include both stale and active digests: got {}",
output.metrics.tiers.episodic
);
}
#[test]
fn retriever_suppresses_stale_semantic_summaries_by_default() {
let db = test_db();
let retriever = MemoryRetriever::new(default_config());
let session_id = roboticus_db::sessions::find_or_create(&db, "agent-1", None).unwrap();
roboticus_db::memory::store_semantic(
&db,
"learned",
"session:agent-1:alpha",
"alpha policy was retired after the incident",
0.8,
)
.unwrap();
let active_id = roboticus_db::memory::store_semantic(
&db,
"learned",
"session:agent-1:beta",
"beta policy is active with the latest safeguards",
0.9,
)
.unwrap();
roboticus_db::memory::mark_semantic_stale_by_category_and_key_prefix(
&db,
"learned",
"session:agent-1:",
&active_id,
"superseded_by_newer_session_summary",
)
.unwrap();
let output = retriever.retrieve_with_metrics(
&db,
&session_id,
"alpha beta policy safeguards",
None,
ComplexityLevel::L2,
None,
);
assert!(
output.metrics.tiers.semantic <= 1,
"stale semantic summaries should be suppressed: got {} semantic",
output.metrics.tiers.semantic
);
}
#[test]
fn retriever_includes_stale_semantic_summaries_when_history_requested() {
let db = test_db();
let retriever = MemoryRetriever::new(default_config());
let session_id = roboticus_db::sessions::find_or_create(&db, "agent-1", None).unwrap();
roboticus_db::memory::store_semantic(
&db,
"learned",
"session:agent-1:alpha",
"alpha policy was retired after the incident",
0.8,
)
.unwrap();
let active_id = roboticus_db::memory::store_semantic(
&db,
"learned",
"session:agent-1:beta",
"beta policy is active with the latest safeguards",
0.9,
)
.unwrap();
roboticus_db::memory::mark_semantic_stale_by_category_and_key_prefix(
&db,
"learned",
"session:agent-1:",
&active_id,
"superseded_by_newer_session_summary",
)
.unwrap();
let output = retriever.retrieve_with_metrics(
&db,
&session_id,
"show history of the alpha beta policy change",
None,
ComplexityLevel::L2,
None,
);
assert!(
output.metrics.tiers.semantic >= 2,
"history query should include stale semantic summaries: got {}",
output.metrics.tiers.semantic
);
}
#[test]
fn retrieve_with_metrics_empty_db() {
let db = test_db();
let retriever = MemoryRetriever::new(default_config());
let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
let output = retriever.retrieve_with_metrics(
&db,
&session_id,
"hello",
None,
ComplexityLevel::L1,
None,
);
assert!(output.text.is_empty());
assert!(!output.metrics.retrieval_hit);
assert_eq!(output.metrics.retrieval_count, 0);
assert_eq!(output.metrics.avg_similarity, 0.0);
assert_eq!(output.metrics.budget_utilization, 0.0);
}
#[test]
fn retrieve_with_metrics_working_memory_counted() {
let db = test_db();
let retriever = MemoryRetriever::new(default_config());
let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
roboticus_db::memory::store_working(&db, &session_id, "goal", "fix the pipeline", 8)
.unwrap();
roboticus_db::memory::store_working(&db, &session_id, "note", "version 0.11", 7).unwrap();
let output = retriever.retrieve_with_metrics(
&db,
&session_id,
"hello",
None,
ComplexityLevel::L2,
None,
);
assert!(output.metrics.retrieval_hit);
assert!(
output.metrics.tiers.working >= 2,
"working tier count should reflect stored entries"
);
assert!(output.metrics.retrieval_count >= 2);
assert!(output.metrics.budget_utilization > 0.0);
let json = serde_json::to_string(&output.metrics.tiers).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(parsed["working"].as_u64().unwrap() >= 2);
}
#[test]
fn retrieve_with_metrics_procedural_counted() {
let db = test_db();
let retriever = MemoryRetriever::new(default_config());
let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
roboticus_db::memory::store_procedural(&db, "web_search", "search the web").unwrap();
roboticus_db::memory::record_procedural_success(&db, "web_search").unwrap();
let output = retriever.retrieve_with_metrics(
&db,
&session_id,
"search",
None,
ComplexityLevel::L2,
None,
);
assert!(output.metrics.tiers.procedural >= 1);
}
}