do-memory-mcp 0.1.31

Model Context Protocol (MCP) server for AI agents
Documentation
//! Embedding tools tests.

#![allow(unused_imports)]

use std::sync::Arc;

use do_memory_core::SelfLearningMemory;

use crate::mcp::tools::embeddings::tool::{
    EmbeddingTools, configure_embeddings_tool, embedding_provider_status_tool,
    generate_embedding_tool, query_semantic_memory_tool, search_by_embedding_tool,
    test_embeddings_tool,
};
use crate::mcp::tools::embeddings::types::{
    ConfigureEmbeddingsInput, EmbeddingProviderStatusInput, GenerateEmbeddingInput,
    QuerySemanticMemoryInput, SearchByEmbeddingInput,
};

#[test]
fn test_configure_embeddings_tool_definition() {
    let tool = configure_embeddings_tool();
    assert_eq!(tool.name, "configure_embeddings");
    assert!(!tool.description.is_empty());
    assert!(tool.input_schema.is_object());

    // Verify required fields
    let schema = tool.input_schema.as_object().unwrap();
    let properties = schema.get("properties").unwrap().as_object().unwrap();
    assert!(properties.contains_key("provider"));

    // Verify provider enum
    let provider = properties.get("provider").unwrap().as_object().unwrap();
    let enum_values = provider.get("enum").unwrap().as_array().unwrap();
    assert_eq!(enum_values.len(), 5);
    assert!(enum_values.contains(&serde_json::json!("openai")));
    assert!(enum_values.contains(&serde_json::json!("local")));
    assert!(enum_values.contains(&serde_json::json!("mistral")));
    assert!(enum_values.contains(&serde_json::json!("azure")));
    assert!(enum_values.contains(&serde_json::json!("cohere")));
}

#[test]
fn test_query_semantic_memory_tool_definition() {
    let tool = query_semantic_memory_tool();
    assert_eq!(tool.name, "query_semantic_memory");
    assert!(!tool.description.is_empty());
    assert!(tool.input_schema.is_object());

    // Verify required fields
    let schema = tool.input_schema.as_object().unwrap();
    let required = schema.get("required").unwrap().as_array().unwrap();
    assert_eq!(required.len(), 1);
    assert!(required.contains(&serde_json::json!("query")));
}

#[test]
fn test_test_embeddings_tool_definition() {
    let tool = test_embeddings_tool();
    assert_eq!(tool.name, "test_embeddings");
    assert!(!tool.description.is_empty());
    assert!(tool.input_schema.is_object());

    // Should have no required properties
    let schema = tool.input_schema.as_object().unwrap();
    let properties = schema.get("properties").unwrap().as_object().unwrap();
    assert!(properties.is_empty());
}

#[tokio::test]
async fn test_configure_embeddings_local() {
    let memory = Arc::new(SelfLearningMemory::new());
    let tools = EmbeddingTools::new(memory);

    let input = ConfigureEmbeddingsInput {
        provider: "local".to_string(),
        model: Some("sentence-transformers/all-MiniLM-L6-v2".to_string()),
        api_key_env: None,
        similarity_threshold: Some(0.75),
        batch_size: Some(16),
        base_url: None,
        api_version: None,
        resource_name: None,
        deployment_name: None,
    };

    let result = tools.execute_configure_embeddings(input).await;
    assert!(result.is_ok());

    let output = result.unwrap();
    assert!(output.success);
    assert_eq!(output.provider, "local");
    assert_eq!(output.dimension, 384);
    assert!(output.warnings.is_empty());
}

#[tokio::test]
async fn test_configure_embeddings_openai() {
    let memory = Arc::new(SelfLearningMemory::new());
    let tools = EmbeddingTools::new(memory);

    let input = ConfigureEmbeddingsInput {
        provider: "openai".to_string(),
        model: Some("text-embedding-3-small".to_string()),
        api_key_env: Some("OPENAI_API_KEY".to_string()),
        similarity_threshold: None,
        batch_size: None,
        base_url: None,
        api_version: None,
        resource_name: None,
        deployment_name: None,
    };

    let result = tools.execute_configure_embeddings(input).await;
    // May succeed or fail depending on whether OPENAI_API_KEY is set
    // We're testing that it doesn't panic
    assert!(result.is_ok() || result.is_err());
}

