use async_trait::async_trait;
use crate::embedding::error::EmbeddingError;
pub type Vector = Vec<f32>;
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
async fn embed(&self, text: String) -> Result<Vector, EmbeddingError>;
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vector>, EmbeddingError> {
let mut results = Vec::with_capacity(texts.len());
for text in texts {
results.push(self.embed(text).await?);
}
Ok(results)
}
fn dimension(&self) -> usize;
fn provider_name(&self) -> &'static str;
fn metrics(&self) -> Option<ProviderMetrics> {
None
}
}
#[derive(Debug, Clone, Default)]
pub struct ProviderMetrics {
pub total_requests: u64,
pub total_tokens: u64,
pub failed_requests: u64,
pub estimated_cost_usd: f64,
}
impl ProviderMetrics {
pub fn success_rate(&self) -> f64 {
if self.total_requests == 0 {
return 1.0;
}
let successful = self.total_requests - self.failed_requests;
successful as f64 / self.total_requests as f64
}
pub fn failure_rate(&self) -> f64 {
if self.total_requests == 0 {
return 0.0;
}
self.failed_requests as f64 / self.total_requests as f64
}
pub fn avg_cost_per_request(&self) -> f64 {
if self.total_requests == 0 {
return 0.0;
}
self.estimated_cost_usd / self.total_requests as f64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_metrics_default() {
let metrics = ProviderMetrics::default();
assert_eq!(metrics.total_requests, 0);
assert_eq!(metrics.total_tokens, 0);
assert_eq!(metrics.failed_requests, 0);
assert_eq!(metrics.estimated_cost_usd, 0.0);
}
#[test]
fn test_provider_metrics_success_rate() {
let metrics = ProviderMetrics {
total_requests: 100,
failed_requests: 5,
..Default::default()
};
assert_eq!(metrics.success_rate(), 0.95);
assert_eq!(metrics.failure_rate(), 0.05);
let empty_metrics = ProviderMetrics::default();
assert_eq!(empty_metrics.success_rate(), 1.0);
assert_eq!(empty_metrics.failure_rate(), 0.0);
}
#[test]
fn test_provider_metrics_avg_cost() {
let metrics = ProviderMetrics {
total_requests: 1000,
estimated_cost_usd: 0.50,
..Default::default()
};
assert_eq!(metrics.avg_cost_per_request(), 0.0005);
let empty_metrics = ProviderMetrics::default();
assert_eq!(empty_metrics.avg_cost_per_request(), 0.0);
}
struct MockProvider {
dimension: usize,
name: &'static str,
}
#[async_trait]
impl EmbeddingProvider for MockProvider {
async fn embed(&self, _text: String) -> Result<Vector, EmbeddingError> {
Ok(vec![0.0; self.dimension])
}
fn dimension(&self) -> usize {
self.dimension
}
fn provider_name(&self) -> &'static str {
self.name
}
}
#[tokio::test]
async fn test_provider_trait_object() {
let provider: Box<dyn EmbeddingProvider> = Box::new(MockProvider {
dimension: 768,
name: "mock",
});
assert_eq!(provider.dimension(), 768);
assert_eq!(provider.provider_name(), "mock");
let embedding = provider.embed("test".to_string()).await.unwrap();
assert_eq!(embedding.len(), 768);
}
#[tokio::test]
async fn test_default_batch_implementation() {
let provider: Box<dyn EmbeddingProvider> = Box::new(MockProvider {
dimension: 768,
name: "mock",
});
let texts = vec![
"first".to_string(),
"second".to_string(),
"third".to_string(),
];
let embeddings = provider.embed_batch(texts.clone()).await.unwrap();
assert_eq!(embeddings.len(), texts.len());
for embedding in embeddings {
assert_eq!(embedding.len(), 768);
}
}
#[test]
fn test_metrics_optional() {
let provider = MockProvider {
dimension: 768,
name: "mock",
};
assert!(provider.metrics().is_none());
}
}