#![allow(
clippy::expect_used,
clippy::inefficient_to_string,
clippy::similar_names,
clippy::field_reassign_with_default,
clippy::uninlined_format_args,
clippy::vec_init_then_push
)]
use do_memory_core::episode::Episode;
use do_memory_core::retrieval::{CacheKey, QueryCache};
use do_memory_core::types::{TaskContext, TaskType};
use std::collections::HashMap;
use std::sync::Arc;
use uuid::Uuid;
fn create_episode(id: &str, domain: &str) -> Arc<Episode> {
let mut context = TaskContext::default();
context.domain = domain.to_string();
Arc::new(Episode {
episode_id: Uuid::parse_str(id).unwrap_or_else(|_| Uuid::new_v4()),
task_type: TaskType::CodeGeneration,
task_description: format!("Task in {} domain", domain),
context,
start_time: chrono::Utc::now(),
end_time: None,
steps: vec![],
outcome: None,
reward: None,
reflection: None,
patterns: vec![],
heuristics: vec![],
applied_patterns: vec![],
salient_features: None,
tags: vec![],
checkpoints: vec![],
metadata: HashMap::new(),
})
}
#[test]
fn test_multi_domain_workflow() {
let cache = QueryCache::new();
let web_query =
CacheKey::new("implement REST API".to_string()).with_domain(Some("web-api".to_string()));
let data_query = CacheKey::new("process CSV file".to_string())
.with_domain(Some("data-processing".to_string()));
let ml_query =
CacheKey::new("train model".to_string()).with_domain(Some("machine-learning".to_string()));
let web_episodes = vec![
create_episode("00000000-0000-0000-0000-000000000001", "web-api"),
create_episode("00000000-0000-0000-0000-000000000002", "web-api"),
];
let data_episodes = vec![create_episode(
"00000000-0000-0000-0000-000000000003",
"data-processing",
)];
let ml_episodes = vec![create_episode(
"00000000-0000-0000-0000-000000000004",
"machine-learning",
)];
cache.put(web_query.clone(), web_episodes.clone());
cache.put(data_query.clone(), data_episodes.clone());
cache.put(ml_query.clone(), ml_episodes.clone());
assert_eq!(cache.size(), 3);
assert!(cache.get(&web_query).is_some());
assert!(cache.get(&data_query).is_some());
assert!(cache.get(&ml_query).is_some());
cache.invalidate_domain("web-api");
assert!(cache.get(&web_query).is_none());
assert!(cache.get(&data_query).is_some());
assert!(cache.get(&ml_query).is_some());
assert_eq!(cache.size(), 3); assert_eq!(cache.effective_size(), 2);
let metrics = cache.metrics();
assert_eq!(metrics.invalidations, 1);
assert!(metrics.hits > 0); }
#[test]
fn test_high_frequency_invalidation() {
let cache = QueryCache::new();
let domain_a_queries: Vec<_> = (0..10)
.map(|i| CacheKey::new(format!("query-a-{}", i)).with_domain(Some("domain-a".to_string())))
.collect();
let domain_b_queries: Vec<_> = (0..10)
.map(|i| CacheKey::new(format!("query-b-{}", i)).with_domain(Some("domain-b".to_string())))
.collect();
for query in &domain_a_queries {
cache.put(
query.clone(),
vec![create_episode(
"00000000-0000-0000-0000-000000000001",
"domain-a",
)],
);
}
for query in &domain_b_queries {
cache.put(
query.clone(),
vec![create_episode(
"00000000-0000-0000-0000-000000000002",
"domain-b",
)],
);
}
assert_eq!(cache.size(), 20);
for _ in 0..5 {
for query in &domain_a_queries {
cache.put(
query.clone(),
vec![create_episode(
"00000000-0000-0000-0000-000000000001",
"domain-a",
)],
);
}
cache.invalidate_domain("domain-a");
}
for query in &domain_a_queries {
assert!(cache.get(query).is_none());
}
for query in &domain_b_queries {
assert!(cache.get(query).is_some());
}
assert_eq!(cache.size(), 20); assert_eq!(cache.effective_size(), 10); }
#[test]
fn test_cache_hit_rate_improvement() {
let cache = QueryCache::new();
let domains = vec!["web-api", "data-processing", "machine-learning"];
for domain in &domains {
for i in 0..5 {
let key = CacheKey::new(format!("query-{}-{}", domain, i))
.with_domain(Some(domain.to_string()));
let episodes = vec![create_episode(
&format!("00000000-0000-0000-0000-00000000000{}", i),
domain,
)];
cache.put(key, episodes);
}
}
assert_eq!(cache.size(), 15);
for domain in &domains {
for i in 0..5 {
let key = CacheKey::new(format!("query-{}-{}", domain, i))
.with_domain(Some(domain.to_string()));
let _ = cache.get(&key); }
}
let metrics_before = cache.metrics();
let hit_rate_before = metrics_before.hit_rate();
cache.invalidate_domain("web-api");
for domain in &domains {
for i in 0..5 {
let key = CacheKey::new(format!("query-{}-{}", domain, i))
.with_domain(Some(domain.to_string()));
let _ = cache.get(&key); }
}
let metrics_after = cache.metrics();
let hit_rate_after = metrics_after.hit_rate();
assert!(hit_rate_after > 0.6); println!(
"Hit rate before: {:.1}%, after: {:.1}%",
hit_rate_before * 100.0,
hit_rate_after * 100.0
);
}
#[test]
fn test_domain_isolation_correctness() {
let cache = QueryCache::new();
let query_text = "implement feature X".to_string();
let key_web = CacheKey::new(query_text.clone()).with_domain(Some("web-api".to_string()));
let key_data =
CacheKey::new(query_text.clone()).with_domain(Some("data-processing".to_string()));
let web_episodes = vec![create_episode(
"00000000-0000-0000-0000-000000000001",
"web-api",
)];
let data_episodes = vec![create_episode(
"00000000-0000-0000-0000-000000000002",
"data-processing",
)];
cache.put(key_web.clone(), web_episodes.clone());
cache.put(key_data.clone(), data_episodes.clone());
let cached_web = cache.get(&key_web).expect("web-api should be cached");
let cached_data = cache.get(&key_data).expect("data should be cached");
assert_eq!(cached_web.len(), 1);
assert_eq!(cached_data.len(), 1);
assert_eq!(cached_web[0].context.domain, "web-api");
assert_eq!(cached_data[0].context.domain, "data-processing");
cache.invalidate_domain("web-api");
assert!(cache.get(&key_web).is_none());
let cached_data_after = cache.get(&key_data).expect("data should still be cached");
assert_eq!(cached_data_after.len(), 1);
}