#[tokio::test]
async fn test_configure_embeddings_invalid_provider() {
    let memory = Arc::new(SelfLearningMemory::new());
    let tools = EmbeddingTools::new(memory);

    let input = ConfigureEmbeddingsInput {
        provider: "invalid-provider".to_string(),
        model: None,
        api_key_env: None,
        similarity_threshold: None,
        batch_size: None,
        base_url: None,
        api_version: None,
        resource_name: None,
        deployment_name: None,
    };

    let result = tools.execute_configure_embeddings(input).await;
    assert!(result.is_err());
    assert!(
        result
            .unwrap_err()
            .to_string()
            .contains("Unsupported provider")
    );
}

#[tokio::test]
async fn test_query_semantic_memory() {
    let memory = Arc::new(SelfLearningMemory::new());
    let tools = EmbeddingTools::new(memory);

    let input = QuerySemanticMemoryInput {
        query: "implement REST API".to_string(),
        limit: Some(5),
        similarity_threshold: Some(0.8),
        domain: Some("web-api".to_string()),
        task_type: Some("code_generation".to_string()),
    };

    let result = tools.execute_query_semantic_memory(input).await;
    assert!(result.is_ok());

    let output = result.unwrap();
    assert!(output.query_time_ms >= 0.0);
    assert_eq!(output.embedding_dimension, 384);
}

#[tokio::test]
async fn test_test_embeddings() {
    let memory = Arc::new(SelfLearningMemory::new());
    let tools = EmbeddingTools::new(memory);

    let result = tools.execute_test_embeddings().await;
    assert!(result.is_ok());

    let output = result.unwrap();
    assert!(!output.available); // Not configured by default
    // When no semantic service is configured, sample_embedding is empty
    assert_eq!(output.sample_embedding.len(), 0);
    assert!(!output.message.is_empty());
    assert!(output.message.contains("not yet configured"));
}

#[tokio::test]
async fn test_configure_embeddings_azure_missing_fields() {
    let memory = Arc::new(SelfLearningMemory::new());
    let tools = EmbeddingTools::new(memory);

    let input = ConfigureEmbeddingsInput {
        provider: "azure".to_string(),
        model: None,
        api_key_env: None, // Don't require API key for this validation test
        similarity_threshold: None,
        batch_size: None,
        base_url: None,
        api_version: None,
        resource_name: None,   // Missing required field
        deployment_name: None, // Missing required field
    };

    let result = tools.execute_configure_embeddings(input).await;
    assert!(result.is_err());
    let error_msg = result.unwrap_err().to_string();
    assert!(
        error_msg.contains("deployment_name") || error_msg.contains("resource_name"),
        "Expected error about missing deployment_name or resource_name, got: {}",
        error_msg
    );
}

#[test]
fn test_generate_embedding_tool_definition() {
    let tool = generate_embedding_tool();
    assert_eq!(tool.name, "generate_embedding");
    assert!(!tool.description.is_empty());
    assert!(tool.input_schema.is_object());

    // Verify required fields
    let schema = tool.input_schema.as_object().unwrap();
    let required = schema.get("required").unwrap().as_array().unwrap();
    assert_eq!(required.len(), 1);
    assert!(required.contains(&serde_json::json!("text")));

    // Verify properties
    let properties = schema.get("properties").unwrap().as_object().unwrap();
    assert!(properties.contains_key("text"));
    assert!(properties.contains_key("normalize"));
}

