use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Json, Response},
Router,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use oxibonsai_rag::embedding::{Embedder, IdentityEmbedder, TfIdfEmbedder};
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum EmbeddingInput {
Single(String),
Batch(Vec<String>),
TokenIds(Vec<u32>),
BatchTokenIds(Vec<Vec<u32>>),
}
impl EmbeddingInput {
pub fn as_strings(&self) -> Vec<String> {
match self {
EmbeddingInput::Single(s) => vec![s.clone()],
EmbeddingInput::Batch(v) => v.clone(),
EmbeddingInput::TokenIds(ids) => {
vec![ids
.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.join(" ")]
}
EmbeddingInput::BatchTokenIds(batch) => batch
.iter()
.map(|ids| {
ids.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.join(" ")
})
.collect(),
}
}
pub fn len(&self) -> usize {
match self {
EmbeddingInput::Single(_) => 1,
EmbeddingInput::Batch(v) => v.len(),
EmbeddingInput::TokenIds(_) => 1,
EmbeddingInput::BatchTokenIds(v) => v.len(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Deserialize)]
pub struct EmbeddingRequest {
pub model: Option<String>,
pub input: EmbeddingInput,
pub encoding_format: Option<String>,
pub dimensions: Option<usize>,
pub user: Option<String>,
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum EmbeddingData {
Float(Vec<f32>),
Base64(String),
}
#[derive(Debug, Serialize)]
pub struct EmbeddingObject {
pub object: String,
pub embedding: EmbeddingData,
pub index: usize,
}
#[derive(Debug, Serialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: usize,
pub total_tokens: usize,
}
#[derive(Debug, Serialize)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingObject>,
pub model: String,
pub usage: EmbeddingUsage,
}
pub struct EmbedderRegistry {
default_dim: usize,
tfidf: std::sync::Mutex<Option<TfIdfEmbedder>>,
identity: IdentityEmbedder,
}
impl EmbedderRegistry {
pub fn new(default_dim: usize) -> Self {
let dim = default_dim.max(1);
let identity = match IdentityEmbedder::new(dim) {
Ok(embedder) => embedder,
Err(_) => unreachable!("dim ≥ 1 was guaranteed by max(1) above"),
};
Self {
default_dim: dim,
tfidf: std::sync::Mutex::new(None),
identity,
}
}
pub fn embed_texts(&self, texts: &[String]) -> Vec<Vec<f32>> {
let guard = self.tfidf.lock().expect("embedder registry mutex poisoned");
if let Some(ref tfidf) = *guard {
texts
.iter()
.map(|t| {
tfidf
.embed(t)
.unwrap_or_else(|_| vec![0.0; tfidf.embedding_dim()])
})
.collect()
} else {
texts
.iter()
.map(|t| {
self.identity
.embed(t)
.unwrap_or_else(|_| vec![0.0; self.default_dim])
})
.collect()
}
}
pub fn fit_tfidf(&self, corpus: &[String]) {
if corpus.is_empty() {
return;
}
let refs: Vec<&str> = corpus.iter().map(String::as_str).collect();
let fitted = TfIdfEmbedder::fit(&refs, self.default_dim);
let mut guard = self.tfidf.lock().expect("embedder registry mutex poisoned");
*guard = Some(fitted);
}
pub fn embedding_dim(&self) -> usize {
let guard = self.tfidf.lock().expect("embedder registry mutex poisoned");
if let Some(ref tfidf) = *guard {
tfidf.embedding_dim()
} else {
self.default_dim
}
}
pub fn encode_base64(embedding: &[f32]) -> String {
let mut out = String::with_capacity(embedding.len() * 8);
for value in embedding {
let bytes = value.to_le_bytes();
for byte in bytes {
use std::fmt::Write as _;
let _ = write!(out, "{byte:02x}");
}
}
out
}
}
pub struct EmbeddingAppState {
pub registry: EmbedderRegistry,
}
impl EmbeddingAppState {
pub fn new(dim: usize) -> Self {
Self {
registry: EmbedderRegistry::new(dim),
}
}
}
#[tracing::instrument(skip(state))]
pub async fn create_embeddings(
State(state): State<Arc<EmbeddingAppState>>,
Json(req): Json<EmbeddingRequest>,
) -> Result<Response, StatusCode> {
if req.input.is_empty() {
return Err(StatusCode::UNPROCESSABLE_ENTITY);
}
let texts = req.input.as_strings();
let use_base64 = req
.encoding_format
.as_deref()
.map(|f| f == "base64")
.unwrap_or(false);
if texts.len() >= 2 {
state.registry.fit_tfidf(&texts);
}
let raw_embeddings = state.registry.embed_texts(&texts);
let prompt_tokens: usize = texts
.iter()
.map(|t| t.split_whitespace().count().max(1))
.sum();
let model_name = req.model.unwrap_or_else(|| "bonsai-embeddings".to_string());
let data: Vec<EmbeddingObject> = raw_embeddings
.into_iter()
.enumerate()
.map(|(index, mut vec)| {
if let Some(dim) = req.dimensions {
vec.truncate(dim);
}
let embedding = if use_base64 {
EmbeddingData::Base64(EmbedderRegistry::encode_base64(&vec))
} else {
EmbeddingData::Float(vec)
};
EmbeddingObject {
object: "embedding".to_owned(),
embedding,
index,
}
})
.collect();
let response = EmbeddingResponse {
object: "list".to_owned(),
data,
model: model_name,
usage: EmbeddingUsage {
prompt_tokens,
total_tokens: prompt_tokens,
},
};
Ok(Json(response).into_response())
}
pub fn create_embeddings_router(dim: usize) -> Router {
let state = Arc::new(EmbeddingAppState::new(dim));
Router::new()
.route("/v1/embeddings", axum::routing::post(create_embeddings))
.with_state(state)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn embedding_input_single_as_strings() {
let input = EmbeddingInput::Single("hello world".to_string());
assert_eq!(input.as_strings(), vec!["hello world"]);
assert_eq!(input.len(), 1);
assert!(!input.is_empty());
}
#[test]
fn embedding_input_batch_as_strings() {
let input = EmbeddingInput::Batch(vec!["foo".to_string(), "bar".to_string()]);
let strings = input.as_strings();
assert_eq!(strings.len(), 2);
assert_eq!(strings[0], "foo");
assert_eq!(strings[1], "bar");
assert_eq!(input.len(), 2);
}
#[test]
fn embedding_input_token_ids_as_strings() {
let input = EmbeddingInput::TokenIds(vec![1u32, 2, 3]);
let strings = input.as_strings();
assert_eq!(strings.len(), 1);
assert_eq!(strings[0], "1 2 3");
}
#[test]
fn embedding_input_batch_token_ids_as_strings() {
let input = EmbeddingInput::BatchTokenIds(vec![vec![10u32, 20], vec![30u32]]);
let strings = input.as_strings();
assert_eq!(strings.len(), 2);
assert_eq!(strings[0], "10 20");
assert_eq!(strings[1], "30");
}
#[test]
fn embedding_input_empty_batch_is_empty() {
let input = EmbeddingInput::Batch(vec![]);
assert!(input.is_empty());
assert_eq!(input.len(), 0);
}
#[test]
fn embedder_registry_basic_embed() {
let registry = EmbedderRegistry::new(32);
let texts = vec!["hello world".to_string(), "foo bar baz".to_string()];
let embeddings = registry.embed_texts(&texts);
assert_eq!(embeddings.len(), 2);
for emb in &embeddings {
assert_eq!(emb.len(), 32, "expected 32 dimensions, got {}", emb.len());
}
}
#[test]
fn embedder_registry_tfidf_fit_changes_dim() {
let registry = EmbedderRegistry::new(64);
let corpus: Vec<String> = (0..20)
.map(|i| format!("document number {i} with some unique words term{i}"))
.collect();
registry.fit_tfidf(&corpus);
let dim = registry.embedding_dim();
assert!(dim > 0, "expected positive dimension after fit");
}
#[test]
fn embedder_registry_fit_empty_corpus_is_noop() {
let registry = EmbedderRegistry::new(16);
registry.fit_tfidf(&[]);
assert_eq!(registry.embedding_dim(), 16);
}
#[test]
fn embedder_registry_embed_after_fit() {
let registry = EmbedderRegistry::new(32);
let corpus: Vec<String> = vec![
"the quick brown fox".to_string(),
"jumped over the lazy dog".to_string(),
"the fox and the dog".to_string(),
];
registry.fit_tfidf(&corpus);
let embeddings = registry.embed_texts(&corpus);
for emb in &embeddings {
assert!(!emb.is_empty(), "embedding must not be empty after fit");
}
}
#[test]
fn encode_base64_non_empty() {
let vec = vec![1.0f32, 0.5f32, -1.0f32];
let encoded = EmbedderRegistry::encode_base64(&vec);
assert_eq!(
encoded.len(),
24,
"expected 24 hex chars for 3 f32 values, got {}",
encoded.len()
);
assert!(!encoded.is_empty());
}
#[test]
fn encode_base64_empty_input() {
let encoded = EmbedderRegistry::encode_base64(&[]);
assert!(encoded.is_empty());
}
#[test]
fn encode_base64_deterministic() {
let vec = vec![std::f32::consts::PI, 2.71f32];
let a = EmbedderRegistry::encode_base64(&vec);
let b = EmbedderRegistry::encode_base64(&vec);
assert_eq!(a, b, "encoding must be deterministic");
}
#[test]
fn encode_base64_known_value() {
let vec = vec![1.0f32];
let encoded = EmbedderRegistry::encode_base64(&vec);
assert_eq!(encoded, "0000803f");
}
#[test]
fn embedding_response_serialises_correctly() {
let resp = EmbeddingResponse {
object: "list".to_owned(),
data: vec![EmbeddingObject {
object: "embedding".to_owned(),
embedding: EmbeddingData::Float(vec![0.1, 0.2]),
index: 0,
}],
model: "bonsai-embeddings".to_owned(),
usage: EmbeddingUsage {
prompt_tokens: 3,
total_tokens: 3,
},
};
let json = serde_json::to_string(&resp).expect("serialisation must succeed");
assert!(json.contains("\"object\":\"list\""));
assert!(json.contains("\"object\":\"embedding\""));
assert!(json.contains("\"index\":0"));
}
}