use std::sync::Arc;
use axum::extract::State;
use axum::Json;
use serde::{Deserialize, Serialize};
use tokio::sync::oneshot;
use crate::error::{ServerError, ServerResult};
use crate::queue::BatchRequest;
use crate::state::AppState;
#[derive(Debug, Deserialize)]
pub struct EmbeddingRequest {
pub model: String,
pub input: EmbeddingInput,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum EmbeddingInput {
Single(String),
Batch(Vec<String>),
}
#[derive(Debug, Serialize)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingData>,
pub model: String,
pub usage: EmbeddingUsage,
}
#[derive(Debug, Serialize)]
pub struct EmbeddingData {
pub object: String,
pub index: usize,
pub embedding: Vec<f32>,
}
#[derive(Debug, Serialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: usize,
pub total_tokens: usize,
}
pub async fn embeddings(
State(state): State<Arc<AppState>>,
Json(request): Json<EmbeddingRequest>,
) -> ServerResult<Json<EmbeddingResponse>> {
let inputs: Vec<String> = match request.input {
EmbeddingInput::Single(s) => vec![s],
EmbeddingInput::Batch(v) => v,
};
if inputs.is_empty() {
return Err(ServerError::InvalidRequest {
message: "input must contain at least one string".to_string(),
});
}
let model_id = state.model_id.clone();
let mut data = Vec::with_capacity(inputs.len());
let mut total_tokens = 0usize;
for (idx, text) in inputs.into_iter().enumerate() {
total_tokens += text.split_whitespace().count().max(1);
let (reply_tx, reply_rx) = oneshot::channel::<Result<Vec<f32>, String>>();
state
.queue
.send(BatchRequest::Embed {
text,
reply: reply_tx,
})
.await
.map_err(|_| ServerError::WorkerDead)?;
let embedding = reply_rx
.await
.map_err(|_| ServerError::WorkerDead)?
.map_err(|e| ServerError::InvalidRequest { message: e })?;
data.push(EmbeddingData {
object: "embedding".to_string(),
index: idx,
embedding,
});
}
Ok(Json(EmbeddingResponse {
object: "list".to_string(),
data,
model: model_id,
usage: EmbeddingUsage {
prompt_tokens: total_tokens,
total_tokens,
},
}))
}
#[cfg(test)]
mod tests {
use serde_json::json;
use crate::test_helpers::{build_live_test_app, build_test_app, post_json};
#[tokio::test]
async fn test_embeddings_missing_input_returns_422() {
let app = build_test_app().await;
let (status, _body) =
post_json(app, "/v1/embeddings", json!({"model": "test-model"})).await;
assert_eq!(status.as_u16(), 422, "missing input field should yield 422");
}
#[tokio::test]
async fn test_embeddings_empty_body_returns_422() {
let app = build_test_app().await;
let (status, _body) = post_json(app, "/v1/embeddings", json!({})).await;
assert_eq!(status.as_u16(), 422);
}
#[tokio::test]
async fn test_embeddings_single_input_worker_dead_returns_503() {
let app = build_test_app().await;
let (status, body) = post_json(
app,
"/v1/embeddings",
json!({
"model": "test-model",
"input": "hello world"
}),
)
.await;
assert_eq!(status.as_u16(), 503);
assert_eq!(
body["error"]["type"].as_str().unwrap_or(""),
"service_unavailable"
);
}
#[tokio::test]
async fn test_embeddings_batch_input_worker_dead_returns_503() {
let app = build_test_app().await;
let (status, body) = post_json(
app,
"/v1/embeddings",
json!({
"model": "test-model",
"input": ["foo", "bar", "baz"]
}),
)
.await;
assert_eq!(
status.as_u16(),
503,
"batch dead worker should be 503: {body}"
);
}
#[tokio::test]
async fn test_embeddings_empty_batch_returns_400() {
let app = build_test_app().await;
let (status, body) = post_json(
app,
"/v1/embeddings",
json!({
"model": "test-model",
"input": []
}),
)
.await;
assert_eq!(status.as_u16(), 400, "empty batch should yield 400: {body}");
}
#[tokio::test]
async fn test_embeddings_single_text_returns_200() {
let app = build_live_test_app().await;
let body = json!({
"model": "test",
"input": "hello world"
});
let (status, json) = post_json(app, "/v1/embeddings", body).await;
assert_eq!(
status.as_u16(),
200,
"live worker should return 200: {json}"
);
assert_eq!(
json["object"].as_str().unwrap_or(""),
"list",
"object field must be 'list': {json}"
);
let embedding = json["data"][0]["embedding"]
.as_array()
.expect("test: data[0].embedding must be an array");
assert!(!embedding.is_empty(), "embedding vector must not be empty");
}
#[tokio::test]
async fn test_embeddings_batch_returns_200() {
let app = build_live_test_app().await;
let body = json!({
"model": "test",
"input": ["hello", "world"]
});
let (status, json) = post_json(app, "/v1/embeddings", body).await;
assert_eq!(
status.as_u16(),
200,
"live worker batch should return 200: {json}"
);
let data = json["data"]
.as_array()
.expect("test: data field must be an array");
assert_eq!(
data.len(),
2,
"two inputs must produce two embeddings: {json}"
);
}
#[tokio::test]
async fn test_embeddings_response_has_model_field() {
let app = build_live_test_app().await;
let body = json!({
"model": "test-model",
"input": "test"
});
let (status, json) = post_json(app, "/v1/embeddings", body).await;
assert_eq!(
status.as_u16(),
200,
"live worker should return 200: {json}"
);
assert_eq!(
json["model"].as_str().unwrap_or(""),
"test-model",
"model field must match AppState model_id: {json}"
);
}
}