#[test]
fn test_search_by_embedding_tool_definition() {
    let tool = search_by_embedding_tool();
    assert_eq!(tool.name, "search_by_embedding");
    assert!(!tool.description.is_empty());
    assert!(tool.input_schema.is_object());

    // Verify required fields
    let schema = tool.input_schema.as_object().unwrap();
    let required = schema.get("required").unwrap().as_array().unwrap();
    assert_eq!(required.len(), 1);
    assert!(required.contains(&serde_json::json!("embedding")));

    // Verify properties
    let properties = schema.get("properties").unwrap().as_object().unwrap();
    assert!(properties.contains_key("embedding"));
    assert!(properties.contains_key("limit"));
    assert!(properties.contains_key("similarity_threshold"));
}

#[test]
fn test_embedding_provider_status_tool_definition() {
    let tool = embedding_provider_status_tool();
    assert_eq!(tool.name, "embedding_provider_status");
    assert!(!tool.description.is_empty());
    assert!(tool.input_schema.is_object());

    // Should have no required properties
    let schema = tool.input_schema.as_object().unwrap();
    let required = schema.get("required");
    assert!(required.is_none() || required.unwrap().as_array().unwrap().is_empty());

    // Verify properties
    let properties = schema.get("properties").unwrap().as_object().unwrap();
    assert!(properties.contains_key("test_connectivity"));
}

#[tokio::test]
async fn test_generate_embedding_not_configured() {
    let memory = Arc::new(SelfLearningMemory::new());
    let tools = EmbeddingTools::new(memory);

    let input = GenerateEmbeddingInput {
        text: "test text".to_string(),
        normalize: true,
    };

    let result = tools.execute_generate_embedding(input).await;
    assert!(result.is_err());
    let error_msg = result.unwrap_err().to_string();
    assert!(
        error_msg.contains("not configured"),
        "Expected error about not configured, got: {}",
        error_msg
    );
}

#[tokio::test]
async fn test_search_by_embedding_not_configured() {
    let memory = Arc::new(SelfLearningMemory::new());
    let tools = EmbeddingTools::new(memory);

    let input = SearchByEmbeddingInput {
        embedding: vec![0.1; 384], // Default dimension
        limit: 10,
        similarity_threshold: 0.7,
        domain: None,
        task_type: None,
    };

    let result = tools.execute_search_by_embedding(input).await;
    assert!(result.is_err());
    let error_msg = result.unwrap_err().to_string();
    assert!(
        error_msg.contains("not configured"),
        "Expected error about not configured, got: {}",
        error_msg
    );
}

#[tokio::test]
async fn test_embedding_provider_status_not_configured() {
    let memory = Arc::new(SelfLearningMemory::new());
    let tools = EmbeddingTools::new(memory);

    let input = EmbeddingProviderStatusInput {
        test_connectivity: false,
    };

    let result = tools.execute_embedding_provider_status(input).await;
    assert!(result.is_ok());

    let output = result.unwrap();
    assert!(!output.configured);
    assert!(!output.available);
    assert_eq!(output.provider, "not-configured");
    assert!(!output.warnings.is_empty());
}

#[tokio::test]
async fn test_embedding_provider_status_with_test() {
    let memory = Arc::new(SelfLearningMemory::new());
    let tools = EmbeddingTools::new(memory);

    let input = EmbeddingProviderStatusInput {
        test_connectivity: true,
    };

    let result = tools.execute_embedding_provider_status(input).await;
    assert!(result.is_ok());

    let output = result.unwrap();
    assert!(!output.configured);
    assert!(!output.available);
    // test_result should be None since not configured
    assert!(output.test_result.is_none());
}

#[tokio::test]
async fn test_search_by_embedding_wrong_dimension() {
    let memory = Arc::new(SelfLearningMemory::new());
    let tools = EmbeddingTools::new(memory);

    let input = SearchByEmbeddingInput {
        embedding: vec![0.1; 128], // Wrong dimension
        limit: 10,
        similarity_threshold: 0.7,
        domain: None,
        task_type: None,
    };

    let result = tools.execute_search_by_embedding(input).await;
    // Since not configured, will error about not configured first
    assert!(result.is_err());
}