Skip to main content

quiver_providers/
lib.rs

1// SPDX-License-Identifier: AGPL-3.0-only
2//! Opt-in, provider-agnostic embedding & reranking adapters (ADR-0047/0058).
3//!
4//! The Quiver engine is deliberately model-agnostic: it stores and searches
5//! float vectors and knows nothing about embedding models. This crate is the
6//! **edge** adapter that lets an operator turn *"give me text"* into a
7//! stored/searched vector without the client running an embedding model — the
8//! single biggest RAG friction. It lives in its own lean crate (no axum/tonic)
9//! so it can be shared by both the network server (`quiver-server`) and the
10//! in-process MCP server (`quiver-mcp`) without either pulling the other's
11//! dependency tree (ADR-0058); it is never used by `quiver-core` or the
12//! `quiver-embed` engine crate, so library-mode users pay nothing.
13//!
14//! ## Design (ADR-0047)
15//! - **Provider-agnostic.** [`EmbeddingProvider`] / [`RerankProvider`] are traits;
16//!   OpenAI-compatible servers (OpenAI, Ollama's `/v1` endpoint, vLLM, LM Studio,
17//!   llama.cpp, …) share one HTTP adapter parameterized by base URL + auth, Cohere
18//!   has its own shape, and a deterministic [`FakeEmbedder`]/[`FakeReranker`] backs
19//!   tests and the acceptance script. No vendor is hard-coded; selection is config.
20//! - **Opt-in, per collection, default off.** Configured in the **server config**
21//!   (`[embedding.<collection>]` / `[rerank.<collection>]`), not the on-disk
22//!   descriptor — so the engine and the crash gate are untouched.
23//! - **No secrets on disk.** Config stores the *name* of an environment variable
24//!   ([`EmbeddingConfig::api_key_env`]); the value is resolved at registry-build
25//!   time and never persisted.
26//!
27//! ## Testing honesty
28//! The pure request-build and response-parse functions are unit-tested, and the
29//! `fake` provider exercises the full text-in/text-out path. The methods that make
30//! a live HTTP call ([`OpenAiCompatEmbedder::embed`], [`CohereEmbedder::embed`],
31//! [`CohereReranker::rerank`]) are thin shells around those tested helpers and a
32//! `ureq` call; live network calls are **not** in CI (stated, not faked).
33
34use std::collections::HashMap;
35use std::path::Path;
36use std::sync::Arc;
37use std::time::Duration;
38
39use figment::{
40    Figment,
41    providers::{Format, Toml},
42};
43use serde::{Deserialize, Serialize};
44use serde_json::{Value, json};
45use thiserror::Error;
46
47/// A timeout for any single provider HTTP call. Embedding/reranking is a
48/// best-effort convenience; a slow provider must not pin a server thread forever.
49const PROVIDER_TIMEOUT: Duration = Duration::from_secs(30);
50
51/// An error from a provider call or its configuration.
52#[derive(Debug, Error)]
53pub enum ProviderError {
54    /// The HTTP request to the provider failed (transport, status, or timeout).
55    #[error("embedding provider request failed: {0}")]
56    Http(String),
57    /// The provider returned a body we could not parse into embeddings/scores.
58    #[error("embedding provider returned a malformed response: {0}")]
59    Parse(String),
60    /// A configured `api_key_env` variable is not set in the environment.
61    #[error("api key environment variable {0} is not set")]
62    MissingKey(String),
63    /// The configuration named a provider/endpoint combination we cannot build
64    /// (e.g. `http`/`ollama` without an `endpoint`).
65    #[error("invalid embedding configuration: {0}")]
66    Config(String),
67}
68
69/// Which provider backs a collection's embedding (or rerank). The OpenAI-compatible
70/// trio (`openai`, `ollama`, `http`) share one adapter; `cohere` is its own; `fake`
71/// is deterministic and for tests/acceptance only.
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
73#[serde(rename_all = "snake_case")]
74pub enum ProviderKind {
75    /// OpenAI `/v1/embeddings` (Bearer auth).
76    Openai,
77    /// An Ollama server's OpenAI-compatible `/v1/embeddings` (usually no auth).
78    Ollama,
79    /// Any OpenAI-compatible server at an explicit `endpoint` (vLLM, LM Studio, …).
80    Http,
81    /// Cohere `/v2/embed` and `/v2/rerank`.
82    Cohere,
83    /// A deterministic, network-free hash embedder/reranker (tests/acceptance).
84    Fake,
85}
86
87/// A collection's embedding configuration (server config table
88/// `[embedding.<collection>]`). Secrets are referenced by env-var *name* only.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct EmbeddingConfig {
91    /// The provider backing this collection.
92    pub provider: ProviderKind,
93    /// The model id passed to the provider (ignored by `fake`).
94    #[serde(default)]
95    pub model: String,
96    /// Base URL override; required for `http`/`ollama`, optional for `openai`/`cohere`
97    /// (which default to their public endpoints).
98    #[serde(default)]
99    pub endpoint: String,
100    /// The collection's vector dimension; the embedder must return this many floats.
101    pub dim: u32,
102    /// The *name* of the environment variable holding the API key (resolved at
103    /// call time; never persisted). Empty ⇒ no auth header (e.g. local Ollama).
104    #[serde(default)]
105    pub api_key_env: String,
106}
107
108/// A collection's rerank configuration (server config table `[rerank.<collection>]`).
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct RerankConfig {
111    /// The provider backing rerank (`cohere` or `fake`).
112    pub provider: ProviderKind,
113    /// The rerank model id (ignored by `fake`).
114    #[serde(default)]
115    pub model: String,
116    /// Base URL override (defaults to the provider's public endpoint).
117    #[serde(default)]
118    pub endpoint: String,
119    /// The *name* of the environment variable holding the API key.
120    #[serde(default)]
121    pub api_key_env: String,
122}
123
124/// Embeds a batch of texts into dense vectors (one per input).
125pub trait EmbeddingProvider: Send + Sync {
126    /// Embed `texts`, returning one `dim`-length vector per input, in order.
127    fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, ProviderError>;
128    /// The dimensionality every returned vector must have.
129    fn dim(&self) -> usize;
130}
131
132/// Scores `(query, document)` pairs for relevance; higher is more relevant.
133pub trait RerankProvider: Send + Sync {
134    /// Return one relevance score per `doc`, in the input order.
135    fn rerank(&self, query: &str, docs: &[String]) -> Result<Vec<f32>, ProviderError>;
136}
137
138// ---------------------------------------------------------------------------
139// Fake provider (deterministic, network-free) — tests & the acceptance script.
140// ---------------------------------------------------------------------------
141
142/// A deterministic embedder that hashes text into a unit-ish vector. Never used
143/// in production config paths beyond the explicit `fake` selection; it exists so
144/// the whole text-in/text-out path is testable without a network.
145pub struct FakeEmbedder {
146    dim: usize,
147}
148
149impl FakeEmbedder {
150    /// A fake embedder producing `dim`-length vectors.
151    #[must_use]
152    pub fn new(dim: usize) -> Self {
153        Self { dim }
154    }
155}
156
157/// FNV-1a over bytes — the same stable hash the tokenizer uses, kept local so this
158/// module has no cross-crate coupling for a one-liner.
159fn fnv1a(bytes: &[u8]) -> u64 {
160    let mut h: u64 = 0xcbf2_9ce4_8422_2325;
161    for &b in bytes {
162        h ^= u64::from(b);
163        h = h.wrapping_mul(0x0000_0100_0000_01b3);
164    }
165    h
166}
167
168impl EmbeddingProvider for FakeEmbedder {
169    fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, ProviderError> {
170        Ok(texts
171            .iter()
172            .map(|t| {
173                // Per-dimension hash of (text, i) → a stable, content-dependent
174                // vector. Deterministic so tests can assert identical-text →
175                // identical-vector and different-text → different-vector.
176                (0..self.dim)
177                    .map(|i| {
178                        let h = fnv1a(format!("{t}:{i}").as_bytes());
179                        // Map the high bits into [-1, 1).
180                        (h >> 40) as f32 / f32::from(1u16 << 11) - 1.0
181                    })
182                    .collect()
183            })
184            .collect())
185    }
186    fn dim(&self) -> usize {
187        self.dim
188    }
189}
190
191/// A deterministic reranker scoring documents by lexical token overlap with the
192/// query. Network-free; backs tests and the acceptance script.
193pub struct FakeReranker;
194
195impl RerankProvider for FakeReranker {
196    fn rerank(&self, query: &str, docs: &[String]) -> Result<Vec<f32>, ProviderError> {
197        let q: std::collections::HashSet<String> =
198            query.split_whitespace().map(|w| w.to_lowercase()).collect();
199        Ok(docs
200            .iter()
201            .map(|d| {
202                let overlap = d
203                    .split_whitespace()
204                    .filter(|w| q.contains(&w.to_lowercase()))
205                    .count();
206                overlap as f32
207            })
208            .collect())
209    }
210}
211
212// ---------------------------------------------------------------------------
213// OpenAI-compatible embedder (OpenAI / Ollama / any OpenAI-shaped server).
214// ---------------------------------------------------------------------------
215
216/// An embedder that speaks the OpenAI `/v1/embeddings` request/response shape.
217pub struct OpenAiCompatEmbedder {
218    url: String,
219    model: String,
220    api_key: Option<String>,
221    dim: usize,
222}
223
224/// Build the OpenAI `/v1/embeddings` request body. Pure (unit-tested).
225fn openai_body(model: &str, texts: &[String]) -> Value {
226    json!({ "model": model, "input": texts })
227}
228
229/// Parse an OpenAI `/v1/embeddings` response into vectors (order preserved). Pure.
230fn parse_openai(body: &Value) -> Result<Vec<Vec<f32>>, ProviderError> {
231    let data = body
232        .get("data")
233        .and_then(Value::as_array)
234        .ok_or_else(|| ProviderError::Parse("missing `data` array".into()))?;
235    data.iter()
236        .map(|row| {
237            row.get("embedding")
238                .and_then(Value::as_array)
239                .ok_or_else(|| ProviderError::Parse("a `data` row had no `embedding` array".into()))
240                .map(|arr| {
241                    arr.iter()
242                        .filter_map(|v| v.as_f64().map(|f| f as f32))
243                        .collect()
244                })
245        })
246        .collect()
247}
248
249impl EmbeddingProvider for OpenAiCompatEmbedder {
250    fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, ProviderError> {
251        // Live HTTP shell around the tested `openai_body` / `parse_openai` helpers
252        // (not exercised in CI; the `fake` provider covers the path).
253        let mut req = ureq::post(&self.url).timeout(PROVIDER_TIMEOUT);
254        if let Some(key) = &self.api_key {
255            req = req.set("Authorization", &format!("Bearer {key}"));
256        }
257        let resp = req
258            .send_json(openai_body(&self.model, texts))
259            .map_err(|e| ProviderError::Http(e.to_string()))?;
260        let body: Value = resp
261            .into_json()
262            .map_err(|e| ProviderError::Parse(e.to_string()))?;
263        let vectors = parse_openai(&body)?;
264        check_dims(&vectors, self.dim)?;
265        Ok(vectors)
266    }
267    fn dim(&self) -> usize {
268        self.dim
269    }
270}
271
272// ---------------------------------------------------------------------------
273// Cohere embedder + reranker.
274// ---------------------------------------------------------------------------
275
276/// An embedder that speaks Cohere `/v2/embed`.
277pub struct CohereEmbedder {
278    url: String,
279    model: String,
280    api_key: String,
281    dim: usize,
282}
283
284/// Build the Cohere `/v2/embed` request body. Pure (unit-tested).
285fn cohere_embed_body(model: &str, texts: &[String]) -> Value {
286    json!({
287        "model": model,
288        "texts": texts,
289        "input_type": "search_document",
290        "embedding_types": ["float"],
291    })
292}
293
294/// Parse a Cohere `/v2/embed` response (`{"embeddings":{"float":[[...]]}}`). Pure.
295fn parse_cohere_embed(body: &Value) -> Result<Vec<Vec<f32>>, ProviderError> {
296    let floats = body
297        .get("embeddings")
298        .and_then(|e| e.get("float"))
299        .and_then(Value::as_array)
300        .ok_or_else(|| ProviderError::Parse("missing `embeddings.float` array".into()))?;
301    Ok(floats
302        .iter()
303        .map(|row| {
304            row.as_array()
305                .map(|arr| {
306                    arr.iter()
307                        .filter_map(|v| v.as_f64().map(|f| f as f32))
308                        .collect()
309                })
310                .unwrap_or_default()
311        })
312        .collect())
313}
314
315impl EmbeddingProvider for CohereEmbedder {
316    fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, ProviderError> {
317        let resp = ureq::post(&self.url)
318            .timeout(PROVIDER_TIMEOUT)
319            .set("Authorization", &format!("Bearer {}", self.api_key))
320            .send_json(cohere_embed_body(&self.model, texts))
321            .map_err(|e| ProviderError::Http(e.to_string()))?;
322        let body: Value = resp
323            .into_json()
324            .map_err(|e| ProviderError::Parse(e.to_string()))?;
325        let vectors = parse_cohere_embed(&body)?;
326        check_dims(&vectors, self.dim)?;
327        Ok(vectors)
328    }
329    fn dim(&self) -> usize {
330        self.dim
331    }
332}
333
334/// A reranker that speaks Cohere `/v2/rerank`.
335pub struct CohereReranker {
336    url: String,
337    model: String,
338    api_key: String,
339}
340
341/// Build the Cohere `/v2/rerank` request body. Pure (unit-tested).
342fn cohere_rerank_body(model: &str, query: &str, docs: &[String]) -> Value {
343    json!({ "model": model, "query": query, "documents": docs })
344}
345
346/// Parse a Cohere `/v2/rerank` response into a score *per input document*, in the
347/// original order. Cohere returns `{"results":[{"index":i,"relevance_score":s}]}`
348/// (possibly truncated/reordered), so we scatter by `index`. Pure (unit-tested).
349fn parse_cohere_rerank(body: &Value, n_docs: usize) -> Result<Vec<f32>, ProviderError> {
350    let results = body
351        .get("results")
352        .and_then(Value::as_array)
353        .ok_or_else(|| ProviderError::Parse("missing `results` array".into()))?;
354    let mut scores = vec![0.0_f32; n_docs];
355    for r in results {
356        let idx = r
357            .get("index")
358            .and_then(Value::as_u64)
359            .ok_or_else(|| ProviderError::Parse("a result had no `index`".into()))?
360            as usize;
361        let score = r
362            .get("relevance_score")
363            .and_then(Value::as_f64)
364            .ok_or_else(|| ProviderError::Parse("a result had no `relevance_score`".into()))?
365            as f32;
366        if idx < n_docs {
367            scores[idx] = score;
368        }
369    }
370    Ok(scores)
371}
372
373impl RerankProvider for CohereReranker {
374    fn rerank(&self, query: &str, docs: &[String]) -> Result<Vec<f32>, ProviderError> {
375        let resp = ureq::post(&self.url)
376            .timeout(PROVIDER_TIMEOUT)
377            .set("Authorization", &format!("Bearer {}", self.api_key))
378            .send_json(cohere_rerank_body(&self.model, query, docs))
379            .map_err(|e| ProviderError::Http(e.to_string()))?;
380        let body: Value = resp
381            .into_json()
382            .map_err(|e| ProviderError::Parse(e.to_string()))?;
383        parse_cohere_rerank(&body, docs.len())
384    }
385}
386
387/// Reject a provider response whose vectors do not all match the collection's dim
388/// (a misconfigured model is a clear error, not a silent wrong-length insert).
389fn check_dims(vectors: &[Vec<f32>], dim: usize) -> Result<(), ProviderError> {
390    if let Some(bad) = vectors.iter().find(|v| v.len() != dim) {
391        return Err(ProviderError::Parse(format!(
392            "provider returned a {}-dim vector but the collection expects {dim}",
393            bad.len()
394        )));
395    }
396    Ok(())
397}
398
399// ---------------------------------------------------------------------------
400// Registry: build the per-collection providers from config (resolves secrets).
401// ---------------------------------------------------------------------------
402
403/// The default base URLs for the public providers.
404const OPENAI_DEFAULT: &str = "https://api.openai.com/v1/embeddings";
405const COHERE_EMBED_DEFAULT: &str = "https://api.cohere.com/v2/embed";
406const COHERE_RERANK_DEFAULT: &str = "https://api.cohere.com/v2/rerank";
407
408/// Resolve an `api_key_env` name to its value, or `None` when the name is empty.
409fn resolve_key(api_key_env: &str) -> Result<Option<String>, ProviderError> {
410    if api_key_env.is_empty() {
411        return Ok(None);
412    }
413    std::env::var(api_key_env)
414        .map(Some)
415        .map_err(|_| ProviderError::MissingKey(api_key_env.to_owned()))
416}
417
418/// The `[embedding.*]` and `[rerank.*]` tables of a Quiver config file, used by
419/// [`EmbedRegistry::from_toml_path`]. Every other config key is ignored.
420#[derive(Debug, Default, Deserialize)]
421struct ProviderTables {
422    #[serde(default)]
423    embedding: HashMap<String, EmbeddingConfig>,
424    #[serde(default)]
425    rerank: HashMap<String, RerankConfig>,
426}
427
428/// Per-collection embedding/rerank providers, built once at startup from config.
429#[derive(Clone, Default)]
430pub struct EmbedRegistry {
431    embedders: HashMap<String, Arc<dyn EmbeddingProvider>>,
432    rerankers: HashMap<String, Arc<dyn RerankProvider>>,
433}
434
435impl EmbedRegistry {
436    /// Build the registry from the server config's `embedding`/`rerank` tables,
437    /// resolving each `api_key_env` from the environment (a missing required key is
438    /// a hard startup error, surfacing misconfiguration immediately).
439    pub fn from_config(
440        embedding: &HashMap<String, EmbeddingConfig>,
441        rerank: &HashMap<String, RerankConfig>,
442    ) -> Result<Self, ProviderError> {
443        let mut embedders: HashMap<String, Arc<dyn EmbeddingProvider>> = HashMap::new();
444        for (collection, cfg) in embedding {
445            embedders.insert(collection.clone(), build_embedder(cfg)?);
446        }
447        let mut rerankers: HashMap<String, Arc<dyn RerankProvider>> = HashMap::new();
448        for (collection, cfg) in rerank {
449            rerankers.insert(collection.clone(), build_reranker(cfg)?);
450        }
451        Ok(Self {
452            embedders,
453            rerankers,
454        })
455    }
456
457    /// Build a registry from the `[embedding.*]` / `[rerank.*]` tables of a Quiver
458    /// TOML config file — the same tables `quiver serve` reads — so the MCP server
459    /// (`quiver mcp`) can offer text-in/text-out tools with the same configuration
460    /// surface as the network server (ADR-0058). Any other config keys are ignored.
461    ///
462    /// A missing file yields an *empty* registry rather than an error: the MCP
463    /// server still starts, and the text tools report "no embedding provider
464    /// configured" only when actually invoked. A malformed file, or a provider that
465    /// cannot be built (e.g. a missing required API key), is a hard error.
466    pub fn from_toml_path(path: &Path) -> Result<Self, ProviderError> {
467        let tables: ProviderTables = Figment::from(Toml::file(path))
468            .extract()
469            .map_err(|e| ProviderError::Config(e.to_string()))?;
470        Self::from_config(&tables.embedding, &tables.rerank)
471    }
472
473    /// The embedder configured for `collection`, if any.
474    #[must_use]
475    pub fn embedder(&self, collection: &str) -> Option<Arc<dyn EmbeddingProvider>> {
476        self.embedders.get(collection).cloned()
477    }
478
479    /// The reranker configured for `collection`, if any.
480    #[must_use]
481    pub fn reranker(&self, collection: &str) -> Option<Arc<dyn RerankProvider>> {
482        self.rerankers.get(collection).cloned()
483    }
484
485    /// Whether any embedding or rerank provider is configured (so callers can skip
486    /// per-request work entirely on the common no-providers path).
487    #[must_use]
488    pub fn is_empty(&self) -> bool {
489        self.embedders.is_empty() && self.rerankers.is_empty()
490    }
491}
492
493/// Build one embedder from its config, resolving the API key from the environment.
494fn build_embedder(cfg: &EmbeddingConfig) -> Result<Arc<dyn EmbeddingProvider>, ProviderError> {
495    let dim = cfg.dim as usize;
496    match cfg.provider {
497        ProviderKind::Fake => Ok(Arc::new(FakeEmbedder::new(dim))),
498        ProviderKind::Openai => Ok(Arc::new(OpenAiCompatEmbedder {
499            url: if cfg.endpoint.is_empty() {
500                OPENAI_DEFAULT.to_owned()
501            } else {
502                cfg.endpoint.clone()
503            },
504            model: cfg.model.clone(),
505            api_key: resolve_key(&cfg.api_key_env)?,
506            dim,
507        })),
508        ProviderKind::Ollama | ProviderKind::Http => {
509            if cfg.endpoint.is_empty() {
510                return Err(ProviderError::Config(format!(
511                    "provider {:?} requires an `endpoint` (e.g. http://localhost:11434/v1/embeddings)",
512                    cfg.provider
513                )));
514            }
515            Ok(Arc::new(OpenAiCompatEmbedder {
516                url: cfg.endpoint.clone(),
517                model: cfg.model.clone(),
518                api_key: resolve_key(&cfg.api_key_env)?,
519                dim,
520            }))
521        }
522        ProviderKind::Cohere => Ok(Arc::new(CohereEmbedder {
523            url: if cfg.endpoint.is_empty() {
524                COHERE_EMBED_DEFAULT.to_owned()
525            } else {
526                cfg.endpoint.clone()
527            },
528            model: cfg.model.clone(),
529            api_key: resolve_key(&cfg.api_key_env)?.ok_or_else(|| {
530                ProviderError::Config("cohere embedding requires api_key_env".into())
531            })?,
532            dim,
533        })),
534    }
535}
536
537/// Build one reranker from its config (only `cohere` and `fake` are supported).
538fn build_reranker(cfg: &RerankConfig) -> Result<Arc<dyn RerankProvider>, ProviderError> {
539    match cfg.provider {
540        ProviderKind::Fake => Ok(Arc::new(FakeReranker)),
541        ProviderKind::Cohere => Ok(Arc::new(CohereReranker {
542            url: if cfg.endpoint.is_empty() {
543                COHERE_RERANK_DEFAULT.to_owned()
544            } else {
545                cfg.endpoint.clone()
546            },
547            model: cfg.model.clone(),
548            api_key: resolve_key(&cfg.api_key_env)?.ok_or_else(|| {
549                ProviderError::Config("cohere rerank requires api_key_env".into())
550            })?,
551        })),
552        other => Err(ProviderError::Config(format!(
553            "rerank provider {other:?} is not supported (use `cohere` or `fake`)"
554        ))),
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561
562    #[test]
563    fn fake_embedder_is_deterministic_and_content_dependent() {
564        let e = FakeEmbedder::new(8);
565        let a = e.embed(&["hello world".into()]).unwrap();
566        let b = e.embed(&["hello world".into()]).unwrap();
567        let c = e.embed(&["different text".into()]).unwrap();
568        assert_eq!(a.len(), 1);
569        assert_eq!(a[0].len(), 8);
570        assert_eq!(a, b, "identical text → identical vector");
571        assert_ne!(a, c, "different text → different vector");
572        assert_eq!(e.dim(), 8);
573    }
574
575    #[test]
576    fn fake_embedder_batches_in_order() {
577        let e = FakeEmbedder::new(4);
578        let batch = e.embed(&["a".into(), "b".into()]).unwrap();
579        let a = e.embed(&["a".into()]).unwrap();
580        let b = e.embed(&["b".into()]).unwrap();
581        assert_eq!(batch[0], a[0]);
582        assert_eq!(batch[1], b[0]);
583    }
584
585    #[test]
586    fn fake_reranker_scores_by_overlap() {
587        let r = FakeReranker;
588        let scores = r
589            .rerank(
590                "quick brown fox",
591                &[
592                    "the quick brown fox".into(),
593                    "lazy dog".into(),
594                    "fox".into(),
595                ],
596            )
597            .unwrap();
598        assert_eq!(scores, vec![3.0, 0.0, 1.0]);
599    }
600
601    #[test]
602    fn openai_body_and_parse_roundtrip() {
603        let body = openai_body("text-embedding-3-small", &["hi".into(), "yo".into()]);
604        assert_eq!(body["model"], "text-embedding-3-small");
605        assert_eq!(body["input"][1], "yo");
606        let resp = json!({"data":[{"embedding":[0.1,0.2]},{"embedding":[0.3,0.4]}]});
607        let vecs = parse_openai(&resp).unwrap();
608        assert_eq!(vecs, vec![vec![0.1_f32, 0.2], vec![0.3, 0.4]]);
609    }
610
611    #[test]
612    fn parse_openai_rejects_malformed() {
613        assert!(parse_openai(&json!({"oops": 1})).is_err());
614        assert!(parse_openai(&json!({"data":[{"no_embedding": 1}]})).is_err());
615    }
616
617    #[test]
618    fn cohere_embed_body_and_parse_roundtrip() {
619        let body = cohere_embed_body("embed-v4.0", &["doc".into()]);
620        assert_eq!(body["model"], "embed-v4.0");
621        assert_eq!(body["input_type"], "search_document");
622        assert_eq!(body["texts"][0], "doc");
623        let resp = json!({"embeddings":{"float":[[1.0,2.0,3.0]]}});
624        assert_eq!(
625            parse_cohere_embed(&resp).unwrap(),
626            vec![vec![1.0_f32, 2.0, 3.0]]
627        );
628        assert!(parse_cohere_embed(&json!({"embeddings":{}})).is_err());
629    }
630
631    #[test]
632    fn cohere_rerank_scatters_by_index() {
633        let body = cohere_rerank_body("rerank-v3.5", "q", &["a".into(), "b".into()]);
634        assert_eq!(body["query"], "q");
635        // Cohere may reorder and reference docs by their input index.
636        let resp = json!({"results":[
637            {"index":1,"relevance_score":0.9},
638            {"index":0,"relevance_score":0.1},
639        ]});
640        assert_eq!(parse_cohere_rerank(&resp, 2).unwrap(), vec![0.1_f32, 0.9]);
641        // Out-of-range indices are ignored, missing fields error.
642        assert_eq!(
643            parse_cohere_rerank(&json!({"results":[{"index":9,"relevance_score":1.0}]}), 2)
644                .unwrap(),
645            vec![0.0_f32, 0.0]
646        );
647        assert!(parse_cohere_rerank(&json!({"nope":1}), 2).is_err());
648        assert!(parse_cohere_rerank(&json!({"results":[{"index":0}]}), 1).is_err());
649    }
650
651    #[test]
652    fn check_dims_enforces_collection_dim() {
653        assert!(check_dims(&[vec![1.0, 2.0]], 2).is_ok());
654        assert!(check_dims(&[vec![1.0, 2.0, 3.0]], 2).is_err());
655    }
656
657    #[test]
658    fn registry_builds_fake_and_resolves_emptiness() {
659        let mut embedding = HashMap::new();
660        embedding.insert(
661            "docs".to_owned(),
662            EmbeddingConfig {
663                provider: ProviderKind::Fake,
664                model: String::new(),
665                endpoint: String::new(),
666                dim: 16,
667                api_key_env: String::new(),
668            },
669        );
670        let mut rerank = HashMap::new();
671        rerank.insert(
672            "docs".to_owned(),
673            RerankConfig {
674                provider: ProviderKind::Fake,
675                model: String::new(),
676                endpoint: String::new(),
677                api_key_env: String::new(),
678            },
679        );
680        let reg = EmbedRegistry::from_config(&embedding, &rerank).unwrap();
681        assert!(!reg.is_empty());
682        assert_eq!(reg.embedder("docs").unwrap().dim(), 16);
683        assert!(reg.embedder("missing").is_none());
684        assert!(reg.reranker("docs").is_some());
685        assert!(EmbedRegistry::default().is_empty());
686    }
687
688    #[test]
689    fn from_toml_path_loads_embedding_and_rerank_tables() {
690        use std::io::Write;
691        let dir = tempfile::tempdir().unwrap();
692        let path = dir.path().join("quiver.toml");
693        let mut f = std::fs::File::create(&path).unwrap();
694        // A fake provider needs no network/keys, so the loaded registry is usable
695        // in-process. Unrelated tables (here `[server]`) must be ignored.
696        writeln!(
697            f,
698            r#"
699[server]
700host = "127.0.0.1"
701
702[embedding.docs]
703provider = "fake"
704dim = 16
705
706[rerank.docs]
707provider = "fake"
708"#
709        )
710        .unwrap();
711        let reg = EmbedRegistry::from_toml_path(&path).unwrap();
712        assert_eq!(reg.embedder("docs").unwrap().dim(), 16);
713        assert!(reg.reranker("docs").is_some());
714        assert!(reg.embedder("missing").is_none());
715    }
716
717    #[test]
718    fn from_toml_path_missing_file_is_empty_not_an_error() {
719        let reg = EmbedRegistry::from_toml_path(Path::new("definitely-not-here.toml")).unwrap();
720        assert!(reg.is_empty());
721    }
722
723    #[test]
724    fn from_toml_path_propagates_a_misconfigured_provider() {
725        use std::io::Write;
726        let dir = tempfile::tempdir().unwrap();
727        let path = dir.path().join("quiver.toml");
728        let mut f = std::fs::File::create(&path).unwrap();
729        // `http` requires an `endpoint`; omitting it is a hard configuration error.
730        writeln!(
731            f,
732            r#"
733[embedding.docs]
734provider = "http"
735dim = 8
736"#
737        )
738        .unwrap();
739        assert!(matches!(
740            EmbedRegistry::from_toml_path(&path),
741            Err(ProviderError::Config(_))
742        ));
743    }
744
745    #[test]
746    fn http_provider_requires_endpoint() {
747        let cfg = EmbeddingConfig {
748            provider: ProviderKind::Http,
749            model: "m".into(),
750            endpoint: String::new(),
751            dim: 4,
752            api_key_env: String::new(),
753        };
754        assert!(matches!(
755            build_embedder(&cfg),
756            Err(ProviderError::Config(_))
757        ));
758    }
759
760    #[test]
761    fn missing_api_key_is_a_hard_error() {
762        let cfg = EmbeddingConfig {
763            provider: ProviderKind::Openai,
764            model: "m".into(),
765            endpoint: String::new(),
766            dim: 4,
767            api_key_env: "QUIVER_TEST_DEFINITELY_UNSET_KEY".into(),
768        };
769        assert!(matches!(
770            build_embedder(&cfg),
771            Err(ProviderError::MissingKey(_))
772        ));
773    }
774
775    #[test]
776    fn openai_endpoint_defaults_and_overrides() {
777        // Default endpoint when unset, no key required.
778        let def = EmbeddingConfig {
779            provider: ProviderKind::Openai,
780            model: "m".into(),
781            endpoint: String::new(),
782            dim: 4,
783            api_key_env: String::new(),
784        };
785        assert!(build_embedder(&def).is_ok());
786        // Cohere without a key is rejected (rerank too).
787        let cohere = EmbeddingConfig {
788            provider: ProviderKind::Cohere,
789            model: "m".into(),
790            endpoint: String::new(),
791            dim: 4,
792            api_key_env: String::new(),
793        };
794        assert!(matches!(
795            build_embedder(&cohere),
796            Err(ProviderError::Config(_))
797        ));
798        let rr = RerankConfig {
799            provider: ProviderKind::Openai,
800            model: "m".into(),
801            endpoint: String::new(),
802            api_key_env: String::new(),
803        };
804        assert!(matches!(build_reranker(&rr), Err(ProviderError::Config(_))));
805    }
806}