1use std::collections::HashMap;
35use std::path::Path;
36use std::sync::Arc;
37use std::time::Duration;
38
39use figment::{
40 Figment,
41 providers::{Format, Toml},
42};
43use serde::{Deserialize, Serialize};
44use serde_json::{Value, json};
45use thiserror::Error;
46
47const PROVIDER_TIMEOUT: Duration = Duration::from_secs(30);
50
51#[derive(Debug, Error)]
53pub enum ProviderError {
54 #[error("embedding provider request failed: {0}")]
56 Http(String),
57 #[error("embedding provider returned a malformed response: {0}")]
59 Parse(String),
60 #[error("api key environment variable {0} is not set")]
62 MissingKey(String),
63 #[error("invalid embedding configuration: {0}")]
66 Config(String),
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
73#[serde(rename_all = "snake_case")]
74pub enum ProviderKind {
75 Openai,
77 Ollama,
79 Http,
81 Cohere,
83 Fake,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct EmbeddingConfig {
91 pub provider: ProviderKind,
93 #[serde(default)]
95 pub model: String,
96 #[serde(default)]
99 pub endpoint: String,
100 pub dim: u32,
102 #[serde(default)]
105 pub api_key_env: String,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct RerankConfig {
111 pub provider: ProviderKind,
113 #[serde(default)]
115 pub model: String,
116 #[serde(default)]
118 pub endpoint: String,
119 #[serde(default)]
121 pub api_key_env: String,
122}
123
124pub trait EmbeddingProvider: Send + Sync {
126 fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, ProviderError>;
128 fn dim(&self) -> usize;
130}
131
132pub trait RerankProvider: Send + Sync {
134 fn rerank(&self, query: &str, docs: &[String]) -> Result<Vec<f32>, ProviderError>;
136}
137
138pub struct FakeEmbedder {
146 dim: usize,
147}
148
149impl FakeEmbedder {
150 #[must_use]
152 pub fn new(dim: usize) -> Self {
153 Self { dim }
154 }
155}
156
157fn fnv1a(bytes: &[u8]) -> u64 {
160 let mut h: u64 = 0xcbf2_9ce4_8422_2325;
161 for &b in bytes {
162 h ^= u64::from(b);
163 h = h.wrapping_mul(0x0000_0100_0000_01b3);
164 }
165 h
166}
167
168impl EmbeddingProvider for FakeEmbedder {
169 fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, ProviderError> {
170 Ok(texts
171 .iter()
172 .map(|t| {
173 (0..self.dim)
177 .map(|i| {
178 let h = fnv1a(format!("{t}:{i}").as_bytes());
179 (h >> 40) as f32 / f32::from(1u16 << 11) - 1.0
181 })
182 .collect()
183 })
184 .collect())
185 }
186 fn dim(&self) -> usize {
187 self.dim
188 }
189}
190
191pub struct FakeReranker;
194
195impl RerankProvider for FakeReranker {
196 fn rerank(&self, query: &str, docs: &[String]) -> Result<Vec<f32>, ProviderError> {
197 let q: std::collections::HashSet<String> =
198 query.split_whitespace().map(|w| w.to_lowercase()).collect();
199 Ok(docs
200 .iter()
201 .map(|d| {
202 let overlap = d
203 .split_whitespace()
204 .filter(|w| q.contains(&w.to_lowercase()))
205 .count();
206 overlap as f32
207 })
208 .collect())
209 }
210}
211
212pub struct OpenAiCompatEmbedder {
218 url: String,
219 model: String,
220 api_key: Option<String>,
221 dim: usize,
222}
223
224fn openai_body(model: &str, texts: &[String]) -> Value {
226 json!({ "model": model, "input": texts })
227}
228
229fn parse_openai(body: &Value) -> Result<Vec<Vec<f32>>, ProviderError> {
231 let data = body
232 .get("data")
233 .and_then(Value::as_array)
234 .ok_or_else(|| ProviderError::Parse("missing `data` array".into()))?;
235 data.iter()
236 .map(|row| {
237 row.get("embedding")
238 .and_then(Value::as_array)
239 .ok_or_else(|| ProviderError::Parse("a `data` row had no `embedding` array".into()))
240 .map(|arr| {
241 arr.iter()
242 .filter_map(|v| v.as_f64().map(|f| f as f32))
243 .collect()
244 })
245 })
246 .collect()
247}
248
249impl EmbeddingProvider for OpenAiCompatEmbedder {
250 fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, ProviderError> {
251 let mut req = ureq::post(&self.url).timeout(PROVIDER_TIMEOUT);
254 if let Some(key) = &self.api_key {
255 req = req.set("Authorization", &format!("Bearer {key}"));
256 }
257 let resp = req
258 .send_json(openai_body(&self.model, texts))
259 .map_err(|e| ProviderError::Http(e.to_string()))?;
260 let body: Value = resp
261 .into_json()
262 .map_err(|e| ProviderError::Parse(e.to_string()))?;
263 let vectors = parse_openai(&body)?;
264 check_dims(&vectors, self.dim)?;
265 Ok(vectors)
266 }
267 fn dim(&self) -> usize {
268 self.dim
269 }
270}
271
272pub struct CohereEmbedder {
278 url: String,
279 model: String,
280 api_key: String,
281 dim: usize,
282}
283
284fn cohere_embed_body(model: &str, texts: &[String]) -> Value {
286 json!({
287 "model": model,
288 "texts": texts,
289 "input_type": "search_document",
290 "embedding_types": ["float"],
291 })
292}
293
294fn parse_cohere_embed(body: &Value) -> Result<Vec<Vec<f32>>, ProviderError> {
296 let floats = body
297 .get("embeddings")
298 .and_then(|e| e.get("float"))
299 .and_then(Value::as_array)
300 .ok_or_else(|| ProviderError::Parse("missing `embeddings.float` array".into()))?;
301 Ok(floats
302 .iter()
303 .map(|row| {
304 row.as_array()
305 .map(|arr| {
306 arr.iter()
307 .filter_map(|v| v.as_f64().map(|f| f as f32))
308 .collect()
309 })
310 .unwrap_or_default()
311 })
312 .collect())
313}
314
315impl EmbeddingProvider for CohereEmbedder {
316 fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, ProviderError> {
317 let resp = ureq::post(&self.url)
318 .timeout(PROVIDER_TIMEOUT)
319 .set("Authorization", &format!("Bearer {}", self.api_key))
320 .send_json(cohere_embed_body(&self.model, texts))
321 .map_err(|e| ProviderError::Http(e.to_string()))?;
322 let body: Value = resp
323 .into_json()
324 .map_err(|e| ProviderError::Parse(e.to_string()))?;
325 let vectors = parse_cohere_embed(&body)?;
326 check_dims(&vectors, self.dim)?;
327 Ok(vectors)
328 }
329 fn dim(&self) -> usize {
330 self.dim
331 }
332}
333
334pub struct CohereReranker {
336 url: String,
337 model: String,
338 api_key: String,
339}
340
341fn cohere_rerank_body(model: &str, query: &str, docs: &[String]) -> Value {
343 json!({ "model": model, "query": query, "documents": docs })
344}
345
346fn parse_cohere_rerank(body: &Value, n_docs: usize) -> Result<Vec<f32>, ProviderError> {
350 let results = body
351 .get("results")
352 .and_then(Value::as_array)
353 .ok_or_else(|| ProviderError::Parse("missing `results` array".into()))?;
354 let mut scores = vec![0.0_f32; n_docs];
355 for r in results {
356 let idx = r
357 .get("index")
358 .and_then(Value::as_u64)
359 .ok_or_else(|| ProviderError::Parse("a result had no `index`".into()))?
360 as usize;
361 let score = r
362 .get("relevance_score")
363 .and_then(Value::as_f64)
364 .ok_or_else(|| ProviderError::Parse("a result had no `relevance_score`".into()))?
365 as f32;
366 if idx < n_docs {
367 scores[idx] = score;
368 }
369 }
370 Ok(scores)
371}
372
373impl RerankProvider for CohereReranker {
374 fn rerank(&self, query: &str, docs: &[String]) -> Result<Vec<f32>, ProviderError> {
375 let resp = ureq::post(&self.url)
376 .timeout(PROVIDER_TIMEOUT)
377 .set("Authorization", &format!("Bearer {}", self.api_key))
378 .send_json(cohere_rerank_body(&self.model, query, docs))
379 .map_err(|e| ProviderError::Http(e.to_string()))?;
380 let body: Value = resp
381 .into_json()
382 .map_err(|e| ProviderError::Parse(e.to_string()))?;
383 parse_cohere_rerank(&body, docs.len())
384 }
385}
386
387fn check_dims(vectors: &[Vec<f32>], dim: usize) -> Result<(), ProviderError> {
390 if let Some(bad) = vectors.iter().find(|v| v.len() != dim) {
391 return Err(ProviderError::Parse(format!(
392 "provider returned a {}-dim vector but the collection expects {dim}",
393 bad.len()
394 )));
395 }
396 Ok(())
397}
398
399const OPENAI_DEFAULT: &str = "https://api.openai.com/v1/embeddings";
405const COHERE_EMBED_DEFAULT: &str = "https://api.cohere.com/v2/embed";
406const COHERE_RERANK_DEFAULT: &str = "https://api.cohere.com/v2/rerank";
407
408fn resolve_key(api_key_env: &str) -> Result<Option<String>, ProviderError> {
410 if api_key_env.is_empty() {
411 return Ok(None);
412 }
413 std::env::var(api_key_env)
414 .map(Some)
415 .map_err(|_| ProviderError::MissingKey(api_key_env.to_owned()))
416}
417
418#[derive(Debug, Default, Deserialize)]
421struct ProviderTables {
422 #[serde(default)]
423 embedding: HashMap<String, EmbeddingConfig>,
424 #[serde(default)]
425 rerank: HashMap<String, RerankConfig>,
426}
427
428#[derive(Clone, Default)]
430pub struct EmbedRegistry {
431 embedders: HashMap<String, Arc<dyn EmbeddingProvider>>,
432 rerankers: HashMap<String, Arc<dyn RerankProvider>>,
433}
434
435impl EmbedRegistry {
436 pub fn from_config(
440 embedding: &HashMap<String, EmbeddingConfig>,
441 rerank: &HashMap<String, RerankConfig>,
442 ) -> Result<Self, ProviderError> {
443 let mut embedders: HashMap<String, Arc<dyn EmbeddingProvider>> = HashMap::new();
444 for (collection, cfg) in embedding {
445 embedders.insert(collection.clone(), build_embedder(cfg)?);
446 }
447 let mut rerankers: HashMap<String, Arc<dyn RerankProvider>> = HashMap::new();
448 for (collection, cfg) in rerank {
449 rerankers.insert(collection.clone(), build_reranker(cfg)?);
450 }
451 Ok(Self {
452 embedders,
453 rerankers,
454 })
455 }
456
457 pub fn from_toml_path(path: &Path) -> Result<Self, ProviderError> {
467 let tables: ProviderTables = Figment::from(Toml::file(path))
468 .extract()
469 .map_err(|e| ProviderError::Config(e.to_string()))?;
470 Self::from_config(&tables.embedding, &tables.rerank)
471 }
472
473 #[must_use]
475 pub fn embedder(&self, collection: &str) -> Option<Arc<dyn EmbeddingProvider>> {
476 self.embedders.get(collection).cloned()
477 }
478
479 #[must_use]
481 pub fn reranker(&self, collection: &str) -> Option<Arc<dyn RerankProvider>> {
482 self.rerankers.get(collection).cloned()
483 }
484
485 #[must_use]
488 pub fn is_empty(&self) -> bool {
489 self.embedders.is_empty() && self.rerankers.is_empty()
490 }
491}
492
493fn build_embedder(cfg: &EmbeddingConfig) -> Result<Arc<dyn EmbeddingProvider>, ProviderError> {
495 let dim = cfg.dim as usize;
496 match cfg.provider {
497 ProviderKind::Fake => Ok(Arc::new(FakeEmbedder::new(dim))),
498 ProviderKind::Openai => Ok(Arc::new(OpenAiCompatEmbedder {
499 url: if cfg.endpoint.is_empty() {
500 OPENAI_DEFAULT.to_owned()
501 } else {
502 cfg.endpoint.clone()
503 },
504 model: cfg.model.clone(),
505 api_key: resolve_key(&cfg.api_key_env)?,
506 dim,
507 })),
508 ProviderKind::Ollama | ProviderKind::Http => {
509 if cfg.endpoint.is_empty() {
510 return Err(ProviderError::Config(format!(
511 "provider {:?} requires an `endpoint` (e.g. http://localhost:11434/v1/embeddings)",
512 cfg.provider
513 )));
514 }
515 Ok(Arc::new(OpenAiCompatEmbedder {
516 url: cfg.endpoint.clone(),
517 model: cfg.model.clone(),
518 api_key: resolve_key(&cfg.api_key_env)?,
519 dim,
520 }))
521 }
522 ProviderKind::Cohere => Ok(Arc::new(CohereEmbedder {
523 url: if cfg.endpoint.is_empty() {
524 COHERE_EMBED_DEFAULT.to_owned()
525 } else {
526 cfg.endpoint.clone()
527 },
528 model: cfg.model.clone(),
529 api_key: resolve_key(&cfg.api_key_env)?.ok_or_else(|| {
530 ProviderError::Config("cohere embedding requires api_key_env".into())
531 })?,
532 dim,
533 })),
534 }
535}
536
537fn build_reranker(cfg: &RerankConfig) -> Result<Arc<dyn RerankProvider>, ProviderError> {
539 match cfg.provider {
540 ProviderKind::Fake => Ok(Arc::new(FakeReranker)),
541 ProviderKind::Cohere => Ok(Arc::new(CohereReranker {
542 url: if cfg.endpoint.is_empty() {
543 COHERE_RERANK_DEFAULT.to_owned()
544 } else {
545 cfg.endpoint.clone()
546 },
547 model: cfg.model.clone(),
548 api_key: resolve_key(&cfg.api_key_env)?.ok_or_else(|| {
549 ProviderError::Config("cohere rerank requires api_key_env".into())
550 })?,
551 })),
552 other => Err(ProviderError::Config(format!(
553 "rerank provider {other:?} is not supported (use `cohere` or `fake`)"
554 ))),
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561
562 #[test]
563 fn fake_embedder_is_deterministic_and_content_dependent() {
564 let e = FakeEmbedder::new(8);
565 let a = e.embed(&["hello world".into()]).unwrap();
566 let b = e.embed(&["hello world".into()]).unwrap();
567 let c = e.embed(&["different text".into()]).unwrap();
568 assert_eq!(a.len(), 1);
569 assert_eq!(a[0].len(), 8);
570 assert_eq!(a, b, "identical text → identical vector");
571 assert_ne!(a, c, "different text → different vector");
572 assert_eq!(e.dim(), 8);
573 }
574
575 #[test]
576 fn fake_embedder_batches_in_order() {
577 let e = FakeEmbedder::new(4);
578 let batch = e.embed(&["a".into(), "b".into()]).unwrap();
579 let a = e.embed(&["a".into()]).unwrap();
580 let b = e.embed(&["b".into()]).unwrap();
581 assert_eq!(batch[0], a[0]);
582 assert_eq!(batch[1], b[0]);
583 }
584
585 #[test]
586 fn fake_reranker_scores_by_overlap() {
587 let r = FakeReranker;
588 let scores = r
589 .rerank(
590 "quick brown fox",
591 &[
592 "the quick brown fox".into(),
593 "lazy dog".into(),
594 "fox".into(),
595 ],
596 )
597 .unwrap();
598 assert_eq!(scores, vec![3.0, 0.0, 1.0]);
599 }
600
601 #[test]
602 fn openai_body_and_parse_roundtrip() {
603 let body = openai_body("text-embedding-3-small", &["hi".into(), "yo".into()]);
604 assert_eq!(body["model"], "text-embedding-3-small");
605 assert_eq!(body["input"][1], "yo");
606 let resp = json!({"data":[{"embedding":[0.1,0.2]},{"embedding":[0.3,0.4]}]});
607 let vecs = parse_openai(&resp).unwrap();
608 assert_eq!(vecs, vec![vec![0.1_f32, 0.2], vec![0.3, 0.4]]);
609 }
610
611 #[test]
612 fn parse_openai_rejects_malformed() {
613 assert!(parse_openai(&json!({"oops": 1})).is_err());
614 assert!(parse_openai(&json!({"data":[{"no_embedding": 1}]})).is_err());
615 }
616
617 #[test]
618 fn cohere_embed_body_and_parse_roundtrip() {
619 let body = cohere_embed_body("embed-v4.0", &["doc".into()]);
620 assert_eq!(body["model"], "embed-v4.0");
621 assert_eq!(body["input_type"], "search_document");
622 assert_eq!(body["texts"][0], "doc");
623 let resp = json!({"embeddings":{"float":[[1.0,2.0,3.0]]}});
624 assert_eq!(
625 parse_cohere_embed(&resp).unwrap(),
626 vec![vec![1.0_f32, 2.0, 3.0]]
627 );
628 assert!(parse_cohere_embed(&json!({"embeddings":{}})).is_err());
629 }
630
631 #[test]
632 fn cohere_rerank_scatters_by_index() {
633 let body = cohere_rerank_body("rerank-v3.5", "q", &["a".into(), "b".into()]);
634 assert_eq!(body["query"], "q");
635 let resp = json!({"results":[
637 {"index":1,"relevance_score":0.9},
638 {"index":0,"relevance_score":0.1},
639 ]});
640 assert_eq!(parse_cohere_rerank(&resp, 2).unwrap(), vec![0.1_f32, 0.9]);
641 assert_eq!(
643 parse_cohere_rerank(&json!({"results":[{"index":9,"relevance_score":1.0}]}), 2)
644 .unwrap(),
645 vec![0.0_f32, 0.0]
646 );
647 assert!(parse_cohere_rerank(&json!({"nope":1}), 2).is_err());
648 assert!(parse_cohere_rerank(&json!({"results":[{"index":0}]}), 1).is_err());
649 }
650
651 #[test]
652 fn check_dims_enforces_collection_dim() {
653 assert!(check_dims(&[vec![1.0, 2.0]], 2).is_ok());
654 assert!(check_dims(&[vec![1.0, 2.0, 3.0]], 2).is_err());
655 }
656
657 #[test]
658 fn registry_builds_fake_and_resolves_emptiness() {
659 let mut embedding = HashMap::new();
660 embedding.insert(
661 "docs".to_owned(),
662 EmbeddingConfig {
663 provider: ProviderKind::Fake,
664 model: String::new(),
665 endpoint: String::new(),
666 dim: 16,
667 api_key_env: String::new(),
668 },
669 );
670 let mut rerank = HashMap::new();
671 rerank.insert(
672 "docs".to_owned(),
673 RerankConfig {
674 provider: ProviderKind::Fake,
675 model: String::new(),
676 endpoint: String::new(),
677 api_key_env: String::new(),
678 },
679 );
680 let reg = EmbedRegistry::from_config(&embedding, &rerank).unwrap();
681 assert!(!reg.is_empty());
682 assert_eq!(reg.embedder("docs").unwrap().dim(), 16);
683 assert!(reg.embedder("missing").is_none());
684 assert!(reg.reranker("docs").is_some());
685 assert!(EmbedRegistry::default().is_empty());
686 }
687
688 #[test]
689 fn from_toml_path_loads_embedding_and_rerank_tables() {
690 use std::io::Write;
691 let dir = tempfile::tempdir().unwrap();
692 let path = dir.path().join("quiver.toml");
693 let mut f = std::fs::File::create(&path).unwrap();
694 writeln!(
697 f,
698 r#"
699[server]
700host = "127.0.0.1"
701
702[embedding.docs]
703provider = "fake"
704dim = 16
705
706[rerank.docs]
707provider = "fake"
708"#
709 )
710 .unwrap();
711 let reg = EmbedRegistry::from_toml_path(&path).unwrap();
712 assert_eq!(reg.embedder("docs").unwrap().dim(), 16);
713 assert!(reg.reranker("docs").is_some());
714 assert!(reg.embedder("missing").is_none());
715 }
716
717 #[test]
718 fn from_toml_path_missing_file_is_empty_not_an_error() {
719 let reg = EmbedRegistry::from_toml_path(Path::new("definitely-not-here.toml")).unwrap();
720 assert!(reg.is_empty());
721 }
722
723 #[test]
724 fn from_toml_path_propagates_a_misconfigured_provider() {
725 use std::io::Write;
726 let dir = tempfile::tempdir().unwrap();
727 let path = dir.path().join("quiver.toml");
728 let mut f = std::fs::File::create(&path).unwrap();
729 writeln!(
731 f,
732 r#"
733[embedding.docs]
734provider = "http"
735dim = 8
736"#
737 )
738 .unwrap();
739 assert!(matches!(
740 EmbedRegistry::from_toml_path(&path),
741 Err(ProviderError::Config(_))
742 ));
743 }
744
745 #[test]
746 fn http_provider_requires_endpoint() {
747 let cfg = EmbeddingConfig {
748 provider: ProviderKind::Http,
749 model: "m".into(),
750 endpoint: String::new(),
751 dim: 4,
752 api_key_env: String::new(),
753 };
754 assert!(matches!(
755 build_embedder(&cfg),
756 Err(ProviderError::Config(_))
757 ));
758 }
759
760 #[test]
761 fn missing_api_key_is_a_hard_error() {
762 let cfg = EmbeddingConfig {
763 provider: ProviderKind::Openai,
764 model: "m".into(),
765 endpoint: String::new(),
766 dim: 4,
767 api_key_env: "QUIVER_TEST_DEFINITELY_UNSET_KEY".into(),
768 };
769 assert!(matches!(
770 build_embedder(&cfg),
771 Err(ProviderError::MissingKey(_))
772 ));
773 }
774
775 #[test]
776 fn openai_endpoint_defaults_and_overrides() {
777 let def = EmbeddingConfig {
779 provider: ProviderKind::Openai,
780 model: "m".into(),
781 endpoint: String::new(),
782 dim: 4,
783 api_key_env: String::new(),
784 };
785 assert!(build_embedder(&def).is_ok());
786 let cohere = EmbeddingConfig {
788 provider: ProviderKind::Cohere,
789 model: "m".into(),
790 endpoint: String::new(),
791 dim: 4,
792 api_key_env: String::new(),
793 };
794 assert!(matches!(
795 build_embedder(&cohere),
796 Err(ProviderError::Config(_))
797 ));
798 let rr = RerankConfig {
799 provider: ProviderKind::Openai,
800 model: "m".into(),
801 endpoint: String::new(),
802 api_key_env: String::new(),
803 };
804 assert!(matches!(build_reranker(&rr), Err(ProviderError::Config(_))));
805 }
806}