Skip to main content

rmcp_memex/embeddings/
mod.rs

1//! Universal embedding client with config-driven provider cascade.
2//!
3//! Supports any OpenAI-compatible embedding API (Ollama, vLLM, TEI, etc.)
4//! Providers are tried in priority order until one responds.
5//!
6//! # Example config.toml
7//! ```toml
8//! [embeddings]
9//! required_dimension = 2560
10//! max_batch_chars = 32000
11//! max_batch_items = 16
12//!
13//! [[embeddings.providers]]
14//! name = "ollama-local"
15//! base_url = "http://localhost:11434"
16//! model = "qwen3-embedding:4b"
17//! priority = 1
18//!
19//! [[embeddings.providers]]
20//! name = "dragon"
21//! base_url = "http://dragon:12345"
22//! model = "Qwen/Qwen3-Embedding-4B"
23//! priority = 2
24//! ```
25
26use anyhow::{Result, anyhow};
27use reqwest::Client;
28use serde::{Deserialize, Serialize};
29use std::time::Duration;
30
31pub const DEFAULT_REQUIRED_DIMENSION: usize = 2560;
32pub const DEFAULT_OLLAMA_EMBEDDING_MODEL: &str = "qwen3-embedding:4b";
33
34pub fn infer_embedding_dimension(model: &str) -> Option<usize> {
35    let model = model.to_ascii_lowercase();
36
37    if model.contains("qwen3-vl-embedding") {
38        Some(2048)
39    } else if model.contains("qwen3-embedding") {
40        if model.contains("0.6b") {
41            Some(1024)
42        } else if model.contains("4b") {
43            Some(2560)
44        } else if model.contains("8b") {
45            Some(4096)
46        } else {
47            None
48        }
49    } else if model.contains("bge-m3") || model.contains("mxbai-embed") {
50        Some(1024)
51    } else if model.contains("nomic-embed") {
52        Some(768)
53    } else if model.contains("all-minilm") {
54        Some(384)
55    } else {
56        None
57    }
58}
59
60// =============================================================================
61// REQUEST/RESPONSE TYPES (OpenAI-compatible)
62// =============================================================================
63
64#[derive(Debug, Serialize)]
65struct EmbeddingRequest {
66    input: Vec<String>,
67    model: String,
68}
69
70#[derive(Debug, Deserialize)]
71struct EmbeddingResponse {
72    data: Vec<EmbeddingData>,
73}
74
75#[derive(Debug, Deserialize)]
76struct EmbeddingData {
77    embedding: Vec<f32>,
78}
79
80#[derive(Debug, Serialize)]
81struct RerankRequest {
82    query: String,
83    documents: Vec<String>,
84    model: String,
85}
86
87#[derive(Debug, Deserialize)]
88struct RerankResponse {
89    results: Vec<RerankResult>,
90}
91
92#[derive(Debug, Deserialize)]
93struct RerankResult {
94    index: usize,
95    score: f32,
96}
97
98// =============================================================================
99// PROVIDER CONFIGURATION
100// =============================================================================
101
102/// Single embedding provider configuration
103#[derive(Debug, Clone, Deserialize, Serialize, Default)]
104pub struct ProviderConfig {
105    /// Human-readable name for logging
106    #[serde(default)]
107    pub name: String,
108    /// Base URL (e.g., "http://localhost:11434")
109    #[serde(default)]
110    pub base_url: String,
111    /// Model name to use
112    #[serde(default)]
113    pub model: String,
114    /// Priority (1 = highest, tried first)
115    #[serde(default = "default_priority")]
116    pub priority: u8,
117    /// Embedding endpoint path (default: /v1/embeddings)
118    #[serde(default = "default_embeddings_endpoint")]
119    pub endpoint: String,
120}
121
122fn default_priority() -> u8 {
123    10
124}
125
126fn default_embeddings_endpoint() -> String {
127    "/v1/embeddings".to_string()
128}
129
130/// Reranker configuration (optional, separate from embedders)
131#[derive(Debug, Clone, Deserialize, Serialize, Default)]
132pub struct RerankerConfig {
133    /// Base URL for reranker service
134    pub base_url: Option<String>,
135    /// Model name
136    pub model: Option<String>,
137    /// Endpoint path (default: /v1/rerank)
138    #[serde(default = "default_rerank_endpoint")]
139    pub endpoint: String,
140}
141
142fn default_rerank_endpoint() -> String {
143    "/v1/rerank".to_string()
144}
145
146fn default_dimension() -> usize {
147    DEFAULT_REQUIRED_DIMENSION
148}
149
150fn default_max_batch_chars() -> usize {
151    128000 // Increased 4x for better GPU utilization
152}
153
154fn default_max_batch_items() -> usize {
155    64 // Increased 4x - fewer API calls, better throughput
156}
157
158fn build_provider_endpoint(base_url: &str, endpoint: &str) -> String {
159    let base_url = base_url.trim_end_matches('/');
160    let endpoint = endpoint.trim();
161    if endpoint.starts_with('/') {
162        format!("{}{}", base_url, endpoint)
163    } else {
164        format!("{}/{}", base_url, endpoint)
165    }
166}
167
168/// Complete embedding configuration
169#[derive(Debug, Clone, Deserialize, Serialize)]
170pub struct EmbeddingConfig {
171    /// Required vector dimension (mismatch corrupts database!)
172    #[serde(default = "default_dimension")]
173    pub required_dimension: usize,
174    /// Maximum characters per embedding batch to avoid OOM (default: 32000)
175    #[serde(default = "default_max_batch_chars")]
176    pub max_batch_chars: usize,
177    /// Maximum items per embedding batch (default: 16)
178    #[serde(default = "default_max_batch_items")]
179    pub max_batch_items: usize,
180    /// List of providers to try in priority order
181    #[serde(default)]
182    pub providers: Vec<ProviderConfig>,
183    /// Optional reranker configuration
184    #[serde(default)]
185    pub reranker: RerankerConfig,
186}
187
188impl Default for EmbeddingConfig {
189    fn default() -> Self {
190        Self {
191            required_dimension: default_dimension(),
192            max_batch_chars: default_max_batch_chars(),
193            max_batch_items: default_max_batch_items(),
194            providers: vec![
195                ProviderConfig {
196                    name: "ollama-local".to_string(),
197                    base_url: "http://localhost:11434".to_string(),
198                    model: DEFAULT_OLLAMA_EMBEDDING_MODEL.to_string(),
199                    priority: 1,
200                    endpoint: default_embeddings_endpoint(),
201                },
202                ProviderConfig {
203                    name: "dragon".to_string(),
204                    base_url: "http://dragon:12345".to_string(),
205                    model: "Qwen/Qwen3-Embedding-4B".to_string(),
206                    priority: 2,
207                    endpoint: default_embeddings_endpoint(),
208                },
209            ],
210            reranker: RerankerConfig::default(),
211        }
212    }
213}
214
215impl EmbeddingConfig {
216    /// Returns the name of the first (highest priority) provider
217    pub fn provider_name(&self) -> String {
218        self.providers
219            .first()
220            .map(|p| p.name.clone())
221            .unwrap_or_else(|| "none".to_string())
222    }
223
224    /// Returns the model name of the first (highest priority) provider
225    pub fn model_name(&self) -> String {
226        self.providers
227            .first()
228            .map(|p| p.model.clone())
229            .unwrap_or_else(|| "none".to_string())
230    }
231
232    /// Alias for required_dimension for API compatibility
233    pub fn dimension(&self) -> usize {
234        self.required_dimension
235    }
236}
237
238// =============================================================================
239// LEGACY CONFIG (backward compatibility)
240// =============================================================================
241
242/// Legacy MLX configuration - deprecated, use EmbeddingConfig instead
243#[derive(Debug, Clone)]
244pub struct MlxConfig {
245    pub disabled: bool,
246    pub local_port: u16,
247    pub dragon_url: String,
248    pub dragon_port: u16,
249    pub embedder_model: String,
250    pub reranker_model: String,
251    pub reranker_port_offset: u16,
252    pub max_batch_chars: usize,
253    pub max_batch_items: usize,
254}
255
256/// Options for merging file config into MlxConfig
257#[derive(Debug, Clone, Default)]
258pub struct MlxMergeOptions {
259    pub disabled: Option<bool>,
260    pub local_port: Option<u16>,
261    pub dragon_url: Option<String>,
262    pub dragon_port: Option<u16>,
263    pub embedder_model: Option<String>,
264    pub reranker_model: Option<String>,
265    pub reranker_port_offset: Option<u16>,
266}
267
268impl Default for MlxConfig {
269    fn default() -> Self {
270        Self {
271            disabled: false,
272            local_port: 12345,
273            dragon_url: "http://dragon".to_string(),
274            dragon_port: 12345,
275            embedder_model: "Qwen/Qwen3-Embedding-4B".to_string(),
276            reranker_model: "Qwen/Qwen3-Reranker-4B".to_string(),
277            reranker_port_offset: 1,
278            max_batch_chars: default_max_batch_chars(),
279            max_batch_items: default_max_batch_items(),
280        }
281    }
282}
283
284impl MlxConfig {
285    /// Create config from environment variables (legacy support)
286    pub fn from_env() -> Self {
287        let disabled = std::env::var("DISABLE_MLX")
288            .map(|v| v == "1" || v.to_lowercase() == "true")
289            .unwrap_or(false);
290
291        let local_port = std::env::var("EMBEDDER_PORT")
292            .ok()
293            .and_then(|s| s.parse().ok())
294            .unwrap_or(12345);
295
296        let dragon_url =
297            std::env::var("DRAGON_BASE_URL").unwrap_or_else(|_| "http://dragon".to_string());
298
299        let dragon_port = std::env::var("DRAGON_EMBEDDER_PORT")
300            .ok()
301            .and_then(|s| s.parse().ok())
302            .unwrap_or(local_port);
303
304        let reranker_port_offset = std::env::var("RERANKER_PORT")
305            .ok()
306            .and_then(|s| s.parse::<u16>().ok())
307            .map(|rp| rp.saturating_sub(local_port))
308            .unwrap_or(1);
309
310        let embedder_model = std::env::var("EMBEDDER_MODEL")
311            .unwrap_or_else(|_| "Qwen/Qwen3-Embedding-4B".to_string());
312
313        let reranker_model = std::env::var("RERANKER_MODEL")
314            .unwrap_or_else(|_| "Qwen/Qwen3-Reranker-4B".to_string());
315
316        let max_batch_chars = std::env::var("MLX_MAX_BATCH_CHARS")
317            .ok()
318            .and_then(|s| s.parse().ok())
319            .unwrap_or(32000);
320
321        let max_batch_items = std::env::var("MLX_MAX_BATCH_ITEMS")
322            .ok()
323            .and_then(|s| s.parse().ok())
324            .unwrap_or(16);
325
326        Self {
327            disabled,
328            local_port,
329            dragon_url,
330            dragon_port,
331            embedder_model,
332            reranker_model,
333            reranker_port_offset,
334            max_batch_chars,
335            max_batch_items,
336        }
337    }
338
339    /// Merge with values from file config
340    pub fn merge_file_config(&mut self, opts: MlxMergeOptions) {
341        if let Some(v) = opts.disabled {
342            self.disabled = v;
343        }
344        if let Some(v) = opts.local_port {
345            self.local_port = v;
346        }
347        if let Some(v) = opts.dragon_url {
348            self.dragon_url = v;
349        }
350        if let Some(v) = opts.dragon_port {
351            self.dragon_port = v;
352        }
353        if let Some(v) = opts.embedder_model {
354            self.embedder_model = v;
355        }
356        if let Some(v) = opts.reranker_model {
357            self.reranker_model = v;
358        }
359        if let Some(v) = opts.reranker_port_offset {
360            self.reranker_port_offset = v;
361        }
362    }
363
364    /// Convert legacy config to new EmbeddingConfig
365    pub fn to_embedding_config(&self) -> EmbeddingConfig {
366        let reranker_port = self.local_port + self.reranker_port_offset;
367        let required_dimension =
368            infer_embedding_dimension(&self.embedder_model).unwrap_or(DEFAULT_REQUIRED_DIMENSION);
369
370        EmbeddingConfig {
371            required_dimension,
372            max_batch_chars: self.max_batch_chars,
373            max_batch_items: self.max_batch_items,
374            providers: vec![
375                ProviderConfig {
376                    name: "local".to_string(),
377                    base_url: format!("http://localhost:{}", self.local_port),
378                    model: self.embedder_model.clone(),
379                    priority: 1,
380                    endpoint: default_embeddings_endpoint(),
381                },
382                ProviderConfig {
383                    name: "dragon".to_string(),
384                    base_url: format!("{}:{}", self.dragon_url, self.dragon_port),
385                    model: self.embedder_model.clone(),
386                    priority: 2,
387                    endpoint: default_embeddings_endpoint(),
388                },
389            ],
390            reranker: RerankerConfig {
391                base_url: Some(format!("{}:{}", self.dragon_url, reranker_port)),
392                model: Some(self.reranker_model.clone()),
393                endpoint: default_rerank_endpoint(),
394            },
395        }
396    }
397
398    /// Set batch limits
399    pub fn with_batch_limits(mut self, max_chars: usize, max_items: usize) -> Self {
400        self.max_batch_chars = max_chars;
401        self.max_batch_items = max_items;
402        self
403    }
404}
405
406// =============================================================================
407// EMBEDDING CLIENT
408// =============================================================================
409
410/// Universal embedding client with provider cascade
411pub struct EmbeddingClient {
412    client: Client,
413    embedder_url: String,
414    embedder_model: String,
415    reranker_url: Option<String>,
416    reranker_model: Option<String>,
417    /// Which provider we're connected to
418    connected_to: String,
419    /// Expected dimension (for validation)
420    required_dimension: usize,
421    /// Maximum characters per embedding batch
422    max_batch_chars: usize,
423    /// Maximum items per embedding batch
424    max_batch_items: usize,
425}
426
427// Type alias for backward compatibility
428pub type MLXBridge = EmbeddingClient;
429
430impl EmbeddingClient {
431    /// Create client with config-driven provider cascade
432    pub async fn new(config: &EmbeddingConfig) -> Result<Self> {
433        if config.providers.is_empty() {
434            return Err(anyhow!(
435                "No embedding providers configured! Add providers to [embeddings.providers]"
436            ));
437        }
438
439        // Long timeout for large embedding batches (100+ chunks can take minutes)
440        let client = Client::builder()
441            .timeout(Duration::from_secs(300))
442            .connect_timeout(Duration::from_secs(10))
443            .build()?;
444
445        // Sort providers by priority
446        let mut providers = config.providers.clone();
447        providers.sort_by_key(|p| p.priority);
448
449        // Try each provider in order using the real embedding endpoint.
450        let mut tried = Vec::new();
451        for provider in &providers {
452            let base_url = provider.base_url.trim_end_matches('/');
453            let provider_name = if provider.name.trim().is_empty() {
454                "<unnamed-provider>"
455            } else {
456                provider.name.as_str()
457            };
458            let model = provider.model.trim();
459            let embedder_url = build_provider_endpoint(base_url, &provider.endpoint);
460
461            match probe_provider_dimension(&client, provider).await {
462                Ok(actual_dim) if actual_dim == config.required_dimension => {
463                    tracing::info!(
464                        "Embedding: Connected to {} ({}) with model '{}' [{} dims]",
465                        provider_name,
466                        embedder_url,
467                        model,
468                        actual_dim
469                    );
470
471                    // Build reranker URL if configured
472                    let (reranker_url, reranker_model) =
473                        if let Some(ref rr_base) = config.reranker.base_url {
474                            (
475                                Some(format!(
476                                    "{}{}",
477                                    rr_base.trim_end_matches('/'),
478                                    config.reranker.endpoint
479                                )),
480                                config.reranker.model.clone(),
481                            )
482                        } else {
483                            (None, None)
484                        };
485
486                    return Ok(Self {
487                        client,
488                        embedder_url,
489                        embedder_model: provider.model.clone(),
490                        reranker_url,
491                        reranker_model,
492                        connected_to: provider.name.clone(),
493                        required_dimension: config.required_dimension,
494                        max_batch_chars: config.max_batch_chars,
495                        max_batch_items: config.max_batch_items,
496                    });
497                }
498                Ok(actual_dim) => {
499                    let failure = format!(
500                        "- {} ({} model='{}'): the configured embedding endpoint returned {} dims, but config.required_dimension={}.\n  Action: set [embeddings].required_dimension = {} or choose a {}-dim model.",
501                        provider_name,
502                        embedder_url,
503                        model,
504                        actual_dim,
505                        config.required_dimension,
506                        actual_dim,
507                        config.required_dimension
508                    );
509                    tracing::error!("Embedding: validation failed: {}", failure);
510                    tried.push(failure);
511                }
512                Err(e) => {
513                    let failure = format!(
514                        "- {} ({} model='{}'): {}",
515                        provider_name, embedder_url, model, e
516                    );
517                    tracing::warn!("Embedding: provider probe failed: {}", failure);
518                    tried.push(failure);
519                }
520            }
521        }
522
523        // All providers failed
524        Err(anyhow!(
525            "No embedding provider passed validation for required_dimension={}. \
526             Each provider must succeed on its configured embedding endpoint before rmcp-memex will start.\nTried:\n{}",
527            config.required_dimension,
528            tried.join("\n")
529        ))
530    }
531
532    /// Create from legacy MlxConfig (backward compatibility)
533    pub async fn from_legacy(config: &MlxConfig) -> Result<Self> {
534        if config.disabled {
535            return Err(anyhow!(
536                "Embedding disabled via config. No fallback available!"
537            ));
538        }
539        tracing::warn!("Using legacy [mlx] config - please migrate to [embeddings.providers]");
540        let embedding_config = config.to_embedding_config();
541        Self::new(&embedding_config).await
542    }
543
544    /// Legacy constructor from env vars only
545    pub async fn from_env() -> Result<Self> {
546        let config = MlxConfig::from_env();
547        Self::from_legacy(&config).await
548    }
549
550    /// Get which provider we're connected to
551    pub fn connected_to(&self) -> &str {
552        &self.connected_to
553    }
554
555    /// Get required dimension
556    pub fn required_dimension(&self) -> usize {
557        self.required_dimension
558    }
559
560    /// Create a stub client for tests that don't need real embeddings.
561    /// The client will fail on any actual embed() call, but lets McpCore
562    /// be constructed and dispatch protocol-level requests.
563    #[cfg(test)]
564    pub(crate) fn stub_for_tests() -> Self {
565        Self {
566            client: reqwest::Client::new(),
567            embedder_url: "http://stub:0/v1/embeddings".to_string(),
568            embedder_model: "stub".to_string(),
569            reranker_url: None,
570            reranker_model: None,
571            connected_to: "stub-test".to_string(),
572            required_dimension: 4096,
573            max_batch_chars: 32000,
574            max_batch_items: 16,
575        }
576    }
577
578    pub async fn embed(&mut self, text: &str) -> Result<Vec<f32>> {
579        let text_preview: String = text.chars().take(100).collect();
580        tracing::debug!(
581            "Embedding single text ({} chars): {}{}",
582            text.chars().count(),
583            text_preview,
584            if text.chars().count() > 100 {
585                "..."
586            } else {
587                ""
588            }
589        );
590
591        let request = EmbeddingRequest {
592            input: vec![text.to_string()],
593            model: self.embedder_model.clone(),
594        };
595
596        let response = match self
597            .client
598            .post(&self.embedder_url)
599            .json(&request)
600            .send()
601            .await
602        {
603            Ok(resp) => resp,
604            Err(e) => {
605                tracing::error!(
606                    "Embedding request failed: {:?}\n  URL: {}\n  Model: {}",
607                    e,
608                    self.embedder_url,
609                    self.embedder_model
610                );
611                return Err(anyhow!("Embedding request failed: {}", e));
612            }
613        };
614
615        let status = response.status();
616        let response_text = response.text().await.unwrap_or_else(|e| {
617            tracing::warn!("Failed to read response body: {:?}", e);
618            "<failed to read body>".to_string()
619        });
620
621        if !status.is_success() {
622            tracing::error!(
623                "Embedding API error (HTTP {}):\n  URL: {}\n  Model: {}\n  Response: {}",
624                status,
625                self.embedder_url,
626                self.embedder_model,
627                response_text
628            );
629            return Err(anyhow!(
630                "Embedding API error (HTTP {}): {}",
631                status,
632                response_text
633            ));
634        }
635
636        let parsed: EmbeddingResponse = match serde_json::from_str(&response_text) {
637            Ok(r) => r,
638            Err(e) => {
639                tracing::error!(
640                    "Failed to parse embedding response: {:?}\n  Response body: {}",
641                    e,
642                    response_text
643                );
644                return Err(anyhow!("Failed to parse embedding response: {}", e));
645            }
646        };
647
648        let embedding = parsed
649            .data
650            .into_iter()
651            .next()
652            .map(|d| d.embedding)
653            .ok_or_else(|| {
654                tracing::error!("No embedding returned in response: {}", response_text);
655                anyhow!("No embedding returned")
656            })?;
657
658        // Validate dimension
659        if embedding.len() != self.required_dimension {
660            tracing::error!(
661                "Dimension mismatch! Expected {}, got {}. Model: {}",
662                self.required_dimension,
663                embedding.len(),
664                self.embedder_model
665            );
666            return Err(anyhow!(
667                "Dimension mismatch! Expected {}, got {}. This would corrupt the database!",
668                self.required_dimension,
669                embedding.len()
670            ));
671        }
672
673        tracing::debug!("Successfully embedded text ({} dims)", embedding.len());
674        Ok(embedding)
675    }
676
677    /// Embed a batch of texts with intelligent batching to avoid OOM.
678    ///
679    /// Large texts are chunked and only the first chunk is embedded.
680    /// Batches are split to stay under max_batch_chars and max_batch_items.
681    /// Failed chunks are retried individually with exponential backoff.
682    pub async fn embed_batch(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
683        if texts.is_empty() {
684            return Ok(vec![]);
685        }
686
687        let mut all_embeddings = Vec::with_capacity(texts.len());
688        let mut current_batch: Vec<String> = Vec::new();
689        let mut current_batch_indices: Vec<usize> = Vec::new();
690        let mut current_chars = 0;
691
692        // Max chars per individual text (half of batch limit for safety)
693        let max_text_chars = self.max_batch_chars / 2;
694
695        // Prepare all texts first
696        let prepared_texts: Vec<String> = texts
697            .iter()
698            .map(|text| {
699                let char_count = text.chars().count();
700                if char_count > max_text_chars {
701                    tracing::debug!(
702                        "Text too large ({} chars), truncating to {} chars",
703                        char_count,
704                        max_text_chars
705                    );
706                    truncate_at_boundary(text, max_text_chars)
707                } else {
708                    text.clone()
709                }
710            })
711            .collect();
712
713        // Pre-allocate result vector with None
714        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
715        let mut failed_indices: Vec<usize> = Vec::new();
716
717        for (idx, text_to_embed) in prepared_texts.iter().enumerate() {
718            let text_len = text_to_embed.chars().count();
719
720            // Check if we need to flush current batch
721            if !current_batch.is_empty()
722                && (current_chars + text_len > self.max_batch_chars
723                    || current_batch.len() >= self.max_batch_items)
724            {
725                // Flush current batch with retry
726                match self.embed_batch_internal(&current_batch).await {
727                    Ok(batch_embeddings) => {
728                        for (i, emb) in batch_embeddings.into_iter().enumerate() {
729                            if let Some(orig_idx) = current_batch_indices.get(i) {
730                                results[*orig_idx] = Some(emb);
731                            }
732                        }
733                    }
734                    Err(e) => {
735                        tracing::warn!(
736                            "Batch embedding failed for {} texts, will retry individually: {}",
737                            current_batch.len(),
738                            e
739                        );
740                        failed_indices.extend(current_batch_indices.iter().copied());
741                    }
742                }
743                current_batch.clear();
744                current_batch_indices.clear();
745                current_chars = 0;
746            }
747
748            current_batch.push(text_to_embed.clone());
749            current_batch_indices.push(idx);
750            current_chars += text_len;
751        }
752
753        // Flush remaining batch
754        if !current_batch.is_empty() {
755            match self.embed_batch_internal(&current_batch).await {
756                Ok(batch_embeddings) => {
757                    for (i, emb) in batch_embeddings.into_iter().enumerate() {
758                        if let Some(orig_idx) = current_batch_indices.get(i) {
759                            results[*orig_idx] = Some(emb);
760                        }
761                    }
762                }
763                Err(e) => {
764                    tracing::warn!(
765                        "Batch embedding failed for {} texts, will retry individually: {}",
766                        current_batch.len(),
767                        e
768                    );
769                    failed_indices.extend(current_batch_indices.iter().copied());
770                }
771            }
772        }
773
774        // Retry failed chunks individually with exponential backoff
775        const MAX_RETRIES: usize = 3;
776        for idx in failed_indices {
777            let text = &prepared_texts[idx];
778            let mut attempts = 0;
779            let mut last_error = String::new();
780
781            while attempts < MAX_RETRIES {
782                match self.embed(text).await {
783                    Ok(embedding) => {
784                        results[idx] = Some(embedding);
785                        tracing::info!(
786                            "Retry succeeded for chunk {} after {} attempts",
787                            idx,
788                            attempts + 1
789                        );
790                        break;
791                    }
792                    Err(e) => {
793                        attempts += 1;
794                        last_error = e.to_string();
795                        tracing::warn!(
796                            "Embed attempt {}/{} failed for chunk {}: {}",
797                            attempts,
798                            MAX_RETRIES,
799                            idx,
800                            e
801                        );
802                        if attempts < MAX_RETRIES {
803                            // Exponential backoff: 100ms, 200ms, 400ms
804                            let delay_ms = 100 * (1 << attempts);
805                            tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
806                        }
807                    }
808                }
809            }
810
811            if results[idx].is_none() {
812                tracing::error!(
813                    "Chunk {} failed after {} retries: {}",
814                    idx,
815                    MAX_RETRIES,
816                    last_error
817                );
818                return Err(anyhow!(
819                    "Failed to embed chunk {} after {} retries: {}",
820                    idx,
821                    MAX_RETRIES,
822                    last_error
823                ));
824            }
825        }
826
827        // Collect all results - all should be Some at this point
828        for (idx, opt) in results.iter().enumerate() {
829            match opt {
830                Some(emb) => all_embeddings.push(emb.clone()),
831                None => {
832                    return Err(anyhow!(
833                        "Internal error: missing embedding for chunk {}",
834                        idx
835                    ));
836                }
837            }
838        }
839
840        Ok(all_embeddings)
841    }
842
843    /// Internal batch embedding - sends directly to server
844    async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
845        let total_chars: usize = texts.iter().map(|t| t.chars().count()).sum();
846
847        tracing::debug!(
848            "Embedding batch: {} texts, {} chars total",
849            texts.len(),
850            total_chars
851        );
852
853        // Log first few chars of each text in trace mode for debugging
854        for (i, text) in texts.iter().enumerate() {
855            let preview: String = text.chars().take(50).collect();
856            tracing::trace!(
857                "  Batch[{}]: {} chars - {}{}",
858                i,
859                text.chars().count(),
860                preview,
861                if text.chars().count() > 50 { "..." } else { "" }
862            );
863        }
864
865        let request = EmbeddingRequest {
866            input: texts.to_vec(),
867            model: self.embedder_model.clone(),
868        };
869
870        // Retry with exponential backoff: 1s, 2s, 4s, 8s, 16s, 30s (max)
871        const MAX_BATCH_RETRIES: usize = 10;
872        const MAX_BACKOFF_SECS: u64 = 30;
873        let mut attempt = 0;
874
875        loop {
876            attempt += 1;
877            let response = match self
878                .client
879                .post(&self.embedder_url)
880                .json(&request)
881                .send()
882                .await
883            {
884                Ok(resp) => resp,
885                Err(e) => {
886                    if attempt >= MAX_BATCH_RETRIES {
887                        tracing::error!(
888                            "Batch embedding failed after {} retries: {:?}\n  URL: {}\n  Model: {}",
889                            MAX_BATCH_RETRIES,
890                            e,
891                            self.embedder_url,
892                            self.embedder_model
893                        );
894                        return Err(anyhow!(
895                            "Embedding request failed after {} retries: {}",
896                            MAX_BATCH_RETRIES,
897                            e
898                        ));
899                    }
900
901                    // Exponential backoff with cap
902                    let backoff_secs = (1u64 << attempt.min(5)).min(MAX_BACKOFF_SECS);
903                    tracing::warn!(
904                        "Embedding request failed (attempt {}/{}), retrying in {}s: {}",
905                        attempt,
906                        MAX_BATCH_RETRIES,
907                        backoff_secs,
908                        e
909                    );
910                    tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
911                    continue;
912                }
913            };
914
915            // Success - process response
916            if !response.status().is_success() {
917                let status = response.status();
918                let body = response.text().await.unwrap_or_default();
919
920                if attempt >= MAX_BATCH_RETRIES {
921                    tracing::error!(
922                        "Embedding API error after {} retries: {} - {}",
923                        MAX_BATCH_RETRIES,
924                        status,
925                        body
926                    );
927                    return Err(anyhow!("Embedding API error: {} - {}", status, body));
928                }
929
930                let backoff_secs = (1u64 << attempt.min(5)).min(MAX_BACKOFF_SECS);
931                tracing::warn!(
932                    "Embedding API error (attempt {}/{}), retrying in {}s: {} - {}",
933                    attempt,
934                    MAX_BATCH_RETRIES,
935                    backoff_secs,
936                    status,
937                    body
938                );
939                tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
940                continue;
941            }
942
943            // Parse response
944            let embedding_response: EmbeddingResponse = match response.json().await {
945                Ok(r) => r,
946                Err(e) => {
947                    if attempt >= MAX_BATCH_RETRIES {
948                        return Err(anyhow!("Failed to parse embedding response: {}", e));
949                    }
950                    let backoff_secs = (1u64 << attempt.min(5)).min(MAX_BACKOFF_SECS);
951                    tracing::warn!(
952                        "Failed to parse response (attempt {}/{}), retrying in {}s: {}",
953                        attempt,
954                        MAX_BATCH_RETRIES,
955                        backoff_secs,
956                        e
957                    );
958                    tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
959                    continue;
960                }
961            };
962
963            // Validate dimensions
964            let embeddings: Vec<Vec<f32>> = embedding_response
965                .data
966                .into_iter()
967                .map(|d| d.embedding)
968                .collect();
969
970            if embeddings.len() != texts.len() {
971                return Err(anyhow!(
972                    "Embedding count mismatch: got {} embeddings for {} texts",
973                    embeddings.len(),
974                    texts.len()
975                ));
976            }
977
978            if let Some(first) = embeddings.first()
979                && first.len() != self.required_dimension
980            {
981                return Err(anyhow!(
982                    "Dimension mismatch: expected {}, got {}",
983                    self.required_dimension,
984                    first.len()
985                ));
986            }
987
988            return Ok(embeddings);
989        }
990    }
991
992    pub async fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<(usize, f32)>> {
993        let reranker_url = self.reranker_url.as_ref().ok_or_else(|| {
994            anyhow!("Reranker not configured. Add [embeddings.reranker] to config.")
995        })?;
996        let reranker_model = self
997            .reranker_model
998            .as_ref()
999            .ok_or_else(|| anyhow!("Reranker model not configured."))?;
1000
1001        let query_preview: String = query.chars().take(100).collect();
1002        tracing::debug!(
1003            "Reranking {} documents for query: {}{}",
1004            documents.len(),
1005            query_preview,
1006            if query.chars().count() > 100 {
1007                "..."
1008            } else {
1009                ""
1010            }
1011        );
1012
1013        let request = RerankRequest {
1014            query: query.to_string(),
1015            documents: documents.to_vec(),
1016            model: reranker_model.clone(),
1017        };
1018
1019        let response = match self.client.post(reranker_url).json(&request).send().await {
1020            Ok(resp) => resp,
1021            Err(e) => {
1022                tracing::error!(
1023                    "Rerank request failed: {:?}\n  URL: {}\n  Model: {}\n  Query: {}\n  Documents: {}",
1024                    e,
1025                    reranker_url,
1026                    reranker_model,
1027                    query_preview,
1028                    documents.len()
1029                );
1030                return Err(anyhow!("Rerank request failed: {}", e));
1031            }
1032        };
1033
1034        let status = response.status();
1035        let response_text = response.text().await.unwrap_or_else(|e| {
1036            tracing::warn!("Failed to read rerank response body: {:?}", e);
1037            "<failed to read body>".to_string()
1038        });
1039
1040        if !status.is_success() {
1041            tracing::error!(
1042                "Rerank API error (HTTP {}):\n  URL: {}\n  Model: {}\n  Response: {}",
1043                status,
1044                reranker_url,
1045                reranker_model,
1046                response_text
1047            );
1048            return Err(anyhow!(
1049                "Rerank API error (HTTP {}): {}",
1050                status,
1051                response_text
1052            ));
1053        }
1054
1055        let parsed: RerankResponse = match serde_json::from_str(&response_text) {
1056            Ok(r) => r,
1057            Err(e) => {
1058                tracing::error!(
1059                    "Failed to parse rerank response: {:?}\n  Response body: {}",
1060                    e,
1061                    response_text
1062                );
1063                return Err(anyhow!("Failed to parse rerank response: {}", e));
1064            }
1065        };
1066
1067        tracing::debug!("Rerank complete: {} documents scored", parsed.results.len());
1068
1069        Ok(parsed
1070            .results
1071            .into_iter()
1072            .map(|r| (r.index, r.score))
1073            .collect())
1074    }
1075}
1076
1077pub(crate) async fn probe_provider_dimension(
1078    client: &Client,
1079    provider: &ProviderConfig,
1080) -> Result<usize> {
1081    let base_url = provider.base_url.trim_end_matches('/');
1082    if base_url.is_empty() {
1083        return Err(anyhow!("provider base_url is empty"));
1084    }
1085
1086    let endpoint = provider.endpoint.trim();
1087    if endpoint.is_empty() {
1088        return Err(anyhow!("provider endpoint is empty"));
1089    }
1090
1091    let model = provider.model.trim();
1092    if model.is_empty() {
1093        return Err(anyhow!("provider model is empty"));
1094    }
1095
1096    let embedder_url = build_provider_endpoint(base_url, endpoint);
1097    let request = EmbeddingRequest {
1098        input: vec!["dimension probe".to_string()],
1099        model: model.to_string(),
1100    };
1101
1102    let response = client
1103        .post(&embedder_url)
1104        .json(&request)
1105        .timeout(Duration::from_secs(30))
1106        .send()
1107        .await
1108        .map_err(|e| anyhow!("POST {} failed: {}", embedder_url, e))?;
1109
1110    let status = response.status();
1111    let body = response.text().await.unwrap_or_default();
1112    if !status.is_success() {
1113        let hint = if status.as_u16() == 404 {
1114            " Check provider.endpoint; Ollama and OpenAI-compatible servers typically use /v1/embeddings."
1115        } else {
1116            ""
1117        };
1118        return Err(anyhow!(
1119            "POST {} returned {} for model '{}': {}{}",
1120            embedder_url,
1121            status,
1122            model,
1123            body.chars().take(300).collect::<String>(),
1124            hint
1125        ));
1126    }
1127
1128    let embed_response: EmbeddingResponse = serde_json::from_str(&body).map_err(|e| {
1129        anyhow!(
1130            "POST {} returned non-embedding JSON for model '{}': {} (body: {})",
1131            embedder_url,
1132            model,
1133            e,
1134            body.chars().take(200).collect::<String>()
1135        )
1136    })?;
1137
1138    embed_response
1139        .data
1140        .first()
1141        .map(|d| d.embedding.len())
1142        .ok_or_else(|| {
1143            anyhow!(
1144                "POST {} returned no embeddings for model '{}'",
1145                embedder_url,
1146                model
1147            )
1148        })
1149}
1150
1151/// Truncate text at a word/sentence boundary to avoid cutting mid-word (UTF-8 safe)
1152fn truncate_at_boundary(text: &str, max_chars: usize) -> String {
1153    let char_count = text.chars().count();
1154    if char_count <= max_chars {
1155        return text.to_string();
1156    }
1157
1158    // Get byte index of max_chars-th character (UTF-8 safe)
1159    let byte_idx = text
1160        .char_indices()
1161        .nth(max_chars)
1162        .map(|(idx, _)| idx)
1163        .unwrap_or(text.len());
1164
1165    let truncated = &text[..byte_idx];
1166
1167    // Try to find a sentence boundary first (prefer complete sentences)
1168    let half_byte_idx = text
1169        .char_indices()
1170        .nth(max_chars / 2)
1171        .map(|(idx, _)| idx)
1172        .unwrap_or(0);
1173
1174    if let Some(pos) = truncated.rfind(['.', '!', '?', '\n'])
1175        && pos > half_byte_idx
1176    {
1177        return text[..=pos].to_string();
1178    }
1179
1180    // Fall back to word boundary
1181    if let Some(pos) = truncated.rfind([' ', '\t', '\n']) {
1182        return text[..pos].to_string();
1183    }
1184
1185    // Last resort: hard truncate
1186    truncated.to_string()
1187}
1188
1189// =============================================================================
1190// TOKEN-AWARE VALIDATION
1191// =============================================================================
1192//
1193// Embedding models have token limits (e.g., 8192 for qwen3-embedding).
1194// These utilities estimate token counts and validate chunks before embedding.
1195// =============================================================================
1196
1197/// Token estimation configuration
1198#[derive(Debug, Clone)]
1199pub struct TokenConfig {
1200    /// Maximum tokens for the embedding model
1201    pub max_tokens: usize,
1202    /// Average characters per token (varies by language)
1203    /// English: ~4 chars/token, Polish/multilingual: ~2-3 chars/token
1204    pub chars_per_token: f32,
1205}
1206
1207impl Default for TokenConfig {
1208    fn default() -> Self {
1209        Self {
1210            max_tokens: 8192,     // qwen3-embedding default
1211            chars_per_token: 3.0, // Conservative for multilingual
1212        }
1213    }
1214}
1215
1216impl TokenConfig {
1217    /// Create config for English-only content
1218    pub fn english() -> Self {
1219        Self {
1220            max_tokens: 8192,
1221            chars_per_token: 4.0,
1222        }
1223    }
1224
1225    /// Create config for multilingual/Polish content
1226    pub fn for_multilingual_text() -> Self {
1227        Self {
1228            max_tokens: 8192,
1229            chars_per_token: 2.5,
1230        }
1231    }
1232
1233    /// Create config with custom max tokens
1234    pub fn with_max_tokens(mut self, max: usize) -> Self {
1235        self.max_tokens = max;
1236        self
1237    }
1238}
1239
1240/// Estimate token count for text
1241///
1242/// This is a heuristic approximation. For precise counting,
1243/// use the actual tokenizer (tiktoken, sentencepiece, etc.)
1244pub fn estimate_tokens(text: &str, config: &TokenConfig) -> usize {
1245    let char_count = text.chars().count();
1246    (char_count as f32 / config.chars_per_token).ceil() as usize
1247}
1248
1249/// Validate that a chunk fits within token limits
1250///
1251/// Returns Ok(()) if chunk is within limits, Err with details otherwise.
1252pub fn validate_chunk_tokens(chunk: &str, config: &TokenConfig) -> Result<()> {
1253    let estimated = estimate_tokens(chunk, config);
1254
1255    if estimated > config.max_tokens {
1256        return Err(anyhow!(
1257            "Chunk exceeds token limit: ~{} tokens > {} max (text: {} chars). \
1258             Consider reducing chunk_size or enabling truncation.",
1259            estimated,
1260            config.max_tokens,
1261            chunk.chars().count()
1262        ));
1263    }
1264
1265    Ok(())
1266}
1267
1268/// Calculate safe chunk size in characters for given token limit
1269pub fn safe_chunk_size(config: &TokenConfig) -> usize {
1270    // Use 80% of max to leave room for context prefix
1271    let safe_tokens = (config.max_tokens as f32 * 0.8) as usize;
1272    (safe_tokens as f32 * config.chars_per_token) as usize
1273}
1274
1275/// Truncate text to fit within token limit
1276pub fn truncate_to_token_limit(text: &str, config: &TokenConfig) -> String {
1277    let safe_chars = safe_chunk_size(config);
1278
1279    if text.chars().count() <= safe_chars {
1280        return text.to_string();
1281    }
1282
1283    truncate_at_boundary(text, safe_chars)
1284}
1285
1286/// Validate a batch of texts and return which ones exceed limits
1287pub fn validate_batch_tokens(texts: &[String], config: &TokenConfig) -> Vec<(usize, usize)> {
1288    texts
1289        .iter()
1290        .enumerate()
1291        .filter_map(|(idx, text)| {
1292            let estimated = estimate_tokens(text, config);
1293            if estimated > config.max_tokens {
1294                Some((idx, estimated))
1295            } else {
1296                None
1297            }
1298        })
1299        .collect()
1300}
1301
1302#[cfg(test)]
1303mod tests {
1304    use super::*;
1305    use axum::{Json, Router, extract::State, routing::post};
1306    use serde_json::json;
1307
1308    async fn mock_embeddings(State(dim): State<usize>) -> Json<serde_json::Value> {
1309        Json(json!({
1310            "data": [{
1311                "embedding": vec![0.25_f32; dim]
1312            }]
1313        }))
1314    }
1315
1316    async fn spawn_mock_embedding_server(dim: usize) -> String {
1317        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1318        let addr = listener.local_addr().unwrap();
1319        let app = Router::new()
1320            .route("/v1/embeddings", post(mock_embeddings))
1321            .with_state(dim);
1322
1323        tokio::spawn(async move {
1324            axum::serve(listener, app).await.unwrap();
1325        });
1326
1327        tokio::time::sleep(Duration::from_millis(10)).await;
1328
1329        format!("http://{}", addr)
1330    }
1331
1332    #[test]
1333    fn test_provider_sorting() {
1334        let mut providers = [
1335            ProviderConfig {
1336                name: "low".into(),
1337                base_url: "http://a".into(),
1338                model: "m".into(),
1339                priority: 10,
1340                endpoint: "/v1/embeddings".into(),
1341            },
1342            ProviderConfig {
1343                name: "high".into(),
1344                base_url: "http://b".into(),
1345                model: "m".into(),
1346                priority: 1,
1347                endpoint: "/v1/embeddings".into(),
1348            },
1349        ];
1350        providers.sort_by_key(|p| p.priority);
1351        assert_eq!(providers[0].name, "high");
1352        assert_eq!(providers[1].name, "low");
1353    }
1354
1355    #[test]
1356    fn test_legacy_conversion() {
1357        let legacy = MlxConfig {
1358            disabled: false,
1359            local_port: 12345,
1360            dragon_url: "http://dragon".into(),
1361            dragon_port: 12345,
1362            embedder_model: "test-model".into(),
1363            reranker_model: "rerank-model".into(),
1364            reranker_port_offset: 1,
1365            max_batch_chars: 32000,
1366            max_batch_items: 16,
1367        };
1368        let config = legacy.to_embedding_config();
1369        assert_eq!(config.providers.len(), 2);
1370        assert_eq!(config.providers[0].base_url, "http://localhost:12345");
1371        assert!(config.reranker.base_url.is_some());
1372        assert_eq!(config.max_batch_chars, 32000);
1373        assert_eq!(config.max_batch_items, 16);
1374    }
1375
1376    #[test]
1377    fn test_default_config() {
1378        let config = EmbeddingConfig::default();
1379        assert_eq!(config.required_dimension, DEFAULT_REQUIRED_DIMENSION);
1380        assert_eq!(config.max_batch_chars, 128000); // 4x larger for GPU efficiency
1381        assert_eq!(config.max_batch_items, 64); // 4x more items per batch
1382        assert!(!config.providers.is_empty());
1383        assert_eq!(config.providers[0].model, DEFAULT_OLLAMA_EMBEDDING_MODEL);
1384    }
1385
1386    #[test]
1387    fn test_infer_embedding_dimension() {
1388        assert_eq!(
1389            infer_embedding_dimension("qwen3-embedding:0.6b"),
1390            Some(1024)
1391        );
1392        assert_eq!(infer_embedding_dimension("qwen3-embedding:4b"), Some(2560));
1393        assert_eq!(infer_embedding_dimension("qwen3-embedding:8b"), Some(4096));
1394        assert_eq!(
1395            infer_embedding_dimension("MedAIBase/Qwen3-VL-Embedding:2b-q8_0"),
1396            Some(2048)
1397        );
1398        assert_eq!(infer_embedding_dimension("nomic-embed-text"), Some(768));
1399        assert_eq!(infer_embedding_dimension("qwen3-embedding"), None);
1400        assert_eq!(infer_embedding_dimension("unknown-model"), None);
1401    }
1402
1403    #[tokio::test]
1404    async fn test_probe_provider_dimension_reads_actual_dimension() {
1405        let base_url = spawn_mock_embedding_server(2560).await;
1406        let client = Client::new();
1407        let provider = ProviderConfig {
1408            name: "mock".into(),
1409            base_url,
1410            model: "mock-embedder".into(),
1411            priority: 1,
1412            endpoint: "/v1/embeddings".into(),
1413        };
1414
1415        let dim = probe_provider_dimension(&client, &provider).await.unwrap();
1416        assert_eq!(dim, 2560);
1417    }
1418
1419    #[tokio::test]
1420    async fn test_embedding_client_fails_fast_on_dimension_mismatch() {
1421        let base_url = spawn_mock_embedding_server(2560).await;
1422        let config = EmbeddingConfig {
1423            required_dimension: 1024,
1424            providers: vec![ProviderConfig {
1425                name: "mock".into(),
1426                base_url,
1427                model: "mock-embedder".into(),
1428                priority: 1,
1429                endpoint: "/v1/embeddings".into(),
1430            }],
1431            ..EmbeddingConfig::default()
1432        };
1433
1434        let err = EmbeddingClient::new(&config)
1435            .await
1436            .err()
1437            .expect("dimension mismatch should fail")
1438            .to_string();
1439        assert!(err.contains("returned 2560 dims"));
1440        assert!(err.contains("required_dimension=1024"));
1441    }
1442
1443    #[test]
1444    fn test_truncate_at_boundary() {
1445        // Test sentence boundary
1446        let text = "Hello world. This is a test.";
1447        let truncated = truncate_at_boundary(text, 15);
1448        assert_eq!(truncated, "Hello world.");
1449
1450        // Test word boundary fallback
1451        let text = "Hello world this is a test";
1452        let truncated = truncate_at_boundary(text, 15);
1453        assert_eq!(truncated, "Hello world");
1454
1455        // Test no truncation needed
1456        let text = "Short text";
1457        let truncated = truncate_at_boundary(text, 100);
1458        assert_eq!(truncated, "Short text");
1459    }
1460
1461    #[test]
1462    fn test_token_estimation() {
1463        let config = TokenConfig::default();
1464
1465        // ~3 chars per token (default multilingual)
1466        let text = "Hello world"; // 11 chars -> ~4 tokens
1467        let tokens = estimate_tokens(text, &config);
1468        assert!((3..=5).contains(&tokens));
1469
1470        // English config (4 chars per token)
1471        let english_config = TokenConfig::english();
1472        let tokens = estimate_tokens(text, &english_config);
1473        assert!((2..=4).contains(&tokens));
1474    }
1475
1476    #[test]
1477    fn test_chunk_validation() {
1478        let config = TokenConfig::default().with_max_tokens(100);
1479
1480        // Short text should pass
1481        let short = "Hello world";
1482        assert!(validate_chunk_tokens(short, &config).is_ok());
1483
1484        // Long text should fail
1485        let long = "a".repeat(1000); // Way more than 100 * 3 = 300 chars
1486        assert!(validate_chunk_tokens(&long, &config).is_err());
1487    }
1488
1489    #[test]
1490    fn test_safe_chunk_size() {
1491        let config = TokenConfig::default(); // 8192 tokens, 3 chars/token
1492
1493        let safe = safe_chunk_size(&config);
1494        // 8192 * 0.8 * 3 = 19660 chars
1495        assert!(safe > 15000 && safe < 25000);
1496    }
1497
1498    #[test]
1499    fn test_batch_validation() {
1500        let config = TokenConfig::default().with_max_tokens(10);
1501
1502        let texts = vec![
1503            "short".to_string(),      // OK
1504            "a".repeat(100),          // Too long
1505            "also short".to_string(), // OK
1506            "b".repeat(200),          // Too long
1507        ];
1508
1509        let failures = validate_batch_tokens(&texts, &config);
1510        assert_eq!(failures.len(), 2);
1511        assert_eq!(failures[0].0, 1); // Index 1
1512        assert_eq!(failures[1].0, 3); // Index 3
1513    }
1514}
1515
1516// =============================================================================
1517// DIMENSION ADAPTER - Cross-dimension embedding compatibility
1518// =============================================================================
1519
1520/// Adapter for cross-dimension embedding compatibility.
1521///
1522/// Enables searching across databases with different embedding dimensions
1523/// (e.g., 1024, 2048, 4096) by expanding or contracting embeddings.
1524///
1525/// # Strategies
1526/// - **Expand**: Zero-pad smaller embeddings to target dimension
1527/// - **Contract**: Truncate or project larger embeddings to target dimension
1528///
1529/// # Example
1530/// ```rust,ignore
1531/// let adapter = DimensionAdapter::new(1024, 4096);
1532/// let expanded = adapter.expand(small_embedding);  // 1024 -> 4096
1533///
1534/// let adapter = DimensionAdapter::new(4096, 1024);
1535/// let contracted = adapter.contract(large_embedding);  // 4096 -> 1024
1536/// ```
1537#[derive(Debug, Clone)]
1538pub struct DimensionAdapter {
1539    /// Source embedding dimension
1540    pub source_dim: usize,
1541    /// Target embedding dimension
1542    pub target_dim: usize,
1543}
1544
1545impl DimensionAdapter {
1546    /// Create a new dimension adapter
1547    pub fn new(source_dim: usize, target_dim: usize) -> Self {
1548        Self {
1549            source_dim,
1550            target_dim,
1551        }
1552    }
1553
1554    /// Check if adaptation is needed
1555    pub fn needs_adaptation(&self) -> bool {
1556        self.source_dim != self.target_dim
1557    }
1558
1559    /// Adapt embedding to target dimension (auto-detect expand/contract)
1560    pub fn adapt(&self, embedding: Vec<f32>) -> Vec<f32> {
1561        if embedding.len() == self.target_dim {
1562            return embedding;
1563        }
1564
1565        if embedding.len() < self.target_dim {
1566            self.expand(embedding)
1567        } else {
1568            self.contract(embedding)
1569        }
1570    }
1571
1572    /// Expand smaller embeddings to target dimension via zero-padding.
1573    ///
1574    /// Uses normalized zero-padding to minimize impact on cosine similarity.
1575    pub fn expand(&self, embedding: Vec<f32>) -> Vec<f32> {
1576        if embedding.len() >= self.target_dim {
1577            return embedding[..self.target_dim].to_vec();
1578        }
1579
1580        let mut padded = embedding;
1581        padded.resize(self.target_dim, 0.0);
1582
1583        // Re-normalize to unit length for cosine similarity
1584        self.normalize(&mut padded);
1585        padded
1586    }
1587
1588    /// Contract larger embeddings to target dimension.
1589    ///
1590    /// Uses averaging of consecutive elements for dimensions that are powers of 2,
1591    /// otherwise falls back to truncation.
1592    pub fn contract(&self, embedding: Vec<f32>) -> Vec<f32> {
1593        if embedding.len() <= self.target_dim {
1594            return embedding;
1595        }
1596
1597        // For power-of-2 reductions (4096->2048, 2048->1024), use averaging
1598        // This preserves more information than truncation
1599        if self.is_power_of_two_reduction(embedding.len()) {
1600            self.average_reduction(embedding)
1601        } else {
1602            // Fallback to truncation
1603            embedding[..self.target_dim].to_vec()
1604        }
1605    }
1606
1607    /// Check if this is a clean power-of-2 reduction (e.g., 4096->2048)
1608    fn is_power_of_two_reduction(&self, source_len: usize) -> bool {
1609        source_len > self.target_dim
1610            && source_len.is_power_of_two()
1611            && self.target_dim.is_power_of_two()
1612            && source_len.is_multiple_of(self.target_dim)
1613    }
1614
1615    /// Reduce by averaging consecutive elements (preserves information better than truncation)
1616    fn average_reduction(&self, embedding: Vec<f32>) -> Vec<f32> {
1617        let factor = embedding.len() / self.target_dim;
1618        let mut result = Vec::with_capacity(self.target_dim);
1619
1620        for chunk in embedding.chunks(factor) {
1621            let sum: f32 = chunk.iter().sum();
1622            result.push(sum / factor as f32);
1623        }
1624
1625        // Re-normalize
1626        self.normalize(&mut result);
1627        result
1628    }
1629
1630    /// Normalize vector to unit length (L2 norm)
1631    fn normalize(&self, vec: &mut [f32]) {
1632        let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
1633        if norm > 1e-10 {
1634            for v in vec.iter_mut() {
1635                *v /= norm;
1636            }
1637        }
1638    }
1639}
1640
1641/// Perform cross-dimension search by adapting query embedding
1642pub fn cross_dimension_search_adapt(query_embedding: Vec<f32>, target_dim: usize) -> Vec<f32> {
1643    let adapter = DimensionAdapter::new(query_embedding.len(), target_dim);
1644    adapter.adapt(query_embedding)
1645}
1646
1647#[cfg(test)]
1648mod dimension_adapter_tests {
1649    use super::*;
1650
1651    #[test]
1652    fn test_expand_1024_to_4096() {
1653        let adapter = DimensionAdapter::new(1024, 4096);
1654        let small = vec![0.1f32; 1024];
1655        let expanded = adapter.expand(small);
1656
1657        assert_eq!(expanded.len(), 4096);
1658        // First 1024 should be non-zero (after normalization)
1659        assert!(expanded[0].abs() > 1e-10);
1660        // Last elements should be zero
1661        assert!(expanded[4095].abs() < 1e-10);
1662    }
1663
1664    #[test]
1665    fn test_contract_4096_to_1024() {
1666        let adapter = DimensionAdapter::new(4096, 1024);
1667        let large = vec![0.1f32; 4096];
1668        let contracted = adapter.contract(large);
1669
1670        assert_eq!(contracted.len(), 1024);
1671        // Should be normalized
1672        let norm: f32 = contracted.iter().map(|x| x * x).sum::<f32>().sqrt();
1673        assert!((norm - 1.0).abs() < 1e-5);
1674    }
1675
1676    #[test]
1677    fn test_adapt_auto_detect() {
1678        let adapter = DimensionAdapter::new(1024, 4096);
1679
1680        // Small to large (expand)
1681        let small = vec![0.1f32; 1024];
1682        let result = adapter.adapt(small);
1683        assert_eq!(result.len(), 4096);
1684
1685        // Large to small (contract)
1686        let adapter = DimensionAdapter::new(4096, 1024);
1687        let large = vec![0.1f32; 4096];
1688        let result = adapter.adapt(large);
1689        assert_eq!(result.len(), 1024);
1690    }
1691
1692    #[test]
1693    fn test_no_adaptation_needed() {
1694        let adapter = DimensionAdapter::new(4096, 4096);
1695        assert!(!adapter.needs_adaptation());
1696
1697        let embedding = vec![0.1f32; 4096];
1698        let result = adapter.adapt(embedding.clone());
1699        assert_eq!(result, embedding);
1700    }
1701
1702    #[test]
1703    fn test_average_reduction_preserves_info() {
1704        let adapter = DimensionAdapter::new(4096, 2048);
1705
1706        // Create embedding with distinct values
1707        let large: Vec<f32> = (0..4096).map(|i| i as f32 / 4096.0).collect();
1708        let contracted = adapter.contract(large);
1709
1710        assert_eq!(contracted.len(), 2048);
1711        // Averaged values should be between min and max of source chunks
1712        // After normalization, should be unit length
1713        let norm: f32 = contracted.iter().map(|x| x * x).sum::<f32>().sqrt();
1714        assert!((norm - 1.0).abs() < 1e-5);
1715    }
1716}