use super::TextEmbedder;
use anyhow::{anyhow, Context, Result};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone, Copy)]
pub enum OpenAIModel {
TextEmbedding3Small,
TextEmbedding3Large,
Ada002,
}
impl OpenAIModel {
pub fn as_str(&self) -> &'static str {
match self {
OpenAIModel::TextEmbedding3Small => "text-embedding-3-small",
OpenAIModel::TextEmbedding3Large => "text-embedding-3-large",
OpenAIModel::Ada002 => "text-embedding-ada-002",
}
}
pub fn dimension(&self) -> usize {
match self {
OpenAIModel::TextEmbedding3Small => 1536,
OpenAIModel::TextEmbedding3Large => 3072,
OpenAIModel::Ada002 => 1536,
}
}
pub fn cost_per_million_tokens(&self) -> f64 {
match self {
OpenAIModel::TextEmbedding3Small => 0.02,
OpenAIModel::TextEmbedding3Large => 0.13,
OpenAIModel::Ada002 => 0.10,
}
}
}
#[derive(Debug, Serialize)]
struct EmbeddingRequest {
model: String,
input: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<usize>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
usage: Usage,
}
#[derive(Debug, Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
index: usize,
}
#[derive(Debug, Deserialize)]
struct Usage {
prompt_tokens: usize,
total_tokens: usize,
}
struct RateLimiter {
requests_per_minute: usize,
last_requests: Arc<std::sync::Mutex<Vec<std::time::Instant>>>,
}
impl RateLimiter {
fn new(requests_per_minute: usize) -> Self {
Self {
requests_per_minute,
last_requests: Arc::new(std::sync::Mutex::new(Vec::new())),
}
}
async fn wait_if_needed(&self) {
let mut requests = self.last_requests.lock().unwrap();
let now = std::time::Instant::now();
requests.retain(|&time| now.duration_since(time) < Duration::from_secs(60));
if requests.len() >= self.requests_per_minute {
if let Some(&oldest) = requests.first() {
let wait_time = Duration::from_secs(60)
.checked_sub(now.duration_since(oldest))
.unwrap_or(Duration::from_secs(0));
if wait_time > Duration::from_secs(0) {
drop(requests); tokio::time::sleep(wait_time).await;
requests = self.last_requests.lock().unwrap();
}
}
}
requests.push(now);
}
}
pub struct OpenAIEmbedding {
client: reqwest::Client,
api_key: String,
model: OpenAIModel,
rate_limiter: RateLimiter,
max_retries: usize,
}
impl OpenAIEmbedding {
pub async fn new(api_key: String, model: OpenAIModel) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()
.context("Failed to create HTTP client")?;
Ok(Self {
client,
api_key,
model,
rate_limiter: RateLimiter::new(500), max_retries: 3,
})
}
pub fn with_rate_limit(mut self, requests_per_minute: usize) -> Self {
self.rate_limiter = RateLimiter::new(requests_per_minute);
self
}
pub fn with_max_retries(mut self, max_retries: usize) -> Self {
self.max_retries = max_retries;
self
}
pub async fn embed_async(&self, text: &str) -> Result<Vec<f32>> {
let embeddings = self.embed_batch_async(&[text]).await?;
embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("No embedding returned"))
}
pub async fn embed_batch_async(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
const BATCH_SIZE: usize = 2048;
let mut all_embeddings = Vec::new();
for chunk in texts.chunks(BATCH_SIZE) {
let chunk_embeddings = self.embed_batch_chunk(chunk).await?;
all_embeddings.extend(chunk_embeddings);
}
Ok(all_embeddings)
}
async fn embed_batch_chunk(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let mut retries = 0;
loop {
self.rate_limiter.wait_if_needed().await;
let request = EmbeddingRequest {
model: self.model.as_str().to_string(),
input: texts.iter().map(|s| s.to_string()).collect(),
dimensions: None,
};
let response = self
.client
.post("https://api.openai.com/v1/embeddings")
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await;
match response {
Ok(resp) if resp.status().is_success() => {
let embedding_response: EmbeddingResponse = resp
.json()
.await
.context("Failed to parse OpenAI response")?;
let mut data = embedding_response.data;
data.sort_by_key(|d| d.index);
return Ok(data.into_iter().map(|d| d.embedding).collect());
}
Ok(resp) if resp.status().as_u16() == 429 && retries < self.max_retries => {
retries += 1;
let wait_time = Duration::from_secs(2_u64.pow(retries as u32));
tokio::time::sleep(wait_time).await;
continue;
}
Ok(resp) => {
let status = resp.status();
let body = resp
.text()
.await
.unwrap_or_else(|_| String::from("(no body)"));
return Err(anyhow!("OpenAI API error {}: {}", status, body));
}
Err(_e) if retries < self.max_retries => {
retries += 1;
let wait_time = Duration::from_secs(2_u64.pow(retries as u32));
tokio::time::sleep(wait_time).await;
continue;
}
Err(e) => {
return Err(anyhow!("Failed to call OpenAI API: {}", e));
}
}
}
}
pub fn estimate_cost(&self, texts: &[&str]) -> f64 {
let total_chars: usize = texts.iter().map(|t| t.len()).sum();
let estimated_tokens = total_chars / 4;
let cost_per_token = self.model.cost_per_million_tokens() / 1_000_000.0;
estimated_tokens as f64 * cost_per_token
}
pub fn model(&self) -> OpenAIModel {
self.model
}
}
impl TextEmbedder for OpenAIEmbedding {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let runtime = tokio::runtime::Runtime::new().context("Failed to create tokio runtime")?;
runtime.block_on(self.embed_async(text))
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let runtime = tokio::runtime::Runtime::new().context("Failed to create tokio runtime")?;
runtime.block_on(self.embed_batch_async(texts))
}
fn dimension(&self) -> Result<usize> {
Ok(self.model.dimension())
}
}