use serde::{Deserialize, Serialize};
use crate::provider::Provider;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum EmbeddingModel {
Qwen3Embedding8B,
TextEmbedding3Small,
}
#[derive(Debug, Clone)]
pub struct EmbeddingModelInfo {
pub model: EmbeddingModel,
pub description: &'static str,
pub context_window: usize,
pub providers: Vec<EmbeddingProviderMapping>,
}
#[derive(Debug, Clone)]
pub struct EmbeddingProviderMapping {
pub provider: Provider,
pub model_id: &'static str,
}
impl EmbeddingProviderMapping {
pub const fn new(provider: Provider, model_id: &'static str) -> Self {
Self { provider, model_id }
}
}
impl EmbeddingModel {
pub fn info(&self) -> EmbeddingModelInfo {
match self {
EmbeddingModel::Qwen3Embedding8B => EmbeddingModelInfo {
model: *self,
description: "Qwen3 Embedding 8B - Multilingual embedding model",
context_window: 32_768,
providers: vec![EmbeddingProviderMapping::new(
Provider::OpenRouter,
"qwen/qwen3-embedding-8b",
)],
},
EmbeddingModel::TextEmbedding3Small => EmbeddingModelInfo {
model: *self,
description: "OpenAI Text Embedding 3 Small - Fast, efficient embedding model",
context_window: 8_191,
providers: vec![EmbeddingProviderMapping::new(
Provider::OpenRouter,
"openai/text-embedding-3-small",
)],
},
}
}
pub fn name(&self) -> &'static str {
match self {
EmbeddingModel::Qwen3Embedding8B => "Qwen3 Embedding 8B",
EmbeddingModel::TextEmbedding3Small => "Text Embedding 3 Small",
}
}
pub fn default() -> Self {
EmbeddingModel::TextEmbedding3Small
}
pub fn all() -> &'static [EmbeddingModel] {
&[
EmbeddingModel::Qwen3Embedding8B,
EmbeddingModel::TextEmbedding3Small,
]
}
}
impl std::fmt::Display for EmbeddingModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone, Serialize)]
pub struct EmbeddingRequest {
pub model: String,
pub input: EmbeddingInput,
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<u32>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(untagged)]
pub enum EmbeddingInput {
Single(String),
Multiple(Vec<String>),
}
impl EmbeddingRequest {
pub fn new(model: impl Into<String>, input: impl Into<String>) -> Self {
Self {
model: model.into(),
input: EmbeddingInput::Single(input.into()),
encoding_format: None,
dimensions: None,
}
}
pub fn new_batch(model: impl Into<String>, inputs: Vec<String>) -> Self {
Self {
model: model.into(),
input: EmbeddingInput::Multiple(inputs),
encoding_format: None,
dimensions: None,
}
}
pub fn with_encoding_format(mut self, format: impl Into<String>) -> Self {
self.encoding_format = Some(format.into());
self
}
pub fn with_dimensions(mut self, dimensions: u32) -> Self {
self.dimensions = Some(dimensions);
self
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingData>,
pub model: String,
#[serde(default)]
pub usage: Option<EmbeddingUsage>,
}
impl EmbeddingResponse {
pub fn embedding(&self) -> Option<&[f32]> {
self.data.first().map(|d| d.embedding.as_slice())
}
pub fn embeddings(&self) -> Vec<&[f32]> {
self.data.iter().map(|d| d.embedding.as_slice()).collect()
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingData {
pub object: String,
pub index: usize,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct EmbeddingUsage {
#[serde(default)]
pub prompt_tokens: u32,
#[serde(default)]
pub total_tokens: u32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_model_info() {
let info = EmbeddingModel::Qwen3Embedding8B.info();
assert!(!info.providers.is_empty());
assert!(info.context_window > 0);
}
#[test]
fn test_all_embedding_models_have_providers() {
for model in EmbeddingModel::all() {
let info = model.info();
assert!(
!info.providers.is_empty(),
"Embedding model {} has no providers",
model.name()
);
}
}
#[test]
fn test_embedding_request() {
let request = EmbeddingRequest::new("model", "Hello, world!");
assert_eq!(request.model, "model");
let batch_request =
EmbeddingRequest::new_batch("model", vec!["Hello".to_string(), "World".to_string()]);
if let EmbeddingInput::Multiple(inputs) = batch_request.input {
assert_eq!(inputs.len(), 2);
} else {
panic!("Expected Multiple input");
}
}
#[test]
fn test_embedding_response_parsing() {
let json = r#"{
"object": "list",
"data": [{
"object": "embedding",
"index": 0,
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5]
}],
"model": "qwen/qwen3-embedding-8b",
"usage": {
"prompt_tokens": 5,
"total_tokens": 5
}
}"#;
let response: EmbeddingResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.data.len(), 1);
assert_eq!(response.embedding().unwrap().len(), 5);
assert_eq!(response.usage.as_ref().unwrap().prompt_tokens, 5);
}
}