Skip to main content

do_memory_mcp/mcp/tools/embeddings/tool/execute/
configure.rs

1//! Configure embeddings tool implementation.
2
3use super::super::definitions::EmbeddingTools;
4use crate::mcp::tools::embeddings::types::{ConfigureEmbeddingsInput, ConfigureEmbeddingsOutput};
5use anyhow::{Result, anyhow};
6use do_memory_core::embeddings::config::{
7    AzureOpenAIConfig, CustomConfig, EmbeddingConfig, EmbeddingProvider, LocalConfig,
8    ProviderConfig,
9};
10use tracing::{debug, info, instrument};
11
12impl EmbeddingTools {
13    /// Execute the configure_embeddings tool
14    #[instrument(skip(self, input), fields(provider = %input.provider))]
15    pub async fn execute_configure_embeddings(
16        &self,
17        input: ConfigureEmbeddingsInput,
18    ) -> Result<ConfigureEmbeddingsOutput> {
19        info!("Configuring embedding provider: {}", input.provider);
20
21        let mut warnings = Vec::new();
22
23        // Parse provider type
24        let provider_type = match input.provider.to_lowercase().as_str() {
25            "openai" => EmbeddingProvider::OpenAI,
26            "local" => EmbeddingProvider::Local,
27            "mistral" => EmbeddingProvider::Mistral,
28            "azure" => EmbeddingProvider::AzureOpenAI,
29            "cohere" => {
30                warnings.push(
31                    "Cohere provider not yet implemented, using Local as fallback".to_string(),
32                );
33                EmbeddingProvider::Local
34            }
35            _ => {
36                return Err(anyhow!(
37                    "Unsupported provider: {}. Supported providers: openai, local, mistral, azure, cohere",
38                    input.provider
39                ));
40            }
41        };
42
43        // Validate API key for cloud providers
44        if matches!(
45            provider_type,
46            EmbeddingProvider::OpenAI | EmbeddingProvider::Mistral | EmbeddingProvider::AzureOpenAI
47        ) {
48            if let Some(api_key_env) = &input.api_key_env {
49                if std::env::var(api_key_env).is_err() {
50                    return Err(anyhow!(
51                        "Environment variable '{}' not set. Please set the API key.",
52                        api_key_env
53                    ));
54                }
55            } else {
56                warnings.push(format!(
57                    "No api_key_env specified for {}. Make sure API key is set in standard environment variable.",
58                    input.provider
59                ));
60            }
61        }
62
63        // Build model configuration based on provider
64        let provider_config =
65            match provider_type {
66                EmbeddingProvider::OpenAI => {
67                    let model_name = input.model.as_deref().unwrap_or("text-embedding-3-small");
68                    match model_name {
69                        "text-embedding-3-small" => ProviderConfig::openai_3_small(),
70                        "text-embedding-3-large" => ProviderConfig::openai_3_large(),
71                        "text-embedding-ada-002" => ProviderConfig::openai_ada_002(),
72                        _ => {
73                            warnings.push(format!(
74                                "Unknown OpenAI model '{}', using text-embedding-3-small",
75                                model_name
76                            ));
77                            ProviderConfig::openai_3_small()
78                        }
79                    }
80                }
81                EmbeddingProvider::Mistral => {
82                    let model_name = input.model.as_deref().unwrap_or("mistral-embed");
83                    if model_name != "mistral-embed" {
84                        warnings.push(format!(
85                            "Unknown Mistral model '{}', using mistral-embed",
86                            model_name
87                        ));
88                    }
89                    ProviderConfig::mistral_embed()
90                }
91                EmbeddingProvider::AzureOpenAI => {
92                    let deployment = input.deployment_name.as_ref().ok_or_else(|| {
93                        anyhow!("deployment_name required for Azure OpenAI provider")
94                    })?;
95                    let resource = input.resource_name.as_ref().ok_or_else(|| {
96                        anyhow!("resource_name required for Azure OpenAI provider")
97                    })?;
98                    let api_version = input.api_version.as_deref().unwrap_or("2023-05-15");
99
100                    // Azure dimension depends on the underlying model
101                    let dimension = 1536; // Default for ada-002 and text-embedding-3-small
102                    ProviderConfig::AzureOpenAI(AzureOpenAIConfig::new(
103                        deployment,
104                        resource,
105                        api_version,
106                        dimension,
107                    ))
108                }
109                EmbeddingProvider::Local => {
110                    let model_name = input
111                        .model
112                        .as_deref()
113                        .unwrap_or("sentence-transformers/all-MiniLM-L6-v2");
114                    let dimension = 384; // Default for MiniLM
115                    ProviderConfig::Local(LocalConfig::new(model_name, dimension))
116                }
117                EmbeddingProvider::Custom(_) => {
118                    let model_name = input.model.as_deref().unwrap_or("custom-model");
119                    let base_url = input
120                        .base_url
121                        .as_deref()
122                        .ok_or_else(|| anyhow!("base_url required for custom provider"))?;
123                    ProviderConfig::Custom(CustomConfig::new(model_name, 384, base_url))
124                }
125            };
126
127        // Build embedding configuration
128        let embedding_config = EmbeddingConfig {
129            provider: provider_config.clone(),
130            similarity_threshold: input.similarity_threshold.unwrap_or(0.7),
131            batch_size: input.batch_size.unwrap_or(32),
132            cache_embeddings: true,
133            timeout_seconds: 30,
134        };
135
136        // NOTE: In a real implementation, you would update the memory system's
137        // semantic_service here. Since semantic_service is private and Option,
138        // we simulate the configuration response.
139
140        debug!(
141            "Configured embedding provider: {:?} with model: {}",
142            embedding_config.provider,
143            embedding_config.provider.model_name()
144        );
145
146        let provider_name = input.provider.clone();
147        Ok(ConfigureEmbeddingsOutput {
148            success: true,
149            provider: input.provider,
150            model: provider_config.model_name(),
151            dimension: provider_config.effective_dimension(),
152            message: format!(
153                "Successfully configured {} provider with model {} (dimension: {})",
154                provider_name,
155                provider_config.model_name(),
156                provider_config.effective_dimension()
157            ),
158            warnings,
159        })
160    }
161}