1#![allow(dead_code)]
2
3use std::time::Duration;
4
5use reqwest::Client;
6use serde::Deserialize;
7use tokio::time::sleep;
8
9use crate::error::{NanoError, Result};
10
11const DEFAULT_EMBED_MODEL: &str = "text-embedding-3-small";
12const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
13const DEFAULT_TIMEOUT_MS: u64 = 30_000;
14const DEFAULT_RETRY_ATTEMPTS: usize = 4;
15const DEFAULT_RETRY_BACKOFF_MS: u64 = 200;
16
17#[derive(Clone)]
18enum EmbeddingTransport {
19 Mock,
20 OpenAi {
21 api_key: String,
22 base_url: String,
23 http: Client,
24 },
25}
26
27#[derive(Clone)]
28pub(crate) struct EmbeddingClient {
29 model: String,
30 retry_attempts: usize,
31 retry_backoff_ms: u64,
32 transport: EmbeddingTransport,
33}
34
35struct EmbedCallError {
36 message: String,
37 retryable: bool,
38}
39
40#[derive(Debug, Deserialize)]
41struct OpenAiEmbeddingResponse {
42 data: Vec<OpenAiEmbeddingDatum>,
43}
44
45#[derive(Debug, Deserialize)]
46struct OpenAiEmbeddingDatum {
47 index: usize,
48 embedding: Vec<f32>,
49}
50
51#[derive(Debug, Deserialize)]
52struct OpenAiErrorEnvelope {
53 error: OpenAiErrorBody,
54}
55
56#[derive(Debug, Deserialize)]
57struct OpenAiErrorBody {
58 message: String,
59}
60
61impl EmbeddingClient {
62 pub(crate) fn from_env() -> Result<Self> {
63 let model = std::env::var("NANOGRAPH_EMBED_MODEL")
64 .ok()
65 .map(|v| v.trim().to_string())
66 .filter(|v| !v.is_empty())
67 .unwrap_or_else(|| DEFAULT_EMBED_MODEL.to_string());
68 let retry_attempts =
69 parse_env_usize("NANOGRAPH_EMBED_RETRY_ATTEMPTS", DEFAULT_RETRY_ATTEMPTS);
70 let retry_backoff_ms =
71 parse_env_u64("NANOGRAPH_EMBED_RETRY_BACKOFF_MS", DEFAULT_RETRY_BACKOFF_MS);
72
73 if env_flag("NANOGRAPH_EMBEDDINGS_MOCK") {
74 return Ok(Self {
75 model,
76 retry_attempts,
77 retry_backoff_ms,
78 transport: EmbeddingTransport::Mock,
79 });
80 }
81
82 let api_key = std::env::var("OPENAI_API_KEY")
83 .ok()
84 .map(|v| v.trim().to_string())
85 .filter(|v| !v.is_empty())
86 .ok_or_else(|| {
87 NanoError::Execution(
88 "OPENAI_API_KEY is required when an embedding call is needed".to_string(),
89 )
90 })?;
91 let base_url = std::env::var("OPENAI_BASE_URL")
92 .ok()
93 .map(|v| v.trim_end_matches('/').to_string())
94 .filter(|v| !v.is_empty())
95 .unwrap_or_else(|| DEFAULT_OPENAI_BASE_URL.to_string());
96 let timeout_ms = parse_env_u64("NANOGRAPH_EMBED_TIMEOUT_MS", DEFAULT_TIMEOUT_MS);
97 let http = Client::builder()
98 .timeout(Duration::from_millis(timeout_ms))
99 .build()
100 .map_err(|e| {
101 NanoError::Execution(format!("failed to initialize HTTP client: {}", e))
102 })?;
103
104 Ok(Self {
105 model,
106 retry_attempts,
107 retry_backoff_ms,
108 transport: EmbeddingTransport::OpenAi {
109 api_key,
110 base_url,
111 http,
112 },
113 })
114 }
115
116 #[cfg(test)]
117 pub(crate) fn mock_for_tests() -> Self {
118 Self {
119 model: DEFAULT_EMBED_MODEL.to_string(),
120 retry_attempts: DEFAULT_RETRY_ATTEMPTS,
121 retry_backoff_ms: DEFAULT_RETRY_BACKOFF_MS,
122 transport: EmbeddingTransport::Mock,
123 }
124 }
125
126 pub(crate) fn model(&self) -> &str {
127 &self.model
128 }
129
130 pub(crate) async fn embed_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
131 let mut vectors = self.embed_texts(&[input.to_string()], expected_dim).await?;
132 vectors.pop().ok_or_else(|| {
133 NanoError::Execution("embedding provider returned no vector".to_string())
134 })
135 }
136
137 pub(crate) async fn embed_texts(
138 &self,
139 inputs: &[String],
140 expected_dim: usize,
141 ) -> Result<Vec<Vec<f32>>> {
142 if expected_dim == 0 {
143 return Err(NanoError::Execution(
144 "embedding dimension must be greater than zero".to_string(),
145 ));
146 }
147 if inputs.is_empty() {
148 return Ok(Vec::new());
149 }
150
151 match &self.transport {
152 EmbeddingTransport::Mock => Ok(inputs
153 .iter()
154 .map(|input| mock_embedding(input, expected_dim))
155 .collect()),
156 EmbeddingTransport::OpenAi { .. } => {
157 self.embed_texts_openai_with_retry(inputs, expected_dim)
158 .await
159 }
160 }
161 }
162
163 async fn embed_texts_openai_with_retry(
164 &self,
165 inputs: &[String],
166 expected_dim: usize,
167 ) -> Result<Vec<Vec<f32>>> {
168 let max_attempt = self.retry_attempts.max(1);
169 let mut attempt = 0usize;
170 loop {
171 attempt += 1;
172 match self.embed_texts_openai_once(inputs, expected_dim).await {
173 Ok(vectors) => return Ok(vectors),
174 Err(err) => {
175 if !err.retryable || attempt >= max_attempt {
176 return Err(NanoError::Execution(err.message));
177 }
178 let shift = (attempt - 1).min(10) as u32;
179 let delay = self.retry_backoff_ms.saturating_mul(1u64 << shift);
180 sleep(Duration::from_millis(delay)).await;
181 }
182 }
183 }
184 }
185
186 async fn embed_texts_openai_once(
187 &self,
188 inputs: &[String],
189 expected_dim: usize,
190 ) -> std::result::Result<Vec<Vec<f32>>, EmbedCallError> {
191 let (api_key, base_url, http) = match &self.transport {
192 EmbeddingTransport::OpenAi {
193 api_key,
194 base_url,
195 http,
196 } => (api_key, base_url, http),
197 EmbeddingTransport::Mock => unreachable!("mock transport should not call OpenAI"),
198 };
199
200 let request = serde_json::json!({
201 "model": self.model,
202 "input": inputs,
203 "dimensions": expected_dim,
204 });
205 let url = format!("{}/embeddings", base_url);
206 let response = http
207 .post(&url)
208 .bearer_auth(api_key)
209 .json(&request)
210 .send()
211 .await;
212
213 let response = match response {
214 Ok(resp) => resp,
215 Err(err) => {
216 let retryable = err.is_timeout() || err.is_connect() || err.is_request();
217 return Err(EmbedCallError {
218 message: format!("embedding request failed: {}", err),
219 retryable,
220 });
221 }
222 };
223
224 let status = response.status();
225 let body = match response.text().await {
226 Ok(body) => body,
227 Err(err) => {
228 return Err(EmbedCallError {
229 message: format!(
230 "embedding response read failed (status {}): {}",
231 status, err
232 ),
233 retryable: status.is_server_error() || status.as_u16() == 429,
234 });
235 }
236 };
237
238 if !status.is_success() {
239 let message = parse_openai_error_message(&body).unwrap_or_else(|| body.clone());
240 return Err(EmbedCallError {
241 message: format!(
242 "embedding request failed with status {}: {}",
243 status, message
244 ),
245 retryable: status.is_server_error() || status.as_u16() == 429,
246 });
247 }
248
249 let mut parsed: OpenAiEmbeddingResponse =
250 serde_json::from_str(&body).map_err(|err| EmbedCallError {
251 message: format!("embedding response decode failed: {}", err),
252 retryable: false,
253 })?;
254
255 if parsed.data.len() != inputs.len() {
256 return Err(EmbedCallError {
257 message: format!(
258 "embedding response size mismatch: expected {}, got {}",
259 inputs.len(),
260 parsed.data.len()
261 ),
262 retryable: false,
263 });
264 }
265
266 parsed.data.sort_by_key(|item| item.index);
267 let mut vectors = Vec::with_capacity(parsed.data.len());
268 for (idx, item) in parsed.data.into_iter().enumerate() {
269 if item.index != idx {
270 return Err(EmbedCallError {
271 message: format!(
272 "embedding response index mismatch at position {}: got {}",
273 idx, item.index
274 ),
275 retryable: false,
276 });
277 }
278 if item.embedding.len() != expected_dim {
279 return Err(EmbedCallError {
280 message: format!(
281 "embedding dimension mismatch: expected {}, got {}",
282 expected_dim,
283 item.embedding.len()
284 ),
285 retryable: false,
286 });
287 }
288 vectors.push(item.embedding);
289 }
290 Ok(vectors)
291 }
292}
293
294fn parse_openai_error_message(body: &str) -> Option<String> {
295 serde_json::from_str::<OpenAiErrorEnvelope>(body)
296 .ok()
297 .map(|e| e.error.message)
298 .filter(|msg| !msg.trim().is_empty())
299}
300
301fn parse_env_usize(name: &str, default: usize) -> usize {
302 std::env::var(name)
303 .ok()
304 .and_then(|v| v.parse::<usize>().ok())
305 .filter(|v| *v > 0)
306 .unwrap_or(default)
307}
308
309fn parse_env_u64(name: &str, default: u64) -> u64 {
310 std::env::var(name)
311 .ok()
312 .and_then(|v| v.parse::<u64>().ok())
313 .filter(|v| *v > 0)
314 .unwrap_or(default)
315}
316
317fn env_flag(name: &str) -> bool {
318 std::env::var(name)
319 .ok()
320 .map(|v| {
321 let s = v.trim().to_ascii_lowercase();
322 s == "1" || s == "true" || s == "yes" || s == "on"
323 })
324 .unwrap_or(false)
325}
326
327fn mock_embedding(input: &str, dim: usize) -> Vec<f32> {
328 let mut seed = fnv1a64(input.as_bytes());
329 let mut out = Vec::with_capacity(dim);
330 for _ in 0..dim {
331 seed = xorshift64(seed);
332 let ratio = (seed as f64 / u64::MAX as f64) as f32;
333 out.push((ratio * 2.0) - 1.0);
334 }
335
336 let norm = out
337 .iter()
338 .map(|v| (*v as f64) * (*v as f64))
339 .sum::<f64>()
340 .sqrt() as f32;
341 if norm > f32::EPSILON {
342 for value in &mut out {
343 *value /= norm;
344 }
345 }
346 out
347}
348
349fn fnv1a64(bytes: &[u8]) -> u64 {
350 let mut hash = 14695981039346656037u64;
351 for byte in bytes {
352 hash ^= *byte as u64;
353 hash = hash.wrapping_mul(1099511628211u64);
354 }
355 hash
356}
357
358fn xorshift64(mut x: u64) -> u64 {
359 x ^= x << 13;
360 x ^= x >> 7;
361 x ^= x << 17;
362 x
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[tokio::test]
370 async fn mock_embeddings_are_deterministic() {
371 let client = EmbeddingClient::mock_for_tests();
372 let a = client.embed_text("alpha", 8).await.unwrap();
373 let b = client.embed_text("alpha", 8).await.unwrap();
374 let c = client.embed_text("beta", 8).await.unwrap();
375 assert_eq!(a, b);
376 assert_ne!(a, c);
377 assert_eq!(a.len(), 8);
378 }
379}