use std::fmt;
pub type EmbeddingInput = Vec<String>;
pub type EmbeddingVector = Vec<f32>;
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct EmbeddingUsage {
pub total_tokens: u32,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct EmbeddingResponse {
pub embeddings: Vec<EmbeddingVector>,
pub model: String,
pub dimension: usize,
pub usage: EmbeddingUsage,
}
#[derive(Debug, thiserror::Error)]
pub enum EmbeddingError {
#[error("embedding provider not configured: {0}")]
NotConfigured(String),
#[error("embedding API error (status {status}): {message}")]
Api {
status: u16,
message: String,
},
#[error("embedding network error: {0}")]
Network(String),
#[error("embedding dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch {
expected: usize,
actual: usize,
},
#[error("embedding batch too large: {size} exceeds max {max}")]
BatchTooLarge {
size: usize,
max: usize,
},
#[error("embedding rate limited, retry after {retry_after_secs}s")]
RateLimited {
retry_after_secs: u64,
},
#[error("embedding internal error: {0}")]
Internal(String),
}
impl fmt::Display for EmbeddingUsage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "EmbeddingUsage(tokens={})", self.total_tokens)
}
}
#[async_trait::async_trait]
pub trait EmbeddingProvider: Send + Sync {
fn name(&self) -> &str;
fn dimension(&self) -> usize;
fn model_id(&self) -> &str;
async fn embed(&self, inputs: EmbeddingInput) -> Result<EmbeddingResponse, EmbeddingError>;
async fn health_check(&self) -> Result<(), EmbeddingError>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn embedding_usage_default_is_zero() {
let usage = EmbeddingUsage::default();
assert_eq!(usage.total_tokens, 0);
}
#[test]
fn embedding_usage_display() {
let usage = EmbeddingUsage { total_tokens: 42 };
assert_eq!(usage.to_string(), "EmbeddingUsage(tokens=42)");
}
#[test]
fn embedding_response_fields() {
let response = EmbeddingResponse {
embeddings: vec![vec![0.1, 0.2, 0.3]],
model: "test-model".to_string(),
dimension: 3,
usage: EmbeddingUsage { total_tokens: 10 },
};
assert_eq!(response.embeddings.len(), 1);
assert_eq!(response.dimension, 3);
assert_eq!(response.model, "test-model");
assert_eq!(response.usage.total_tokens, 10);
}
#[test]
fn embedding_response_serde_roundtrip() {
let response = EmbeddingResponse {
embeddings: vec![vec![1.0, 2.0], vec![3.0, 4.0]],
model: "test".to_string(),
dimension: 2,
usage: EmbeddingUsage { total_tokens: 5 },
};
let json = serde_json::to_string(&response).expect("serialize");
let deserialized: EmbeddingResponse = serde_json::from_str(&json).expect("deserialize");
assert_eq!(deserialized.embeddings.len(), 2);
assert_eq!(deserialized.dimension, 2);
assert_eq!(deserialized.usage.total_tokens, 5);
}
#[test]
fn embedding_error_display_not_configured() {
let err = EmbeddingError::NotConfigured("missing api_key".to_string());
assert!(err.to_string().contains("not configured"));
}
#[test]
fn embedding_error_display_api() {
let err = EmbeddingError::Api {
status: 500,
message: "server error".to_string(),
};
let msg = err.to_string();
assert!(msg.contains("500"));
assert!(msg.contains("server error"));
}
#[test]
fn embedding_error_display_dimension_mismatch() {
let err = EmbeddingError::DimensionMismatch {
expected: 768,
actual: 1536,
};
let msg = err.to_string();
assert!(msg.contains("768"));
assert!(msg.contains("1536"));
}
#[test]
fn embedding_error_display_batch_too_large() {
let err = EmbeddingError::BatchTooLarge {
size: 3000,
max: 2048,
};
let msg = err.to_string();
assert!(msg.contains("3000"));
assert!(msg.contains("2048"));
}
#[test]
fn embedding_error_display_rate_limited() {
let err = EmbeddingError::RateLimited {
retry_after_secs: 30,
};
assert!(err.to_string().contains("30"));
}
#[test]
fn embedding_error_display_network() {
let err = EmbeddingError::Network("connection refused".to_string());
let msg = err.to_string();
assert!(msg.contains("connection refused"));
assert!(msg.contains("network"));
}
#[test]
fn embedding_error_display_internal() {
let err = EmbeddingError::Internal("something broke".to_string());
let msg = err.to_string();
assert!(msg.contains("something broke"));
assert!(msg.contains("internal"));
}
#[test]
fn embedding_usage_display_zero() {
let usage = EmbeddingUsage::default();
assert_eq!(usage.to_string(), "EmbeddingUsage(tokens=0)");
}
#[test]
fn embedding_response_empty_vectors() {
let response = EmbeddingResponse {
embeddings: vec![],
model: "empty".to_string(),
dimension: 0,
usage: EmbeddingUsage::default(),
};
assert!(response.embeddings.is_empty());
assert_eq!(response.dimension, 0);
}
#[test]
fn embedding_usage_serde_roundtrip() {
let usage = EmbeddingUsage { total_tokens: 100 };
let json = serde_json::to_string(&usage).expect("serialize");
let deserialized: EmbeddingUsage = serde_json::from_str(&json).expect("deserialize");
assert_eq!(deserialized.total_tokens, 100);
}
}