1use anyhow::{Result, anyhow};
27use reqwest::Client;
28use serde::{Deserialize, Serialize};
29use std::time::Duration;
30
31pub const DEFAULT_REQUIRED_DIMENSION: usize = 2560;
32pub const DEFAULT_OLLAMA_EMBEDDING_MODEL: &str = "qwen3-embedding:4b";
33
34pub fn infer_embedding_dimension(model: &str) -> Option<usize> {
35 let model = model.to_ascii_lowercase();
36
37 if model.contains("qwen3-vl-embedding") {
38 Some(2048)
39 } else if model.contains("qwen3-embedding") {
40 if model.contains("0.6b") {
41 Some(1024)
42 } else if model.contains("4b") {
43 Some(2560)
44 } else if model.contains("8b") {
45 Some(4096)
46 } else {
47 None
48 }
49 } else if model.contains("bge-m3") || model.contains("mxbai-embed") {
50 Some(1024)
51 } else if model.contains("nomic-embed") {
52 Some(768)
53 } else if model.contains("all-minilm") {
54 Some(384)
55 } else {
56 None
57 }
58}
59
60#[derive(Debug, Serialize)]
65struct EmbeddingRequest {
66 input: Vec<String>,
67 model: String,
68}
69
70#[derive(Debug, Deserialize)]
71struct EmbeddingResponse {
72 data: Vec<EmbeddingData>,
73}
74
75#[derive(Debug, Deserialize)]
76struct EmbeddingData {
77 embedding: Vec<f32>,
78}
79
80#[derive(Debug, Serialize)]
81struct RerankRequest {
82 query: String,
83 documents: Vec<String>,
84 model: String,
85}
86
87#[derive(Debug, Deserialize)]
88struct RerankResponse {
89 results: Vec<RerankResult>,
90}
91
92#[derive(Debug, Deserialize)]
93struct RerankResult {
94 index: usize,
95 score: f32,
96}
97
98#[derive(Debug, Clone, Deserialize, Serialize, Default)]
104pub struct ProviderConfig {
105 #[serde(default)]
107 pub name: String,
108 #[serde(default)]
110 pub base_url: String,
111 #[serde(default)]
113 pub model: String,
114 #[serde(default = "default_priority")]
116 pub priority: u8,
117 #[serde(default = "default_embeddings_endpoint")]
119 pub endpoint: String,
120}
121
122fn default_priority() -> u8 {
123 10
124}
125
126fn default_embeddings_endpoint() -> String {
127 "/v1/embeddings".to_string()
128}
129
130#[derive(Debug, Clone, Deserialize, Serialize, Default)]
132pub struct RerankerConfig {
133 pub base_url: Option<String>,
135 pub model: Option<String>,
137 #[serde(default = "default_rerank_endpoint")]
139 pub endpoint: String,
140}
141
142fn default_rerank_endpoint() -> String {
143 "/v1/rerank".to_string()
144}
145
146fn default_dimension() -> usize {
147 DEFAULT_REQUIRED_DIMENSION
148}
149
150fn default_max_batch_chars() -> usize {
151 128000 }
153
154fn default_max_batch_items() -> usize {
155 64 }
157
158fn build_provider_endpoint(base_url: &str, endpoint: &str) -> String {
159 let base_url = base_url.trim_end_matches('/');
160 let endpoint = endpoint.trim();
161 if endpoint.starts_with('/') {
162 format!("{}{}", base_url, endpoint)
163 } else {
164 format!("{}/{}", base_url, endpoint)
165 }
166}
167
168#[derive(Debug, Clone, Deserialize, Serialize)]
170pub struct EmbeddingConfig {
171 #[serde(default = "default_dimension")]
173 pub required_dimension: usize,
174 #[serde(default = "default_max_batch_chars")]
176 pub max_batch_chars: usize,
177 #[serde(default = "default_max_batch_items")]
179 pub max_batch_items: usize,
180 #[serde(default)]
182 pub providers: Vec<ProviderConfig>,
183 #[serde(default)]
185 pub reranker: RerankerConfig,
186}
187
188impl Default for EmbeddingConfig {
189 fn default() -> Self {
190 Self {
191 required_dimension: default_dimension(),
192 max_batch_chars: default_max_batch_chars(),
193 max_batch_items: default_max_batch_items(),
194 providers: vec![
195 ProviderConfig {
196 name: "ollama-local".to_string(),
197 base_url: "http://localhost:11434".to_string(),
198 model: DEFAULT_OLLAMA_EMBEDDING_MODEL.to_string(),
199 priority: 1,
200 endpoint: default_embeddings_endpoint(),
201 },
202 ProviderConfig {
203 name: "dragon".to_string(),
204 base_url: "http://dragon:12345".to_string(),
205 model: "Qwen/Qwen3-Embedding-4B".to_string(),
206 priority: 2,
207 endpoint: default_embeddings_endpoint(),
208 },
209 ],
210 reranker: RerankerConfig::default(),
211 }
212 }
213}
214
215impl EmbeddingConfig {
216 pub fn provider_name(&self) -> String {
218 self.providers
219 .first()
220 .map(|p| p.name.clone())
221 .unwrap_or_else(|| "none".to_string())
222 }
223
224 pub fn model_name(&self) -> String {
226 self.providers
227 .first()
228 .map(|p| p.model.clone())
229 .unwrap_or_else(|| "none".to_string())
230 }
231
232 pub fn dimension(&self) -> usize {
234 self.required_dimension
235 }
236}
237
238#[derive(Debug, Clone)]
244pub struct MlxConfig {
245 pub disabled: bool,
246 pub local_port: u16,
247 pub dragon_url: String,
248 pub dragon_port: u16,
249 pub embedder_model: String,
250 pub reranker_model: String,
251 pub reranker_port_offset: u16,
252 pub max_batch_chars: usize,
253 pub max_batch_items: usize,
254}
255
256#[derive(Debug, Clone, Default)]
258pub struct MlxMergeOptions {
259 pub disabled: Option<bool>,
260 pub local_port: Option<u16>,
261 pub dragon_url: Option<String>,
262 pub dragon_port: Option<u16>,
263 pub embedder_model: Option<String>,
264 pub reranker_model: Option<String>,
265 pub reranker_port_offset: Option<u16>,
266}
267
268impl Default for MlxConfig {
269 fn default() -> Self {
270 Self {
271 disabled: false,
272 local_port: 12345,
273 dragon_url: "http://dragon".to_string(),
274 dragon_port: 12345,
275 embedder_model: "Qwen/Qwen3-Embedding-4B".to_string(),
276 reranker_model: "Qwen/Qwen3-Reranker-4B".to_string(),
277 reranker_port_offset: 1,
278 max_batch_chars: default_max_batch_chars(),
279 max_batch_items: default_max_batch_items(),
280 }
281 }
282}
283
284impl MlxConfig {
285 pub fn from_env() -> Self {
287 let disabled = std::env::var("DISABLE_MLX")
288 .map(|v| v == "1" || v.to_lowercase() == "true")
289 .unwrap_or(false);
290
291 let local_port = std::env::var("EMBEDDER_PORT")
292 .ok()
293 .and_then(|s| s.parse().ok())
294 .unwrap_or(12345);
295
296 let dragon_url =
297 std::env::var("DRAGON_BASE_URL").unwrap_or_else(|_| "http://dragon".to_string());
298
299 let dragon_port = std::env::var("DRAGON_EMBEDDER_PORT")
300 .ok()
301 .and_then(|s| s.parse().ok())
302 .unwrap_or(local_port);
303
304 let reranker_port_offset = std::env::var("RERANKER_PORT")
305 .ok()
306 .and_then(|s| s.parse::<u16>().ok())
307 .map(|rp| rp.saturating_sub(local_port))
308 .unwrap_or(1);
309
310 let embedder_model = std::env::var("EMBEDDER_MODEL")
311 .unwrap_or_else(|_| "Qwen/Qwen3-Embedding-4B".to_string());
312
313 let reranker_model = std::env::var("RERANKER_MODEL")
314 .unwrap_or_else(|_| "Qwen/Qwen3-Reranker-4B".to_string());
315
316 let max_batch_chars = std::env::var("MLX_MAX_BATCH_CHARS")
317 .ok()
318 .and_then(|s| s.parse().ok())
319 .unwrap_or(32000);
320
321 let max_batch_items = std::env::var("MLX_MAX_BATCH_ITEMS")
322 .ok()
323 .and_then(|s| s.parse().ok())
324 .unwrap_or(16);
325
326 Self {
327 disabled,
328 local_port,
329 dragon_url,
330 dragon_port,
331 embedder_model,
332 reranker_model,
333 reranker_port_offset,
334 max_batch_chars,
335 max_batch_items,
336 }
337 }
338
339 pub fn merge_file_config(&mut self, opts: MlxMergeOptions) {
341 if let Some(v) = opts.disabled {
342 self.disabled = v;
343 }
344 if let Some(v) = opts.local_port {
345 self.local_port = v;
346 }
347 if let Some(v) = opts.dragon_url {
348 self.dragon_url = v;
349 }
350 if let Some(v) = opts.dragon_port {
351 self.dragon_port = v;
352 }
353 if let Some(v) = opts.embedder_model {
354 self.embedder_model = v;
355 }
356 if let Some(v) = opts.reranker_model {
357 self.reranker_model = v;
358 }
359 if let Some(v) = opts.reranker_port_offset {
360 self.reranker_port_offset = v;
361 }
362 }
363
364 pub fn to_embedding_config(&self) -> EmbeddingConfig {
366 let reranker_port = self.local_port + self.reranker_port_offset;
367 let required_dimension =
368 infer_embedding_dimension(&self.embedder_model).unwrap_or(DEFAULT_REQUIRED_DIMENSION);
369
370 EmbeddingConfig {
371 required_dimension,
372 max_batch_chars: self.max_batch_chars,
373 max_batch_items: self.max_batch_items,
374 providers: vec![
375 ProviderConfig {
376 name: "local".to_string(),
377 base_url: format!("http://localhost:{}", self.local_port),
378 model: self.embedder_model.clone(),
379 priority: 1,
380 endpoint: default_embeddings_endpoint(),
381 },
382 ProviderConfig {
383 name: "dragon".to_string(),
384 base_url: format!("{}:{}", self.dragon_url, self.dragon_port),
385 model: self.embedder_model.clone(),
386 priority: 2,
387 endpoint: default_embeddings_endpoint(),
388 },
389 ],
390 reranker: RerankerConfig {
391 base_url: Some(format!("{}:{}", self.dragon_url, reranker_port)),
392 model: Some(self.reranker_model.clone()),
393 endpoint: default_rerank_endpoint(),
394 },
395 }
396 }
397
398 pub fn with_batch_limits(mut self, max_chars: usize, max_items: usize) -> Self {
400 self.max_batch_chars = max_chars;
401 self.max_batch_items = max_items;
402 self
403 }
404}
405
406pub struct EmbeddingClient {
412 client: Client,
413 embedder_url: String,
414 embedder_model: String,
415 reranker_url: Option<String>,
416 reranker_model: Option<String>,
417 connected_to: String,
419 required_dimension: usize,
421 max_batch_chars: usize,
423 max_batch_items: usize,
425}
426
427pub type MLXBridge = EmbeddingClient;
429
430impl EmbeddingClient {
431 pub async fn new(config: &EmbeddingConfig) -> Result<Self> {
433 if config.providers.is_empty() {
434 return Err(anyhow!(
435 "No embedding providers configured! Add providers to [embeddings.providers]"
436 ));
437 }
438
439 let client = Client::builder()
441 .timeout(Duration::from_secs(300))
442 .connect_timeout(Duration::from_secs(10))
443 .build()?;
444
445 let mut providers = config.providers.clone();
447 providers.sort_by_key(|p| p.priority);
448
449 let mut tried = Vec::new();
451 for provider in &providers {
452 let base_url = provider.base_url.trim_end_matches('/');
453 let provider_name = if provider.name.trim().is_empty() {
454 "<unnamed-provider>"
455 } else {
456 provider.name.as_str()
457 };
458 let model = provider.model.trim();
459 let embedder_url = build_provider_endpoint(base_url, &provider.endpoint);
460
461 match probe_provider_dimension(&client, provider).await {
462 Ok(actual_dim) if actual_dim == config.required_dimension => {
463 tracing::info!(
464 "Embedding: Connected to {} ({}) with model '{}' [{} dims]",
465 provider_name,
466 embedder_url,
467 model,
468 actual_dim
469 );
470
471 let (reranker_url, reranker_model) =
473 if let Some(ref rr_base) = config.reranker.base_url {
474 (
475 Some(format!(
476 "{}{}",
477 rr_base.trim_end_matches('/'),
478 config.reranker.endpoint
479 )),
480 config.reranker.model.clone(),
481 )
482 } else {
483 (None, None)
484 };
485
486 return Ok(Self {
487 client,
488 embedder_url,
489 embedder_model: provider.model.clone(),
490 reranker_url,
491 reranker_model,
492 connected_to: provider.name.clone(),
493 required_dimension: config.required_dimension,
494 max_batch_chars: config.max_batch_chars,
495 max_batch_items: config.max_batch_items,
496 });
497 }
498 Ok(actual_dim) => {
499 let failure = format!(
500 "- {} ({} model='{}'): the configured embedding endpoint returned {} dims, but config.required_dimension={}.\n Action: set [embeddings].required_dimension = {} or choose a {}-dim model.",
501 provider_name,
502 embedder_url,
503 model,
504 actual_dim,
505 config.required_dimension,
506 actual_dim,
507 config.required_dimension
508 );
509 tracing::error!("Embedding: validation failed: {}", failure);
510 tried.push(failure);
511 }
512 Err(e) => {
513 let failure = format!(
514 "- {} ({} model='{}'): {}",
515 provider_name, embedder_url, model, e
516 );
517 tracing::warn!("Embedding: provider probe failed: {}", failure);
518 tried.push(failure);
519 }
520 }
521 }
522
523 Err(anyhow!(
525 "No embedding provider passed validation for required_dimension={}. \
526 Each provider must succeed on its configured embedding endpoint before rmcp-memex will start.\nTried:\n{}",
527 config.required_dimension,
528 tried.join("\n")
529 ))
530 }
531
532 pub async fn from_legacy(config: &MlxConfig) -> Result<Self> {
534 if config.disabled {
535 return Err(anyhow!(
536 "Embedding disabled via config. No fallback available!"
537 ));
538 }
539 tracing::warn!("Using legacy [mlx] config - please migrate to [embeddings.providers]");
540 let embedding_config = config.to_embedding_config();
541 Self::new(&embedding_config).await
542 }
543
544 pub async fn from_env() -> Result<Self> {
546 let config = MlxConfig::from_env();
547 Self::from_legacy(&config).await
548 }
549
550 pub fn connected_to(&self) -> &str {
552 &self.connected_to
553 }
554
555 pub fn required_dimension(&self) -> usize {
557 self.required_dimension
558 }
559
560 #[cfg(test)]
564 pub(crate) fn stub_for_tests() -> Self {
565 Self {
566 client: reqwest::Client::new(),
567 embedder_url: "http://stub:0/v1/embeddings".to_string(),
568 embedder_model: "stub".to_string(),
569 reranker_url: None,
570 reranker_model: None,
571 connected_to: "stub-test".to_string(),
572 required_dimension: 4096,
573 max_batch_chars: 32000,
574 max_batch_items: 16,
575 }
576 }
577
578 pub async fn embed(&mut self, text: &str) -> Result<Vec<f32>> {
579 let text_preview: String = text.chars().take(100).collect();
580 tracing::debug!(
581 "Embedding single text ({} chars): {}{}",
582 text.chars().count(),
583 text_preview,
584 if text.chars().count() > 100 {
585 "..."
586 } else {
587 ""
588 }
589 );
590
591 let request = EmbeddingRequest {
592 input: vec![text.to_string()],
593 model: self.embedder_model.clone(),
594 };
595
596 let response = match self
597 .client
598 .post(&self.embedder_url)
599 .json(&request)
600 .send()
601 .await
602 {
603 Ok(resp) => resp,
604 Err(e) => {
605 tracing::error!(
606 "Embedding request failed: {:?}\n URL: {}\n Model: {}",
607 e,
608 self.embedder_url,
609 self.embedder_model
610 );
611 return Err(anyhow!("Embedding request failed: {}", e));
612 }
613 };
614
615 let status = response.status();
616 let response_text = response.text().await.unwrap_or_else(|e| {
617 tracing::warn!("Failed to read response body: {:?}", e);
618 "<failed to read body>".to_string()
619 });
620
621 if !status.is_success() {
622 tracing::error!(
623 "Embedding API error (HTTP {}):\n URL: {}\n Model: {}\n Response: {}",
624 status,
625 self.embedder_url,
626 self.embedder_model,
627 response_text
628 );
629 return Err(anyhow!(
630 "Embedding API error (HTTP {}): {}",
631 status,
632 response_text
633 ));
634 }
635
636 let parsed: EmbeddingResponse = match serde_json::from_str(&response_text) {
637 Ok(r) => r,
638 Err(e) => {
639 tracing::error!(
640 "Failed to parse embedding response: {:?}\n Response body: {}",
641 e,
642 response_text
643 );
644 return Err(anyhow!("Failed to parse embedding response: {}", e));
645 }
646 };
647
648 let embedding = parsed
649 .data
650 .into_iter()
651 .next()
652 .map(|d| d.embedding)
653 .ok_or_else(|| {
654 tracing::error!("No embedding returned in response: {}", response_text);
655 anyhow!("No embedding returned")
656 })?;
657
658 if embedding.len() != self.required_dimension {
660 tracing::error!(
661 "Dimension mismatch! Expected {}, got {}. Model: {}",
662 self.required_dimension,
663 embedding.len(),
664 self.embedder_model
665 );
666 return Err(anyhow!(
667 "Dimension mismatch! Expected {}, got {}. This would corrupt the database!",
668 self.required_dimension,
669 embedding.len()
670 ));
671 }
672
673 tracing::debug!("Successfully embedded text ({} dims)", embedding.len());
674 Ok(embedding)
675 }
676
677 pub async fn embed_batch(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
683 if texts.is_empty() {
684 return Ok(vec![]);
685 }
686
687 let mut all_embeddings = Vec::with_capacity(texts.len());
688 let mut current_batch: Vec<String> = Vec::new();
689 let mut current_batch_indices: Vec<usize> = Vec::new();
690 let mut current_chars = 0;
691
692 let max_text_chars = self.max_batch_chars / 2;
694
695 let prepared_texts: Vec<String> = texts
697 .iter()
698 .map(|text| {
699 let char_count = text.chars().count();
700 if char_count > max_text_chars {
701 tracing::debug!(
702 "Text too large ({} chars), truncating to {} chars",
703 char_count,
704 max_text_chars
705 );
706 truncate_at_boundary(text, max_text_chars)
707 } else {
708 text.clone()
709 }
710 })
711 .collect();
712
713 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
715 let mut failed_indices: Vec<usize> = Vec::new();
716
717 for (idx, text_to_embed) in prepared_texts.iter().enumerate() {
718 let text_len = text_to_embed.chars().count();
719
720 if !current_batch.is_empty()
722 && (current_chars + text_len > self.max_batch_chars
723 || current_batch.len() >= self.max_batch_items)
724 {
725 match self.embed_batch_internal(¤t_batch).await {
727 Ok(batch_embeddings) => {
728 for (i, emb) in batch_embeddings.into_iter().enumerate() {
729 if let Some(orig_idx) = current_batch_indices.get(i) {
730 results[*orig_idx] = Some(emb);
731 }
732 }
733 }
734 Err(e) => {
735 tracing::warn!(
736 "Batch embedding failed for {} texts, will retry individually: {}",
737 current_batch.len(),
738 e
739 );
740 failed_indices.extend(current_batch_indices.iter().copied());
741 }
742 }
743 current_batch.clear();
744 current_batch_indices.clear();
745 current_chars = 0;
746 }
747
748 current_batch.push(text_to_embed.clone());
749 current_batch_indices.push(idx);
750 current_chars += text_len;
751 }
752
753 if !current_batch.is_empty() {
755 match self.embed_batch_internal(¤t_batch).await {
756 Ok(batch_embeddings) => {
757 for (i, emb) in batch_embeddings.into_iter().enumerate() {
758 if let Some(orig_idx) = current_batch_indices.get(i) {
759 results[*orig_idx] = Some(emb);
760 }
761 }
762 }
763 Err(e) => {
764 tracing::warn!(
765 "Batch embedding failed for {} texts, will retry individually: {}",
766 current_batch.len(),
767 e
768 );
769 failed_indices.extend(current_batch_indices.iter().copied());
770 }
771 }
772 }
773
774 const MAX_RETRIES: usize = 3;
776 for idx in failed_indices {
777 let text = &prepared_texts[idx];
778 let mut attempts = 0;
779 let mut last_error = String::new();
780
781 while attempts < MAX_RETRIES {
782 match self.embed(text).await {
783 Ok(embedding) => {
784 results[idx] = Some(embedding);
785 tracing::info!(
786 "Retry succeeded for chunk {} after {} attempts",
787 idx,
788 attempts + 1
789 );
790 break;
791 }
792 Err(e) => {
793 attempts += 1;
794 last_error = e.to_string();
795 tracing::warn!(
796 "Embed attempt {}/{} failed for chunk {}: {}",
797 attempts,
798 MAX_RETRIES,
799 idx,
800 e
801 );
802 if attempts < MAX_RETRIES {
803 let delay_ms = 100 * (1 << attempts);
805 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
806 }
807 }
808 }
809 }
810
811 if results[idx].is_none() {
812 tracing::error!(
813 "Chunk {} failed after {} retries: {}",
814 idx,
815 MAX_RETRIES,
816 last_error
817 );
818 return Err(anyhow!(
819 "Failed to embed chunk {} after {} retries: {}",
820 idx,
821 MAX_RETRIES,
822 last_error
823 ));
824 }
825 }
826
827 for (idx, opt) in results.iter().enumerate() {
829 match opt {
830 Some(emb) => all_embeddings.push(emb.clone()),
831 None => {
832 return Err(anyhow!(
833 "Internal error: missing embedding for chunk {}",
834 idx
835 ));
836 }
837 }
838 }
839
840 Ok(all_embeddings)
841 }
842
843 async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
845 let total_chars: usize = texts.iter().map(|t| t.chars().count()).sum();
846
847 tracing::debug!(
848 "Embedding batch: {} texts, {} chars total",
849 texts.len(),
850 total_chars
851 );
852
853 for (i, text) in texts.iter().enumerate() {
855 let preview: String = text.chars().take(50).collect();
856 tracing::trace!(
857 " Batch[{}]: {} chars - {}{}",
858 i,
859 text.chars().count(),
860 preview,
861 if text.chars().count() > 50 { "..." } else { "" }
862 );
863 }
864
865 let request = EmbeddingRequest {
866 input: texts.to_vec(),
867 model: self.embedder_model.clone(),
868 };
869
870 const MAX_BATCH_RETRIES: usize = 10;
872 const MAX_BACKOFF_SECS: u64 = 30;
873 let mut attempt = 0;
874
875 loop {
876 attempt += 1;
877 let response = match self
878 .client
879 .post(&self.embedder_url)
880 .json(&request)
881 .send()
882 .await
883 {
884 Ok(resp) => resp,
885 Err(e) => {
886 if attempt >= MAX_BATCH_RETRIES {
887 tracing::error!(
888 "Batch embedding failed after {} retries: {:?}\n URL: {}\n Model: {}",
889 MAX_BATCH_RETRIES,
890 e,
891 self.embedder_url,
892 self.embedder_model
893 );
894 return Err(anyhow!(
895 "Embedding request failed after {} retries: {}",
896 MAX_BATCH_RETRIES,
897 e
898 ));
899 }
900
901 let backoff_secs = (1u64 << attempt.min(5)).min(MAX_BACKOFF_SECS);
903 tracing::warn!(
904 "Embedding request failed (attempt {}/{}), retrying in {}s: {}",
905 attempt,
906 MAX_BATCH_RETRIES,
907 backoff_secs,
908 e
909 );
910 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
911 continue;
912 }
913 };
914
915 if !response.status().is_success() {
917 let status = response.status();
918 let body = response.text().await.unwrap_or_default();
919
920 if attempt >= MAX_BATCH_RETRIES {
921 tracing::error!(
922 "Embedding API error after {} retries: {} - {}",
923 MAX_BATCH_RETRIES,
924 status,
925 body
926 );
927 return Err(anyhow!("Embedding API error: {} - {}", status, body));
928 }
929
930 let backoff_secs = (1u64 << attempt.min(5)).min(MAX_BACKOFF_SECS);
931 tracing::warn!(
932 "Embedding API error (attempt {}/{}), retrying in {}s: {} - {}",
933 attempt,
934 MAX_BATCH_RETRIES,
935 backoff_secs,
936 status,
937 body
938 );
939 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
940 continue;
941 }
942
943 let embedding_response: EmbeddingResponse = match response.json().await {
945 Ok(r) => r,
946 Err(e) => {
947 if attempt >= MAX_BATCH_RETRIES {
948 return Err(anyhow!("Failed to parse embedding response: {}", e));
949 }
950 let backoff_secs = (1u64 << attempt.min(5)).min(MAX_BACKOFF_SECS);
951 tracing::warn!(
952 "Failed to parse response (attempt {}/{}), retrying in {}s: {}",
953 attempt,
954 MAX_BATCH_RETRIES,
955 backoff_secs,
956 e
957 );
958 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
959 continue;
960 }
961 };
962
963 let embeddings: Vec<Vec<f32>> = embedding_response
965 .data
966 .into_iter()
967 .map(|d| d.embedding)
968 .collect();
969
970 if embeddings.len() != texts.len() {
971 return Err(anyhow!(
972 "Embedding count mismatch: got {} embeddings for {} texts",
973 embeddings.len(),
974 texts.len()
975 ));
976 }
977
978 if let Some(first) = embeddings.first()
979 && first.len() != self.required_dimension
980 {
981 return Err(anyhow!(
982 "Dimension mismatch: expected {}, got {}",
983 self.required_dimension,
984 first.len()
985 ));
986 }
987
988 return Ok(embeddings);
989 }
990 }
991
992 pub async fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<(usize, f32)>> {
993 let reranker_url = self.reranker_url.as_ref().ok_or_else(|| {
994 anyhow!("Reranker not configured. Add [embeddings.reranker] to config.")
995 })?;
996 let reranker_model = self
997 .reranker_model
998 .as_ref()
999 .ok_or_else(|| anyhow!("Reranker model not configured."))?;
1000
1001 let query_preview: String = query.chars().take(100).collect();
1002 tracing::debug!(
1003 "Reranking {} documents for query: {}{}",
1004 documents.len(),
1005 query_preview,
1006 if query.chars().count() > 100 {
1007 "..."
1008 } else {
1009 ""
1010 }
1011 );
1012
1013 let request = RerankRequest {
1014 query: query.to_string(),
1015 documents: documents.to_vec(),
1016 model: reranker_model.clone(),
1017 };
1018
1019 let response = match self.client.post(reranker_url).json(&request).send().await {
1020 Ok(resp) => resp,
1021 Err(e) => {
1022 tracing::error!(
1023 "Rerank request failed: {:?}\n URL: {}\n Model: {}\n Query: {}\n Documents: {}",
1024 e,
1025 reranker_url,
1026 reranker_model,
1027 query_preview,
1028 documents.len()
1029 );
1030 return Err(anyhow!("Rerank request failed: {}", e));
1031 }
1032 };
1033
1034 let status = response.status();
1035 let response_text = response.text().await.unwrap_or_else(|e| {
1036 tracing::warn!("Failed to read rerank response body: {:?}", e);
1037 "<failed to read body>".to_string()
1038 });
1039
1040 if !status.is_success() {
1041 tracing::error!(
1042 "Rerank API error (HTTP {}):\n URL: {}\n Model: {}\n Response: {}",
1043 status,
1044 reranker_url,
1045 reranker_model,
1046 response_text
1047 );
1048 return Err(anyhow!(
1049 "Rerank API error (HTTP {}): {}",
1050 status,
1051 response_text
1052 ));
1053 }
1054
1055 let parsed: RerankResponse = match serde_json::from_str(&response_text) {
1056 Ok(r) => r,
1057 Err(e) => {
1058 tracing::error!(
1059 "Failed to parse rerank response: {:?}\n Response body: {}",
1060 e,
1061 response_text
1062 );
1063 return Err(anyhow!("Failed to parse rerank response: {}", e));
1064 }
1065 };
1066
1067 tracing::debug!("Rerank complete: {} documents scored", parsed.results.len());
1068
1069 Ok(parsed
1070 .results
1071 .into_iter()
1072 .map(|r| (r.index, r.score))
1073 .collect())
1074 }
1075}
1076
1077pub(crate) async fn probe_provider_dimension(
1078 client: &Client,
1079 provider: &ProviderConfig,
1080) -> Result<usize> {
1081 let base_url = provider.base_url.trim_end_matches('/');
1082 if base_url.is_empty() {
1083 return Err(anyhow!("provider base_url is empty"));
1084 }
1085
1086 let endpoint = provider.endpoint.trim();
1087 if endpoint.is_empty() {
1088 return Err(anyhow!("provider endpoint is empty"));
1089 }
1090
1091 let model = provider.model.trim();
1092 if model.is_empty() {
1093 return Err(anyhow!("provider model is empty"));
1094 }
1095
1096 let embedder_url = build_provider_endpoint(base_url, endpoint);
1097 let request = EmbeddingRequest {
1098 input: vec!["dimension probe".to_string()],
1099 model: model.to_string(),
1100 };
1101
1102 let response = client
1103 .post(&embedder_url)
1104 .json(&request)
1105 .timeout(Duration::from_secs(30))
1106 .send()
1107 .await
1108 .map_err(|e| anyhow!("POST {} failed: {}", embedder_url, e))?;
1109
1110 let status = response.status();
1111 let body = response.text().await.unwrap_or_default();
1112 if !status.is_success() {
1113 let hint = if status.as_u16() == 404 {
1114 " Check provider.endpoint; Ollama and OpenAI-compatible servers typically use /v1/embeddings."
1115 } else {
1116 ""
1117 };
1118 return Err(anyhow!(
1119 "POST {} returned {} for model '{}': {}{}",
1120 embedder_url,
1121 status,
1122 model,
1123 body.chars().take(300).collect::<String>(),
1124 hint
1125 ));
1126 }
1127
1128 let embed_response: EmbeddingResponse = serde_json::from_str(&body).map_err(|e| {
1129 anyhow!(
1130 "POST {} returned non-embedding JSON for model '{}': {} (body: {})",
1131 embedder_url,
1132 model,
1133 e,
1134 body.chars().take(200).collect::<String>()
1135 )
1136 })?;
1137
1138 embed_response
1139 .data
1140 .first()
1141 .map(|d| d.embedding.len())
1142 .ok_or_else(|| {
1143 anyhow!(
1144 "POST {} returned no embeddings for model '{}'",
1145 embedder_url,
1146 model
1147 )
1148 })
1149}
1150
1151fn truncate_at_boundary(text: &str, max_chars: usize) -> String {
1153 let char_count = text.chars().count();
1154 if char_count <= max_chars {
1155 return text.to_string();
1156 }
1157
1158 let byte_idx = text
1160 .char_indices()
1161 .nth(max_chars)
1162 .map(|(idx, _)| idx)
1163 .unwrap_or(text.len());
1164
1165 let truncated = &text[..byte_idx];
1166
1167 let half_byte_idx = text
1169 .char_indices()
1170 .nth(max_chars / 2)
1171 .map(|(idx, _)| idx)
1172 .unwrap_or(0);
1173
1174 if let Some(pos) = truncated.rfind(['.', '!', '?', '\n'])
1175 && pos > half_byte_idx
1176 {
1177 return text[..=pos].to_string();
1178 }
1179
1180 if let Some(pos) = truncated.rfind([' ', '\t', '\n']) {
1182 return text[..pos].to_string();
1183 }
1184
1185 truncated.to_string()
1187}
1188
1189#[derive(Debug, Clone)]
1199pub struct TokenConfig {
1200 pub max_tokens: usize,
1202 pub chars_per_token: f32,
1205}
1206
1207impl Default for TokenConfig {
1208 fn default() -> Self {
1209 Self {
1210 max_tokens: 8192, chars_per_token: 3.0, }
1213 }
1214}
1215
1216impl TokenConfig {
1217 pub fn english() -> Self {
1219 Self {
1220 max_tokens: 8192,
1221 chars_per_token: 4.0,
1222 }
1223 }
1224
1225 pub fn for_multilingual_text() -> Self {
1227 Self {
1228 max_tokens: 8192,
1229 chars_per_token: 2.5,
1230 }
1231 }
1232
1233 pub fn with_max_tokens(mut self, max: usize) -> Self {
1235 self.max_tokens = max;
1236 self
1237 }
1238}
1239
1240pub fn estimate_tokens(text: &str, config: &TokenConfig) -> usize {
1245 let char_count = text.chars().count();
1246 (char_count as f32 / config.chars_per_token).ceil() as usize
1247}
1248
1249pub fn validate_chunk_tokens(chunk: &str, config: &TokenConfig) -> Result<()> {
1253 let estimated = estimate_tokens(chunk, config);
1254
1255 if estimated > config.max_tokens {
1256 return Err(anyhow!(
1257 "Chunk exceeds token limit: ~{} tokens > {} max (text: {} chars). \
1258 Consider reducing chunk_size or enabling truncation.",
1259 estimated,
1260 config.max_tokens,
1261 chunk.chars().count()
1262 ));
1263 }
1264
1265 Ok(())
1266}
1267
1268pub fn safe_chunk_size(config: &TokenConfig) -> usize {
1270 let safe_tokens = (config.max_tokens as f32 * 0.8) as usize;
1272 (safe_tokens as f32 * config.chars_per_token) as usize
1273}
1274
1275pub fn truncate_to_token_limit(text: &str, config: &TokenConfig) -> String {
1277 let safe_chars = safe_chunk_size(config);
1278
1279 if text.chars().count() <= safe_chars {
1280 return text.to_string();
1281 }
1282
1283 truncate_at_boundary(text, safe_chars)
1284}
1285
1286pub fn validate_batch_tokens(texts: &[String], config: &TokenConfig) -> Vec<(usize, usize)> {
1288 texts
1289 .iter()
1290 .enumerate()
1291 .filter_map(|(idx, text)| {
1292 let estimated = estimate_tokens(text, config);
1293 if estimated > config.max_tokens {
1294 Some((idx, estimated))
1295 } else {
1296 None
1297 }
1298 })
1299 .collect()
1300}
1301
1302#[cfg(test)]
1303mod tests {
1304 use super::*;
1305 use axum::{Json, Router, extract::State, routing::post};
1306 use serde_json::json;
1307
1308 async fn mock_embeddings(State(dim): State<usize>) -> Json<serde_json::Value> {
1309 Json(json!({
1310 "data": [{
1311 "embedding": vec![0.25_f32; dim]
1312 }]
1313 }))
1314 }
1315
1316 async fn spawn_mock_embedding_server(dim: usize) -> String {
1317 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1318 let addr = listener.local_addr().unwrap();
1319 let app = Router::new()
1320 .route("/v1/embeddings", post(mock_embeddings))
1321 .with_state(dim);
1322
1323 tokio::spawn(async move {
1324 axum::serve(listener, app).await.unwrap();
1325 });
1326
1327 tokio::time::sleep(Duration::from_millis(10)).await;
1328
1329 format!("http://{}", addr)
1330 }
1331
1332 #[test]
1333 fn test_provider_sorting() {
1334 let mut providers = [
1335 ProviderConfig {
1336 name: "low".into(),
1337 base_url: "http://a".into(),
1338 model: "m".into(),
1339 priority: 10,
1340 endpoint: "/v1/embeddings".into(),
1341 },
1342 ProviderConfig {
1343 name: "high".into(),
1344 base_url: "http://b".into(),
1345 model: "m".into(),
1346 priority: 1,
1347 endpoint: "/v1/embeddings".into(),
1348 },
1349 ];
1350 providers.sort_by_key(|p| p.priority);
1351 assert_eq!(providers[0].name, "high");
1352 assert_eq!(providers[1].name, "low");
1353 }
1354
1355 #[test]
1356 fn test_legacy_conversion() {
1357 let legacy = MlxConfig {
1358 disabled: false,
1359 local_port: 12345,
1360 dragon_url: "http://dragon".into(),
1361 dragon_port: 12345,
1362 embedder_model: "test-model".into(),
1363 reranker_model: "rerank-model".into(),
1364 reranker_port_offset: 1,
1365 max_batch_chars: 32000,
1366 max_batch_items: 16,
1367 };
1368 let config = legacy.to_embedding_config();
1369 assert_eq!(config.providers.len(), 2);
1370 assert_eq!(config.providers[0].base_url, "http://localhost:12345");
1371 assert!(config.reranker.base_url.is_some());
1372 assert_eq!(config.max_batch_chars, 32000);
1373 assert_eq!(config.max_batch_items, 16);
1374 }
1375
1376 #[test]
1377 fn test_default_config() {
1378 let config = EmbeddingConfig::default();
1379 assert_eq!(config.required_dimension, DEFAULT_REQUIRED_DIMENSION);
1380 assert_eq!(config.max_batch_chars, 128000); assert_eq!(config.max_batch_items, 64); assert!(!config.providers.is_empty());
1383 assert_eq!(config.providers[0].model, DEFAULT_OLLAMA_EMBEDDING_MODEL);
1384 }
1385
1386 #[test]
1387 fn test_infer_embedding_dimension() {
1388 assert_eq!(
1389 infer_embedding_dimension("qwen3-embedding:0.6b"),
1390 Some(1024)
1391 );
1392 assert_eq!(infer_embedding_dimension("qwen3-embedding:4b"), Some(2560));
1393 assert_eq!(infer_embedding_dimension("qwen3-embedding:8b"), Some(4096));
1394 assert_eq!(
1395 infer_embedding_dimension("MedAIBase/Qwen3-VL-Embedding:2b-q8_0"),
1396 Some(2048)
1397 );
1398 assert_eq!(infer_embedding_dimension("nomic-embed-text"), Some(768));
1399 assert_eq!(infer_embedding_dimension("qwen3-embedding"), None);
1400 assert_eq!(infer_embedding_dimension("unknown-model"), None);
1401 }
1402
1403 #[tokio::test]
1404 async fn test_probe_provider_dimension_reads_actual_dimension() {
1405 let base_url = spawn_mock_embedding_server(2560).await;
1406 let client = Client::new();
1407 let provider = ProviderConfig {
1408 name: "mock".into(),
1409 base_url,
1410 model: "mock-embedder".into(),
1411 priority: 1,
1412 endpoint: "/v1/embeddings".into(),
1413 };
1414
1415 let dim = probe_provider_dimension(&client, &provider).await.unwrap();
1416 assert_eq!(dim, 2560);
1417 }
1418
1419 #[tokio::test]
1420 async fn test_embedding_client_fails_fast_on_dimension_mismatch() {
1421 let base_url = spawn_mock_embedding_server(2560).await;
1422 let config = EmbeddingConfig {
1423 required_dimension: 1024,
1424 providers: vec![ProviderConfig {
1425 name: "mock".into(),
1426 base_url,
1427 model: "mock-embedder".into(),
1428 priority: 1,
1429 endpoint: "/v1/embeddings".into(),
1430 }],
1431 ..EmbeddingConfig::default()
1432 };
1433
1434 let err = EmbeddingClient::new(&config)
1435 .await
1436 .err()
1437 .expect("dimension mismatch should fail")
1438 .to_string();
1439 assert!(err.contains("returned 2560 dims"));
1440 assert!(err.contains("required_dimension=1024"));
1441 }
1442
1443 #[test]
1444 fn test_truncate_at_boundary() {
1445 let text = "Hello world. This is a test.";
1447 let truncated = truncate_at_boundary(text, 15);
1448 assert_eq!(truncated, "Hello world.");
1449
1450 let text = "Hello world this is a test";
1452 let truncated = truncate_at_boundary(text, 15);
1453 assert_eq!(truncated, "Hello world");
1454
1455 let text = "Short text";
1457 let truncated = truncate_at_boundary(text, 100);
1458 assert_eq!(truncated, "Short text");
1459 }
1460
1461 #[test]
1462 fn test_token_estimation() {
1463 let config = TokenConfig::default();
1464
1465 let text = "Hello world"; let tokens = estimate_tokens(text, &config);
1468 assert!((3..=5).contains(&tokens));
1469
1470 let english_config = TokenConfig::english();
1472 let tokens = estimate_tokens(text, &english_config);
1473 assert!((2..=4).contains(&tokens));
1474 }
1475
1476 #[test]
1477 fn test_chunk_validation() {
1478 let config = TokenConfig::default().with_max_tokens(100);
1479
1480 let short = "Hello world";
1482 assert!(validate_chunk_tokens(short, &config).is_ok());
1483
1484 let long = "a".repeat(1000); assert!(validate_chunk_tokens(&long, &config).is_err());
1487 }
1488
1489 #[test]
1490 fn test_safe_chunk_size() {
1491 let config = TokenConfig::default(); let safe = safe_chunk_size(&config);
1494 assert!(safe > 15000 && safe < 25000);
1496 }
1497
1498 #[test]
1499 fn test_batch_validation() {
1500 let config = TokenConfig::default().with_max_tokens(10);
1501
1502 let texts = vec![
1503 "short".to_string(), "a".repeat(100), "also short".to_string(), "b".repeat(200), ];
1508
1509 let failures = validate_batch_tokens(&texts, &config);
1510 assert_eq!(failures.len(), 2);
1511 assert_eq!(failures[0].0, 1); assert_eq!(failures[1].0, 3); }
1514}
1515
1516#[derive(Debug, Clone)]
1538pub struct DimensionAdapter {
1539 pub source_dim: usize,
1541 pub target_dim: usize,
1543}
1544
1545impl DimensionAdapter {
1546 pub fn new(source_dim: usize, target_dim: usize) -> Self {
1548 Self {
1549 source_dim,
1550 target_dim,
1551 }
1552 }
1553
1554 pub fn needs_adaptation(&self) -> bool {
1556 self.source_dim != self.target_dim
1557 }
1558
1559 pub fn adapt(&self, embedding: Vec<f32>) -> Vec<f32> {
1561 if embedding.len() == self.target_dim {
1562 return embedding;
1563 }
1564
1565 if embedding.len() < self.target_dim {
1566 self.expand(embedding)
1567 } else {
1568 self.contract(embedding)
1569 }
1570 }
1571
1572 pub fn expand(&self, embedding: Vec<f32>) -> Vec<f32> {
1576 if embedding.len() >= self.target_dim {
1577 return embedding[..self.target_dim].to_vec();
1578 }
1579
1580 let mut padded = embedding;
1581 padded.resize(self.target_dim, 0.0);
1582
1583 self.normalize(&mut padded);
1585 padded
1586 }
1587
1588 pub fn contract(&self, embedding: Vec<f32>) -> Vec<f32> {
1593 if embedding.len() <= self.target_dim {
1594 return embedding;
1595 }
1596
1597 if self.is_power_of_two_reduction(embedding.len()) {
1600 self.average_reduction(embedding)
1601 } else {
1602 embedding[..self.target_dim].to_vec()
1604 }
1605 }
1606
1607 fn is_power_of_two_reduction(&self, source_len: usize) -> bool {
1609 source_len > self.target_dim
1610 && source_len.is_power_of_two()
1611 && self.target_dim.is_power_of_two()
1612 && source_len.is_multiple_of(self.target_dim)
1613 }
1614
1615 fn average_reduction(&self, embedding: Vec<f32>) -> Vec<f32> {
1617 let factor = embedding.len() / self.target_dim;
1618 let mut result = Vec::with_capacity(self.target_dim);
1619
1620 for chunk in embedding.chunks(factor) {
1621 let sum: f32 = chunk.iter().sum();
1622 result.push(sum / factor as f32);
1623 }
1624
1625 self.normalize(&mut result);
1627 result
1628 }
1629
1630 fn normalize(&self, vec: &mut [f32]) {
1632 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
1633 if norm > 1e-10 {
1634 for v in vec.iter_mut() {
1635 *v /= norm;
1636 }
1637 }
1638 }
1639}
1640
1641pub fn cross_dimension_search_adapt(query_embedding: Vec<f32>, target_dim: usize) -> Vec<f32> {
1643 let adapter = DimensionAdapter::new(query_embedding.len(), target_dim);
1644 adapter.adapt(query_embedding)
1645}
1646
1647#[cfg(test)]
1648mod dimension_adapter_tests {
1649 use super::*;
1650
1651 #[test]
1652 fn test_expand_1024_to_4096() {
1653 let adapter = DimensionAdapter::new(1024, 4096);
1654 let small = vec![0.1f32; 1024];
1655 let expanded = adapter.expand(small);
1656
1657 assert_eq!(expanded.len(), 4096);
1658 assert!(expanded[0].abs() > 1e-10);
1660 assert!(expanded[4095].abs() < 1e-10);
1662 }
1663
1664 #[test]
1665 fn test_contract_4096_to_1024() {
1666 let adapter = DimensionAdapter::new(4096, 1024);
1667 let large = vec![0.1f32; 4096];
1668 let contracted = adapter.contract(large);
1669
1670 assert_eq!(contracted.len(), 1024);
1671 let norm: f32 = contracted.iter().map(|x| x * x).sum::<f32>().sqrt();
1673 assert!((norm - 1.0).abs() < 1e-5);
1674 }
1675
1676 #[test]
1677 fn test_adapt_auto_detect() {
1678 let adapter = DimensionAdapter::new(1024, 4096);
1679
1680 let small = vec![0.1f32; 1024];
1682 let result = adapter.adapt(small);
1683 assert_eq!(result.len(), 4096);
1684
1685 let adapter = DimensionAdapter::new(4096, 1024);
1687 let large = vec![0.1f32; 4096];
1688 let result = adapter.adapt(large);
1689 assert_eq!(result.len(), 1024);
1690 }
1691
1692 #[test]
1693 fn test_no_adaptation_needed() {
1694 let adapter = DimensionAdapter::new(4096, 4096);
1695 assert!(!adapter.needs_adaptation());
1696
1697 let embedding = vec![0.1f32; 4096];
1698 let result = adapter.adapt(embedding.clone());
1699 assert_eq!(result, embedding);
1700 }
1701
1702 #[test]
1703 fn test_average_reduction_preserves_info() {
1704 let adapter = DimensionAdapter::new(4096, 2048);
1705
1706 let large: Vec<f32> = (0..4096).map(|i| i as f32 / 4096.0).collect();
1708 let contracted = adapter.contract(large);
1709
1710 assert_eq!(contracted.len(), 2048);
1711 let norm: f32 = contracted.iter().map(|x| x * x).sum::<f32>().sqrt();
1714 assert!((norm - 1.0).abs() < 1e-5);
1715 }
1716}