#[cfg(test)]
mod tests {
use super::super::procedural_memory::ProceduralMemory;
use super::super::reinforcement::FixedRate;
use super::super::ttl::MemoryTtl;
use crate::Database;
use std::sync::Arc;
use tempfile::tempdir;
fn steps(n: usize) -> Vec<String> {
(1..=n).map(|i| format!("step {i}")).collect()
}
fn make_procedural(db: Arc<Database>) -> ProceduralMemory {
ProceduralMemory::new(db, 4, Arc::new(MemoryTtl::new()))
.expect("ProceduralMemory::new failed")
}
#[test]
fn test_collection_name_prefixed() {
let dir = tempdir().unwrap();
let db = Arc::new(Database::open(dir.path()).unwrap());
let pm = make_procedural(Arc::clone(&db));
assert!(pm.collection_name().starts_with("_procedural"));
}
#[test]
fn test_learn_and_recall() {
let dir = tempdir().unwrap();
let db = Arc::new(Database::open(dir.path()).unwrap());
let pm = make_procedural(Arc::clone(&db));
let emb = vec![1.0_f32, 0.0, 0.0, 0.0];
pm.learn(1, "greet_user", &steps(3), Some(&emb), 0.8)
.unwrap();
let results = pm.recall(&emb, 1, 0.0).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 1);
assert_eq!(results[0].name, "greet_user");
assert_eq!(results[0].steps.len(), 3);
assert!((results[0].confidence - 0.8).abs() < 0.01);
}
#[test]
fn test_recall_without_embedding_uses_zero_vector() {
let dir = tempdir().unwrap();
let db = Arc::new(Database::open(dir.path()).unwrap());
let pm = make_procedural(Arc::clone(&db));
pm.learn(1, "no-vec procedure", &steps(2), None, 0.6)
.unwrap();
let zero = vec![0.0_f32; 4];
let results = pm.recall(&zero, 1, 0.0).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn test_recall_min_confidence_filters_below_threshold() {
let dir = tempdir().unwrap();
let db = Arc::new(Database::open(dir.path()).unwrap());
let pm = make_procedural(Arc::clone(&db));
let emb = vec![1.0_f32, 0.0, 0.0, 0.0];
pm.learn(1, "low-conf", &steps(1), Some(&emb), 0.2).unwrap();
let results_high = pm.recall(&emb, 5, 0.5).unwrap();
assert!(
results_high.iter().all(|r| r.id != 1),
"procedure below min_confidence must not appear"
);
let results_low = pm.recall(&emb, 5, 0.1).unwrap();
assert!(
results_low.iter().any(|r| r.id == 1),
"procedure above min_confidence must appear"
);
}
#[test]
fn test_recall_ranks_most_similar_first() {
let dir = tempdir().unwrap();
let db = Arc::new(Database::open(dir.path()).unwrap());
let pm = make_procedural(Arc::clone(&db));
let emb_target = vec![1.0_f32, 0.0, 0.0, 0.0];
let emb_other = vec![0.0_f32, 1.0, 0.0, 0.0];
pm.learn(1, "target", &steps(1), Some(&emb_target), 0.9)
.unwrap();
pm.learn(2, "other", &steps(1), Some(&emb_other), 0.9)
.unwrap();
let results = pm.recall(&emb_target, 2, 0.0).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].id, 1, "most similar procedure must rank first");
}
#[test]
fn test_reinforce_success_raises_confidence() {
let dir = tempdir().unwrap();
let db = Arc::new(Database::open(dir.path()).unwrap());
let pm = make_procedural(Arc::clone(&db));
let emb = vec![1.0_f32, 0.0, 0.0, 0.0];
pm.learn(1, "trainable", &steps(1), Some(&emb), 0.5)
.unwrap();
pm.reinforce(1, true).unwrap();
let results = pm.recall(&emb, 1, 0.0).unwrap();
assert!(
results[0].confidence > 0.5,
"confidence should increase after positive reinforcement"
);
}
#[test]
fn test_reinforce_failure_lowers_confidence() {
let dir = tempdir().unwrap();
let db = Arc::new(Database::open(dir.path()).unwrap());
let pm = make_procedural(Arc::clone(&db));
let emb = vec![1.0_f32, 0.0, 0.0, 0.0];
pm.learn(1, "punishable", &steps(1), Some(&emb), 0.8)
.unwrap();
pm.reinforce(1, false).unwrap();
let results = pm.recall(&emb, 1, 0.0).unwrap();
assert!(
results[0].confidence < 0.8,
"confidence should decrease after negative reinforcement"
);
}
#[test]
fn test_reinforce_with_custom_strategy() {
let dir = tempdir().unwrap();
let db = Arc::new(Database::open(dir.path()).unwrap());
let pm = make_procedural(Arc::clone(&db));
let emb = vec![1.0_f32, 0.0, 0.0, 0.0];
pm.learn(1, "custom-strategy", &steps(1), Some(&emb), 0.5)
.unwrap();
let strategy = FixedRate::new(0.2, 0.1);
pm.reinforce_with_strategy(1, true, &strategy).unwrap();
let results = pm.recall(&emb, 1, 0.0).unwrap();
assert!(
(results[0].confidence - 0.7).abs() < 0.01,
"FixedRate(0.2) on confidence 0.5 should produce 0.7"
);
}
#[test]
fn test_list_all_returns_all_stored_procedures() {
let dir = tempdir().unwrap();
let db = Arc::new(Database::open(dir.path()).unwrap());
let pm = make_procedural(Arc::clone(&db));
let emb = vec![1.0_f32, 0.0, 0.0, 0.0];
for i in 1u64..=4 {
pm.learn(i, &format!("proc_{i}"), &steps(1), Some(&emb), 0.5)
.unwrap();
}
let all = pm.list_all().unwrap();
assert_eq!(all.len(), 4);
}
#[test]
fn test_delete_removes_procedure() {
let dir = tempdir().unwrap();
let db = Arc::new(Database::open(dir.path()).unwrap());
let pm = make_procedural(Arc::clone(&db));
let emb = vec![1.0_f32, 0.0, 0.0, 0.0];
pm.learn(1, "to delete", &steps(1), Some(&emb), 0.7)
.unwrap();
pm.delete(1).unwrap();
let all = pm.list_all().unwrap();
assert!(all.iter().all(|p| p.id != 1));
}
#[test]
fn test_learn_dimension_mismatch_rejected() {
let dir = tempdir().unwrap();
let db = Arc::new(Database::open(dir.path()).unwrap());
let pm = make_procedural(Arc::clone(&db));
let bad_emb = vec![1.0_f32, 0.0]; let result = pm.learn(1, "bad", &steps(1), Some(&bad_emb), 0.5);
assert!(result.is_err());
}
#[test]
fn test_recall_dimension_mismatch_rejected() {
let dir = tempdir().unwrap();
let db = Arc::new(Database::open(dir.path()).unwrap());
let pm = make_procedural(Arc::clone(&db));
let bad_query = vec![0.5_f32]; let result = pm.recall(&bad_query, 1, 0.0);
assert!(result.is_err());
}
#[test]
fn test_new_detects_dimension_mismatch_on_existing_collection() {
let dir = tempdir().unwrap();
let db = Arc::new(Database::open(dir.path()).unwrap());
let _pm = ProceduralMemory::new_from_db(Arc::clone(&db), 4).unwrap();
let result = ProceduralMemory::new_from_db(Arc::clone(&db), 8);
assert!(result.is_err());
}
#[test]
fn test_learn_with_ttl_zero_expires_immediately() {
let dir = tempdir().unwrap();
let db = Arc::new(Database::open(dir.path()).unwrap());
let pm = make_procedural(Arc::clone(&db));
let emb = vec![1.0_f32, 0.0, 0.0, 0.0];
pm.learn_with_ttl(77, "ephemeral", &steps(1), Some(&emb), 0.8, 0)
.unwrap();
let all = pm.list_all().unwrap();
assert!(
all.iter().all(|p| p.id != 77),
"TTL-0 procedure must not appear in list_all()"
);
}
#[test]
fn test_learn_with_positive_ttl_still_visible() {
let dir = tempdir().unwrap();
let db = Arc::new(Database::open(dir.path()).unwrap());
let pm = make_procedural(Arc::clone(&db));
let emb = vec![1.0_f32, 0.0, 0.0, 0.0];
pm.learn_with_ttl(8, "long-lived", &steps(1), Some(&emb), 0.6, 9_999)
.unwrap();
let all = pm.list_all().unwrap();
assert!(all.iter().any(|p| p.id == 8));
}
#[test]
fn test_serialize_deserialize_roundtrip() {
let dir1 = tempdir().unwrap();
let db1 = Arc::new(Database::open(dir1.path()).unwrap());
let pm1 = make_procedural(Arc::clone(&db1));
let emb = vec![1.0_f32, 0.0, 0.0, 0.0];
pm1.learn(1, "proc_a", &steps(2), Some(&emb), 0.7).unwrap();
pm1.learn(2, "proc_b", &steps(3), Some(&emb), 0.9).unwrap();
let bytes = pm1.serialize().unwrap();
let dir2 = tempdir().unwrap();
let db2 = Arc::new(Database::open(dir2.path()).unwrap());
let pm2 = make_procedural(Arc::clone(&db2));
pm2.deserialize(&bytes).unwrap();
let all = pm2.list_all().unwrap();
assert_eq!(all.len(), 2);
let ids: Vec<u64> = all.iter().map(|p| p.id).collect();
assert!(ids.contains(&1));
assert!(ids.contains(&2));
}
#[test]
fn test_deserialize_empty_bytes_is_noop() {
let dir = tempdir().unwrap();
let db = Arc::new(Database::open(dir.path()).unwrap());
let pm = make_procedural(Arc::clone(&db));
let emb = vec![1.0_f32, 0.0, 0.0, 0.0];
pm.learn(1, "existing", &steps(1), Some(&emb), 0.5).unwrap();
pm.deserialize(&[]).unwrap();
let all = pm.list_all().unwrap();
assert_eq!(all.len(), 1);
}
}