use std::time::Duration;
use reqwest::Client;
use serde::Deserialize;
use tokio::time::sleep;
use crate::error::{NanoError, Result};
const DEFAULT_EMBED_MODEL: &str = "text-embedding-3-small";
const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
const DEFAULT_TIMEOUT_MS: u64 = 30_000;
const DEFAULT_RETRY_ATTEMPTS: usize = 4;
const DEFAULT_RETRY_BACKOFF_MS: u64 = 200;
#[derive(Clone)]
enum EmbeddingTransport {
Mock,
OpenAi {
api_key: String,
base_url: String,
http: Client,
},
}
#[derive(Clone)]
pub(crate) struct EmbeddingClient {
model: String,
retry_attempts: usize,
retry_backoff_ms: u64,
transport: EmbeddingTransport,
}
struct EmbedCallError {
message: String,
retryable: bool,
}
#[derive(Debug, Deserialize)]
struct OpenAiEmbeddingResponse {
data: Vec<OpenAiEmbeddingDatum>,
}
#[derive(Debug, Deserialize)]
struct OpenAiEmbeddingDatum {
index: usize,
embedding: Vec<f32>,
}
#[derive(Debug, Deserialize)]
struct OpenAiErrorEnvelope {
error: OpenAiErrorBody,
}
#[derive(Debug, Deserialize)]
struct OpenAiErrorBody {
message: String,
}
impl EmbeddingClient {
pub(crate) fn from_env() -> Result<Self> {
let model = std::env::var("NANOGRAPH_EMBED_MODEL")
.ok()
.map(|v| v.trim().to_string())
.filter(|v| !v.is_empty())
.unwrap_or_else(|| DEFAULT_EMBED_MODEL.to_string());
let retry_attempts =
parse_env_usize("NANOGRAPH_EMBED_RETRY_ATTEMPTS", DEFAULT_RETRY_ATTEMPTS);
let retry_backoff_ms =
parse_env_u64("NANOGRAPH_EMBED_RETRY_BACKOFF_MS", DEFAULT_RETRY_BACKOFF_MS);
if env_flag("NANOGRAPH_EMBEDDINGS_MOCK") {
return Ok(Self {
model,
retry_attempts,
retry_backoff_ms,
transport: EmbeddingTransport::Mock,
});
}
let api_key = std::env::var("OPENAI_API_KEY")
.ok()
.map(|v| v.trim().to_string())
.filter(|v| !v.is_empty())
.ok_or_else(|| {
NanoError::Execution(
"OPENAI_API_KEY is required when an embedding call is needed".to_string(),
)
})?;
let base_url = std::env::var("OPENAI_BASE_URL")
.ok()
.map(|v| v.trim_end_matches('/').to_string())
.filter(|v| !v.is_empty())
.unwrap_or_else(|| DEFAULT_OPENAI_BASE_URL.to_string());
let timeout_ms = parse_env_u64("NANOGRAPH_EMBED_TIMEOUT_MS", DEFAULT_TIMEOUT_MS);
let http = Client::builder()
.timeout(Duration::from_millis(timeout_ms))
.build()
.map_err(|e| {
NanoError::Execution(format!("failed to initialize HTTP client: {}", e))
})?;
Ok(Self {
model,
retry_attempts,
retry_backoff_ms,
transport: EmbeddingTransport::OpenAi {
api_key,
base_url,
http,
},
})
}
#[cfg(test)]
pub(crate) fn mock_for_tests() -> Self {
Self {
model: DEFAULT_EMBED_MODEL.to_string(),
retry_attempts: DEFAULT_RETRY_ATTEMPTS,
retry_backoff_ms: DEFAULT_RETRY_BACKOFF_MS,
transport: EmbeddingTransport::Mock,
}
}
pub(crate) fn model(&self) -> &str {
&self.model
}
pub(crate) async fn embed_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
let mut vectors = self.embed_texts(&[input.to_string()], expected_dim).await?;
vectors.pop().ok_or_else(|| {
NanoError::Execution("embedding provider returned no vector".to_string())
})
}
pub(crate) async fn embed_texts(
&self,
inputs: &[String],
expected_dim: usize,
) -> Result<Vec<Vec<f32>>> {
if expected_dim == 0 {
return Err(NanoError::Execution(
"embedding dimension must be greater than zero".to_string(),
));
}
if inputs.is_empty() {
return Ok(Vec::new());
}
match &self.transport {
EmbeddingTransport::Mock => Ok(inputs
.iter()
.map(|input| mock_embedding(input, expected_dim))
.collect()),
EmbeddingTransport::OpenAi { .. } => {
self.embed_texts_openai_with_retry(inputs, expected_dim)
.await
}
}
}
async fn embed_texts_openai_with_retry(
&self,
inputs: &[String],
expected_dim: usize,
) -> Result<Vec<Vec<f32>>> {
let max_attempt = self.retry_attempts.max(1);
let mut attempt = 0usize;
loop {
attempt += 1;
match self.embed_texts_openai_once(inputs, expected_dim).await {
Ok(vectors) => return Ok(vectors),
Err(err) => {
if !err.retryable || attempt >= max_attempt {
return Err(NanoError::Execution(err.message));
}
let shift = (attempt - 1).min(10) as u32;
let delay = self.retry_backoff_ms.saturating_mul(1u64 << shift);
sleep(Duration::from_millis(delay)).await;
}
}
}
}
async fn embed_texts_openai_once(
&self,
inputs: &[String],
expected_dim: usize,
) -> std::result::Result<Vec<Vec<f32>>, EmbedCallError> {
let (api_key, base_url, http) = match &self.transport {
EmbeddingTransport::OpenAi {
api_key,
base_url,
http,
} => (api_key, base_url, http),
EmbeddingTransport::Mock => unreachable!("mock transport should not call OpenAI"),
};
let request = serde_json::json!({
"model": self.model,
"input": inputs,
"dimensions": expected_dim,
});
let url = format!("{}/embeddings", base_url);
let response = http
.post(&url)
.bearer_auth(api_key)
.json(&request)
.send()
.await;
let response = match response {
Ok(resp) => resp,
Err(err) => {
let retryable = err.is_timeout() || err.is_connect() || err.is_request();
return Err(EmbedCallError {
message: format!("embedding request failed: {}", err),
retryable,
});
}
};
let status = response.status();
let body = match response.text().await {
Ok(body) => body,
Err(err) => {
return Err(EmbedCallError {
message: format!(
"embedding response read failed (status {}): {}",
status, err
),
retryable: status.is_server_error() || status.as_u16() == 429,
});
}
};
if !status.is_success() {
let message = parse_openai_error_message(&body).unwrap_or_else(|| body.clone());
return Err(EmbedCallError {
message: format!(
"embedding request failed with status {}: {}",
status, message
),
retryable: status.is_server_error() || status.as_u16() == 429,
});
}
let mut parsed: OpenAiEmbeddingResponse =
serde_json::from_str(&body).map_err(|err| EmbedCallError {
message: format!("embedding response decode failed: {}", err),
retryable: false,
})?;
if parsed.data.len() != inputs.len() {
return Err(EmbedCallError {
message: format!(
"embedding response size mismatch: expected {}, got {}",
inputs.len(),
parsed.data.len()
),
retryable: false,
});
}
parsed.data.sort_by_key(|item| item.index);
let mut vectors = Vec::with_capacity(parsed.data.len());
for (idx, item) in parsed.data.into_iter().enumerate() {
if item.index != idx {
return Err(EmbedCallError {
message: format!(
"embedding response index mismatch at position {}: got {}",
idx, item.index
),
retryable: false,
});
}
if item.embedding.len() != expected_dim {
return Err(EmbedCallError {
message: format!(
"embedding dimension mismatch: expected {}, got {}",
expected_dim,
item.embedding.len()
),
retryable: false,
});
}
vectors.push(item.embedding);
}
Ok(vectors)
}
}
fn parse_openai_error_message(body: &str) -> Option<String> {
serde_json::from_str::<OpenAiErrorEnvelope>(body)
.ok()
.map(|e| e.error.message)
.filter(|msg| !msg.trim().is_empty())
}
fn parse_env_usize(name: &str, default: usize) -> usize {
std::env::var(name)
.ok()
.and_then(|v| v.parse::<usize>().ok())
.filter(|v| *v > 0)
.unwrap_or(default)
}
fn parse_env_u64(name: &str, default: u64) -> u64 {
std::env::var(name)
.ok()
.and_then(|v| v.parse::<u64>().ok())
.filter(|v| *v > 0)
.unwrap_or(default)
}
fn env_flag(name: &str) -> bool {
std::env::var(name)
.ok()
.map(|v| {
let s = v.trim().to_ascii_lowercase();
s == "1" || s == "true" || s == "yes" || s == "on"
})
.unwrap_or(false)
}
fn mock_embedding(input: &str, dim: usize) -> Vec<f32> {
let mut seed = fnv1a64(input.as_bytes());
let mut out = Vec::with_capacity(dim);
for _ in 0..dim {
seed = xorshift64(seed);
let ratio = (seed as f64 / u64::MAX as f64) as f32;
out.push((ratio * 2.0) - 1.0);
}
let norm = out
.iter()
.map(|v| (*v as f64) * (*v as f64))
.sum::<f64>()
.sqrt() as f32;
if norm > f32::EPSILON {
for value in &mut out {
*value /= norm;
}
}
out
}
fn fnv1a64(bytes: &[u8]) -> u64 {
let mut hash = 14695981039346656037u64;
for byte in bytes {
hash ^= *byte as u64;
hash = hash.wrapping_mul(1099511628211u64);
}
hash
}
fn xorshift64(mut x: u64) -> u64 {
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
x
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn mock_embeddings_are_deterministic() {
let client = EmbeddingClient::mock_for_tests();
let a = client.embed_text("alpha", 8).await.unwrap();
let b = client.embed_text("alpha", 8).await.unwrap();
let c = client.embed_text("beta", 8).await.unwrap();
assert_eq!(a, b);
assert_ne!(a, c);
assert_eq!(a.len(), 8);
}
}