use crate::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use std::sync::LazyLock;
static GLOBAL_RATE_LIMITERS: LazyLock<RwLock<HashMap<String, Arc<RateLimiter>>>> =
LazyLock::new(|| RwLock::new(HashMap::new()));
fn make_global_key(api_base_url: &str, api_key: &str) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
api_base_url.hash(&mut hasher);
api_key.hash(&mut hasher);
format!("{:x}", hasher.finish())
}
async fn get_or_create_global_rate_limiter(
api_base_url: &str,
api_key: &str,
calls_per_minute: u32
) -> Arc<RateLimiter> {
let key = make_global_key(api_base_url, api_key);
{
let registry = GLOBAL_RATE_LIMITERS.read().await;
if let Some(limiter) = registry.get(&key) {
return limiter.clone();
}
}
let mut registry = GLOBAL_RATE_LIMITERS.write().await;
if let Some(limiter) = registry.get(&key) {
return limiter.clone();
}
let limiter = Arc::new(RateLimiter::new(calls_per_minute));
registry.insert(key, limiter.clone());
limiter
}
pub struct RateLimiter {
next_allowed_nanos: AtomicU64,
min_interval_nanos: u64,
}
impl RateLimiter {
pub fn new(calls_per_minute: u32) -> Self {
let calls = calls_per_minute.max(1) as u64;
let min_interval_nanos = 60_000_000_000u64 / calls; Self {
next_allowed_nanos: AtomicU64::new(0),
min_interval_nanos,
}
}
pub async fn acquire(&self) {
loop {
let now_nanos = Self::current_nanos();
let next_allowed = self.next_allowed_nanos.load(Ordering::SeqCst);
let our_slot = std::cmp::max(now_nanos, next_allowed) + self.min_interval_nanos;
match self.next_allowed_nanos.compare_exchange(
next_allowed,
our_slot,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => {
let wait_nanos = our_slot.saturating_sub(now_nanos);
if wait_nanos > 1_000_000 { let wait = Duration::from_nanos(wait_nanos);
debug!("Rate limiter: acquired slot, waiting {:?}", wait);
tokio::time::sleep(wait).await;
}
return;
}
Err(_) => {
tokio::time::sleep(Duration::from_millis(1)).await;
}
}
}
}
fn current_nanos() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64
}
pub async fn backoff_on_rate_limit(&self) {
warn!("Rate limit hit (429), backing off for 5 seconds");
let now_nanos = Self::current_nanos();
let new_next = now_nanos + 5_000_000_000u64; self.next_allowed_nanos.store(new_next, Ordering::SeqCst);
tokio::time::sleep(Duration::from_secs(5)).await;
}
}
struct CacheItem {
embedding: Vec<f32>,
created_at: Instant,
}
struct InnerCache {
entries: HashMap<String, CacheItem>,
access_order: Vec<String>,
max_entries: usize,
ttl: Duration,
}
impl InnerCache {
fn new(max_entries: usize, ttl_secs: u64) -> Self {
Self {
entries: HashMap::new(),
access_order: Vec::new(),
max_entries,
ttl: Duration::from_secs(ttl_secs),
}
}
fn peek(&self, key: &str) -> Option<Vec<f32>> {
if let Some(item) = self.entries.get(key) {
if item.created_at.elapsed() < self.ttl {
return Some(item.embedding.clone());
}
}
None
}
fn put(&mut self, key: String, embedding: Vec<f32>) {
if self.entries.len() >= self.max_entries && !self.entries.contains_key(&key) {
if let Some(oldest) = self.access_order.first().cloned() {
self.entries.remove(&oldest);
self.access_order.remove(0);
}
}
self.entries.insert(
key.clone(),
CacheItem {
embedding,
created_at: Instant::now(),
},
);
self.access_order.retain(|k| k != &key);
self.access_order.push(key);
}
fn compute_key(model: &str, text: &str) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
format!("{}:{:x}", model, hasher.finish())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
pub api_base_url: String,
pub api_key: String,
pub model_name: String,
pub batch_size: usize,
pub timeout_secs: u64,
pub calls_per_minute: u32,
pub cache_max_entries: usize,
pub cache_ttl_secs: u64,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
api_base_url: std::env::var("EMBEDDING_API_BASE_URL")
.unwrap_or_else(|_| "https://api.openai.com/v1".to_string()),
api_key: std::env::var("EMBEDDING_API_KEY")
.or_else(|_| std::env::var("LLM_API_KEY"))
.unwrap_or_else(|_| "".to_string()),
model_name: std::env::var("EMBEDDING_MODEL")
.unwrap_or_else(|_| "text-embedding-3-small".to_string()),
batch_size: 10,
timeout_secs: 30,
calls_per_minute: 30,
cache_max_entries: 10_000,
cache_ttl_secs: 3_600,
}
}
}
pub struct EmbeddingClient {
config: EmbeddingConfig,
client: reqwest::Client,
rate_limiter: Arc<RateLimiter>,
cache: Arc<RwLock<InnerCache>>,
}
impl EmbeddingClient {
pub fn new(config: EmbeddingConfig) -> Result<Self> {
let calls_per_minute = config.calls_per_minute;
let cache = Arc::new(RwLock::new(InnerCache::new(
config.cache_max_entries,
config.cache_ttl_secs,
)));
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(config.timeout_secs))
.build()
.map_err(|e| crate::Error::Embedding(format!("Failed to create HTTP client: {}", e)))?;
info!(
"EmbeddingClient initialized: model={}, rate_limit={}/min, cache={}entries/{}s",
config.model_name,
config.calls_per_minute,
config.cache_max_entries,
config.cache_ttl_secs,
);
Ok(Self {
config,
client,
rate_limiter: Arc::new(RateLimiter::new(calls_per_minute)),
cache,
})
}
pub async fn new_with_global_limiter(config: EmbeddingConfig) -> Result<Self> {
let rate_limiter = get_or_create_global_rate_limiter(
&config.api_base_url,
&config.api_key,
config.calls_per_minute,
).await;
let cache = Arc::new(RwLock::new(InnerCache::new(
config.cache_max_entries,
config.cache_ttl_secs,
)));
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(config.timeout_secs))
.build()
.map_err(|e| crate::Error::Embedding(format!("Failed to create HTTP client: {}", e)))?;
info!(
"EmbeddingClient initialized (global limiter): model={}, rate_limit={}/min, cache={}entries/{}s",
config.model_name,
config.calls_per_minute,
config.cache_max_entries,
config.cache_ttl_secs,
);
Ok(Self {
config,
client,
rate_limiter,
cache,
})
}
pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
let cache_key = InnerCache::compute_key(&self.config.model_name, text);
{
let cache = self.cache.read().await;
if let Some(cached) = cache.peek(&cache_key) {
debug!("Cache hit for text (len={})", text.chars().count());
return Ok(cached);
}
}
let results = self.embed_batch_raw(&[text.to_string()]).await?;
let embedding = results
.into_iter()
.next()
.ok_or_else(|| crate::Error::Embedding("No embedding returned".to_string()))?;
{
let mut cache = self.cache.write().await;
cache.put(cache_key, embedding.clone());
}
Ok(embedding)
}
pub async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
let mut miss_texts: Vec<String> = Vec::new();
let mut miss_indices: Vec<usize> = Vec::new();
{
let cache = self.cache.read().await;
for (idx, text) in texts.iter().enumerate() {
let key = InnerCache::compute_key(&self.config.model_name, text);
if let Some(cached) = cache.peek(&key) {
results[idx] = Some(cached);
} else {
miss_texts.push(text.clone());
miss_indices.push(idx);
}
}
}
if miss_texts.is_empty() {
debug!("All {} embeddings served from cache", texts.len());
return Ok(results.into_iter().map(|opt| opt.unwrap()).collect());
}
debug!(
"{}/{} cache misses, calling API",
miss_texts.len(),
texts.len()
);
let mut api_results: Vec<Vec<f32>> = Vec::with_capacity(miss_texts.len());
for chunk in miss_texts.chunks(self.config.batch_size) {
let embeddings = self.embed_batch_raw(chunk).await?;
api_results.extend(embeddings);
}
{
let mut cache = self.cache.write().await;
for (api_idx, (text, embedding)) in
miss_texts.iter().zip(api_results.iter()).enumerate()
{
let key = InnerCache::compute_key(&self.config.model_name, text);
cache.put(key, embedding.clone());
let result_idx = miss_indices[api_idx];
results[result_idx] = Some(embedding.clone());
}
}
Ok(results.into_iter().map(|opt| opt.unwrap()).collect())
}
pub async fn embed_batch_chunked(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
self.embed_batch(texts).await
}
pub async fn dimension(&self) -> Result<usize> {
let embedding = self.embed("test").await?;
Ok(embedding.len())
}
async fn embed_batch_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
#[derive(Serialize)]
struct EmbeddingRequest {
input: Vec<String>,
model: String,
}
#[derive(Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
}
#[derive(Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
let request = EmbeddingRequest {
input: texts.to_vec(),
model: self.config.model_name.clone(),
};
let url = format!("{}/embeddings", self.config.api_base_url);
self.rate_limiter.acquire().await;
let response = self
.client
.post(&url)
.header(
"Authorization",
format!("Bearer {}", self.config.api_key),
)
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| crate::Error::Embedding(format!("HTTP request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
if status.as_u16() == 429 {
self.rate_limiter.backoff_on_rate_limit().await;
}
return Err(crate::Error::Embedding(format!(
"Embedding API error ({}): {}",
status, body
)));
}
let embedding_response: EmbeddingResponse = response
.json()
.await
.map_err(|e| {
crate::Error::Embedding(format!("Failed to parse response: {}", e))
})?;
Ok(embedding_response
.data
.into_iter()
.map(|d| d.embedding)
.collect())
}
}