Skip to main content

oxibonsai_runtime/
embeddings.rs

1//! OpenAI v1 embeddings endpoint.
2//!
3//! Implements `POST /v1/embeddings` — an OpenAI-compatible embedding API that
4//! converts text (or token ID arrays) into dense float vectors.
5//!
6//! # Backends
7//!
8//! The [`EmbedderRegistry`] manages two complementary backends:
9//!
10//! - **[`TfIdfEmbedder`]** — fitted on-the-fly from the texts seen so far;
11//!   becomes active after the first call (or explicit [`EmbedderRegistry::fit_tfidf`]).
12//! - **[`IdentityEmbedder`]** — byte-hash fallback used before TF-IDF is fitted,
13//!   always available and fully deterministic.
14//!
15//! # Encoding formats
16//!
17//! - `"float"` (default) — embedding returned as a JSON array of `f32` values.
18//! - `"base64"` — embedding encoded as a hex string; each `f32` is serialised
19//!   as four little-endian bytes rendered as lowercase hex.
20//!
21//! # Dimensions
22//!
23//! Setting `dimensions` in the request truncates the embedding vectors to that
24//! length before returning them.  If `dimensions` exceeds the natural embedding
25//! size the full vector is returned unmodified.
26
27use axum::{
28    extract::State,
29    http::StatusCode,
30    response::{IntoResponse, Json, Response},
31    Router,
32};
33use serde::{Deserialize, Serialize};
34use std::sync::Arc;
35
36use oxibonsai_rag::embedding::{Embedder, IdentityEmbedder, TfIdfEmbedder};
37
38// ─── Request / Response types ─────────────────────────────────────────────────
39
40/// Input accepted by the embeddings endpoint.
41///
42/// All four variants are deserialized from untagged JSON, so the format is
43/// inferred from the structure of the value supplied in the `"input"` field:
44///
45/// | JSON value | Variant |
46/// |---|---|
47/// | `"some text"` | `Single` |
48/// | `["text one", "text two"]` | `Batch` |
49/// | `[42, 1337]` | `TokenIds` |
50/// | `[[42, 1337], [9, 99]]` | `BatchTokenIds` |
51#[derive(Debug, Deserialize)]
52#[serde(untagged)]
53pub enum EmbeddingInput {
54    /// A single text string.
55    Single(String),
56    /// A batch of text strings.
57    Batch(Vec<String>),
58    /// A single token-ID sequence (converted to a space-joined string).
59    TokenIds(Vec<u32>),
60    /// A batch of token-ID sequences.
61    BatchTokenIds(Vec<Vec<u32>>),
62}
63
64impl EmbeddingInput {
65    /// Convert all inputs to `String` form for embedding.
66    ///
67    /// Token-ID sequences are rendered as space-separated decimal numbers so
68    /// they can be passed through the text-based embedder.
69    pub fn as_strings(&self) -> Vec<String> {
70        match self {
71            EmbeddingInput::Single(s) => vec![s.clone()],
72            EmbeddingInput::Batch(v) => v.clone(),
73            EmbeddingInput::TokenIds(ids) => {
74                vec![ids
75                    .iter()
76                    .map(|id| id.to_string())
77                    .collect::<Vec<_>>()
78                    .join(" ")]
79            }
80            EmbeddingInput::BatchTokenIds(batch) => batch
81                .iter()
82                .map(|ids| {
83                    ids.iter()
84                        .map(|id| id.to_string())
85                        .collect::<Vec<_>>()
86                        .join(" ")
87                })
88                .collect(),
89        }
90    }
91
92    /// Number of distinct inputs.
93    pub fn len(&self) -> usize {
94        match self {
95            EmbeddingInput::Single(_) => 1,
96            EmbeddingInput::Batch(v) => v.len(),
97            EmbeddingInput::TokenIds(_) => 1,
98            EmbeddingInput::BatchTokenIds(v) => v.len(),
99        }
100    }
101
102    /// Whether the input contains no items.
103    pub fn is_empty(&self) -> bool {
104        self.len() == 0
105    }
106}
107
108/// `POST /v1/embeddings` request body.
109#[derive(Debug, Deserialize)]
110pub struct EmbeddingRequest {
111    /// The model name (accepted but ignored — the registry selects the backend).
112    pub model: Option<String>,
113    /// The text(s) or token sequence(s) to embed.
114    pub input: EmbeddingInput,
115    /// Encoding format: `"float"` (default) or `"base64"`.
116    pub encoding_format: Option<String>,
117    /// If set, truncate each embedding to this many dimensions.
118    pub dimensions: Option<usize>,
119    /// Opaque caller identifier (not processed).
120    pub user: Option<String>,
121}
122
123/// The serialised form of a single embedding.
124///
125/// When the request specifies `encoding_format = "base64"` the `Base64` variant
126/// is used; otherwise `Float`.
127#[derive(Debug, Serialize)]
128#[serde(untagged)]
129pub enum EmbeddingData {
130    /// Embedding as a JSON array of `f32` values.
131    Float(Vec<f32>),
132    /// Embedding encoded as a hex string (see [`EmbedderRegistry::encode_base64`]).
133    Base64(String),
134}
135
136/// A single embedding object in the response.
137#[derive(Debug, Serialize)]
138pub struct EmbeddingObject {
139    /// Always `"embedding"`.
140    pub object: String,
141    /// The dense vector (or its encoded form).
142    pub embedding: EmbeddingData,
143    /// Zero-based position of this item among all inputs.
144    pub index: usize,
145}
146
147/// Token usage reported in the embeddings response.
148#[derive(Debug, Serialize)]
149pub struct EmbeddingUsage {
150    /// Total tokens consumed by the prompt(s).
151    pub prompt_tokens: usize,
152    /// Same as `prompt_tokens` (there are no completion tokens for embeddings).
153    pub total_tokens: usize,
154}
155
156/// `POST /v1/embeddings` response body.
157#[derive(Debug, Serialize)]
158pub struct EmbeddingResponse {
159    /// Always `"list"`.
160    pub object: String,
161    /// One [`EmbeddingObject`] per input.
162    pub data: Vec<EmbeddingObject>,
163    /// The model / backend used.
164    pub model: String,
165    /// Token usage statistics.
166    pub usage: EmbeddingUsage,
167}
168
169// ─── EmbedderRegistry ─────────────────────────────────────────────────────────
170
171/// Thread-safe registry that holds the active embedding backends.
172///
173/// On creation the TF-IDF slot is empty and the `IdentityEmbedder` is used as
174/// a deterministic fall-back.  Once enough documents have been seen (or
175/// [`EmbedderRegistry::fit_tfidf`] is called explicitly) the TF-IDF embedder
176/// is installed and used for all subsequent requests.
177pub struct EmbedderRegistry {
178    default_dim: usize,
179    tfidf: std::sync::Mutex<Option<TfIdfEmbedder>>,
180    identity: IdentityEmbedder,
181}
182
183impl EmbedderRegistry {
184    /// Create a new registry.
185    ///
186    /// `default_dim` controls the dimensionality of the `IdentityEmbedder`
187    /// fallback and is also used as `max_features` when fitting TF-IDF.
188    pub fn new(default_dim: usize) -> Self {
189        let dim = default_dim.max(1);
190        // `dim` is guaranteed ≥ 1 by the `.max(1)` clamp above, so
191        // `IdentityEmbedder::new` cannot fail here.  We still handle the
192        // `Err` branch explicitly to avoid `.expect()` in production code.
193        let identity = match IdentityEmbedder::new(dim) {
194            Ok(embedder) => embedder,
195            Err(_) => unreachable!("dim ≥ 1 was guaranteed by max(1) above"),
196        };
197        Self {
198            default_dim: dim,
199            tfidf: std::sync::Mutex::new(None),
200            identity,
201        }
202    }
203
204    /// Embed a slice of text strings, returning one dense vector per input.
205    ///
206    /// Uses the TF-IDF backend when it has been fitted; falls back to
207    /// `IdentityEmbedder` otherwise.  Texts that fail to embed are silently
208    /// replaced with a zero vector of the appropriate dimension.
209    pub fn embed_texts(&self, texts: &[String]) -> Vec<Vec<f32>> {
210        let guard = self.tfidf.lock().expect("embedder registry mutex poisoned");
211        if let Some(ref tfidf) = *guard {
212            texts
213                .iter()
214                .map(|t| {
215                    tfidf
216                        .embed(t)
217                        .unwrap_or_else(|_| vec![0.0; tfidf.embedding_dim()])
218                })
219                .collect()
220        } else {
221            texts
222                .iter()
223                .map(|t| {
224                    self.identity
225                        .embed(t)
226                        .unwrap_or_else(|_| vec![0.0; self.default_dim])
227                })
228                .collect()
229        }
230    }
231
232    /// Fit the TF-IDF backend from `corpus`.
233    ///
234    /// After this call [`embed_texts`](Self::embed_texts) will use TF-IDF for
235    /// all subsequent requests.  Subsequent calls replace the existing model.
236    pub fn fit_tfidf(&self, corpus: &[String]) {
237        if corpus.is_empty() {
238            return;
239        }
240        let refs: Vec<&str> = corpus.iter().map(String::as_str).collect();
241        let fitted = TfIdfEmbedder::fit(&refs, self.default_dim);
242        let mut guard = self.tfidf.lock().expect("embedder registry mutex poisoned");
243        *guard = Some(fitted);
244    }
245
246    /// Return the current embedding dimension.
247    ///
248    /// Returns the TF-IDF vocabulary size when a fitted model is present,
249    /// otherwise the configured `default_dim`.
250    pub fn embedding_dim(&self) -> usize {
251        let guard = self.tfidf.lock().expect("embedder registry mutex poisoned");
252        if let Some(ref tfidf) = *guard {
253            tfidf.embedding_dim()
254        } else {
255            self.default_dim
256        }
257    }
258
259    /// Encode an embedding vector as a hex string (pure Rust, no external deps).
260    ///
261    /// Each `f32` is serialised as four bytes in little-endian order, with each
262    /// byte represented as two lowercase hex digits.  The result is therefore
263    /// `8 * embedding.len()` characters long.
264    pub fn encode_base64(embedding: &[f32]) -> String {
265        let mut out = String::with_capacity(embedding.len() * 8);
266        for value in embedding {
267            let bytes = value.to_le_bytes();
268            for byte in bytes {
269                use std::fmt::Write as _;
270                let _ = write!(out, "{byte:02x}");
271            }
272        }
273        out
274    }
275}
276
277// ─── App state ────────────────────────────────────────────────────────────────
278
279/// Axum application state for the embeddings sub-router.
280pub struct EmbeddingAppState {
281    /// The active embedding registry.
282    pub registry: EmbedderRegistry,
283}
284
285impl EmbeddingAppState {
286    /// Create a new state with the given embedding dimensionality.
287    pub fn new(dim: usize) -> Self {
288        Self {
289            registry: EmbedderRegistry::new(dim),
290        }
291    }
292}
293
294// ─── Handler ──────────────────────────────────────────────────────────────────
295
296/// Handler for `POST /v1/embeddings`.
297///
298/// Computes dense vector representations for all supplied inputs and returns
299/// an OpenAI-compatible response.
300#[tracing::instrument(skip(state))]
301pub async fn create_embeddings(
302    State(state): State<Arc<EmbeddingAppState>>,
303    Json(req): Json<EmbeddingRequest>,
304) -> Result<Response, StatusCode> {
305    if req.input.is_empty() {
306        return Err(StatusCode::UNPROCESSABLE_ENTITY);
307    }
308
309    let texts = req.input.as_strings();
310    let use_base64 = req
311        .encoding_format
312        .as_deref()
313        .map(|f| f == "base64")
314        .unwrap_or(false);
315
316    // Fit TF-IDF on the fly when the caller provides a meaningful corpus
317    // (≥ 2 documents).  Single-document batches are not large enough to build
318    // a useful IDF weighting, so we fall back to the IdentityEmbedder in that
319    // case to keep embedding dimensions stable across truncation scenarios.
320    if texts.len() >= 2 {
321        state.registry.fit_tfidf(&texts);
322    }
323
324    let raw_embeddings = state.registry.embed_texts(&texts);
325
326    // Count tokens for usage: approximate as whitespace-split word count.
327    let prompt_tokens: usize = texts
328        .iter()
329        .map(|t| t.split_whitespace().count().max(1))
330        .sum();
331
332    let model_name = req.model.unwrap_or_else(|| "bonsai-embeddings".to_string());
333
334    let data: Vec<EmbeddingObject> = raw_embeddings
335        .into_iter()
336        .enumerate()
337        .map(|(index, mut vec)| {
338            // Optionally truncate to requested dimensions.
339            if let Some(dim) = req.dimensions {
340                vec.truncate(dim);
341            }
342
343            let embedding = if use_base64 {
344                EmbeddingData::Base64(EmbedderRegistry::encode_base64(&vec))
345            } else {
346                EmbeddingData::Float(vec)
347            };
348
349            EmbeddingObject {
350                object: "embedding".to_owned(),
351                embedding,
352                index,
353            }
354        })
355        .collect();
356
357    let response = EmbeddingResponse {
358        object: "list".to_owned(),
359        data,
360        model: model_name,
361        usage: EmbeddingUsage {
362            prompt_tokens,
363            total_tokens: prompt_tokens,
364        },
365    };
366
367    Ok(Json(response).into_response())
368}
369
370// ─── Router factory ───────────────────────────────────────────────────────────
371
372/// Build a standalone Axum router for the embeddings endpoint.
373///
374/// Mount this at the root with [`Router::merge`] or nest it under a path
375/// prefix with [`Router::nest`].  The router exposes a single route:
376///
377/// ```text
378/// POST /v1/embeddings
379/// ```
380pub fn create_embeddings_router(dim: usize) -> Router {
381    let state = Arc::new(EmbeddingAppState::new(dim));
382    Router::new()
383        .route("/v1/embeddings", axum::routing::post(create_embeddings))
384        .with_state(state)
385}
386
387// ─── Unit tests ───────────────────────────────────────────────────────────────
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    // ── EmbeddingInput ────────────────────────────────────────────────────────
394
395    #[test]
396    fn embedding_input_single_as_strings() {
397        let input = EmbeddingInput::Single("hello world".to_string());
398        assert_eq!(input.as_strings(), vec!["hello world"]);
399        assert_eq!(input.len(), 1);
400        assert!(!input.is_empty());
401    }
402
403    #[test]
404    fn embedding_input_batch_as_strings() {
405        let input = EmbeddingInput::Batch(vec!["foo".to_string(), "bar".to_string()]);
406        let strings = input.as_strings();
407        assert_eq!(strings.len(), 2);
408        assert_eq!(strings[0], "foo");
409        assert_eq!(strings[1], "bar");
410        assert_eq!(input.len(), 2);
411    }
412
413    #[test]
414    fn embedding_input_token_ids_as_strings() {
415        let input = EmbeddingInput::TokenIds(vec![1u32, 2, 3]);
416        let strings = input.as_strings();
417        assert_eq!(strings.len(), 1);
418        assert_eq!(strings[0], "1 2 3");
419    }
420
421    #[test]
422    fn embedding_input_batch_token_ids_as_strings() {
423        let input = EmbeddingInput::BatchTokenIds(vec![vec![10u32, 20], vec![30u32]]);
424        let strings = input.as_strings();
425        assert_eq!(strings.len(), 2);
426        assert_eq!(strings[0], "10 20");
427        assert_eq!(strings[1], "30");
428    }
429
430    #[test]
431    fn embedding_input_empty_batch_is_empty() {
432        let input = EmbeddingInput::Batch(vec![]);
433        assert!(input.is_empty());
434        assert_eq!(input.len(), 0);
435    }
436
437    // ── EmbedderRegistry ─────────────────────────────────────────────────────
438
439    #[test]
440    fn embedder_registry_basic_embed() {
441        let registry = EmbedderRegistry::new(32);
442        let texts = vec!["hello world".to_string(), "foo bar baz".to_string()];
443        let embeddings = registry.embed_texts(&texts);
444        assert_eq!(embeddings.len(), 2);
445        // Each embedding must have exactly `default_dim` elements.
446        for emb in &embeddings {
447            assert_eq!(emb.len(), 32, "expected 32 dimensions, got {}", emb.len());
448        }
449    }
450
451    #[test]
452    fn embedder_registry_tfidf_fit_changes_dim() {
453        let registry = EmbedderRegistry::new(64);
454        let corpus: Vec<String> = (0..20)
455            .map(|i| format!("document number {i} with some unique words term{i}"))
456            .collect();
457        registry.fit_tfidf(&corpus);
458        // After fitting the dimension comes from the TF-IDF vocabulary.
459        let dim = registry.embedding_dim();
460        assert!(dim > 0, "expected positive dimension after fit");
461    }
462
463    #[test]
464    fn embedder_registry_fit_empty_corpus_is_noop() {
465        let registry = EmbedderRegistry::new(16);
466        registry.fit_tfidf(&[]);
467        // Should still use IdentityEmbedder (dim == default_dim).
468        assert_eq!(registry.embedding_dim(), 16);
469    }
470
471    #[test]
472    fn embedder_registry_embed_after_fit() {
473        let registry = EmbedderRegistry::new(32);
474        let corpus: Vec<String> = vec![
475            "the quick brown fox".to_string(),
476            "jumped over the lazy dog".to_string(),
477            "the fox and the dog".to_string(),
478        ];
479        registry.fit_tfidf(&corpus);
480        let embeddings = registry.embed_texts(&corpus);
481        for emb in &embeddings {
482            assert!(!emb.is_empty(), "embedding must not be empty after fit");
483        }
484    }
485
486    // ── encode_base64 ─────────────────────────────────────────────────────────
487
488    #[test]
489    fn encode_base64_non_empty() {
490        let vec = vec![1.0f32, 0.5f32, -1.0f32];
491        let encoded = EmbedderRegistry::encode_base64(&vec);
492        // Each f32 → 4 bytes → 8 hex chars; 3 values → 24 chars.
493        assert_eq!(
494            encoded.len(),
495            24,
496            "expected 24 hex chars for 3 f32 values, got {}",
497            encoded.len()
498        );
499        assert!(!encoded.is_empty());
500    }
501
502    #[test]
503    fn encode_base64_empty_input() {
504        let encoded = EmbedderRegistry::encode_base64(&[]);
505        assert!(encoded.is_empty());
506    }
507
508    #[test]
509    fn encode_base64_deterministic() {
510        let vec = vec![std::f32::consts::PI, 2.71f32];
511        let a = EmbedderRegistry::encode_base64(&vec);
512        let b = EmbedderRegistry::encode_base64(&vec);
513        assert_eq!(a, b, "encoding must be deterministic");
514    }
515
516    #[test]
517    fn encode_base64_known_value() {
518        // f32::from_le_bytes([0x00, 0x00, 0x80, 0x3f]) == 1.0
519        let vec = vec![1.0f32];
520        let encoded = EmbedderRegistry::encode_base64(&vec);
521        assert_eq!(encoded, "0000803f");
522    }
523
524    // ── EmbeddingResponse serialisation ──────────────────────────────────────
525
526    #[test]
527    fn embedding_response_serialises_correctly() {
528        let resp = EmbeddingResponse {
529            object: "list".to_owned(),
530            data: vec![EmbeddingObject {
531                object: "embedding".to_owned(),
532                embedding: EmbeddingData::Float(vec![0.1, 0.2]),
533                index: 0,
534            }],
535            model: "bonsai-embeddings".to_owned(),
536            usage: EmbeddingUsage {
537                prompt_tokens: 3,
538                total_tokens: 3,
539            },
540        };
541        let json = serde_json::to_string(&resp).expect("serialisation must succeed");
542        assert!(json.contains("\"object\":\"list\""));
543        assert!(json.contains("\"object\":\"embedding\""));
544        assert!(json.contains("\"index\":0"));
545    }
546}