use do_memory_core::Episode;
use do_memory_core::spatiotemporal::{HierarchicalRetriever, RetrievalQuery};
use do_memory_core::types::{ComplexityLevel, TaskContext, TaskOutcome, TaskType};
use std::sync::Arc;
fn create_test_episode(
domain: &str,
task_type: TaskType,
description: &str,
language: Option<&str>,
) -> Arc<Episode> {
let context = TaskContext {
language: language.map(String::from),
framework: None,
complexity: ComplexityLevel::Moderate,
domain: domain.to_string(),
tags: vec![],
};
let mut episode = Episode::new(description.to_string(), context, task_type);
episode.complete(TaskOutcome::Success {
verdict: "Task completed successfully".to_string(),
artifacts: vec!["output.txt".to_string()],
});
Arc::new(episode)
}
#[tokio::test]
async fn test_retrieval_with_embeddings_enabled() {
let episodes = vec![
create_test_episode(
"web-api",
TaskType::CodeGeneration,
"Implement REST API authentication with JWT tokens",
Some("rust"),
),
create_test_episode(
"web-api",
TaskType::CodeGeneration,
"Create REST endpoints for user management",
Some("rust"),
),
create_test_episode(
"data-science",
TaskType::Analysis,
"Analyze user behavior patterns",
Some("python"),
),
];
let retriever = HierarchicalRetriever::new();
let query_embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
let query = RetrievalQuery {
query_text: "Implement authentication for REST API".to_string(),
query_embedding: Some(query_embedding),
domain: Some("web-api".to_string()),
task_type: Some(TaskType::CodeGeneration),
limit: 2,
episode_embeddings: std::collections::HashMap::new(),
};
let results = retriever.retrieve(&query, &episodes).await.unwrap();
assert_eq!(results.len(), 2, "Should return exactly 2 results");
assert!(
results[0].relevance_score >= results[1].relevance_score,
"Results should be sorted by relevance"
);
for result in &results {
let episode = episodes
.iter()
.find(|e| e.episode_id == result.episode_id)
.unwrap();
assert_eq!(episode.context.domain, "web-api");
assert_eq!(episode.task_type, TaskType::CodeGeneration);
}
assert!(
results[0].level_4_score > 0.0,
"Level 4 score should be non-zero when using embeddings"
);
}
#[tokio::test]
async fn test_retrieval_with_embeddings_disabled() {
let episodes = vec![
create_test_episode(
"web-api",
TaskType::CodeGeneration,
"Implement authentication endpoint",
Some("rust"),
),
create_test_episode(
"web-api",
TaskType::Testing,
"Test authentication flow",
Some("rust"),
),
create_test_episode(
"data-science",
TaskType::Analysis,
"Analyze security patterns",
Some("python"),
),
];
let retriever = HierarchicalRetriever::new();
let query = RetrievalQuery {
query_text: "authentication endpoint".to_string(),
query_embedding: None, domain: Some("web-api".to_string()),
task_type: None,
limit: 3,
episode_embeddings: std::collections::HashMap::new(),
};
let results = retriever.retrieve(&query, &episodes).await.unwrap();
assert!(!results.is_empty(), "Should return results");
assert!(
results[0].relevance_score >= results[results.len() - 1].relevance_score,
"Results should be sorted by relevance"
);
assert!(
results[0].level_4_score >= 0.0,
"Level 4 score should be calculated using text similarity"
);
let first_episode = episodes
.iter()
.find(|e| e.episode_id == results[0].episode_id)
.unwrap();
assert!(
first_episode
.task_description
.to_lowercase()
.contains("authentication"),
"Top result should match query keywords"
);
}
#[tokio::test]
async fn test_fallback_when_embedding_generation_fails() {
let episodes = vec![
create_test_episode(
"web-api",
TaskType::CodeGeneration,
"Build REST API for user auth",
Some("rust"),
),
create_test_episode(
"web-api",
TaskType::CodeGeneration,
"Create OAuth2 integration",
Some("rust"),
),
];
let retriever = HierarchicalRetriever::new();
let query = RetrievalQuery {
query_text: "user authentication".to_string(),
query_embedding: None,
domain: Some("web-api".to_string()),
task_type: Some(TaskType::CodeGeneration),
limit: 2,
episode_embeddings: std::collections::HashMap::new(),
};
let results = retriever.retrieve(&query, &episodes).await;
assert!(
results.is_ok(),
"Retrieval should succeed without embeddings"
);
let results = results.unwrap();
assert!(
!results.is_empty(),
"Should return results using text fallback"
);
for result in &results {
assert!(
result.level_4_score >= 0.0 && result.level_4_score <= 1.0,
"Level 4 score should be in valid range [0, 1]"
);
}
}
#[tokio::test]
async fn test_embedding_dimension_mismatch_handling() {
let episodes = vec![create_test_episode(
"web-api",
TaskType::CodeGeneration,
"Implement API",
Some("rust"),
)];
let retriever = HierarchicalRetriever::new();
let query_embedding_small = vec![0.1, 0.2, 0.3, 0.4, 0.5];
let query = RetrievalQuery {
query_text: "API implementation".to_string(),
query_embedding: Some(query_embedding_small),
domain: Some("web-api".to_string()),
task_type: None,
limit: 1,
episode_embeddings: std::collections::HashMap::new(),
};
let results = retriever.retrieve(&query, &episodes).await;
assert!(
results.is_ok(),
"Should handle dimension mismatch gracefully"
);
}
#[tokio::test]
async fn test_compare_accuracy_embeddings_vs_keywords() {
let episodes = vec![
create_test_episode(
"web-api",
TaskType::CodeGeneration,
"JWT authentication implementation",
Some("rust"),
),
create_test_episode(
"web-api",
TaskType::CodeGeneration,
"Session-based login system",
Some("rust"),
),
create_test_episode(
"web-api",
TaskType::Testing,
"Test user authentication flows",
Some("rust"),
),
create_test_episode(
"data-science",
TaskType::Analysis,
"Security analysis",
Some("python"),
),
];
let retriever = HierarchicalRetriever::new();
let query_text = "implement user login and authentication";
let query_embedding = vec![0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0]; let query_with_emb = RetrievalQuery {
query_text: query_text.to_string(),
query_embedding: Some(query_embedding),
domain: Some("web-api".to_string()),
task_type: Some(TaskType::CodeGeneration),
limit: 2,
episode_embeddings: std::collections::HashMap::new(),
};
let results_with_emb = retriever
.retrieve(&query_with_emb, &episodes)
.await
.unwrap();
let query_without_emb = RetrievalQuery {
query_text: query_text.to_string(),
query_embedding: None,
domain: Some("web-api".to_string()),
task_type: Some(TaskType::CodeGeneration),
limit: 2,
episode_embeddings: std::collections::HashMap::new(),
};
let results_without_emb = retriever
.retrieve(&query_without_emb, &episodes)
.await
.unwrap();
assert_eq!(results_with_emb.len(), 2);
assert_eq!(results_without_emb.len(), 2);
for result in &results_with_emb {
let episode = episodes
.iter()
.find(|e| e.episode_id == result.episode_id)
.unwrap();
assert_eq!(episode.context.domain, "web-api");
}
for result in &results_without_emb {
let episode = episodes
.iter()
.find(|e| e.episode_id == result.episode_id)
.unwrap();
assert_eq!(episode.context.domain, "web-api");
}
println!(
"With embeddings - Top result: {}",
results_with_emb[0].episode_id
);
println!(
"Without embeddings - Top result: {}",
results_without_emb[0].episode_id
);
}
#[tokio::test]
async fn test_empty_query_embedding() {
let episodes = vec![create_test_episode(
"web-api",
TaskType::CodeGeneration,
"Build API",
Some("rust"),
)];
let retriever = HierarchicalRetriever::new();
let query = RetrievalQuery {
query_text: "API".to_string(),
query_embedding: Some(vec![]), domain: None,
task_type: None,
limit: 1,
episode_embeddings: std::collections::HashMap::new(),
};
let results = retriever.retrieve(&query, &episodes).await;
assert!(results.is_ok(), "Should handle empty embedding vector");
}
#[tokio::test]
async fn test_zero_embedding_similarity() {
let episodes = vec![create_test_episode(
"web-api",
TaskType::CodeGeneration,
"API implementation",
Some("rust"),
)];
let retriever = HierarchicalRetriever::new();
let query_embedding = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let query = RetrievalQuery {
query_text: "completely different task".to_string(),
query_embedding: Some(query_embedding),
domain: None,
task_type: None,
limit: 1,
episode_embeddings: std::collections::HashMap::new(),
};
let results = retriever.retrieve(&query, &episodes).await.unwrap();
assert!(
!results.is_empty(),
"Should return results even with low similarity"
);
assert!(
results[0].level_4_score >= 0.0 && results[0].level_4_score <= 1.0,
"Score should be in valid range"
);
}
#[tokio::test]
async fn test_perfect_embedding_match() {
let episodes = vec![create_test_episode(
"web-api",
TaskType::CodeGeneration,
"Implement authentication",
Some("rust"),
)];
let retriever = HierarchicalRetriever::new();
let query_embedding = vec![0.5, 0.9, 0.5, 1.0, 1.0, 0.1, 0.5, 0.5, 0.2, 1.0];
let query = RetrievalQuery {
query_text: "authentication".to_string(),
query_embedding: Some(query_embedding),
domain: Some("web-api".to_string()),
task_type: Some(TaskType::CodeGeneration),
limit: 1,
episode_embeddings: std::collections::HashMap::new(),
};
let results = retriever.retrieve(&query, &episodes).await.unwrap();
assert!(!results.is_empty());
assert!(
results[0].relevance_score > 0.7,
"Perfect match should have high relevance score"
);
}
#[tokio::test]
async fn test_no_episodes() {
let episodes: Vec<Arc<Episode>> = vec![];
let retriever = HierarchicalRetriever::new();
let query = RetrievalQuery {
query_text: "some query".to_string(),
query_embedding: Some(vec![0.5; 10]),
domain: None,
task_type: None,
limit: 5,
episode_embeddings: std::collections::HashMap::new(),
};
let results = retriever.retrieve(&query, &episodes).await.unwrap();
assert!(
results.is_empty(),
"Should return empty results for no episodes"
);
}
#[tokio::test]
async fn test_level_4_score_range() {
let episodes = vec![
create_test_episode(
"web-api",
TaskType::CodeGeneration,
"API implementation",
Some("rust"),
),
create_test_episode(
"data-science",
TaskType::Analysis,
"Data analysis",
Some("python"),
),
create_test_episode(
"mobile-app",
TaskType::Debugging,
"Fix crash bug",
Some("swift"),
),
];
let retriever = HierarchicalRetriever::new();
let query_embeddings = [
vec![0.0; 10], vec![1.0; 10], vec![0.5; 10], vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], vec![-1.0, -0.5, 0.0, 0.5, 1.0, -1.0, -0.5, 0.0, 0.5, 1.0], ];
for (i, query_emb) in query_embeddings.iter().enumerate() {
let query = RetrievalQuery {
query_text: format!("query {i}"),
query_embedding: Some(query_emb.clone()),
domain: None,
task_type: None,
limit: 3,
episode_embeddings: std::collections::HashMap::new(),
};
let results = retriever.retrieve(&query, &episodes).await.unwrap();
for result in &results {
assert!(
result.level_4_score >= 0.0 && result.level_4_score <= 1.0,
"Level 4 score {} should be in [0, 1] range for query {}",
result.level_4_score,
i
);
}
}
}