use std::time::Duration;
use serde::Deserialize;
use crate::error::{MemeError, Result};
const EMBED_BATCH_SIZE: usize = 128;
#[derive(Debug, Clone)]
pub(crate) struct ApiEmbedding {
http: reqwest::Client,
base_url: String,
api_key: String,
model: String,
dimension: usize,
max_retries: u32,
}
impl ApiEmbedding {
pub(crate) fn new(
http: reqwest::Client,
embedding_cfg: &crate::config::EmbeddingConfig,
llm_cfg: &crate::config::LlmConfig,
) -> Result<Self> {
let api_key = embedding_cfg
.api_key
.as_deref()
.or(llm_cfg.api_key.as_deref())
.ok_or_else(|| MemeError::Config("API key is required for API embedding".to_owned()))?
.to_owned();
let base_url = embedding_cfg
.base_url
.as_deref()
.unwrap_or(&llm_cfg.base_url)
.trim_end_matches('/')
.to_owned();
Ok(Self {
http,
base_url,
api_key,
model: embedding_cfg.model.clone(),
dimension: embedding_cfg.dimension,
max_retries: llm_cfg.max_retries,
})
}
#[must_use]
pub(crate) const fn dimension(&self) -> usize {
self.dimension
}
#[tracing::instrument(skip(self, texts), fields(count = texts.len()))]
pub(crate) async fn encode_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
if texts.len() <= EMBED_BATCH_SIZE {
let input: Vec<String> = texts.iter().map(|s| (*s).to_owned()).collect();
return self.embed_with_retry(input).await;
}
let chunks: Vec<Vec<String>> = texts
.chunks(EMBED_BATCH_SIZE)
.map(|chunk| chunk.iter().map(|s| (*s).to_owned()).collect())
.collect();
let max_concurrent = 4;
let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(max_concurrent));
let mut handles = Vec::with_capacity(chunks.len());
for chunk in chunks {
let this = self.clone();
let sem = std::sync::Arc::clone(&semaphore);
handles.push(tokio::spawn(async move {
let _permit = sem.acquire().await;
this.embed_with_retry(chunk).await
}));
}
let mut all_vectors = Vec::with_capacity(texts.len());
for handle in handles {
let vectors = handle
.await
.map_err(|e| MemeError::Embedding(format!("embed task panicked: {e}")))??;
all_vectors.extend(vectors);
}
Ok(all_vectors)
}
pub(crate) async fn encode_query(&self, text: &str) -> Result<Vec<f32>> {
let results = self.embed_with_retry(vec![text.to_owned()]).await?;
results
.into_iter()
.next()
.ok_or_else(|| MemeError::Embedding("empty embedding response for query".to_owned()))
}
async fn embed_with_retry(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>> {
let mut last_err = None;
for attempt in 0..self.max_retries {
match self.call_api(&input).await {
Ok(vectors) => return Ok(vectors),
Err(e) if !e.is_retryable() => return Err(e),
Err(e) => {
tracing::warn!(attempt = attempt + 1, error = %e, "embedding API call failed");
last_err = Some(e);
}
}
if attempt + 1 < self.max_retries {
let wait = 2u64.saturating_pow(attempt).min(30);
tokio::time::sleep(Duration::from_secs(wait)).await;
}
}
Err(last_err
.unwrap_or_else(|| MemeError::Embedding("all embedding retries exhausted".to_owned())))
}
async fn call_api(&self, input: &[String]) -> Result<Vec<Vec<f32>>> {
let url = format!("{}/embeddings", self.base_url);
let body = serde_json::json!({
"model": self.model,
"input": input,
"dimensions": self.dimension,
});
let resp = self
.http
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
return Err(MemeError::Embedding(format!(
"embedding API returned {status}: {text}"
)));
}
let data: EmbeddingResponse = resp.json().await.map_err(|e| {
MemeError::Embedding(format!("failed to parse embedding response: {e}"))
})?;
let mut vectors: Vec<(usize, Vec<f32>)> = data
.data
.into_iter()
.map(|d| (d.index, d.embedding))
.collect();
vectors.sort_by_key(|(idx, _)| *idx);
Ok(vectors.into_iter().map(|(_, v)| v).collect())
}
}
#[derive(Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Deserialize)]
struct EmbeddingData {
index: usize,
embedding: Vec<f32>,
}