mod sim;
#[cfg(feature = "embedding-openai")]
mod openai;
pub use sim::SimEmbeddingProvider;
#[cfg(feature = "embedding-openai")]
pub use openai::OpenAIEmbeddingProvider;
use async_trait::async_trait;
#[derive(Debug, Clone, thiserror::Error)]
pub enum EmbeddingError {
#[error("Request timed out")]
Timeout,
#[error("Rate limit exceeded, retry after {retry_after_secs:?}s")]
RateLimit {
retry_after_secs: Option<u64>,
},
#[error("Context length exceeded: {tokens} tokens")]
ContextOverflow {
tokens: usize,
},
#[error("Invalid response: {message}")]
InvalidResponse {
message: String,
},
#[error("Service unavailable: {message}")]
ServiceUnavailable {
message: String,
},
#[error("Authentication failed")]
AuthenticationFailed,
#[error("JSON error: {message}")]
JsonError {
message: String,
},
#[error("Network error: {message}")]
NetworkError {
message: String,
},
#[error("Invalid request: {message}")]
InvalidRequest {
message: String,
},
#[error("Empty input provided")]
EmptyInput,
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch {
expected: usize,
actual: usize,
},
}
impl EmbeddingError {
#[must_use]
pub fn timeout() -> Self {
Self::Timeout
}
#[must_use]
pub fn rate_limit(retry_after_secs: Option<u64>) -> Self {
Self::RateLimit { retry_after_secs }
}
#[must_use]
pub fn context_overflow(tokens: usize) -> Self {
Self::ContextOverflow { tokens }
}
#[must_use]
pub fn invalid_response(message: impl Into<String>) -> Self {
Self::InvalidResponse {
message: message.into(),
}
}
#[must_use]
pub fn service_unavailable(message: impl Into<String>) -> Self {
Self::ServiceUnavailable {
message: message.into(),
}
}
#[must_use]
pub fn json_error(message: impl Into<String>) -> Self {
Self::JsonError {
message: message.into(),
}
}
#[must_use]
pub fn network_error(message: impl Into<String>) -> Self {
Self::NetworkError {
message: message.into(),
}
}
#[must_use]
pub fn invalid_request(message: impl Into<String>) -> Self {
Self::InvalidRequest {
message: message.into(),
}
}
#[must_use]
pub fn dimension_mismatch(expected: usize, actual: usize) -> Self {
Self::DimensionMismatch { expected, actual }
}
#[must_use]
pub fn is_retryable(&self) -> bool {
matches!(
self,
Self::Timeout | Self::RateLimit { .. } | Self::ServiceUnavailable { .. }
)
}
}
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError>;
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError>;
fn dimensions(&self) -> usize;
fn name(&self) -> &'static str;
fn is_simulation(&self) -> bool;
}
pub fn validate_dimensions(embedding: &[f32], expected: usize) -> Result<(), EmbeddingError> {
if embedding.len() != expected {
return Err(EmbeddingError::dimension_mismatch(
expected,
embedding.len(),
));
}
Ok(())
}
#[must_use]
pub fn normalize_vector(vec: &[f32]) -> Vec<f32> {
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(norm > 0.0, "Cannot normalize zero vector");
vec.iter().map(|x| x / norm).collect()
}
#[must_use]
pub fn is_normalized(vec: &[f32], tolerance: f32) -> bool {
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
(norm - 1.0).abs() < tolerance
}
#[cfg(test)]
mod tests {
use super::*;
use crate::constants::EMBEDDING_DIMENSIONS_COUNT;
#[test]
fn test_embedding_error_constructors() {
let err = EmbeddingError::timeout();
assert!(matches!(err, EmbeddingError::Timeout));
let err = EmbeddingError::rate_limit(Some(60));
assert!(matches!(
err,
EmbeddingError::RateLimit {
retry_after_secs: Some(60)
}
));
let err = EmbeddingError::context_overflow(10000);
assert!(matches!(
err,
EmbeddingError::ContextOverflow { tokens: 10000 }
));
let err = EmbeddingError::invalid_response("bad format");
assert!(matches!(err, EmbeddingError::InvalidResponse { .. }));
let err = EmbeddingError::dimension_mismatch(1536, 768);
assert!(matches!(
err,
EmbeddingError::DimensionMismatch {
expected: 1536,
actual: 768
}
));
}
#[test]
fn test_embedding_error_is_retryable() {
assert!(EmbeddingError::timeout().is_retryable());
assert!(EmbeddingError::rate_limit(Some(60)).is_retryable());
assert!(EmbeddingError::service_unavailable("down").is_retryable());
assert!(!EmbeddingError::AuthenticationFailed.is_retryable());
assert!(!EmbeddingError::EmptyInput.is_retryable());
assert!(!EmbeddingError::json_error("parse failed").is_retryable());
}
#[test]
fn test_validate_dimensions() {
let embedding = vec![0.1; EMBEDDING_DIMENSIONS_COUNT];
assert!(validate_dimensions(&embedding, EMBEDDING_DIMENSIONS_COUNT).is_ok());
let wrong_size = vec![0.1; 768];
assert!(validate_dimensions(&wrong_size, EMBEDDING_DIMENSIONS_COUNT).is_err());
}
#[test]
fn test_normalize_vector() {
let vec = vec![3.0, 4.0]; let normalized = normalize_vector(&vec);
assert!((normalized[0] - 0.6).abs() < 0.001);
assert!((normalized[1] - 0.8).abs() < 0.001);
assert!(is_normalized(&normalized, 0.001));
}
#[test]
fn test_is_normalized() {
let unit = vec![1.0, 0.0, 0.0];
assert!(is_normalized(&unit, 0.001));
let not_unit = vec![2.0, 0.0, 0.0];
assert!(!is_normalized(¬_unit, 0.001));
let normalized = vec![0.6, 0.8]; assert!(is_normalized(&normalized, 0.001));
}
#[test]
#[should_panic(expected = "Cannot normalize zero vector")]
fn test_normalize_zero_vector() {
let zero = vec![0.0, 0.0, 0.0];
let _ = normalize_vector(&zero);
}
}