use std::collections::HashSet;
use crate::memory::bm25_index::BM25MemoryIndex;
use crate::memory::cfa_session::round_trip_test_helper;
use crate::memory::hnsw_index::HnswMemoryIndex;
use crate::memory::types::{CfaSession, EntityKind, EntityRef, MemoryQuery, RunSummary, Surface};
fn mk_summary(
surface: Surface,
surface_event_id: &str,
text: &str,
embedding: Vec<f32>,
) -> RunSummary {
RunSummary::new(surface, surface_event_id, "djb2:0xaaaa", text, embedding)
}
fn three_orthogonal_summaries() -> Vec<RunSummary> {
vec![
mk_summary(
Surface::Mcp,
"dcf_calc",
"Apple iPhone revenue beat consensus",
vec![1.0, 0.0, 0.0],
),
mk_summary(
Surface::Mcp,
"dcf_calc",
"Microsoft Azure cloud revenue accelerated",
vec![0.0, 1.0, 0.0],
),
mk_summary(
Surface::Mcp,
"dcf_calc",
"Tesla deliveries miss in EMEA region",
vec![0.0, 0.0, 1.0],
),
]
}
#[test]
fn ruf_mem_001_run_summary_has_required_fields_and_serialises() {
let s = mk_summary(
Surface::Cli,
"cfa.workflow.audit",
"audit completed for ACME",
vec![0.1, 0.2, 0.3],
);
assert!(!s.run_id.is_nil(), "run_id must be populated");
assert_eq!(s.surface, Surface::Cli);
assert!(!s.surface_event_id.is_empty());
assert!(!s.surface_audit_hash.is_empty());
let delta = (chrono::Utc::now() - s.ts).num_seconds().abs();
assert!(delta < 5, "ts should be set to ~now");
let j = serde_json::to_string(&s).unwrap();
let back: RunSummary = serde_json::from_str(&j).unwrap();
assert_eq!(s, back);
}
#[test]
fn ruf_mem_002_find_returns_top_k_by_distance() {
let mut idx = HnswMemoryIndex::new(3);
let summaries = three_orthogonal_summaries();
for s in &summaries {
idx.ingest(s).unwrap();
}
let res = idx.query(&[1.0, 0.0, 0.0], 3, |_| true);
assert_eq!(res.len(), 3, "exactly k results when pool >= k");
assert_eq!(
res[0].0.run_id, summaries[0].run_id,
"closest must be the orthogonal match"
);
let mut prev = f32::MIN;
for (_, d) in &res {
assert!(*d >= prev, "distances must be sorted ascending");
prev = *d;
}
}
#[test]
fn ruf_mem_003_filters_and_combine() {
let mut idx = HnswMemoryIndex::new(3);
idx.ingest(&mk_summary(
Surface::Cli,
"cmd_a",
"alpha",
vec![1.0, 0.0, 0.0],
))
.unwrap();
idx.ingest(&mk_summary(
Surface::Cli,
"cmd_b",
"alpha",
vec![0.9, 0.1, 0.0],
))
.unwrap();
idx.ingest(&mk_summary(
Surface::Mcp,
"cmd_a",
"alpha",
vec![1.0, 0.0, 0.0],
))
.unwrap();
let res = idx.query(&[1.0, 0.0, 0.0], 5, |s| {
s.surface == Surface::Cli && s.surface_event_id == "cmd_a"
});
assert_eq!(res.len(), 1);
assert_eq!(res[0].0.surface, Surface::Cli);
assert_eq!(res[0].0.surface_event_id, "cmd_a");
}
#[test]
fn ruf_mem_004_dim_mismatch_returns_error_without_panic() {
let mut idx = HnswMemoryIndex::new(3);
let bad = mk_summary(Surface::Mcp, "x", "x", vec![1.0, 0.0]); let res = idx.ingest(&bad);
assert!(res.is_err(), "dim mismatch must surface as Err");
let good = mk_summary(Surface::Mcp, "x", "x", vec![1.0, 0.0, 0.0]);
assert!(idx.ingest(&good).is_ok());
assert_eq!(idx.len(), 1);
}
#[test]
fn ruf_mem_005_session_round_trip_field_equal() {
let mut s = CfaSession::new(Surface::Cli);
s.append(mk_summary(
Surface::Cli,
"cfa.workflow.audit",
"summary 1",
vec![0.1, 0.2, 0.3],
));
s.append(mk_summary(
Surface::Mcp,
"dcf_calc",
"summary 2",
vec![0.4, 0.5, 0.6],
));
s.append(
mk_summary(
Surface::Skill,
"cfa:initiate-coverage",
"summary 3",
vec![0.7, 0.8, 0.9],
)
.with_tenant("tenant-acme"),
);
let restored = round_trip_test_helper(&s).unwrap();
assert_eq!(s, restored, "round-trip must preserve all fields");
}
#[test]
fn ruf_mem_006_retention_window_default_present() {
const DEFAULT_HOT_WINDOW_DAYS: u32 = 90;
assert_eq!(DEFAULT_HOT_WINDOW_DAYS, 90);
}
#[test]
fn ruf_mem_007_surface_enum_canonicalises_to_permitted_token() {
let tokens: HashSet<&str> = [Surface::Cli, Surface::Mcp, Surface::Skill, Surface::Plugin]
.iter()
.map(|s| s.as_str())
.collect();
assert_eq!(tokens.len(), 4);
assert!(tokens.contains("cli"));
assert!(tokens.contains("mcp"));
assert!(tokens.contains("skill"));
assert!(tokens.contains("plugin"));
}
#[test]
fn ruf_mem_008_default_build_uses_native_crates_only() {
let mut idx = HnswMemoryIndex::new(2);
let s = mk_summary(Surface::Cli, "x", "x", vec![1.0, 0.0]);
idx.ingest(&s).unwrap();
assert_eq!(idx.len(), 1);
let bm = BM25MemoryIndex::new();
assert!(bm.is_ok());
}
#[test]
fn ruf_mem_009_successful_ingest_returns_ok_and_advances_len() {
let mut idx = HnswMemoryIndex::new(3);
assert_eq!(idx.len(), 0);
let s = mk_summary(Surface::Mcp, "dcf_calc", "summary", vec![1.0, 0.0, 0.0]);
idx.ingest(&s).unwrap();
assert_eq!(idx.len(), 1);
let res = idx.query(&[1.0, 0.0, 0.0], 1, |_| true);
assert_eq!(res.len(), 1);
assert_eq!(res[0].0.run_id, s.run_id);
}
#[test]
fn ruf_mem_010_memory_query_default_limit_is_positive() {
let q = MemoryQuery::new();
assert!(q.limit >= 1, "default limit must be >= 1");
let q2 = MemoryQuery::default();
assert_eq!(q.limit, q2.limit);
let q3 = MemoryQuery {
limit: 7,
..MemoryQuery::default()
};
assert_eq!(q3.limit, 7);
}
#[test]
fn ruf_mem_inv_001_no_duplicate_run_ids() {
let mut idx = HnswMemoryIndex::new(3);
let s = mk_summary(Surface::Mcp, "dcf_calc", "x", vec![1.0, 0.0, 0.0]);
idx.ingest(&s).unwrap();
let dup = s.clone();
let res = idx.ingest(&dup);
assert!(res.is_err(), "duplicate run_id must be rejected");
assert_eq!(idx.len(), 1);
}
#[test]
fn ruf_mem_inv_002_serialised_run_summary_round_trips_on_v1_schema() {
let s = mk_summary(Surface::Mcp, "dcf_calc", "x", vec![1.0, 0.0, 0.0]);
let j = serde_json::to_value(&s).unwrap();
assert_eq!(j["surface"], "mcp");
assert!(j.get("tenant_id").is_none(), "None tenant_id is omitted");
let back: RunSummary = serde_json::from_value(j).unwrap();
assert_eq!(back, s);
}
#[test]
fn ruf_mem_inv_003_session_round_trip_invariant_holds_for_empty_and_populated() {
let empty = CfaSession::new(Surface::Mcp);
let r1 = round_trip_test_helper(&empty).unwrap();
assert_eq!(empty, r1);
let mut populated = CfaSession::new(Surface::Cli);
for i in 0..5 {
populated.append(mk_summary(
Surface::Mcp,
"dcf_calc",
&format!("note {i}"),
vec![i as f32, 0.0, 0.0],
));
}
let r2 = round_trip_test_helper(&populated).unwrap();
assert_eq!(populated, r2);
}
#[cfg(feature = "multi_agent")]
#[test]
fn ruf_mem_inv_004_entity_extraction_lives_in_multi_agent_module() {
use crate::multi_agent::entity_graph::{extract_entities_from_text, EntityGraph};
use crate::multi_agent::types::RelationKind;
let ents = extract_entities_from_text("AAPL traded with CUSIP 037833100 today");
let kinds: HashSet<crate::multi_agent::types::EntityKind> =
ents.iter().map(|e| e.kind).collect();
assert!(kinds.contains(&crate::multi_agent::types::EntityKind::Ticker));
assert!(kinds.contains(&crate::multi_agent::types::EntityKind::Cusip));
let mut g = EntityGraph::new();
for i in 0..ents.len() {
for j in (i + 1)..ents.len() {
g.add_relation(
ents[i].clone(),
ents[j].clone(),
RelationKind::MentionedTogether,
);
}
}
assert!(g.node_count() >= 2);
}
#[test]
fn ruf_mem_inv_005_query_caps_results_at_limit() {
let mut idx = HnswMemoryIndex::new(3);
for i in 0..20 {
idx.ingest(&mk_summary(
Surface::Mcp,
"dcf_calc",
&format!("note {i}"),
vec![i as f32, 0.0, 0.0],
))
.unwrap();
}
for limit in [1usize, 3, 5, 10] {
let res = idx.query(&[0.0, 0.0, 0.0], limit, |_| true);
assert!(res.len() <= limit, "query() must never exceed limit");
}
}
#[test]
fn bm25_keyword_query_supports_hybrid_retrieval() {
let mut bm = BM25MemoryIndex::new().unwrap();
bm.ingest(&mk_summary(
Surface::Mcp,
"dcf_calc",
"Apple iPhone revenue beat",
vec![0.0, 0.0, 0.0],
))
.unwrap();
bm.ingest(&mk_summary(
Surface::Mcp,
"dcf_calc",
"Tesla EMEA deliveries miss",
vec![0.0, 0.0, 0.0],
))
.unwrap();
let hits = bm.query("apple revenue", 5, None).unwrap();
assert!(!hits.is_empty());
assert!(hits[0].run_summary.summary_text.contains("Apple"));
assert!(hits[0].bm25_score > 0.0);
}
#[test]
fn entity_ref_round_trips_through_serde() {
let e = EntityRef {
kind: EntityKind::Ticker,
value: "AAPL".into(),
};
let j = serde_json::to_string(&e).unwrap();
let back: EntityRef = serde_json::from_str(&j).unwrap();
assert_eq!(e, back);
}