#![allow(unused_imports)]
use std::sync::Arc;
use do_memory_core::SelfLearningMemory;
use crate::mcp::tools::embeddings::tool::{
EmbeddingTools, configure_embeddings_tool, embedding_provider_status_tool,
generate_embedding_tool, query_semantic_memory_tool, search_by_embedding_tool,
test_embeddings_tool,
};
use crate::mcp::tools::embeddings::types::{
ConfigureEmbeddingsInput, EmbeddingProviderStatusInput, GenerateEmbeddingInput,
QuerySemanticMemoryInput, SearchByEmbeddingInput,
};
#[test]
fn test_configure_embeddings_tool_definition() {
let tool = configure_embeddings_tool();
assert_eq!(tool.name, "configure_embeddings");
assert!(!tool.description.is_empty());
assert!(tool.input_schema.is_object());
let schema = tool.input_schema.as_object().unwrap();
let properties = schema.get("properties").unwrap().as_object().unwrap();
assert!(properties.contains_key("provider"));
let provider = properties.get("provider").unwrap().as_object().unwrap();
let enum_values = provider.get("enum").unwrap().as_array().unwrap();
assert_eq!(enum_values.len(), 5);
assert!(enum_values.contains(&serde_json::json!("openai")));
assert!(enum_values.contains(&serde_json::json!("local")));
assert!(enum_values.contains(&serde_json::json!("mistral")));
assert!(enum_values.contains(&serde_json::json!("azure")));
assert!(enum_values.contains(&serde_json::json!("cohere")));
}
#[test]
fn test_query_semantic_memory_tool_definition() {
let tool = query_semantic_memory_tool();
assert_eq!(tool.name, "query_semantic_memory");
assert!(!tool.description.is_empty());
assert!(tool.input_schema.is_object());
let schema = tool.input_schema.as_object().unwrap();
let required = schema.get("required").unwrap().as_array().unwrap();
assert_eq!(required.len(), 1);
assert!(required.contains(&serde_json::json!("query")));
}
#[test]
fn test_test_embeddings_tool_definition() {
let tool = test_embeddings_tool();
assert_eq!(tool.name, "test_embeddings");
assert!(!tool.description.is_empty());
assert!(tool.input_schema.is_object());
let schema = tool.input_schema.as_object().unwrap();
let properties = schema.get("properties").unwrap().as_object().unwrap();
assert!(properties.is_empty());
}
#[tokio::test]
async fn test_configure_embeddings_local() {
let memory = Arc::new(SelfLearningMemory::new());
let tools = EmbeddingTools::new(memory);
let input = ConfigureEmbeddingsInput {
provider: "local".to_string(),
model: Some("sentence-transformers/all-MiniLM-L6-v2".to_string()),
api_key_env: None,
similarity_threshold: Some(0.75),
batch_size: Some(16),
base_url: None,
api_version: None,
resource_name: None,
deployment_name: None,
};
let result = tools.execute_configure_embeddings(input).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.success);
assert_eq!(output.provider, "local");
assert_eq!(output.dimension, 384);
assert!(output.warnings.is_empty());
}
#[tokio::test]
async fn test_configure_embeddings_openai() {
let memory = Arc::new(SelfLearningMemory::new());
let tools = EmbeddingTools::new(memory);
let input = ConfigureEmbeddingsInput {
provider: "openai".to_string(),
model: Some("text-embedding-3-small".to_string()),
api_key_env: Some("OPENAI_API_KEY".to_string()),
similarity_threshold: None,
batch_size: None,
base_url: None,
api_version: None,
resource_name: None,
deployment_name: None,
};
let result = tools.execute_configure_embeddings(input).await;
assert!(result.is_ok() || result.is_err());
}
#[tokio::test]
async fn test_configure_embeddings_invalid_provider() {
let memory = Arc::new(SelfLearningMemory::new());
let tools = EmbeddingTools::new(memory);
let input = ConfigureEmbeddingsInput {
provider: "invalid-provider".to_string(),
model: None,
api_key_env: None,
similarity_threshold: None,
batch_size: None,
base_url: None,
api_version: None,
resource_name: None,
deployment_name: None,
};
let result = tools.execute_configure_embeddings(input).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Unsupported provider")
);
}
#[tokio::test]
async fn test_query_semantic_memory() {
let memory = Arc::new(SelfLearningMemory::new());
let tools = EmbeddingTools::new(memory);
let input = QuerySemanticMemoryInput {
query: "implement REST API".to_string(),
limit: Some(5),
similarity_threshold: Some(0.8),
domain: Some("web-api".to_string()),
task_type: Some("code_generation".to_string()),
};
let result = tools.execute_query_semantic_memory(input).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.query_time_ms >= 0.0);
assert_eq!(output.embedding_dimension, 384);
}
#[tokio::test]
async fn test_test_embeddings() {
let memory = Arc::new(SelfLearningMemory::new());
let tools = EmbeddingTools::new(memory);
let result = tools.execute_test_embeddings().await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(!output.available); assert_eq!(output.sample_embedding.len(), 0);
assert!(!output.message.is_empty());
assert!(output.message.contains("not yet configured"));
}
#[tokio::test]
async fn test_configure_embeddings_azure_missing_fields() {
let memory = Arc::new(SelfLearningMemory::new());
let tools = EmbeddingTools::new(memory);
let input = ConfigureEmbeddingsInput {
provider: "azure".to_string(),
model: None,
api_key_env: None, similarity_threshold: None,
batch_size: None,
base_url: None,
api_version: None,
resource_name: None, deployment_name: None, };
let result = tools.execute_configure_embeddings(input).await;
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(
error_msg.contains("deployment_name") || error_msg.contains("resource_name"),
"Expected error about missing deployment_name or resource_name, got: {}",
error_msg
);
}
#[test]
fn test_generate_embedding_tool_definition() {
let tool = generate_embedding_tool();
assert_eq!(tool.name, "generate_embedding");
assert!(!tool.description.is_empty());
assert!(tool.input_schema.is_object());
let schema = tool.input_schema.as_object().unwrap();
let required = schema.get("required").unwrap().as_array().unwrap();
assert_eq!(required.len(), 1);
assert!(required.contains(&serde_json::json!("text")));
let properties = schema.get("properties").unwrap().as_object().unwrap();
assert!(properties.contains_key("text"));
assert!(properties.contains_key("normalize"));
}
#[test]
fn test_search_by_embedding_tool_definition() {
let tool = search_by_embedding_tool();
assert_eq!(tool.name, "search_by_embedding");
assert!(!tool.description.is_empty());
assert!(tool.input_schema.is_object());
let schema = tool.input_schema.as_object().unwrap();
let required = schema.get("required").unwrap().as_array().unwrap();
assert_eq!(required.len(), 1);
assert!(required.contains(&serde_json::json!("embedding")));
let properties = schema.get("properties").unwrap().as_object().unwrap();
assert!(properties.contains_key("embedding"));
assert!(properties.contains_key("limit"));
assert!(properties.contains_key("similarity_threshold"));
}
#[test]
fn test_embedding_provider_status_tool_definition() {
let tool = embedding_provider_status_tool();
assert_eq!(tool.name, "embedding_provider_status");
assert!(!tool.description.is_empty());
assert!(tool.input_schema.is_object());
let schema = tool.input_schema.as_object().unwrap();
let required = schema.get("required");
assert!(required.is_none() || required.unwrap().as_array().unwrap().is_empty());
let properties = schema.get("properties").unwrap().as_object().unwrap();
assert!(properties.contains_key("test_connectivity"));
}
#[tokio::test]
async fn test_generate_embedding_not_configured() {
let memory = Arc::new(SelfLearningMemory::new());
let tools = EmbeddingTools::new(memory);
let input = GenerateEmbeddingInput {
text: "test text".to_string(),
normalize: true,
};
let result = tools.execute_generate_embedding(input).await;
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(
error_msg.contains("not configured"),
"Expected error about not configured, got: {}",
error_msg
);
}
#[tokio::test]
async fn test_search_by_embedding_not_configured() {
let memory = Arc::new(SelfLearningMemory::new());
let tools = EmbeddingTools::new(memory);
let input = SearchByEmbeddingInput {
embedding: vec![0.1; 384], limit: 10,
similarity_threshold: 0.7,
domain: None,
task_type: None,
};
let result = tools.execute_search_by_embedding(input).await;
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(
error_msg.contains("not configured"),
"Expected error about not configured, got: {}",
error_msg
);
}
#[tokio::test]
async fn test_embedding_provider_status_not_configured() {
let memory = Arc::new(SelfLearningMemory::new());
let tools = EmbeddingTools::new(memory);
let input = EmbeddingProviderStatusInput {
test_connectivity: false,
};
let result = tools.execute_embedding_provider_status(input).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(!output.configured);
assert!(!output.available);
assert_eq!(output.provider, "not-configured");
assert!(!output.warnings.is_empty());
}
#[tokio::test]
async fn test_embedding_provider_status_with_test() {
let memory = Arc::new(SelfLearningMemory::new());
let tools = EmbeddingTools::new(memory);
let input = EmbeddingProviderStatusInput {
test_connectivity: true,
};
let result = tools.execute_embedding_provider_status(input).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(!output.configured);
assert!(!output.available);
assert!(output.test_result.is_none());
}
#[tokio::test]
async fn test_search_by_embedding_wrong_dimension() {
let memory = Arc::new(SelfLearningMemory::new());
let tools = EmbeddingTools::new(memory);
let input = SearchByEmbeddingInput {
embedding: vec![0.1; 128], limit: 10,
similarity_threshold: 0.7,
domain: None,
task_type: None,
};
let result = tools.execute_search_by_embedding(input).await;
assert!(result.is_err());
}