Skip to main content

rust_memex/embeddings/
mod.rs

1//! Universal embedding client with config-driven provider cascade.
2//!
3//! Supports any OpenAI-compatible embedding API (Ollama, vLLM, TEI, etc.)
4//! Providers are tried in priority order until one responds.
5//!
6//! # Example config.toml
7//! ```toml
8//! [embeddings]
9//! required_dimension = 2560
10//! max_batch_chars = 32000
11//! max_batch_items = 16
12//!
13//! [[embeddings.providers]]
14//! name = "ollama-local"
15//! base_url = "http://localhost:11434"
16//! model = "qwen3-embedding:4b"
17//! priority = 1
18//!
19//! [[embeddings.providers]]
20//! name = "dragon"
21//! base_url = "http://dragon:12345"
22//! model = "Qwen/Qwen3-Embedding-4B"
23//! priority = 2
24//! ```
25
26use anyhow::{Result, anyhow};
27use reqwest::Client;
28use serde::{Deserialize, Serialize};
29use std::time::Duration;
30
31pub const DEFAULT_REQUIRED_DIMENSION: usize = 2560;
32pub const DEFAULT_OLLAMA_EMBEDDING_MODEL: &str = "qwen3-embedding:4b";
33const DEFAULT_MAX_BATCH_RETRIES: usize = 10;
34const DEFAULT_MAX_BATCH_BACKOFF_SECS: u64 = 30;
35
36// =============================================================================
37// REQUEST/RESPONSE TYPES (OpenAI-compatible)
38// =============================================================================
39
40#[derive(Debug, Serialize)]
41struct EmbeddingRequest {
42    input: Vec<String>,
43    model: String,
44}
45
46#[derive(Debug, Deserialize)]
47struct EmbeddingResponse {
48    data: Vec<EmbeddingData>,
49}
50
51#[derive(Debug, Deserialize)]
52struct EmbeddingData {
53    embedding: Vec<f32>,
54}
55
56#[derive(Debug, Serialize)]
57struct RerankRequest {
58    query: String,
59    documents: Vec<String>,
60    model: String,
61}
62
63#[derive(Debug, Deserialize)]
64struct RerankResponse {
65    results: Vec<RerankResult>,
66}
67
68#[derive(Debug, Deserialize)]
69struct RerankResult {
70    index: usize,
71    score: f32,
72}
73
74// =============================================================================
75// PROVIDER CONFIGURATION
76// =============================================================================
77
78/// Single embedding provider configuration
79#[derive(Debug, Clone, Deserialize, Serialize, Default)]
80pub struct ProviderConfig {
81    /// Human-readable name for logging
82    #[serde(default)]
83    pub name: String,
84    /// Base URL (e.g., "http://localhost:11434")
85    #[serde(default)]
86    pub base_url: String,
87    /// Model name to use
88    #[serde(default)]
89    pub model: String,
90    /// Priority (1 = highest, tried first)
91    #[serde(default = "default_priority")]
92    pub priority: u8,
93    /// Embedding endpoint path (default: /v1/embeddings)
94    #[serde(default = "default_embeddings_endpoint")]
95    pub endpoint: String,
96}
97
98fn default_priority() -> u8 {
99    10
100}
101
102fn default_embeddings_endpoint() -> String {
103    "/v1/embeddings".to_string()
104}
105
106fn env_usize(name: &str, default: usize) -> usize {
107    std::env::var(name)
108        .ok()
109        .and_then(|value| value.parse::<usize>().ok())
110        .filter(|value| *value > 0)
111        .unwrap_or(default)
112}
113
114fn env_u64(name: &str, default: u64) -> u64 {
115    std::env::var(name)
116        .ok()
117        .and_then(|value| value.parse::<u64>().ok())
118        .filter(|value| *value > 0)
119        .unwrap_or(default)
120}
121
122/// Reranker configuration (optional, separate from embedders)
123#[derive(Debug, Clone, Deserialize, Serialize, Default)]
124pub struct RerankerConfig {
125    /// Base URL for reranker service
126    pub base_url: Option<String>,
127    /// Model name
128    pub model: Option<String>,
129    /// Endpoint path (default: /v1/rerank)
130    #[serde(default = "default_rerank_endpoint")]
131    pub endpoint: String,
132}
133
134fn default_rerank_endpoint() -> String {
135    "/v1/rerank".to_string()
136}
137
138fn default_dimension() -> usize {
139    DEFAULT_REQUIRED_DIMENSION
140}
141
142fn default_max_batch_chars() -> usize {
143    128000 // Increased 4x for better GPU utilization
144}
145
146fn default_max_batch_items() -> usize {
147    64 // Increased 4x - fewer API calls, better throughput
148}
149
150fn build_provider_endpoint(base_url: &str, endpoint: &str) -> String {
151    let base_url = base_url.trim_end_matches('/');
152    let endpoint = endpoint.trim();
153    if endpoint.starts_with('/') {
154        format!("{}{}", base_url, endpoint)
155    } else {
156        format!("{}/{}", base_url, endpoint)
157    }
158}
159
160/// Complete embedding configuration
161#[derive(Debug, Clone, Deserialize, Serialize)]
162pub struct EmbeddingConfig {
163    /// Required vector dimension (mismatch corrupts database!)
164    #[serde(default = "default_dimension")]
165    pub required_dimension: usize,
166    /// Maximum characters per embedding batch to avoid OOM (default: 32000)
167    #[serde(default = "default_max_batch_chars")]
168    pub max_batch_chars: usize,
169    /// Maximum items per embedding batch (default: 16)
170    #[serde(default = "default_max_batch_items")]
171    pub max_batch_items: usize,
172    /// List of providers to try in priority order
173    #[serde(default)]
174    pub providers: Vec<ProviderConfig>,
175    /// Optional reranker configuration
176    #[serde(default)]
177    pub reranker: RerankerConfig,
178}
179
180impl Default for EmbeddingConfig {
181    fn default() -> Self {
182        Self {
183            required_dimension: default_dimension(),
184            max_batch_chars: default_max_batch_chars(),
185            max_batch_items: default_max_batch_items(),
186            providers: vec![
187                ProviderConfig {
188                    name: "ollama-local".to_string(),
189                    base_url: "http://localhost:11434".to_string(),
190                    model: DEFAULT_OLLAMA_EMBEDDING_MODEL.to_string(),
191                    priority: 1,
192                    endpoint: default_embeddings_endpoint(),
193                },
194                ProviderConfig {
195                    name: "dragon".to_string(),
196                    base_url: "http://dragon:12345".to_string(),
197                    model: "Qwen/Qwen3-Embedding-4B".to_string(),
198                    priority: 2,
199                    endpoint: default_embeddings_endpoint(),
200                },
201            ],
202            reranker: RerankerConfig::default(),
203        }
204    }
205}
206
207impl EmbeddingConfig {
208    /// Returns the name of the first (highest priority) provider
209    pub fn provider_name(&self) -> String {
210        self.providers
211            .first()
212            .map(|p| p.name.clone())
213            .unwrap_or_else(|| "none".to_string())
214    }
215
216    /// Returns the model name of the first (highest priority) provider
217    pub fn model_name(&self) -> String {
218        self.providers
219            .first()
220            .map(|p| p.model.clone())
221            .unwrap_or_else(|| "none".to_string())
222    }
223
224    /// Alias for required_dimension for API compatibility
225    pub fn dimension(&self) -> usize {
226        self.required_dimension
227    }
228}
229
230// =============================================================================
231// LEGACY CONFIG (backward compatibility)
232// =============================================================================
233
234/// Legacy MLX configuration - deprecated, use EmbeddingConfig instead
235#[derive(Debug, Clone)]
236pub struct MlxConfig {
237    pub disabled: bool,
238    pub local_port: u16,
239    pub dragon_url: String,
240    pub dragon_port: u16,
241    pub embedder_model: String,
242    pub reranker_model: String,
243    pub reranker_port_offset: u16,
244    pub max_batch_chars: usize,
245    pub max_batch_items: usize,
246}
247
248/// Options for merging file config into MlxConfig
249#[derive(Debug, Clone, Default)]
250pub struct MlxMergeOptions {
251    pub disabled: Option<bool>,
252    pub local_port: Option<u16>,
253    pub dragon_url: Option<String>,
254    pub dragon_port: Option<u16>,
255    pub embedder_model: Option<String>,
256    pub reranker_model: Option<String>,
257    pub reranker_port_offset: Option<u16>,
258}
259
260impl Default for MlxConfig {
261    fn default() -> Self {
262        Self {
263            disabled: false,
264            local_port: 12345,
265            dragon_url: "http://dragon".to_string(),
266            dragon_port: 12345,
267            embedder_model: "Qwen/Qwen3-Embedding-4B".to_string(),
268            reranker_model: "Qwen/Qwen3-Reranker-4B".to_string(),
269            reranker_port_offset: 1,
270            max_batch_chars: default_max_batch_chars(),
271            max_batch_items: default_max_batch_items(),
272        }
273    }
274}
275
276impl MlxConfig {
277    /// Create config from environment variables (legacy support)
278    pub fn from_env() -> Self {
279        let disabled = std::env::var("DISABLE_MLX")
280            .map(|v| v == "1" || v.to_lowercase() == "true")
281            .unwrap_or(false);
282
283        let local_port = std::env::var("EMBEDDER_PORT")
284            .ok()
285            .and_then(|s| s.parse().ok())
286            .unwrap_or(12345);
287
288        let dragon_url =
289            std::env::var("DRAGON_BASE_URL").unwrap_or_else(|_| "http://dragon".to_string());
290
291        let dragon_port = std::env::var("DRAGON_EMBEDDER_PORT")
292            .ok()
293            .and_then(|s| s.parse().ok())
294            .unwrap_or(local_port);
295
296        let reranker_port_offset = std::env::var("RERANKER_PORT")
297            .ok()
298            .and_then(|s| s.parse::<u16>().ok())
299            .map(|rp| rp.saturating_sub(local_port))
300            .unwrap_or(1);
301
302        let embedder_model = std::env::var("EMBEDDER_MODEL")
303            .unwrap_or_else(|_| "Qwen/Qwen3-Embedding-4B".to_string());
304
305        let reranker_model = std::env::var("RERANKER_MODEL")
306            .unwrap_or_else(|_| "Qwen/Qwen3-Reranker-4B".to_string());
307
308        let max_batch_chars = std::env::var("MLX_MAX_BATCH_CHARS")
309            .ok()
310            .and_then(|s| s.parse().ok())
311            .unwrap_or(32000);
312
313        let max_batch_items = std::env::var("MLX_MAX_BATCH_ITEMS")
314            .ok()
315            .and_then(|s| s.parse().ok())
316            .unwrap_or(16);
317
318        Self {
319            disabled,
320            local_port,
321            dragon_url,
322            dragon_port,
323            embedder_model,
324            reranker_model,
325            reranker_port_offset,
326            max_batch_chars,
327            max_batch_items,
328        }
329    }
330
331    /// Merge with values from file config
332    pub fn merge_file_config(&mut self, opts: MlxMergeOptions) {
333        if let Some(v) = opts.disabled {
334            self.disabled = v;
335        }
336        if let Some(v) = opts.local_port {
337            self.local_port = v;
338        }
339        if let Some(v) = opts.dragon_url {
340            self.dragon_url = v;
341        }
342        if let Some(v) = opts.dragon_port {
343            self.dragon_port = v;
344        }
345        if let Some(v) = opts.embedder_model {
346            self.embedder_model = v;
347        }
348        if let Some(v) = opts.reranker_model {
349            self.reranker_model = v;
350        }
351        if let Some(v) = opts.reranker_port_offset {
352            self.reranker_port_offset = v;
353        }
354    }
355
356    /// Convert legacy config to new EmbeddingConfig
357    pub fn to_embedding_config(&self) -> EmbeddingConfig {
358        let reranker_port = self.local_port + self.reranker_port_offset;
359        let required_dimension = DEFAULT_REQUIRED_DIMENSION;
360
361        EmbeddingConfig {
362            required_dimension,
363            max_batch_chars: self.max_batch_chars,
364            max_batch_items: self.max_batch_items,
365            providers: vec![
366                ProviderConfig {
367                    name: "local".to_string(),
368                    base_url: format!("http://localhost:{}", self.local_port),
369                    model: self.embedder_model.clone(),
370                    priority: 1,
371                    endpoint: default_embeddings_endpoint(),
372                },
373                ProviderConfig {
374                    name: "dragon".to_string(),
375                    base_url: format!("{}:{}", self.dragon_url, self.dragon_port),
376                    model: self.embedder_model.clone(),
377                    priority: 2,
378                    endpoint: default_embeddings_endpoint(),
379                },
380            ],
381            reranker: RerankerConfig {
382                base_url: Some(format!("{}:{}", self.dragon_url, reranker_port)),
383                model: Some(self.reranker_model.clone()),
384                endpoint: default_rerank_endpoint(),
385            },
386        }
387    }
388
389    /// Set batch limits
390    pub fn with_batch_limits(mut self, max_chars: usize, max_items: usize) -> Self {
391        self.max_batch_chars = max_chars;
392        self.max_batch_items = max_items;
393        self
394    }
395}
396
397// =============================================================================
398// EMBEDDING CLIENT
399// =============================================================================
400
401/// Universal embedding client with provider cascade
402#[derive(Clone)]
403pub struct EmbeddingClient {
404    client: Client,
405    embedder_url: String,
406    embedder_model: String,
407    reranker_url: Option<String>,
408    reranker_model: Option<String>,
409    /// Which provider we're connected to
410    connected_to: String,
411    /// Expected dimension (for validation)
412    required_dimension: usize,
413    /// Maximum characters per embedding batch
414    max_batch_chars: usize,
415    /// Maximum items per embedding batch
416    max_batch_items: usize,
417}
418
419// Type alias for backward compatibility
420pub type MLXBridge = EmbeddingClient;
421
422impl EmbeddingClient {
423    /// Create client with config-driven provider cascade
424    pub async fn new(config: &EmbeddingConfig) -> Result<Self> {
425        if config.providers.is_empty() {
426            return Err(anyhow!(
427                "No embedding providers configured! Add providers to [embeddings.providers]"
428            ));
429        }
430
431        // Long timeout for large embedding batches (100+ chunks can take minutes)
432        let client = Client::builder()
433            .timeout(Duration::from_secs(300))
434            .connect_timeout(Duration::from_secs(10))
435            .build()?;
436
437        // Sort providers by priority
438        let mut providers = config.providers.clone();
439        providers.sort_by_key(|p| p.priority);
440
441        // Try each provider in order using the real embedding endpoint.
442        let mut tried = Vec::new();
443        for provider in &providers {
444            let base_url = provider.base_url.trim_end_matches('/');
445            let provider_name = if provider.name.trim().is_empty() {
446                "<unnamed-provider>"
447            } else {
448                provider.name.as_str()
449            };
450            let model = provider.model.trim();
451            let embedder_url = build_provider_endpoint(base_url, &provider.endpoint);
452
453            match probe_provider_dimension(&client, provider).await {
454                Ok(actual_dim) if actual_dim == config.required_dimension => {
455                    tracing::info!(
456                        "Embedding: Connected to {} ({}) with model '{}' [{} dims]",
457                        provider_name,
458                        embedder_url,
459                        model,
460                        actual_dim
461                    );
462
463                    // Build reranker URL if configured
464                    let (reranker_url, reranker_model) =
465                        if let Some(ref rr_base) = config.reranker.base_url {
466                            (
467                                Some(format!(
468                                    "{}{}",
469                                    rr_base.trim_end_matches('/'),
470                                    config.reranker.endpoint
471                                )),
472                                config.reranker.model.clone(),
473                            )
474                        } else {
475                            (None, None)
476                        };
477
478                    return Ok(Self {
479                        client,
480                        embedder_url,
481                        embedder_model: provider.model.clone(),
482                        reranker_url,
483                        reranker_model,
484                        connected_to: provider.name.clone(),
485                        required_dimension: config.required_dimension,
486                        max_batch_chars: config.max_batch_chars,
487                        max_batch_items: config.max_batch_items,
488                    });
489                }
490                Ok(actual_dim) => {
491                    let failure = format!(
492                        "- {} ({} model='{}'): the configured embedding endpoint returned {} dims, but config.required_dimension={}.\n  Action: set [embeddings].required_dimension = {} or choose a {}-dim model.",
493                        provider_name,
494                        embedder_url,
495                        model,
496                        actual_dim,
497                        config.required_dimension,
498                        actual_dim,
499                        config.required_dimension
500                    );
501                    tracing::error!("Embedding: validation failed: {}", failure);
502                    tried.push(failure);
503                }
504                Err(e) => {
505                    let failure = format!(
506                        "- {} ({} model='{}'): {}",
507                        provider_name, embedder_url, model, e
508                    );
509                    tracing::warn!("Embedding: provider probe failed: {}", failure);
510                    tried.push(failure);
511                }
512            }
513        }
514
515        // All providers failed
516        Err(anyhow!(
517            "No embedding provider passed validation for required_dimension={}. \
518             Each provider must succeed on its configured embedding endpoint before rust-memex will start.\nTried:\n{}",
519            config.required_dimension,
520            tried.join("\n")
521        ))
522    }
523
524    /// Create from legacy MlxConfig (backward compatibility)
525    pub async fn from_legacy(config: &MlxConfig) -> Result<Self> {
526        if config.disabled {
527            return Err(anyhow!(
528                "Embedding disabled via config. No fallback available!"
529            ));
530        }
531        tracing::warn!("Using legacy [mlx] config - please migrate to [embeddings.providers]");
532        let embedding_config = config.to_embedding_config();
533        Self::new(&embedding_config).await
534    }
535
536    /// Legacy constructor from env vars only
537    pub async fn from_env() -> Result<Self> {
538        let config = MlxConfig::from_env();
539        Self::from_legacy(&config).await
540    }
541
542    /// Get which provider we're connected to
543    pub fn connected_to(&self) -> &str {
544        &self.connected_to
545    }
546
547    /// Get required dimension
548    pub fn required_dimension(&self) -> usize {
549        self.required_dimension
550    }
551
552    /// Get the current runtime batch limits used for embedding requests.
553    pub fn batch_limits(&self) -> (usize, usize) {
554        (self.max_batch_chars, self.max_batch_items)
555    }
556
557    /// Clone the client while overriding the runtime batch limits.
558    pub fn clone_with_batch_limits(&self, max_chars: usize, max_items: usize) -> Self {
559        let mut cloned = self.clone();
560        cloned.max_batch_chars = max_chars.max(1);
561        cloned.max_batch_items = max_items.max(1);
562        cloned
563    }
564
565    /// Create a stub client for tests that don't need real embeddings.
566    /// The client will fail on any actual embed() call, but lets McpCore
567    /// be constructed and dispatch protocol-level requests.
568    #[doc(hidden)]
569    pub fn stub_for_tests() -> Self {
570        Self {
571            client: reqwest::Client::new(),
572            embedder_url: "http://stub:0/v1/embeddings".to_string(),
573            embedder_model: "stub".to_string(),
574            reranker_url: None,
575            reranker_model: None,
576            connected_to: "stub-test".to_string(),
577            required_dimension: 4096,
578            max_batch_chars: 32000,
579            max_batch_items: 16,
580        }
581    }
582
583    pub async fn embed(&mut self, text: &str) -> Result<Vec<f32>> {
584        let text_preview: String = text.chars().take(100).collect();
585        tracing::debug!(
586            "Embedding single text ({} chars): {}{}",
587            text.chars().count(),
588            text_preview,
589            if text.chars().count() > 100 {
590                "..."
591            } else {
592                ""
593            }
594        );
595
596        let request = EmbeddingRequest {
597            input: vec![text.to_string()],
598            model: self.embedder_model.clone(),
599        };
600
601        let response = match self
602            .client
603            .post(&self.embedder_url)
604            .json(&request)
605            .send()
606            .await
607        {
608            Ok(resp) => resp,
609            Err(e) => {
610                tracing::error!(
611                    "Embedding request failed: {:?}\n  URL: {}\n  Model: {}",
612                    e,
613                    self.embedder_url,
614                    self.embedder_model
615                );
616                return Err(anyhow!("Embedding request failed: {}", e));
617            }
618        };
619
620        let status = response.status();
621        let response_text = response.text().await.unwrap_or_else(|e| {
622            tracing::warn!("Failed to read response body: {:?}", e);
623            "<failed to read body>".to_string()
624        });
625
626        if !status.is_success() {
627            tracing::error!(
628                "Embedding API error (HTTP {}):\n  URL: {}\n  Model: {}\n  Response: {}",
629                status,
630                self.embedder_url,
631                self.embedder_model,
632                response_text
633            );
634            return Err(anyhow!(
635                "Embedding API error (HTTP {}): {}",
636                status,
637                response_text
638            ));
639        }
640
641        let parsed: EmbeddingResponse = match serde_json::from_str(&response_text) {
642            Ok(r) => r,
643            Err(e) => {
644                tracing::error!(
645                    "Failed to parse embedding response: {:?}\n  Response body: {}",
646                    e,
647                    response_text
648                );
649                return Err(anyhow!("Failed to parse embedding response: {}", e));
650            }
651        };
652
653        let embedding = parsed
654            .data
655            .into_iter()
656            .next()
657            .map(|d| d.embedding)
658            .ok_or_else(|| {
659                tracing::error!("No embedding returned in response: {}", response_text);
660                anyhow!("No embedding returned")
661            })?;
662
663        // Validate dimension
664        if embedding.len() != self.required_dimension {
665            tracing::error!(
666                "Dimension mismatch! Expected {}, got {}. Model: {}",
667                self.required_dimension,
668                embedding.len(),
669                self.embedder_model
670            );
671            return Err(anyhow!(
672                "Dimension mismatch! Expected {}, got {}. This would corrupt the database!",
673                self.required_dimension,
674                embedding.len()
675            ));
676        }
677
678        tracing::debug!("Successfully embedded text ({} dims)", embedding.len());
679        Ok(embedding)
680    }
681
682    /// Embed a batch of texts with intelligent batching to avoid OOM.
683    ///
684    /// Large texts are chunked and only the first chunk is embedded.
685    /// Batches are split to stay under max_batch_chars and max_batch_items.
686    /// Failed chunks are retried individually with exponential backoff.
687    pub async fn embed_batch(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
688        if texts.is_empty() {
689            return Ok(vec![]);
690        }
691
692        let mut all_embeddings = Vec::with_capacity(texts.len());
693        let mut current_batch: Vec<String> = Vec::new();
694        let mut current_batch_indices: Vec<usize> = Vec::new();
695        let mut current_chars = 0;
696
697        // Max chars per individual text (half of batch limit for safety)
698        let max_text_chars = self.max_batch_chars / 2;
699
700        // Prepare all texts first
701        let prepared_texts: Vec<String> = texts
702            .iter()
703            .map(|text| {
704                let char_count = text.chars().count();
705                if char_count > max_text_chars {
706                    tracing::debug!(
707                        "Text too large ({} chars), truncating to {} chars",
708                        char_count,
709                        max_text_chars
710                    );
711                    truncate_at_boundary(text, max_text_chars)
712                } else {
713                    text.clone()
714                }
715            })
716            .collect();
717
718        // Pre-allocate result vector with None
719        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
720        let mut failed_indices: Vec<usize> = Vec::new();
721
722        for (idx, text_to_embed) in prepared_texts.iter().enumerate() {
723            let text_len = text_to_embed.chars().count();
724
725            // Check if we need to flush current batch
726            if !current_batch.is_empty()
727                && (current_chars + text_len > self.max_batch_chars
728                    || current_batch.len() >= self.max_batch_items)
729            {
730                // Flush current batch with retry
731                match self.embed_batch_internal(&current_batch).await {
732                    Ok(batch_embeddings) => {
733                        for (i, emb) in batch_embeddings.into_iter().enumerate() {
734                            if let Some(orig_idx) = current_batch_indices.get(i) {
735                                results[*orig_idx] = Some(emb);
736                            }
737                        }
738                    }
739                    Err(e) => {
740                        tracing::warn!(
741                            "Batch embedding failed for {} texts, will retry individually: {}",
742                            current_batch.len(),
743                            e
744                        );
745                        failed_indices.extend(current_batch_indices.iter().copied());
746                    }
747                }
748                current_batch.clear();
749                current_batch_indices.clear();
750                current_chars = 0;
751            }
752
753            current_batch.push(text_to_embed.clone());
754            current_batch_indices.push(idx);
755            current_chars += text_len;
756        }
757
758        // Flush remaining batch
759        if !current_batch.is_empty() {
760            match self.embed_batch_internal(&current_batch).await {
761                Ok(batch_embeddings) => {
762                    for (i, emb) in batch_embeddings.into_iter().enumerate() {
763                        if let Some(orig_idx) = current_batch_indices.get(i) {
764                            results[*orig_idx] = Some(emb);
765                        }
766                    }
767                }
768                Err(e) => {
769                    tracing::warn!(
770                        "Batch embedding failed for {} texts, will retry individually: {}",
771                        current_batch.len(),
772                        e
773                    );
774                    failed_indices.extend(current_batch_indices.iter().copied());
775                }
776            }
777        }
778
779        // Retry failed chunks individually with exponential backoff
780        const MAX_RETRIES: usize = 3;
781        for idx in failed_indices {
782            let text = &prepared_texts[idx];
783            let mut attempts = 0;
784            let mut last_error = String::new();
785
786            while attempts < MAX_RETRIES {
787                match self.embed(text).await {
788                    Ok(embedding) => {
789                        results[idx] = Some(embedding);
790                        tracing::info!(
791                            "Retry succeeded for chunk {} after {} attempts",
792                            idx,
793                            attempts + 1
794                        );
795                        break;
796                    }
797                    Err(e) => {
798                        attempts += 1;
799                        last_error = e.to_string();
800                        tracing::warn!(
801                            "Embed attempt {}/{} failed for chunk {}: {}",
802                            attempts,
803                            MAX_RETRIES,
804                            idx,
805                            e
806                        );
807                        if attempts < MAX_RETRIES {
808                            // Exponential backoff: 100ms, 200ms, 400ms
809                            let delay_ms = 100 * (1 << attempts);
810                            tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
811                        }
812                    }
813                }
814            }
815
816            if results[idx].is_none() {
817                tracing::error!(
818                    "Chunk {} failed after {} retries: {}",
819                    idx,
820                    MAX_RETRIES,
821                    last_error
822                );
823                return Err(anyhow!(
824                    "Failed to embed chunk {} after {} retries: {}",
825                    idx,
826                    MAX_RETRIES,
827                    last_error
828                ));
829            }
830        }
831
832        // Collect all results - all should be Some at this point
833        for (idx, opt) in results.iter().enumerate() {
834            match opt {
835                Some(emb) => all_embeddings.push(emb.clone()),
836                None => {
837                    return Err(anyhow!(
838                        "Internal error: missing embedding for chunk {}",
839                        idx
840                    ));
841                }
842            }
843        }
844
845        Ok(all_embeddings)
846    }
847
848    /// Internal batch embedding - sends directly to server
849    async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
850        let total_chars: usize = texts.iter().map(|t| t.chars().count()).sum();
851
852        tracing::debug!(
853            "Embedding batch: {} texts, {} chars total",
854            texts.len(),
855            total_chars
856        );
857
858        // Log first few chars of each text in trace mode for debugging
859        for (i, text) in texts.iter().enumerate() {
860            let preview: String = text.chars().take(50).collect();
861            tracing::trace!(
862                "  Batch[{}]: {} chars - {}{}",
863                i,
864                text.chars().count(),
865                preview,
866                if text.chars().count() > 50 { "..." } else { "" }
867            );
868        }
869
870        let request = EmbeddingRequest {
871            input: texts.to_vec(),
872            model: self.embedder_model.clone(),
873        };
874
875        // Retry with exponential backoff: 1s, 2s, 4s, 8s, 16s, 30s (max by default).
876        // Operators can lower this for deterministic failure-policy tests or short smokes.
877        let max_batch_retries = env_usize(
878            "RUST_MEMEX_EMBED_BATCH_MAX_RETRIES",
879            DEFAULT_MAX_BATCH_RETRIES,
880        );
881        let max_backoff_secs = env_u64(
882            "RUST_MEMEX_EMBED_BATCH_MAX_BACKOFF_SECS",
883            DEFAULT_MAX_BATCH_BACKOFF_SECS,
884        );
885        let mut attempt = 0;
886
887        loop {
888            attempt += 1;
889            let response = match self
890                .client
891                .post(&self.embedder_url)
892                .json(&request)
893                .send()
894                .await
895            {
896                Ok(resp) => resp,
897                Err(e) => {
898                    if attempt >= max_batch_retries {
899                        tracing::error!(
900                            "Batch embedding failed after {} retries: {:?}\n  URL: {}\n  Model: {}",
901                            max_batch_retries,
902                            e,
903                            self.embedder_url,
904                            self.embedder_model
905                        );
906                        return Err(anyhow!(
907                            "Embedding request failed after {} retries: {}",
908                            max_batch_retries,
909                            e
910                        ));
911                    }
912
913                    // Exponential backoff with cap
914                    let backoff_secs = (1u64 << attempt.min(5)).min(max_backoff_secs);
915                    tracing::warn!(
916                        "Embedding request failed (attempt {}/{}), retrying in {}s: {}",
917                        attempt,
918                        max_batch_retries,
919                        backoff_secs,
920                        e
921                    );
922                    tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
923                    continue;
924                }
925            };
926
927            // Success - process response
928            if !response.status().is_success() {
929                let status = response.status();
930                let body = response.text().await.unwrap_or_default();
931
932                if attempt >= max_batch_retries {
933                    tracing::error!(
934                        "Embedding API error after {} retries: {} - {}",
935                        max_batch_retries,
936                        status,
937                        body
938                    );
939                    return Err(anyhow!("Embedding API error: {} - {}", status, body));
940                }
941
942                let backoff_secs = (1u64 << attempt.min(5)).min(max_backoff_secs);
943                tracing::warn!(
944                    "Embedding API error (attempt {}/{}), retrying in {}s: {} - {}",
945                    attempt,
946                    max_batch_retries,
947                    backoff_secs,
948                    status,
949                    body
950                );
951                tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
952                continue;
953            }
954
955            // Parse response
956            let embedding_response: EmbeddingResponse = match response.json().await {
957                Ok(r) => r,
958                Err(e) => {
959                    if attempt >= max_batch_retries {
960                        return Err(anyhow!("Failed to parse embedding response: {}", e));
961                    }
962                    let backoff_secs = (1u64 << attempt.min(5)).min(max_backoff_secs);
963                    tracing::warn!(
964                        "Failed to parse response (attempt {}/{}), retrying in {}s: {}",
965                        attempt,
966                        max_batch_retries,
967                        backoff_secs,
968                        e
969                    );
970                    tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
971                    continue;
972                }
973            };
974
975            // Validate dimensions
976            let embeddings: Vec<Vec<f32>> = embedding_response
977                .data
978                .into_iter()
979                .map(|d| d.embedding)
980                .collect();
981
982            if embeddings.len() != texts.len() {
983                return Err(anyhow!(
984                    "Embedding count mismatch: got {} embeddings for {} texts",
985                    embeddings.len(),
986                    texts.len()
987                ));
988            }
989
990            if let Some(first) = embeddings.first()
991                && first.len() != self.required_dimension
992            {
993                return Err(anyhow!(
994                    "Dimension mismatch: expected {}, got {}",
995                    self.required_dimension,
996                    first.len()
997                ));
998            }
999
1000            return Ok(embeddings);
1001        }
1002    }
1003
1004    pub async fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<(usize, f32)>> {
1005        let reranker_url = self.reranker_url.as_ref().ok_or_else(|| {
1006            anyhow!("Reranker not configured. Add [embeddings.reranker] to config.")
1007        })?;
1008        let reranker_model = self
1009            .reranker_model
1010            .as_ref()
1011            .ok_or_else(|| anyhow!("Reranker model not configured."))?;
1012
1013        let query_preview: String = query.chars().take(100).collect();
1014        tracing::debug!(
1015            "Reranking {} documents for query: {}{}",
1016            documents.len(),
1017            query_preview,
1018            if query.chars().count() > 100 {
1019                "..."
1020            } else {
1021                ""
1022            }
1023        );
1024
1025        let request = RerankRequest {
1026            query: query.to_string(),
1027            documents: documents.to_vec(),
1028            model: reranker_model.clone(),
1029        };
1030
1031        let response = match self.client.post(reranker_url).json(&request).send().await {
1032            Ok(resp) => resp,
1033            Err(e) => {
1034                tracing::error!(
1035                    "Rerank request failed: {:?}\n  URL: {}\n  Model: {}\n  Query: {}\n  Documents: {}",
1036                    e,
1037                    reranker_url,
1038                    reranker_model,
1039                    query_preview,
1040                    documents.len()
1041                );
1042                return Err(anyhow!("Rerank request failed: {}", e));
1043            }
1044        };
1045
1046        let status = response.status();
1047        let response_text = response.text().await.unwrap_or_else(|e| {
1048            tracing::warn!("Failed to read rerank response body: {:?}", e);
1049            "<failed to read body>".to_string()
1050        });
1051
1052        if !status.is_success() {
1053            tracing::error!(
1054                "Rerank API error (HTTP {}):\n  URL: {}\n  Model: {}\n  Response: {}",
1055                status,
1056                reranker_url,
1057                reranker_model,
1058                response_text
1059            );
1060            return Err(anyhow!(
1061                "Rerank API error (HTTP {}): {}",
1062                status,
1063                response_text
1064            ));
1065        }
1066
1067        let parsed: RerankResponse = match serde_json::from_str(&response_text) {
1068            Ok(r) => r,
1069            Err(e) => {
1070                tracing::error!(
1071                    "Failed to parse rerank response: {:?}\n  Response body: {}",
1072                    e,
1073                    response_text
1074                );
1075                return Err(anyhow!("Failed to parse rerank response: {}", e));
1076            }
1077        };
1078
1079        tracing::debug!("Rerank complete: {} documents scored", parsed.results.len());
1080
1081        Ok(parsed
1082            .results
1083            .into_iter()
1084            .map(|r| (r.index, r.score))
1085            .collect())
1086    }
1087}
1088
1089pub(crate) async fn probe_provider_dimension(
1090    client: &Client,
1091    provider: &ProviderConfig,
1092) -> Result<usize> {
1093    let base_url = provider.base_url.trim_end_matches('/');
1094    if base_url.is_empty() {
1095        return Err(anyhow!("provider base_url is empty"));
1096    }
1097
1098    let endpoint = provider.endpoint.trim();
1099    if endpoint.is_empty() {
1100        return Err(anyhow!("provider endpoint is empty"));
1101    }
1102
1103    let model = provider.model.trim();
1104    if model.is_empty() {
1105        return Err(anyhow!("provider model is empty"));
1106    }
1107
1108    let embedder_url = build_provider_endpoint(base_url, endpoint);
1109    let request = EmbeddingRequest {
1110        input: vec!["dimension probe".to_string()],
1111        model: model.to_string(),
1112    };
1113
1114    let response = client
1115        .post(&embedder_url)
1116        .json(&request)
1117        .timeout(Duration::from_secs(30))
1118        .send()
1119        .await
1120        .map_err(|e| anyhow!("POST {} failed: {}", embedder_url, e))?;
1121
1122    let status = response.status();
1123    let body = response.text().await.unwrap_or_default();
1124    if !status.is_success() {
1125        let hint = if status.as_u16() == 404 {
1126            " Check provider.endpoint; Ollama and OpenAI-compatible servers typically use /v1/embeddings."
1127        } else {
1128            ""
1129        };
1130        return Err(anyhow!(
1131            "POST {} returned {} for model '{}': {}{}",
1132            embedder_url,
1133            status,
1134            model,
1135            body.chars().take(300).collect::<String>(),
1136            hint
1137        ));
1138    }
1139
1140    let embed_response: EmbeddingResponse = serde_json::from_str(&body).map_err(|e| {
1141        anyhow!(
1142            "POST {} returned non-embedding JSON for model '{}': {} (body: {})",
1143            embedder_url,
1144            model,
1145            e,
1146            body.chars().take(200).collect::<String>()
1147        )
1148    })?;
1149
1150    embed_response
1151        .data
1152        .first()
1153        .map(|d| d.embedding.len())
1154        .ok_or_else(|| {
1155            anyhow!(
1156                "POST {} returned no embeddings for model '{}'",
1157                embedder_url,
1158                model
1159            )
1160        })
1161}
1162
1163/// Truncate text at a word/sentence boundary to avoid cutting mid-word (UTF-8 safe)
1164fn truncate_at_boundary(text: &str, max_chars: usize) -> String {
1165    let char_count = text.chars().count();
1166    if char_count <= max_chars {
1167        return text.to_string();
1168    }
1169
1170    // Get byte index of max_chars-th character (UTF-8 safe)
1171    let byte_idx = text
1172        .char_indices()
1173        .nth(max_chars)
1174        .map(|(idx, _)| idx)
1175        .unwrap_or(text.len());
1176
1177    let truncated = &text[..byte_idx];
1178
1179    // Try to find a sentence boundary first (prefer complete sentences)
1180    let half_byte_idx = text
1181        .char_indices()
1182        .nth(max_chars / 2)
1183        .map(|(idx, _)| idx)
1184        .unwrap_or(0);
1185
1186    if let Some(pos) = truncated.rfind(['.', '!', '?', '\n'])
1187        && pos > half_byte_idx
1188    {
1189        return text[..=pos].to_string();
1190    }
1191
1192    // Fall back to word boundary
1193    if let Some(pos) = truncated.rfind([' ', '\t', '\n']) {
1194        return text[..pos].to_string();
1195    }
1196
1197    // Last resort: hard truncate
1198    truncated.to_string()
1199}
1200
1201// =============================================================================
1202// TOKEN-AWARE VALIDATION
1203// =============================================================================
1204//
1205// Embedding models have token limits. qwen3-embedding:8b ships with a
1206// 40 960-token context window (verified via `ollama show qwen3-embedding:8b`).
1207// Earlier defaults (8192) were a conservative carry-over from older models and
1208// caused premature truncation of long transcripts. We keep a 6 000-token margin
1209// under the real limit (~35 000) for prompt overhead and language drift.
1210// These utilities estimate token counts and validate chunks before embedding.
1211// =============================================================================
1212
1213/// Default safe ceiling for qwen3-embedding context window.
1214/// Real model limit is 40 960 tokens; we leave ~6k margin for safety.
1215pub const DEFAULT_MAX_TOKENS: usize = 35_000;
1216
1217/// Token estimation configuration
1218#[derive(Debug, Clone)]
1219pub struct TokenConfig {
1220    /// Maximum tokens for the embedding model
1221    pub max_tokens: usize,
1222    /// Average characters per token (varies by language)
1223    /// English: ~4 chars/token, Polish/multilingual: ~2-3 chars/token
1224    pub chars_per_token: f32,
1225}
1226
1227impl Default for TokenConfig {
1228    fn default() -> Self {
1229        Self {
1230            max_tokens: DEFAULT_MAX_TOKENS,
1231            chars_per_token: 3.0,
1232        }
1233    }
1234}
1235
1236impl TokenConfig {
1237    /// Create config for English-only content
1238    pub fn english() -> Self {
1239        Self {
1240            max_tokens: DEFAULT_MAX_TOKENS,
1241            chars_per_token: 4.0,
1242        }
1243    }
1244
1245    /// Create config for multilingual/Polish content
1246    pub fn for_multilingual_text() -> Self {
1247        Self {
1248            max_tokens: DEFAULT_MAX_TOKENS,
1249            chars_per_token: 2.5,
1250        }
1251    }
1252
1253    /// Create config with custom max tokens
1254    pub fn with_max_tokens(mut self, max: usize) -> Self {
1255        self.max_tokens = max;
1256        self
1257    }
1258}
1259
1260/// Estimate token count for text
1261///
1262/// This is a heuristic approximation. For precise counting,
1263/// use the actual tokenizer (tiktoken, sentencepiece, etc.)
1264pub fn estimate_tokens(text: &str, config: &TokenConfig) -> usize {
1265    let char_count = text.chars().count();
1266    (char_count as f32 / config.chars_per_token).ceil() as usize
1267}
1268
1269/// Validate that a chunk fits within token limits
1270///
1271/// Returns Ok(()) if chunk is within limits, Err with details otherwise.
1272pub fn validate_chunk_tokens(chunk: &str, config: &TokenConfig) -> Result<()> {
1273    let estimated = estimate_tokens(chunk, config);
1274
1275    if estimated > config.max_tokens {
1276        return Err(anyhow!(
1277            "Chunk exceeds token limit: ~{} tokens > {} max (text: {} chars). \
1278             Consider reducing chunk_size or enabling truncation.",
1279            estimated,
1280            config.max_tokens,
1281            chunk.chars().count()
1282        ));
1283    }
1284
1285    Ok(())
1286}
1287
1288/// Calculate safe chunk size in characters for given token limit
1289pub fn safe_chunk_size(config: &TokenConfig) -> usize {
1290    // Use 80% of max to leave room for context prefix
1291    let safe_tokens = (config.max_tokens as f32 * 0.8) as usize;
1292    (safe_tokens as f32 * config.chars_per_token) as usize
1293}
1294
1295/// Truncate text to fit within token limit
1296pub fn truncate_to_token_limit(text: &str, config: &TokenConfig) -> String {
1297    let safe_chars = safe_chunk_size(config);
1298
1299    if text.chars().count() <= safe_chars {
1300        return text.to_string();
1301    }
1302
1303    truncate_at_boundary(text, safe_chars)
1304}
1305
1306/// Validate a batch of texts and return which ones exceed limits
1307pub fn validate_batch_tokens(texts: &[String], config: &TokenConfig) -> Vec<(usize, usize)> {
1308    texts
1309        .iter()
1310        .enumerate()
1311        .filter_map(|(idx, text)| {
1312            let estimated = estimate_tokens(text, config);
1313            if estimated > config.max_tokens {
1314                Some((idx, estimated))
1315            } else {
1316                None
1317            }
1318        })
1319        .collect()
1320}
1321
1322#[cfg(test)]
1323mod tests {
1324    use super::*;
1325    use axum::{Json, Router, extract::State, routing::post};
1326    use serde_json::json;
1327
1328    async fn mock_embeddings(State(dim): State<usize>) -> Json<serde_json::Value> {
1329        Json(json!({
1330            "data": [{
1331                "embedding": vec![0.25_f32; dim]
1332            }]
1333        }))
1334    }
1335
1336    async fn spawn_mock_embedding_server(dim: usize) -> String {
1337        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1338        let addr = listener.local_addr().unwrap();
1339        let app = Router::new()
1340            .route("/v1/embeddings", post(mock_embeddings))
1341            .with_state(dim);
1342
1343        tokio::spawn(async move {
1344            axum::serve(listener, app).await.unwrap();
1345        });
1346
1347        tokio::time::sleep(Duration::from_millis(10)).await;
1348
1349        format!("http://{}", addr)
1350    }
1351
1352    #[test]
1353    fn test_provider_sorting() {
1354        let mut providers = [
1355            ProviderConfig {
1356                name: "low".into(),
1357                base_url: "http://a".into(),
1358                model: "m".into(),
1359                priority: 10,
1360                endpoint: "/v1/embeddings".into(),
1361            },
1362            ProviderConfig {
1363                name: "high".into(),
1364                base_url: "http://b".into(),
1365                model: "m".into(),
1366                priority: 1,
1367                endpoint: "/v1/embeddings".into(),
1368            },
1369        ];
1370        providers.sort_by_key(|p| p.priority);
1371        assert_eq!(providers[0].name, "high");
1372        assert_eq!(providers[1].name, "low");
1373    }
1374
1375    #[test]
1376    fn test_legacy_conversion() {
1377        let legacy = MlxConfig {
1378            disabled: false,
1379            local_port: 12345,
1380            dragon_url: "http://dragon".into(),
1381            dragon_port: 12345,
1382            embedder_model: "test-model".into(),
1383            reranker_model: "rerank-model".into(),
1384            reranker_port_offset: 1,
1385            max_batch_chars: 32000,
1386            max_batch_items: 16,
1387        };
1388        let config = legacy.to_embedding_config();
1389        assert_eq!(config.providers.len(), 2);
1390        assert_eq!(config.providers[0].base_url, "http://localhost:12345");
1391        assert!(config.reranker.base_url.is_some());
1392        assert_eq!(config.max_batch_chars, 32000);
1393        assert_eq!(config.max_batch_items, 16);
1394    }
1395
1396    #[test]
1397    fn test_default_config() {
1398        let config = EmbeddingConfig::default();
1399        assert_eq!(config.required_dimension, DEFAULT_REQUIRED_DIMENSION);
1400        assert_eq!(config.max_batch_chars, 128000); // 4x larger for GPU efficiency
1401        assert_eq!(config.max_batch_items, 64); // 4x more items per batch
1402        assert!(!config.providers.is_empty());
1403        assert_eq!(config.providers[0].model, DEFAULT_OLLAMA_EMBEDDING_MODEL);
1404    }
1405
1406    #[tokio::test]
1407    async fn test_probe_provider_dimension_reads_actual_dimension() {
1408        let base_url = spawn_mock_embedding_server(2560).await;
1409        let client = Client::new();
1410        let provider = ProviderConfig {
1411            name: "mock".into(),
1412            base_url,
1413            model: "mock-embedder".into(),
1414            priority: 1,
1415            endpoint: "/v1/embeddings".into(),
1416        };
1417
1418        let dim = probe_provider_dimension(&client, &provider).await.unwrap();
1419        assert_eq!(dim, 2560);
1420    }
1421
1422    #[tokio::test]
1423    async fn test_embedding_client_fails_fast_on_dimension_mismatch() {
1424        let base_url = spawn_mock_embedding_server(2560).await;
1425        let config = EmbeddingConfig {
1426            required_dimension: 1024,
1427            providers: vec![ProviderConfig {
1428                name: "mock".into(),
1429                base_url,
1430                model: "mock-embedder".into(),
1431                priority: 1,
1432                endpoint: "/v1/embeddings".into(),
1433            }],
1434            ..EmbeddingConfig::default()
1435        };
1436
1437        let err = EmbeddingClient::new(&config)
1438            .await
1439            .err()
1440            .expect("dimension mismatch should fail")
1441            .to_string();
1442        assert!(err.contains("returned 2560 dims"));
1443        assert!(err.contains("required_dimension=1024"));
1444    }
1445
1446    #[test]
1447    fn test_truncate_at_boundary() {
1448        // Test sentence boundary
1449        let text = "Hello world. This is a test.";
1450        let truncated = truncate_at_boundary(text, 15);
1451        assert_eq!(truncated, "Hello world.");
1452
1453        // Test word boundary fallback
1454        let text = "Hello world this is a test";
1455        let truncated = truncate_at_boundary(text, 15);
1456        assert_eq!(truncated, "Hello world");
1457
1458        // Test no truncation needed
1459        let text = "Short text";
1460        let truncated = truncate_at_boundary(text, 100);
1461        assert_eq!(truncated, "Short text");
1462    }
1463
1464    #[test]
1465    fn test_token_estimation() {
1466        let config = TokenConfig::default();
1467
1468        // ~3 chars per token (default multilingual)
1469        let text = "Hello world"; // 11 chars -> ~4 tokens
1470        let tokens = estimate_tokens(text, &config);
1471        assert!((3..=5).contains(&tokens));
1472
1473        // English config (4 chars per token)
1474        let english_config = TokenConfig::english();
1475        let tokens = estimate_tokens(text, &english_config);
1476        assert!((2..=4).contains(&tokens));
1477    }
1478
1479    #[test]
1480    fn default_token_ceiling_stays_above_long_transcript_floor() {
1481        let config = TokenConfig::default();
1482
1483        assert_eq!(DEFAULT_MAX_TOKENS, 35_000);
1484        assert_eq!(config.max_tokens, DEFAULT_MAX_TOKENS);
1485        assert!(config.max_tokens >= 35_000);
1486    }
1487
1488    #[test]
1489    fn test_chunk_validation() {
1490        let config = TokenConfig::default().with_max_tokens(100);
1491
1492        // Short text should pass
1493        let short = "Hello world";
1494        assert!(validate_chunk_tokens(short, &config).is_ok());
1495
1496        // Long text should fail
1497        let long = "a".repeat(1000); // Way more than 100 * 3 = 300 chars
1498        assert!(validate_chunk_tokens(&long, &config).is_err());
1499    }
1500
1501    #[test]
1502    fn test_safe_chunk_size() {
1503        let config = TokenConfig::default(); // 35_000 tokens, 3 chars/token
1504
1505        let safe = safe_chunk_size(&config);
1506        // 35_000 * 0.8 * 3 = 84_000 chars
1507        assert!(safe > 80_000 && safe < 90_000);
1508    }
1509
1510    #[test]
1511    fn test_batch_validation() {
1512        let config = TokenConfig::default().with_max_tokens(10);
1513
1514        let texts = vec![
1515            "short".to_string(),      // OK
1516            "a".repeat(100),          // Too long
1517            "also short".to_string(), // OK
1518            "b".repeat(200),          // Too long
1519        ];
1520
1521        let failures = validate_batch_tokens(&texts, &config);
1522        assert_eq!(failures.len(), 2);
1523        assert_eq!(failures[0].0, 1); // Index 1
1524        assert_eq!(failures[1].0, 3); // Index 3
1525    }
1526}
1527
1528// =============================================================================
1529// DIMENSION ADAPTER - Cross-dimension embedding compatibility
1530// =============================================================================
1531
1532/// Adapter for cross-dimension embedding compatibility.
1533///
1534/// Enables searching across databases with different embedding dimensions
1535/// (e.g., 1024, 2048, 4096) by expanding or contracting embeddings.
1536///
1537/// # Strategies
1538/// - **Expand**: Zero-pad smaller embeddings to target dimension
1539/// - **Contract**: Truncate or project larger embeddings to target dimension
1540///
1541/// # Example
1542/// ```rust,ignore
1543/// let adapter = DimensionAdapter::new(1024, 4096);
1544/// let expanded = adapter.expand(small_embedding);  // 1024 -> 4096
1545///
1546/// let adapter = DimensionAdapter::new(4096, 1024);
1547/// let contracted = adapter.contract(large_embedding);  // 4096 -> 1024
1548/// ```
1549#[derive(Debug, Clone)]
1550pub struct DimensionAdapter {
1551    /// Source embedding dimension
1552    pub source_dim: usize,
1553    /// Target embedding dimension
1554    pub target_dim: usize,
1555}
1556
1557impl DimensionAdapter {
1558    /// Create a new dimension adapter
1559    pub fn new(source_dim: usize, target_dim: usize) -> Self {
1560        Self {
1561            source_dim,
1562            target_dim,
1563        }
1564    }
1565
1566    /// Check if adaptation is needed
1567    pub fn needs_adaptation(&self) -> bool {
1568        self.source_dim != self.target_dim
1569    }
1570
1571    /// Adapt embedding to target dimension (auto-detect expand/contract)
1572    pub fn adapt(&self, embedding: Vec<f32>) -> Vec<f32> {
1573        if embedding.len() == self.target_dim {
1574            return embedding;
1575        }
1576
1577        if embedding.len() < self.target_dim {
1578            self.expand(embedding)
1579        } else {
1580            self.contract(embedding)
1581        }
1582    }
1583
1584    /// Expand smaller embeddings to target dimension via zero-padding.
1585    ///
1586    /// Uses normalized zero-padding to minimize impact on cosine similarity.
1587    pub fn expand(&self, embedding: Vec<f32>) -> Vec<f32> {
1588        if embedding.len() >= self.target_dim {
1589            return embedding[..self.target_dim].to_vec();
1590        }
1591
1592        let mut padded = embedding;
1593        padded.resize(self.target_dim, 0.0);
1594
1595        // Re-normalize to unit length for cosine similarity
1596        self.normalize(&mut padded);
1597        padded
1598    }
1599
1600    /// Contract larger embeddings to target dimension.
1601    ///
1602    /// Uses averaging of consecutive elements for dimensions that are powers of 2,
1603    /// otherwise falls back to truncation.
1604    pub fn contract(&self, embedding: Vec<f32>) -> Vec<f32> {
1605        if embedding.len() <= self.target_dim {
1606            return embedding;
1607        }
1608
1609        // For power-of-2 reductions (4096->2048, 2048->1024), use averaging
1610        // This preserves more information than truncation
1611        if self.is_power_of_two_reduction(embedding.len()) {
1612            self.average_reduction(embedding)
1613        } else {
1614            // Fallback to truncation
1615            embedding[..self.target_dim].to_vec()
1616        }
1617    }
1618
1619    /// Check if this is a clean power-of-2 reduction (e.g., 4096->2048)
1620    fn is_power_of_two_reduction(&self, source_len: usize) -> bool {
1621        source_len > self.target_dim
1622            && source_len.is_power_of_two()
1623            && self.target_dim.is_power_of_two()
1624            && source_len.is_multiple_of(self.target_dim)
1625    }
1626
1627    /// Reduce by averaging consecutive elements (preserves information better than truncation)
1628    fn average_reduction(&self, embedding: Vec<f32>) -> Vec<f32> {
1629        let factor = embedding.len() / self.target_dim;
1630        let mut result = Vec::with_capacity(self.target_dim);
1631
1632        for chunk in embedding.chunks(factor) {
1633            let sum: f32 = chunk.iter().sum();
1634            result.push(sum / factor as f32);
1635        }
1636
1637        // Re-normalize
1638        self.normalize(&mut result);
1639        result
1640    }
1641
1642    /// Normalize vector to unit length (L2 norm)
1643    fn normalize(&self, vec: &mut [f32]) {
1644        let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
1645        if norm > 1e-10 {
1646            for v in vec.iter_mut() {
1647                *v /= norm;
1648            }
1649        }
1650    }
1651}
1652
1653/// Perform cross-dimension search by adapting query embedding
1654pub fn cross_dimension_search_adapt(query_embedding: Vec<f32>, target_dim: usize) -> Vec<f32> {
1655    let adapter = DimensionAdapter::new(query_embedding.len(), target_dim);
1656    adapter.adapt(query_embedding)
1657}
1658
1659#[cfg(test)]
1660mod dimension_adapter_tests {
1661    use super::*;
1662
1663    #[test]
1664    fn test_expand_1024_to_4096() {
1665        let adapter = DimensionAdapter::new(1024, 4096);
1666        let small = vec![0.1f32; 1024];
1667        let expanded = adapter.expand(small);
1668
1669        assert_eq!(expanded.len(), 4096);
1670        // First 1024 should be non-zero (after normalization)
1671        assert!(expanded[0].abs() > 1e-10);
1672        // Last elements should be zero
1673        assert!(expanded[4095].abs() < 1e-10);
1674    }
1675
1676    #[test]
1677    fn test_contract_4096_to_1024() {
1678        let adapter = DimensionAdapter::new(4096, 1024);
1679        let large = vec![0.1f32; 4096];
1680        let contracted = adapter.contract(large);
1681
1682        assert_eq!(contracted.len(), 1024);
1683        // Should be normalized
1684        let norm: f32 = contracted.iter().map(|x| x * x).sum::<f32>().sqrt();
1685        assert!((norm - 1.0).abs() < 1e-5);
1686    }
1687
1688    #[test]
1689    fn test_adapt_auto_detect() {
1690        let adapter = DimensionAdapter::new(1024, 4096);
1691
1692        // Small to large (expand)
1693        let small = vec![0.1f32; 1024];
1694        let result = adapter.adapt(small);
1695        assert_eq!(result.len(), 4096);
1696
1697        // Large to small (contract)
1698        let adapter = DimensionAdapter::new(4096, 1024);
1699        let large = vec![0.1f32; 4096];
1700        let result = adapter.adapt(large);
1701        assert_eq!(result.len(), 1024);
1702    }
1703
1704    #[test]
1705    fn test_no_adaptation_needed() {
1706        let adapter = DimensionAdapter::new(4096, 4096);
1707        assert!(!adapter.needs_adaptation());
1708
1709        let embedding = vec![0.1f32; 4096];
1710        let result = adapter.adapt(embedding.clone());
1711        assert_eq!(result, embedding);
1712    }
1713
1714    #[test]
1715    fn test_average_reduction_preserves_info() {
1716        let adapter = DimensionAdapter::new(4096, 2048);
1717
1718        // Create embedding with distinct values
1719        let large: Vec<f32> = (0..4096).map(|i| i as f32 / 4096.0).collect();
1720        let contracted = adapter.contract(large);
1721
1722        assert_eq!(contracted.len(), 2048);
1723        // Averaged values should be between min and max of source chunks
1724        // After normalization, should be unit length
1725        let norm: f32 = contracted.iter().map(|x| x * x).sum::<f32>().sqrt();
1726        assert!((norm - 1.0).abs() < 1e-5);
1727    }
1728}