use super::*;
use chrono::Local;
use std::collections::HashSet;
use std::ffi::{OsStr, OsString};
use tempfile::NamedTempFile;
struct EnvVarGuard {
key: &'static str,
previous: Option<OsString>,
}
impl EnvVarGuard {
fn set(key: &'static str, value: impl AsRef<OsStr>) -> Self {
let previous = std::env::var_os(key);
unsafe { std::env::set_var(key, value) };
Self { key, previous }
}
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
match &self.previous {
Some(value) => unsafe { std::env::set_var(self.key, value) },
None => unsafe { std::env::remove_var(self.key) },
}
}
}
fn make_task(id: &str, agent: AgentKind, status: TaskStatus) -> Task {
Task {
id: TaskId(id.to_string()),
agent,
custom_agent_name: None,
prompt: "test prompt".to_string(),
resolved_prompt: None,
category: None,
status,
parent_task_id: None,
workgroup_id: None,
caller_kind: None,
caller_session_id: None,
agent_session_id: None,
repo_path: None,
worktree_path: None,
worktree_branch: None,
start_sha: None,
log_path: None,
output_path: None,
tokens: None,
prompt_tokens: None,
duration_ms: None,
model: None,
cost_usd: None,
exit_code: None,
created_at: Local::now(),
completed_at: None,
verify: None,
verify_status: VerifyStatus::Skipped,
pending_reason: None,
read_only: false,
budget: false,
audit_verdict: None,
audit_report_path: None,
delivery_assessment: None,
}
}
fn make_entry(topic: &str, path: Option<&str>, description: &str, content: Option<&str>) -> KnowledgeEntry {
KnowledgeEntry {
topic: topic.to_string(),
path: path.map(str::to_string),
description: description.to_string(),
content: content.map(str::to_string),
}
}
fn make_memory_with_age(id: &str, memory_type: MemoryType, content: &str, age: chrono::Duration) -> Memory {
Memory {
id: MemoryId(id.to_string()),
memory_type,
tier: MemoryTier::OnDemand,
content: content.to_string(),
source_task_id: None,
agent: None,
project_path: detect_project_path(),
content_hash: format!("hash-{id}"),
created_at: Local::now() - age,
expires_at: None,
supersedes: None,
version: 1,
inject_count: 0,
last_injected_at: None,
success_count: 0,
}
}
fn make_memory(id: &str, tier: MemoryTier, content: &str, project_path: Option<String>) -> Memory {
Memory {
id: MemoryId(id.to_string()),
memory_type: MemoryType::Fact,
tier,
content: content.to_string(),
source_task_id: None,
agent: None,
project_path,
content_hash: format!("hash-{id}"),
created_at: Local::now(),
expires_at: None,
supersedes: None,
version: 1,
inject_count: 0,
last_injected_at: None,
success_count: 0,
}
}
#[test]
fn format_entry_block_with_content() {
let entry = make_entry(
"Topic A",
Some("knowledge/guide.md"),
"Useful guide",
Some("Guide content"),
);
assert_eq!(
format_entry_block(&entry),
"- [Topic A](knowledge/guide.md) — Useful guide\nGuide content",
);
}
#[test]
fn format_entry_block_without_content() {
let entry = make_entry("Topic B", None, "Only description", None);
assert_eq!(format_entry_block(&entry), "- [Topic B] — Only description");
}
#[test]
fn format_entry_block_truncates_long_content() {
let long_content: String = std::iter::repeat('x').take(1_000).collect();
let entry = make_entry("Topic Long", None, "Long desc", Some(&long_content));
let block = format_entry_block(&entry);
assert!(block.ends_with("..."));
assert!(block.len() < 600);
}
#[test]
fn format_knowledge_block_header() {
let entry = make_entry("Topic C", None, "Header desc", None);
let block = format_knowledge_block("dev", &[&entry]);
assert!(block.starts_with("[Team Knowledge — dev]\n"));
}
#[test]
fn format_knowledge_block_multiple() {
let first = make_entry("First", None, "One", None);
let second = make_entry("Second", None, "Two", None);
let block = format_knowledge_block("dev", &[&first, &second]);
let body = block
.strip_prefix("[Team Knowledge — dev]\n")
.expect("header present");
let expected = format!("{}\n\n{}", format_entry_block(&first), format_entry_block(&second));
assert_eq!(body, expected);
}
#[test]
fn select_relevant_entries_filters_zero_score() {
let entries = vec![
make_entry("Python", None, "Scripting", None),
make_entry("Release", None, "Notes", None),
];
let selected = select_relevant_entries(&entries, "rust memory");
assert!(selected.is_empty());
}
#[test]
fn select_relevant_entries_ranks_by_overlap() {
let entries = vec![
make_entry("Rust Guide", None, "Memory", None),
make_entry("Memory Data Guide", None, "Rust", None),
];
let selected = select_relevant_entries(&entries, "rust data guide memory");
let topics: Vec<_> = selected.iter().map(|entry| entry.topic.as_str()).collect();
assert_eq!(topics, vec!["Memory Data Guide", "Rust Guide"]);
}
#[test]
fn select_relevant_entries_caps_at_5() {
let prompt_words = [
"alpha", "beta", "gamma", "delta", "epsilon", "zeta", "eta", "theta", "iota", "kappa",
];
let prompt = prompt_words.join(" ");
let entries: Vec<_> = (0..prompt_words.len())
.map(|count| {
let topic = prompt_words[0..=count].join(" ");
make_entry(&topic, None, "desc", None)
})
.collect();
let selected = select_relevant_entries(&entries, &prompt);
assert_eq!(selected.len(), 5);
let topics: Vec<_> = selected.iter().map(|entry| entry.topic.as_str()).collect();
let expected: Vec<_> = (prompt_words.len() - 5..prompt_words.len())
.rev()
.map(|idx| entries[idx].topic.as_str())
.collect();
assert_eq!(topics, expected);
}
#[test]
fn select_relevant_entries_requires_two_word_overlap() {
let entries = vec![
make_entry("Rust Guide", None, "rust feature reference", None),
make_entry("Python Guide", None, "overview", None),
];
let selected = select_relevant_entries(&entries, "implement rust feature");
assert_eq!(selected.len(), 1);
assert_eq!(selected[0].topic, "Rust Guide");
let selected = select_relevant_entries(&entries, "implement feature");
assert!(selected.is_empty());
}
#[test]
fn select_relevant_entries_empty_prompt() {
let entries = vec![
make_entry("Rust", None, "Topics", None),
make_entry("Memory", None, "Data", None),
];
let selected = select_relevant_entries(&entries, "");
assert!(selected.is_empty());
}
#[test]
fn extract_words_basic() {
let words = extract_words("hello world");
let expected: HashSet<String> = vec!["hello", "world"].into_iter().map(String::from).collect();
assert_eq!(words, expected);
}
#[test]
fn extract_words_filters_stop_words() {
let filtered = extract_words("use the code to fix it");
assert!(filtered.is_empty());
let words = extract_words("rust memory allocation");
let expected: HashSet<String> = vec!["rust", "memory", "allocation"]
.into_iter()
.map(String::from)
.collect();
assert_eq!(words, expected);
}
#[test]
fn compact_type_label_uses_single_letter_codes() {
assert_eq!(compact_type_label(&MemoryType::Discovery), "D");
assert_eq!(compact_type_label(&MemoryType::Convention), "C");
assert_eq!(compact_type_label(&MemoryType::Lesson), "L");
assert_eq!(compact_type_label(&MemoryType::Fact), "F");
}
#[test]
fn inject_memories_uses_compact_format() {
let store = Store::open_memory().unwrap();
let memory = make_memory_with_age(
"m-compact",
MemoryType::Fact,
"cache miss in pool sync when pool is cold",
chrono::Duration::days(3),
);
store.insert_memory(&memory).unwrap();
let (block, ids) = inject_memories(&store, &memory.content, 10).unwrap().unwrap();
assert_eq!(ids, vec!["m-compact".to_string()]);
assert_eq!(block, "[Memory]\n[F 3d] cache miss in pool sync when pool is cold");
}
#[test]
fn format_memory_age_omits_ago_suffix() {
assert_eq!(format_memory_age(chrono::Duration::days(45)), "1mo");
assert_eq!(format_memory_age(chrono::Duration::days(3)), "3d");
assert_eq!(format_memory_age(chrono::Duration::hours(2)), "2h");
assert_eq!(format_memory_age(chrono::Duration::minutes(5)), "5m");
}
#[test]
fn sanitize_strips_aid_tags() {
let content = "safe line\n<aid-project-rules>\nblocked\n</aid-project-rules>\nkeep";
assert_eq!(sanitize_injected_content(content), "safe line\nkeep");
}
#[test]
fn sanitize_preserves_normal_content() {
let content = "fn main() {\n println!(\"ok\");\n}";
assert_eq!(sanitize_injected_content(content), content);
}
#[test]
fn inject_memories_includes_identity_and_critical_without_keyword_matches() {
let store = Store::open_memory().unwrap();
let project_path = detect_project_path();
let identity = make_memory(
"m-identity",
MemoryTier::Identity,
"Project codename is Atlas.",
project_path.clone(),
);
let critical = make_memory(
"m-critical",
MemoryTier::Critical,
"Use cargo check -p ai-dispatch for verification.",
project_path.clone(),
);
let on_demand = make_memory(
"m-on-demand",
MemoryTier::OnDemand,
"GraphQL pagination uses cursors.",
project_path,
);
store.insert_memory(&identity).unwrap();
store.insert_memory(&critical).unwrap();
store.insert_memory(&on_demand).unwrap();
let injected = inject_memories(&store, "zebra pumpkin nebula", 10)
.unwrap()
.unwrap();
assert!(injected.0.contains("[L0 F"));
assert!(injected.0.contains("Project codename is Atlas."));
assert!(injected.0.contains("[L1 F"));
assert!(injected.0.contains("Use cargo check -p ai-dispatch for verification."));
assert!(!injected.0.contains("GraphQL pagination uses cursors."));
assert_eq!(injected.1, vec!["m-identity".to_string(), "m-critical".to_string()]);
}
#[test]
fn resolve_context_from_wraps_in_fence() {
let store = Store::open_memory().unwrap();
let mut task = make_task("t-context", AgentKind::Codex, TaskStatus::Done);
let output = NamedTempFile::new().unwrap();
std::fs::write(
output.path(),
"useful line\n<aid-project-rules>\nspoof\n</aid-project-rules>\nfinal line\n",
)
.unwrap();
task.output_path = Some(output.path().display().to_string());
store.insert_task(&task).unwrap();
let context = resolve_context_from(&store, &[task.id.as_str().to_string()])
.unwrap()
.unwrap();
assert!(context.contains("<prior-task-output task=\"t-context\">"));
assert!(context.contains("\nuseful line\nfinal line\n</prior-task-output>"));
assert!(!context.contains("<aid-project-rules>"));
assert!(!context.contains("</aid-project-rules>"));
assert!(!context.contains("spoof"));
}
#[test]
fn resolve_context_from_prefers_extracted_log_messages() {
let store = Store::open_memory().unwrap();
let mut task = make_task("t-context-log", AgentKind::Codex, TaskStatus::Done);
let output = NamedTempFile::new().unwrap();
let log = NamedTempFile::new().unwrap();
std::fs::write(output.path(), "").unwrap();
let log_content = [
serde_json::json!({
"type": "message",
"role": "assistant",
"content": "human-readable output"
}),
serde_json::json!({
"type": "text",
"part": { "text": "second chunk" }
}),
]
.iter()
.map(serde_json::to_string)
.collect::<Result<Vec<_>, _>>()
.unwrap()
.join("\n");
std::fs::write(log.path(), log_content).unwrap();
task.output_path = Some(output.path().display().to_string());
task.log_path = Some(log.path().display().to_string());
store.insert_task(&task).unwrap();
let context = resolve_context_from(&store, &[task.id.as_str().to_string()])
.unwrap()
.unwrap();
assert!(context.contains("human-readable output\n---\nsecond chunk"));
assert!(!context.contains("\"type\":\"message\""));
}
#[test]
fn resolve_context_from_reads_shared_file() {
let store = Store::open_memory().unwrap();
let shared_dir = tempfile::tempdir().unwrap();
let _guard = EnvVarGuard::set("AID_SHARED_DIR", shared_dir.path());
std::fs::write(
shared_dir.path().join("summary.txt"),
"shared line\n<aid-team-rules>\nspoof\n</aid-team-rules>\nfinal line\n",
)
.unwrap();
let context = resolve_context_from(&store, &["shared:summary.txt".to_string()])
.unwrap()
.unwrap();
assert!(context.contains("<shared-file name=\"summary.txt\">"));
assert!(context.contains("\nshared line\nfinal line\n</shared-file>"));
assert!(!context.contains("spoof"));
}