use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct EmbedUsage {
pub input_tokens: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EmbedErrorCode {
QueueFull,
BackendUnavailable,
InvalidRequest,
FrameTooLarge,
Internal,
EmbedUnsupported,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum EmbedResponse {
Embeddings {
id: String,
embeddings: Vec<Vec<f32>>,
dimensions: u32,
model: String,
usage: EmbedUsage,
backend: String,
},
Error {
id: String,
code: EmbedErrorCode,
message: String,
},
}
impl EmbedResponse {
pub fn id(&self) -> &str {
match self {
EmbedResponse::Embeddings { id, .. } | EmbedResponse::Error { id, .. } => id,
}
}
pub fn is_ok(&self) -> bool {
matches!(self, EmbedResponse::Embeddings { .. })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn embeddings_variant_round_trips() {
let resp = EmbedResponse::Embeddings {
id: "r1".into(),
embeddings: vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]],
dimensions: 3,
model: "embeddinggemma-300m".into(),
usage: EmbedUsage { input_tokens: 12 },
backend: "llamacpp".into(),
};
let s = serde_json::to_string(&resp).unwrap();
let back: EmbedResponse = serde_json::from_str(&s).unwrap();
assert_eq!(resp, back);
assert!(resp.is_ok());
assert_eq!(resp.id(), "r1");
}
#[test]
fn error_variant_round_trips() {
let resp = EmbedResponse::Error {
id: "r1".into(),
code: EmbedErrorCode::InvalidRequest,
message: "dimensions must be one of [128, 256, 512, 768]".into(),
};
let s = serde_json::to_string(&resp).unwrap();
let back: EmbedResponse = serde_json::from_str(&s).unwrap();
assert_eq!(resp, back);
assert!(!resp.is_ok());
}
#[test]
fn embeddings_serializes_with_type_tag() {
let resp = EmbedResponse::Embeddings {
id: "r1".into(),
embeddings: vec![vec![0.1]],
dimensions: 1,
model: "m".into(),
usage: EmbedUsage { input_tokens: 1 },
backend: "llamacpp".into(),
};
let v: serde_json::Value = serde_json::to_value(&resp).unwrap();
assert_eq!(v["type"], "embeddings");
assert_eq!(v["dimensions"], 1);
}
#[test]
fn error_code_serializes_snake_case() {
let s = serde_json::to_string(&EmbedErrorCode::EmbedUnsupported).unwrap();
assert_eq!(s, "\"embed_unsupported\"");
}
}