do_memory_mcp/mcp/tools/embeddings/tool/execute/
configure.rs1use super::super::definitions::EmbeddingTools;
4use crate::mcp::tools::embeddings::types::{ConfigureEmbeddingsInput, ConfigureEmbeddingsOutput};
5use anyhow::{Result, anyhow};
6use do_memory_core::embeddings::config::{
7 AzureOpenAIConfig, CustomConfig, EmbeddingConfig, EmbeddingProvider, LocalConfig,
8 ProviderConfig,
9};
10use tracing::{debug, info, instrument};
11
12impl EmbeddingTools {
13 #[instrument(skip(self, input), fields(provider = %input.provider))]
15 pub async fn execute_configure_embeddings(
16 &self,
17 input: ConfigureEmbeddingsInput,
18 ) -> Result<ConfigureEmbeddingsOutput> {
19 info!("Configuring embedding provider: {}", input.provider);
20
21 let mut warnings = Vec::new();
22
23 let provider_type = match input.provider.to_lowercase().as_str() {
25 "openai" => EmbeddingProvider::OpenAI,
26 "local" => EmbeddingProvider::Local,
27 "mistral" => EmbeddingProvider::Mistral,
28 "azure" => EmbeddingProvider::AzureOpenAI,
29 "cohere" => {
30 warnings.push(
31 "Cohere provider not yet implemented, using Local as fallback".to_string(),
32 );
33 EmbeddingProvider::Local
34 }
35 _ => {
36 return Err(anyhow!(
37 "Unsupported provider: {}. Supported providers: openai, local, mistral, azure, cohere",
38 input.provider
39 ));
40 }
41 };
42
43 if matches!(
45 provider_type,
46 EmbeddingProvider::OpenAI | EmbeddingProvider::Mistral | EmbeddingProvider::AzureOpenAI
47 ) {
48 if let Some(api_key_env) = &input.api_key_env {
49 if std::env::var(api_key_env).is_err() {
50 return Err(anyhow!(
51 "Environment variable '{}' not set. Please set the API key.",
52 api_key_env
53 ));
54 }
55 } else {
56 warnings.push(format!(
57 "No api_key_env specified for {}. Make sure API key is set in standard environment variable.",
58 input.provider
59 ));
60 }
61 }
62
63 let provider_config =
65 match provider_type {
66 EmbeddingProvider::OpenAI => {
67 let model_name = input.model.as_deref().unwrap_or("text-embedding-3-small");
68 match model_name {
69 "text-embedding-3-small" => ProviderConfig::openai_3_small(),
70 "text-embedding-3-large" => ProviderConfig::openai_3_large(),
71 "text-embedding-ada-002" => ProviderConfig::openai_ada_002(),
72 _ => {
73 warnings.push(format!(
74 "Unknown OpenAI model '{}', using text-embedding-3-small",
75 model_name
76 ));
77 ProviderConfig::openai_3_small()
78 }
79 }
80 }
81 EmbeddingProvider::Mistral => {
82 let model_name = input.model.as_deref().unwrap_or("mistral-embed");
83 if model_name != "mistral-embed" {
84 warnings.push(format!(
85 "Unknown Mistral model '{}', using mistral-embed",
86 model_name
87 ));
88 }
89 ProviderConfig::mistral_embed()
90 }
91 EmbeddingProvider::AzureOpenAI => {
92 let deployment = input.deployment_name.as_ref().ok_or_else(|| {
93 anyhow!("deployment_name required for Azure OpenAI provider")
94 })?;
95 let resource = input.resource_name.as_ref().ok_or_else(|| {
96 anyhow!("resource_name required for Azure OpenAI provider")
97 })?;
98 let api_version = input.api_version.as_deref().unwrap_or("2023-05-15");
99
100 let dimension = 1536; ProviderConfig::AzureOpenAI(AzureOpenAIConfig::new(
103 deployment,
104 resource,
105 api_version,
106 dimension,
107 ))
108 }
109 EmbeddingProvider::Local => {
110 let model_name = input
111 .model
112 .as_deref()
113 .unwrap_or("sentence-transformers/all-MiniLM-L6-v2");
114 let dimension = 384; ProviderConfig::Local(LocalConfig::new(model_name, dimension))
116 }
117 EmbeddingProvider::Custom(_) => {
118 let model_name = input.model.as_deref().unwrap_or("custom-model");
119 let base_url = input
120 .base_url
121 .as_deref()
122 .ok_or_else(|| anyhow!("base_url required for custom provider"))?;
123 ProviderConfig::Custom(CustomConfig::new(model_name, 384, base_url))
124 }
125 };
126
127 let embedding_config = EmbeddingConfig {
129 provider: provider_config.clone(),
130 similarity_threshold: input.similarity_threshold.unwrap_or(0.7),
131 batch_size: input.batch_size.unwrap_or(32),
132 cache_embeddings: true,
133 timeout_seconds: 30,
134 };
135
136 debug!(
141 "Configured embedding provider: {:?} with model: {}",
142 embedding_config.provider,
143 embedding_config.provider.model_name()
144 );
145
146 let provider_name = input.provider.clone();
147 Ok(ConfigureEmbeddingsOutput {
148 success: true,
149 provider: input.provider,
150 model: provider_config.model_name(),
151 dimension: provider_config.effective_dimension(),
152 message: format!(
153 "Successfully configured {} provider with model {} (dimension: {})",
154 provider_name,
155 provider_config.model_name(),
156 provider_config.effective_dimension()
157 ),
158 warnings,
159 })
160 }
161}