use do_memory_core::embeddings::{EmbeddingProvider, LocalEmbeddingProvider};
use do_memory_core::episode::Episode;
use do_memory_core::spatiotemporal::embeddings::{ContextAwareEmbeddings, ContrastivePair};
use do_memory_core::types::{ComplexityLevel, TaskContext, TaskOutcome, TaskType};
use std::sync::Arc;
fn create_test_episode(task_type: TaskType, description: &str, domain: &str) -> Episode {
let context = TaskContext {
language: Some("rust".to_string()),
framework: None,
complexity: ComplexityLevel::Moderate,
domain: domain.to_string(),
tags: vec![],
};
let mut episode = Episode::new(description.to_string(), context, task_type);
episode.complete(TaskOutcome::Success {
verdict: "Success".to_string(),
artifacts: vec![],
});
episode
}
#[tokio::test]
async fn test_context_aware_embeddings_integration() {
let config = do_memory_core::embeddings::EmbeddingConfig::default();
let local_config =
if let do_memory_core::embeddings::ProviderConfig::Local(cfg) = &config.provider {
cfg.clone()
} else {
eprintln!("Skipping test - default config is not local");
return;
};
let base = if let Ok(provider) = LocalEmbeddingProvider::new(local_config).await {
Arc::new(provider)
} else {
eprintln!("Skipping test - embedding model not available");
return;
};
let mut embeddings = ContextAwareEmbeddings::new(base.clone());
assert_eq!(embeddings.adapter_count(), 0);
assert!(!embeddings.has_adapter(TaskType::CodeGeneration));
let coding_pairs = vec![
ContrastivePair {
anchor: create_test_episode(TaskType::CodeGeneration, "implement API", "web"),
positive: create_test_episode(TaskType::CodeGeneration, "build REST service", "web"),
negative: create_test_episode(TaskType::Debugging, "fix crash", "web"),
},
ContrastivePair {
anchor: create_test_episode(TaskType::CodeGeneration, "create database", "backend"),
positive: create_test_episode(TaskType::CodeGeneration, "add schema", "backend"),
negative: create_test_episode(TaskType::Testing, "write tests", "backend"),
},
];
let result = embeddings
.train_adapter(TaskType::CodeGeneration, &coding_pairs)
.await;
assert!(result.is_ok());
assert_eq!(embeddings.adapter_count(), 1);
assert!(embeddings.has_adapter(TaskType::CodeGeneration));
let adapter = embeddings.get_adapter(TaskType::CodeGeneration).unwrap();
assert_eq!(adapter.trained_on_count, 2);
let text = "implement authentication";
let base_embedding = base.embed_text(text).await.unwrap();
let adapted_embedding = embeddings
.get_adapted_embedding(text, Some(TaskType::CodeGeneration))
.await
.unwrap();
assert_eq!(base_embedding.len(), adapted_embedding.len());
let debugging_embedding = embeddings
.get_adapted_embedding(text, Some(TaskType::Debugging))
.await
.unwrap();
assert_eq!(debugging_embedding, base_embedding);
}
#[tokio::test]
async fn test_multiple_task_adapters() {
let config = do_memory_core::embeddings::EmbeddingConfig::default();
let local_config =
if let do_memory_core::embeddings::ProviderConfig::Local(cfg) = &config.provider {
cfg.clone()
} else {
eprintln!("Skipping test - default config is not local");
return;
};
let base = if let Ok(provider) = LocalEmbeddingProvider::new(local_config).await {
Arc::new(provider)
} else {
eprintln!("Skipping test - embedding model not available");
return;
};
let mut embeddings = ContextAwareEmbeddings::new(base);
let task_types = vec![
TaskType::CodeGeneration,
TaskType::Debugging,
TaskType::Refactoring,
];
for task_type in task_types {
let pairs = vec![ContrastivePair {
anchor: create_test_episode(task_type, "task 1", "domain"),
positive: create_test_episode(task_type, "task 2", "domain"),
negative: create_test_episode(TaskType::Analysis, "analyze", "domain"),
}];
embeddings.train_adapter(task_type, &pairs).await.unwrap();
}
assert_eq!(embeddings.adapter_count(), 3);
assert!(embeddings.has_adapter(TaskType::CodeGeneration));
assert!(embeddings.has_adapter(TaskType::Debugging));
assert!(embeddings.has_adapter(TaskType::Refactoring));
assert!(!embeddings.has_adapter(TaskType::Testing));
}
#[tokio::test]
async fn test_empty_training_pairs_error() {
let _config = do_memory_core::embeddings::EmbeddingConfig::default();
let mock = do_memory_core::embeddings::MockLocalModel::new("mock".to_string(), 128);
let mut embeddings = ContextAwareEmbeddings::new(Arc::new(mock));
let result = embeddings
.train_adapter(TaskType::CodeGeneration, &[])
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("empty"));
}
#[tokio::test]
async fn test_backward_compatibility_no_adapters() {
let config = do_memory_core::embeddings::EmbeddingConfig::default();
let local_config =
if let do_memory_core::embeddings::ProviderConfig::Local(cfg) = &config.provider {
cfg.clone()
} else {
eprintln!("Skipping test - default config is not local");
return;
};
let base = if let Ok(provider) = LocalEmbeddingProvider::new(local_config).await {
Arc::new(provider)
} else {
eprintln!("Skipping test - embedding model not available");
return;
};
let embeddings = ContextAwareEmbeddings::new(base.clone());
let text = "test task";
let none_result = embeddings.get_adapted_embedding(text, None).await.unwrap();
let some_result = embeddings
.get_adapted_embedding(text, Some(TaskType::CodeGeneration))
.await
.unwrap();
let base_result = base.embed_text(text).await.unwrap();
assert_eq!(none_result, base_result);
assert_eq!(some_result, base_result);
}