1use axum::{
28 extract::State,
29 http::StatusCode,
30 response::{IntoResponse, Json, Response},
31 Router,
32};
33use serde::{Deserialize, Serialize};
34use std::sync::Arc;
35
36use oxibonsai_rag::embedding::{Embedder, IdentityEmbedder, TfIdfEmbedder};
37
38#[derive(Debug, Deserialize)]
52#[serde(untagged)]
53pub enum EmbeddingInput {
54 Single(String),
56 Batch(Vec<String>),
58 TokenIds(Vec<u32>),
60 BatchTokenIds(Vec<Vec<u32>>),
62}
63
64impl EmbeddingInput {
65 pub fn as_strings(&self) -> Vec<String> {
70 match self {
71 EmbeddingInput::Single(s) => vec![s.clone()],
72 EmbeddingInput::Batch(v) => v.clone(),
73 EmbeddingInput::TokenIds(ids) => {
74 vec![ids
75 .iter()
76 .map(|id| id.to_string())
77 .collect::<Vec<_>>()
78 .join(" ")]
79 }
80 EmbeddingInput::BatchTokenIds(batch) => batch
81 .iter()
82 .map(|ids| {
83 ids.iter()
84 .map(|id| id.to_string())
85 .collect::<Vec<_>>()
86 .join(" ")
87 })
88 .collect(),
89 }
90 }
91
92 pub fn len(&self) -> usize {
94 match self {
95 EmbeddingInput::Single(_) => 1,
96 EmbeddingInput::Batch(v) => v.len(),
97 EmbeddingInput::TokenIds(_) => 1,
98 EmbeddingInput::BatchTokenIds(v) => v.len(),
99 }
100 }
101
102 pub fn is_empty(&self) -> bool {
104 self.len() == 0
105 }
106}
107
108#[derive(Debug, Deserialize)]
110pub struct EmbeddingRequest {
111 pub model: Option<String>,
113 pub input: EmbeddingInput,
115 pub encoding_format: Option<String>,
117 pub dimensions: Option<usize>,
119 pub user: Option<String>,
121}
122
123#[derive(Debug, Serialize)]
128#[serde(untagged)]
129pub enum EmbeddingData {
130 Float(Vec<f32>),
132 Base64(String),
134}
135
136#[derive(Debug, Serialize)]
138pub struct EmbeddingObject {
139 pub object: String,
141 pub embedding: EmbeddingData,
143 pub index: usize,
145}
146
147#[derive(Debug, Serialize)]
149pub struct EmbeddingUsage {
150 pub prompt_tokens: usize,
152 pub total_tokens: usize,
154}
155
156#[derive(Debug, Serialize)]
158pub struct EmbeddingResponse {
159 pub object: String,
161 pub data: Vec<EmbeddingObject>,
163 pub model: String,
165 pub usage: EmbeddingUsage,
167}
168
169pub struct EmbedderRegistry {
178 default_dim: usize,
179 tfidf: std::sync::Mutex<Option<TfIdfEmbedder>>,
180 identity: IdentityEmbedder,
181}
182
183impl EmbedderRegistry {
184 pub fn new(default_dim: usize) -> Self {
189 let dim = default_dim.max(1);
190 let identity = match IdentityEmbedder::new(dim) {
194 Ok(embedder) => embedder,
195 Err(_) => unreachable!("dim ≥ 1 was guaranteed by max(1) above"),
196 };
197 Self {
198 default_dim: dim,
199 tfidf: std::sync::Mutex::new(None),
200 identity,
201 }
202 }
203
204 pub fn embed_texts(&self, texts: &[String]) -> Vec<Vec<f32>> {
210 let guard = self.tfidf.lock().expect("embedder registry mutex poisoned");
211 if let Some(ref tfidf) = *guard {
212 texts
213 .iter()
214 .map(|t| {
215 tfidf
216 .embed(t)
217 .unwrap_or_else(|_| vec![0.0; tfidf.embedding_dim()])
218 })
219 .collect()
220 } else {
221 texts
222 .iter()
223 .map(|t| {
224 self.identity
225 .embed(t)
226 .unwrap_or_else(|_| vec![0.0; self.default_dim])
227 })
228 .collect()
229 }
230 }
231
232 pub fn fit_tfidf(&self, corpus: &[String]) {
237 if corpus.is_empty() {
238 return;
239 }
240 let refs: Vec<&str> = corpus.iter().map(String::as_str).collect();
241 let fitted = TfIdfEmbedder::fit(&refs, self.default_dim);
242 let mut guard = self.tfidf.lock().expect("embedder registry mutex poisoned");
243 *guard = Some(fitted);
244 }
245
246 pub fn embedding_dim(&self) -> usize {
251 let guard = self.tfidf.lock().expect("embedder registry mutex poisoned");
252 if let Some(ref tfidf) = *guard {
253 tfidf.embedding_dim()
254 } else {
255 self.default_dim
256 }
257 }
258
259 pub fn encode_base64(embedding: &[f32]) -> String {
265 let mut out = String::with_capacity(embedding.len() * 8);
266 for value in embedding {
267 let bytes = value.to_le_bytes();
268 for byte in bytes {
269 use std::fmt::Write as _;
270 let _ = write!(out, "{byte:02x}");
271 }
272 }
273 out
274 }
275}
276
277pub struct EmbeddingAppState {
281 pub registry: EmbedderRegistry,
283}
284
285impl EmbeddingAppState {
286 pub fn new(dim: usize) -> Self {
288 Self {
289 registry: EmbedderRegistry::new(dim),
290 }
291 }
292}
293
294#[tracing::instrument(skip(state))]
301pub async fn create_embeddings(
302 State(state): State<Arc<EmbeddingAppState>>,
303 Json(req): Json<EmbeddingRequest>,
304) -> Result<Response, StatusCode> {
305 if req.input.is_empty() {
306 return Err(StatusCode::UNPROCESSABLE_ENTITY);
307 }
308
309 let texts = req.input.as_strings();
310 let use_base64 = req
311 .encoding_format
312 .as_deref()
313 .map(|f| f == "base64")
314 .unwrap_or(false);
315
316 if texts.len() >= 2 {
321 state.registry.fit_tfidf(&texts);
322 }
323
324 let raw_embeddings = state.registry.embed_texts(&texts);
325
326 let prompt_tokens: usize = texts
328 .iter()
329 .map(|t| t.split_whitespace().count().max(1))
330 .sum();
331
332 let model_name = req.model.unwrap_or_else(|| "bonsai-embeddings".to_string());
333
334 let data: Vec<EmbeddingObject> = raw_embeddings
335 .into_iter()
336 .enumerate()
337 .map(|(index, mut vec)| {
338 if let Some(dim) = req.dimensions {
340 vec.truncate(dim);
341 }
342
343 let embedding = if use_base64 {
344 EmbeddingData::Base64(EmbedderRegistry::encode_base64(&vec))
345 } else {
346 EmbeddingData::Float(vec)
347 };
348
349 EmbeddingObject {
350 object: "embedding".to_owned(),
351 embedding,
352 index,
353 }
354 })
355 .collect();
356
357 let response = EmbeddingResponse {
358 object: "list".to_owned(),
359 data,
360 model: model_name,
361 usage: EmbeddingUsage {
362 prompt_tokens,
363 total_tokens: prompt_tokens,
364 },
365 };
366
367 Ok(Json(response).into_response())
368}
369
370pub fn create_embeddings_router(dim: usize) -> Router {
381 let state = Arc::new(EmbeddingAppState::new(dim));
382 Router::new()
383 .route("/v1/embeddings", axum::routing::post(create_embeddings))
384 .with_state(state)
385}
386
387#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
396 fn embedding_input_single_as_strings() {
397 let input = EmbeddingInput::Single("hello world".to_string());
398 assert_eq!(input.as_strings(), vec!["hello world"]);
399 assert_eq!(input.len(), 1);
400 assert!(!input.is_empty());
401 }
402
403 #[test]
404 fn embedding_input_batch_as_strings() {
405 let input = EmbeddingInput::Batch(vec!["foo".to_string(), "bar".to_string()]);
406 let strings = input.as_strings();
407 assert_eq!(strings.len(), 2);
408 assert_eq!(strings[0], "foo");
409 assert_eq!(strings[1], "bar");
410 assert_eq!(input.len(), 2);
411 }
412
413 #[test]
414 fn embedding_input_token_ids_as_strings() {
415 let input = EmbeddingInput::TokenIds(vec![1u32, 2, 3]);
416 let strings = input.as_strings();
417 assert_eq!(strings.len(), 1);
418 assert_eq!(strings[0], "1 2 3");
419 }
420
421 #[test]
422 fn embedding_input_batch_token_ids_as_strings() {
423 let input = EmbeddingInput::BatchTokenIds(vec![vec![10u32, 20], vec![30u32]]);
424 let strings = input.as_strings();
425 assert_eq!(strings.len(), 2);
426 assert_eq!(strings[0], "10 20");
427 assert_eq!(strings[1], "30");
428 }
429
430 #[test]
431 fn embedding_input_empty_batch_is_empty() {
432 let input = EmbeddingInput::Batch(vec![]);
433 assert!(input.is_empty());
434 assert_eq!(input.len(), 0);
435 }
436
437 #[test]
440 fn embedder_registry_basic_embed() {
441 let registry = EmbedderRegistry::new(32);
442 let texts = vec!["hello world".to_string(), "foo bar baz".to_string()];
443 let embeddings = registry.embed_texts(&texts);
444 assert_eq!(embeddings.len(), 2);
445 for emb in &embeddings {
447 assert_eq!(emb.len(), 32, "expected 32 dimensions, got {}", emb.len());
448 }
449 }
450
451 #[test]
452 fn embedder_registry_tfidf_fit_changes_dim() {
453 let registry = EmbedderRegistry::new(64);
454 let corpus: Vec<String> = (0..20)
455 .map(|i| format!("document number {i} with some unique words term{i}"))
456 .collect();
457 registry.fit_tfidf(&corpus);
458 let dim = registry.embedding_dim();
460 assert!(dim > 0, "expected positive dimension after fit");
461 }
462
463 #[test]
464 fn embedder_registry_fit_empty_corpus_is_noop() {
465 let registry = EmbedderRegistry::new(16);
466 registry.fit_tfidf(&[]);
467 assert_eq!(registry.embedding_dim(), 16);
469 }
470
471 #[test]
472 fn embedder_registry_embed_after_fit() {
473 let registry = EmbedderRegistry::new(32);
474 let corpus: Vec<String> = vec![
475 "the quick brown fox".to_string(),
476 "jumped over the lazy dog".to_string(),
477 "the fox and the dog".to_string(),
478 ];
479 registry.fit_tfidf(&corpus);
480 let embeddings = registry.embed_texts(&corpus);
481 for emb in &embeddings {
482 assert!(!emb.is_empty(), "embedding must not be empty after fit");
483 }
484 }
485
486 #[test]
489 fn encode_base64_non_empty() {
490 let vec = vec![1.0f32, 0.5f32, -1.0f32];
491 let encoded = EmbedderRegistry::encode_base64(&vec);
492 assert_eq!(
494 encoded.len(),
495 24,
496 "expected 24 hex chars for 3 f32 values, got {}",
497 encoded.len()
498 );
499 assert!(!encoded.is_empty());
500 }
501
502 #[test]
503 fn encode_base64_empty_input() {
504 let encoded = EmbedderRegistry::encode_base64(&[]);
505 assert!(encoded.is_empty());
506 }
507
508 #[test]
509 fn encode_base64_deterministic() {
510 let vec = vec![std::f32::consts::PI, 2.71f32];
511 let a = EmbedderRegistry::encode_base64(&vec);
512 let b = EmbedderRegistry::encode_base64(&vec);
513 assert_eq!(a, b, "encoding must be deterministic");
514 }
515
516 #[test]
517 fn encode_base64_known_value() {
518 let vec = vec![1.0f32];
520 let encoded = EmbedderRegistry::encode_base64(&vec);
521 assert_eq!(encoded, "0000803f");
522 }
523
524 #[test]
527 fn embedding_response_serialises_correctly() {
528 let resp = EmbeddingResponse {
529 object: "list".to_owned(),
530 data: vec![EmbeddingObject {
531 object: "embedding".to_owned(),
532 embedding: EmbeddingData::Float(vec![0.1, 0.2]),
533 index: 0,
534 }],
535 model: "bonsai-embeddings".to_owned(),
536 usage: EmbeddingUsage {
537 prompt_tokens: 3,
538 total_tokens: 3,
539 },
540 };
541 let json = serde_json::to_string(&resp).expect("serialisation must succeed");
542 assert!(json.contains("\"object\":\"list\""));
543 assert!(json.contains("\"object\":\"embedding\""));
544 assert!(json.contains("\"index\":0"));
545 }
546}