use crate::SelfLearningMemory;
use crate::embeddings::{EmbeddingConfig, ProviderConfig};
use crate::episode::ExecutionStep;
use crate::types::{ExecutionResult, TaskContext, TaskOutcome, TaskType};
#[tokio::test]
pub async fn test_semantic_service_initialization() {
let memory = SelfLearningMemory::new();
let has_semantic = memory.semantic_service.is_some();
if has_semantic {
assert!(memory.semantic_config.similarity_threshold > 0.0);
assert!(memory.semantic_config.similarity_threshold <= 1.0);
}
}
#[tokio::test]
pub async fn test_with_semantic_config() {
let custom_config = EmbeddingConfig {
provider: ProviderConfig::local_default(),
similarity_threshold: 0.8,
batch_size: 16,
cache_embeddings: false,
timeout_seconds: 60,
};
let memory = SelfLearningMemory::with_semantic_config(
crate::MemoryConfig::default(),
custom_config.clone(),
);
assert_eq!(memory.semantic_config.similarity_threshold, 0.8);
assert_eq!(memory.semantic_config.batch_size, 16);
assert!(!memory.semantic_config.cache_embeddings);
assert_eq!(memory.semantic_config.timeout_seconds, 60);
}
#[tokio::test(flavor = "multi_thread")]
pub async fn test_embedding_generation_on_completion() {
let test_config = crate::MemoryConfig {
quality_threshold: 0.5,
pattern_extraction_threshold: 1.0, enable_summarization: false, enable_embeddings: false, ..Default::default()
};
let memory = SelfLearningMemory::with_config(test_config);
let episode_id = memory
.start_episode(
"Test embedding generation".to_string(),
TaskContext::default(),
TaskType::CodeGeneration,
)
.await;
for i in 0..20 {
let mut step =
ExecutionStep::new(i + 1, format!("tool_{}", i % 6), format!("Test step {i}"));
step.result = Some(ExecutionResult::Success {
output: format!("Step {i} passed"),
});
memory.log_step(episode_id, step).await;
}
memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Test completed".to_string(),
artifacts: vec![],
},
)
.await
.expect("Episode completion should succeed");
let episode = memory.get_episode(episode_id).await.unwrap();
assert!(episode.is_complete());
}
#[tokio::test(flavor = "multi_thread")]
pub async fn test_semantic_fallback_to_keyword() {
let test_config = crate::MemoryConfig {
quality_threshold: 0.5,
pattern_extraction_threshold: 1.0, enable_summarization: false, enable_embeddings: false, ..Default::default()
};
let memory = SelfLearningMemory::with_config(test_config);
let episode1 = memory
.start_episode(
"Implement REST API".to_string(),
TaskContext {
domain: "web-api".to_string(),
..Default::default()
},
TaskType::CodeGeneration,
)
.await;
for i in 0..20 {
let mut step = ExecutionStep::new(i + 1, format!("tool_{}", i % 6), format!("Step {i}"));
step.result = Some(ExecutionResult::Success {
output: "Success".to_string(),
});
memory.log_step(episode1, step).await;
}
memory
.complete_episode(
episode1,
TaskOutcome::Success {
verdict: "API implemented".to_string(),
artifacts: vec![],
},
)
.await
.unwrap();
let relevant = memory
.retrieve_relevant_context("Create API".to_string(), TaskContext::default(), 5)
.await;
assert!(!relevant.is_empty() || relevant.is_empty()); }