use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, thiserror::Error)]
pub enum EmbeddingError {
#[error("HTTP request failed: {0}")]
HttpError(String),
#[error("Invalid response: {0}")]
InvalidResponse(String),
#[error("Rate limited, retry after {retry_after:?}")]
RateLimited {
retry_after: Option<std::time::Duration>,
},
#[error("Authentication failed")]
AuthFailed,
#[error("Text too long: {length} > {max}")]
TextTooLong { length: usize, max: usize },
}
impl From<reqwest::Error> for EmbeddingError {
fn from(e: reqwest::Error) -> Self {
EmbeddingError::HttpError(e.to_string())
}
}
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
fn dimension(&self) -> usize;
fn model_name(&self) -> &str;
fn max_input_length(&self) -> usize;
async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError>;
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
let mut embeddings = Vec::with_capacity(texts.len());
for text in texts {
embeddings.push(self.embed(text).await?);
}
Ok(embeddings)
}
}
pub struct OpenAiEmbeddings {
client: reqwest::Client,
api_key: String,
model: String,
dimension: usize,
}
impl OpenAiEmbeddings {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
client: reqwest::Client::new(),
api_key: api_key.into(),
model: "text-embedding-3-small".to_string(),
dimension: 1536,
}
}
pub fn ada_002(api_key: impl Into<String>) -> Self {
Self {
client: reqwest::Client::new(),
api_key: api_key.into(),
model: "text-embedding-ada-002".to_string(),
dimension: 1536,
}
}
pub fn large(api_key: impl Into<String>) -> Self {
Self {
client: reqwest::Client::new(),
api_key: api_key.into(),
model: "text-embedding-3-large".to_string(),
dimension: 3072,
}
}
pub fn with_model(
api_key: impl Into<String>,
model: impl Into<String>,
dimension: usize,
) -> Self {
Self {
client: reqwest::Client::new(),
api_key: api_key.into(),
model: model.into(),
dimension,
}
}
}
#[derive(Debug, Serialize)]
struct OpenAiEmbeddingRequest<'a> {
model: &'a str,
input: &'a [String],
}
#[derive(Debug, Deserialize)]
struct OpenAiEmbeddingResponse {
data: Vec<OpenAiEmbeddingData>,
}
#[derive(Debug, Deserialize)]
struct OpenAiEmbeddingData {
embedding: Vec<f32>,
}
#[async_trait]
impl EmbeddingProvider for OpenAiEmbeddings {
fn dimension(&self) -> usize {
self.dimension
}
fn model_name(&self) -> &str {
&self.model
}
fn max_input_length(&self) -> usize {
32_000
}
async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
if text.len() > self.max_input_length() {
return Err(EmbeddingError::TextTooLong {
length: text.len(),
max: self.max_input_length(),
});
}
let embeddings = self.embed_batch(&[text.to_string()]).await?;
embeddings
.into_iter()
.next()
.ok_or_else(|| EmbeddingError::InvalidResponse("No embedding returned".to_string()))
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
if texts.is_empty() {
return Ok(Vec::new());
}
let request = OpenAiEmbeddingRequest {
model: &self.model,
input: texts,
};
let response = self
.client
.post("https://api.openai.com/v1/embeddings")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&request)
.send()
.await?;
let status = response.status();
if status == reqwest::StatusCode::UNAUTHORIZED {
return Err(EmbeddingError::AuthFailed);
}
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
let retry_after = response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.map(std::time::Duration::from_secs);
return Err(EmbeddingError::RateLimited { retry_after });
}
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(EmbeddingError::HttpError(format!(
"Status {}: {}",
status, error_text
)));
}
let result: OpenAiEmbeddingResponse = response.json().await.map_err(|e| {
EmbeddingError::InvalidResponse(format!("Failed to parse response: {}", e))
})?;
Ok(result.data.into_iter().map(|d| d.embedding).collect())
}
}
pub struct NearAiEmbeddings {
client: reqwest::Client,
base_url: String,
session: std::sync::Arc<crate::llm::SessionManager>,
model: String,
dimension: usize,
}
impl NearAiEmbeddings {
pub fn new(
base_url: impl Into<String>,
session: std::sync::Arc<crate::llm::SessionManager>,
) -> Self {
Self {
client: reqwest::Client::new(),
base_url: base_url.into(),
session,
model: "text-embedding-3-small".to_string(),
dimension: 1536,
}
}
pub fn with_model(mut self, model: impl Into<String>, dimension: usize) -> Self {
self.model = model.into();
self.dimension = dimension;
self
}
}
#[derive(Debug, Serialize)]
struct NearAiEmbeddingRequest<'a> {
model: &'a str,
input: &'a [String],
}
#[derive(Debug, Deserialize)]
struct NearAiEmbeddingResponse {
data: Vec<NearAiEmbeddingData>,
}
#[derive(Debug, Deserialize)]
struct NearAiEmbeddingData {
embedding: Vec<f32>,
}
#[async_trait]
impl EmbeddingProvider for NearAiEmbeddings {
fn dimension(&self) -> usize {
self.dimension
}
fn model_name(&self) -> &str {
&self.model
}
fn max_input_length(&self) -> usize {
32_000
}
async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
if text.len() > self.max_input_length() {
return Err(EmbeddingError::TextTooLong {
length: text.len(),
max: self.max_input_length(),
});
}
let embeddings = self.embed_batch(&[text.to_string()]).await?;
embeddings
.into_iter()
.next()
.ok_or_else(|| EmbeddingError::InvalidResponse("No embedding returned".to_string()))
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
use secrecy::ExposeSecret;
if texts.is_empty() {
return Ok(Vec::new());
}
let request = NearAiEmbeddingRequest {
model: &self.model,
input: texts,
};
let token = self
.session
.get_token()
.await
.map_err(|_| EmbeddingError::AuthFailed)?;
let url = format!("{}/v1/embeddings", self.base_url);
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", token.expose_secret()))
.json(&request)
.send()
.await?;
let status = response.status();
if status == reqwest::StatusCode::UNAUTHORIZED {
return Err(EmbeddingError::AuthFailed);
}
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
let retry_after = response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.map(std::time::Duration::from_secs);
return Err(EmbeddingError::RateLimited { retry_after });
}
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(EmbeddingError::HttpError(format!(
"Status {}: {}",
status, error_text
)));
}
let result: NearAiEmbeddingResponse = response.json().await.map_err(|e| {
EmbeddingError::InvalidResponse(format!("Failed to parse response: {}", e))
})?;
Ok(result.data.into_iter().map(|d| d.embedding).collect())
}
}
pub struct MockEmbeddings {
dimension: usize,
}
impl MockEmbeddings {
pub fn new(dimension: usize) -> Self {
Self { dimension }
}
}
#[async_trait]
impl EmbeddingProvider for MockEmbeddings {
fn dimension(&self) -> usize {
self.dimension
}
fn model_name(&self) -> &str {
"mock-embedding"
}
fn max_input_length(&self) -> usize {
10_000
}
async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
text.hash(&mut hasher);
let hash = hasher.finish();
let mut embedding = Vec::with_capacity(self.dimension);
let mut seed = hash;
for _ in 0..self.dimension {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let value = (seed as f32 / u64::MAX as f32) * 2.0 - 1.0;
embedding.push(value);
}
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > 0.0 {
for x in &mut embedding {
*x /= magnitude;
}
}
Ok(embedding)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mock_embeddings() {
let provider = MockEmbeddings::new(128);
let embedding = provider.embed("hello world").await.unwrap();
assert_eq!(embedding.len(), 128);
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((magnitude - 1.0).abs() < 0.001);
}
#[tokio::test]
async fn test_mock_embeddings_deterministic() {
let provider = MockEmbeddings::new(64);
let emb1 = provider.embed("test").await.unwrap();
let emb2 = provider.embed("test").await.unwrap();
assert_eq!(emb1, emb2);
}
#[tokio::test]
async fn test_mock_embeddings_batch() {
let provider = MockEmbeddings::new(64);
let texts = vec!["hello".to_string(), "world".to_string()];
let embeddings = provider.embed_batch(&texts).await.unwrap();
assert_eq!(embeddings.len(), 2);
assert_eq!(embeddings[0].len(), 64);
assert_eq!(embeddings[1].len(), 64);
assert_ne!(embeddings[0], embeddings[1]);
}
#[test]
fn test_openai_embeddings_config() {
let provider = OpenAiEmbeddings::new("test-key");
assert_eq!(provider.dimension(), 1536);
assert_eq!(provider.model_name(), "text-embedding-3-small");
let provider = OpenAiEmbeddings::large("test-key");
assert_eq!(provider.dimension(), 3072);
assert_eq!(provider.model_name(), "text-embedding-3-large");
}
}