Skip to main content

omnigraph_compiler/
embedding.rs

1#![allow(dead_code)]
2
3use std::time::Duration;
4
5use reqwest::Client;
6use serde::Deserialize;
7use tokio::time::sleep;
8
9use crate::error::{NanoError, Result};
10
11const DEFAULT_EMBED_MODEL: &str = "text-embedding-3-small";
12const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
13const DEFAULT_TIMEOUT_MS: u64 = 30_000;
14const DEFAULT_RETRY_ATTEMPTS: usize = 4;
15const DEFAULT_RETRY_BACKOFF_MS: u64 = 200;
16
17#[derive(Clone)]
18enum EmbeddingTransport {
19    Mock,
20    OpenAi {
21        api_key: String,
22        base_url: String,
23        http: Client,
24    },
25}
26
27#[derive(Clone)]
28pub(crate) struct EmbeddingClient {
29    model: String,
30    retry_attempts: usize,
31    retry_backoff_ms: u64,
32    transport: EmbeddingTransport,
33}
34
35struct EmbedCallError {
36    message: String,
37    retryable: bool,
38}
39
40#[derive(Debug, Deserialize)]
41struct OpenAiEmbeddingResponse {
42    data: Vec<OpenAiEmbeddingDatum>,
43}
44
45#[derive(Debug, Deserialize)]
46struct OpenAiEmbeddingDatum {
47    index: usize,
48    embedding: Vec<f32>,
49}
50
51#[derive(Debug, Deserialize)]
52struct OpenAiErrorEnvelope {
53    error: OpenAiErrorBody,
54}
55
56#[derive(Debug, Deserialize)]
57struct OpenAiErrorBody {
58    message: String,
59}
60
61impl EmbeddingClient {
62    pub(crate) fn from_env() -> Result<Self> {
63        let model = std::env::var("NANOGRAPH_EMBED_MODEL")
64            .ok()
65            .map(|v| v.trim().to_string())
66            .filter(|v| !v.is_empty())
67            .unwrap_or_else(|| DEFAULT_EMBED_MODEL.to_string());
68        let retry_attempts =
69            parse_env_usize("NANOGRAPH_EMBED_RETRY_ATTEMPTS", DEFAULT_RETRY_ATTEMPTS);
70        let retry_backoff_ms =
71            parse_env_u64("NANOGRAPH_EMBED_RETRY_BACKOFF_MS", DEFAULT_RETRY_BACKOFF_MS);
72
73        if env_flag("NANOGRAPH_EMBEDDINGS_MOCK") {
74            return Ok(Self {
75                model,
76                retry_attempts,
77                retry_backoff_ms,
78                transport: EmbeddingTransport::Mock,
79            });
80        }
81
82        let api_key = std::env::var("OPENAI_API_KEY")
83            .ok()
84            .map(|v| v.trim().to_string())
85            .filter(|v| !v.is_empty())
86            .ok_or_else(|| {
87                NanoError::Execution(
88                    "OPENAI_API_KEY is required when an embedding call is needed".to_string(),
89                )
90            })?;
91        let base_url = std::env::var("OPENAI_BASE_URL")
92            .ok()
93            .map(|v| v.trim_end_matches('/').to_string())
94            .filter(|v| !v.is_empty())
95            .unwrap_or_else(|| DEFAULT_OPENAI_BASE_URL.to_string());
96        let timeout_ms = parse_env_u64("NANOGRAPH_EMBED_TIMEOUT_MS", DEFAULT_TIMEOUT_MS);
97        let http = Client::builder()
98            .timeout(Duration::from_millis(timeout_ms))
99            .build()
100            .map_err(|e| {
101                NanoError::Execution(format!("failed to initialize HTTP client: {}", e))
102            })?;
103
104        Ok(Self {
105            model,
106            retry_attempts,
107            retry_backoff_ms,
108            transport: EmbeddingTransport::OpenAi {
109                api_key,
110                base_url,
111                http,
112            },
113        })
114    }
115
116    #[cfg(test)]
117    pub(crate) fn mock_for_tests() -> Self {
118        Self {
119            model: DEFAULT_EMBED_MODEL.to_string(),
120            retry_attempts: DEFAULT_RETRY_ATTEMPTS,
121            retry_backoff_ms: DEFAULT_RETRY_BACKOFF_MS,
122            transport: EmbeddingTransport::Mock,
123        }
124    }
125
126    pub(crate) fn model(&self) -> &str {
127        &self.model
128    }
129
130    pub(crate) async fn embed_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
131        let mut vectors = self.embed_texts(&[input.to_string()], expected_dim).await?;
132        vectors.pop().ok_or_else(|| {
133            NanoError::Execution("embedding provider returned no vector".to_string())
134        })
135    }
136
137    pub(crate) async fn embed_texts(
138        &self,
139        inputs: &[String],
140        expected_dim: usize,
141    ) -> Result<Vec<Vec<f32>>> {
142        if expected_dim == 0 {
143            return Err(NanoError::Execution(
144                "embedding dimension must be greater than zero".to_string(),
145            ));
146        }
147        if inputs.is_empty() {
148            return Ok(Vec::new());
149        }
150
151        match &self.transport {
152            EmbeddingTransport::Mock => Ok(inputs
153                .iter()
154                .map(|input| mock_embedding(input, expected_dim))
155                .collect()),
156            EmbeddingTransport::OpenAi { .. } => {
157                self.embed_texts_openai_with_retry(inputs, expected_dim)
158                    .await
159            }
160        }
161    }
162
163    async fn embed_texts_openai_with_retry(
164        &self,
165        inputs: &[String],
166        expected_dim: usize,
167    ) -> Result<Vec<Vec<f32>>> {
168        let max_attempt = self.retry_attempts.max(1);
169        let mut attempt = 0usize;
170        loop {
171            attempt += 1;
172            match self.embed_texts_openai_once(inputs, expected_dim).await {
173                Ok(vectors) => return Ok(vectors),
174                Err(err) => {
175                    if !err.retryable || attempt >= max_attempt {
176                        return Err(NanoError::Execution(err.message));
177                    }
178                    let shift = (attempt - 1).min(10) as u32;
179                    let delay = self.retry_backoff_ms.saturating_mul(1u64 << shift);
180                    sleep(Duration::from_millis(delay)).await;
181                }
182            }
183        }
184    }
185
186    async fn embed_texts_openai_once(
187        &self,
188        inputs: &[String],
189        expected_dim: usize,
190    ) -> std::result::Result<Vec<Vec<f32>>, EmbedCallError> {
191        let (api_key, base_url, http) = match &self.transport {
192            EmbeddingTransport::OpenAi {
193                api_key,
194                base_url,
195                http,
196            } => (api_key, base_url, http),
197            EmbeddingTransport::Mock => unreachable!("mock transport should not call OpenAI"),
198        };
199
200        let request = serde_json::json!({
201            "model": self.model,
202            "input": inputs,
203            "dimensions": expected_dim,
204        });
205        let url = format!("{}/embeddings", base_url);
206        let response = http
207            .post(&url)
208            .bearer_auth(api_key)
209            .json(&request)
210            .send()
211            .await;
212
213        let response = match response {
214            Ok(resp) => resp,
215            Err(err) => {
216                let retryable = err.is_timeout() || err.is_connect() || err.is_request();
217                return Err(EmbedCallError {
218                    message: format!("embedding request failed: {}", err),
219                    retryable,
220                });
221            }
222        };
223
224        let status = response.status();
225        let body = match response.text().await {
226            Ok(body) => body,
227            Err(err) => {
228                return Err(EmbedCallError {
229                    message: format!(
230                        "embedding response read failed (status {}): {}",
231                        status, err
232                    ),
233                    retryable: status.is_server_error() || status.as_u16() == 429,
234                });
235            }
236        };
237
238        if !status.is_success() {
239            let message = parse_openai_error_message(&body).unwrap_or_else(|| body.clone());
240            return Err(EmbedCallError {
241                message: format!(
242                    "embedding request failed with status {}: {}",
243                    status, message
244                ),
245                retryable: status.is_server_error() || status.as_u16() == 429,
246            });
247        }
248
249        let mut parsed: OpenAiEmbeddingResponse =
250            serde_json::from_str(&body).map_err(|err| EmbedCallError {
251                message: format!("embedding response decode failed: {}", err),
252                retryable: false,
253            })?;
254
255        if parsed.data.len() != inputs.len() {
256            return Err(EmbedCallError {
257                message: format!(
258                    "embedding response size mismatch: expected {}, got {}",
259                    inputs.len(),
260                    parsed.data.len()
261                ),
262                retryable: false,
263            });
264        }
265
266        parsed.data.sort_by_key(|item| item.index);
267        let mut vectors = Vec::with_capacity(parsed.data.len());
268        for (idx, item) in parsed.data.into_iter().enumerate() {
269            if item.index != idx {
270                return Err(EmbedCallError {
271                    message: format!(
272                        "embedding response index mismatch at position {}: got {}",
273                        idx, item.index
274                    ),
275                    retryable: false,
276                });
277            }
278            if item.embedding.len() != expected_dim {
279                return Err(EmbedCallError {
280                    message: format!(
281                        "embedding dimension mismatch: expected {}, got {}",
282                        expected_dim,
283                        item.embedding.len()
284                    ),
285                    retryable: false,
286                });
287            }
288            vectors.push(item.embedding);
289        }
290        Ok(vectors)
291    }
292}
293
294fn parse_openai_error_message(body: &str) -> Option<String> {
295    serde_json::from_str::<OpenAiErrorEnvelope>(body)
296        .ok()
297        .map(|e| e.error.message)
298        .filter(|msg| !msg.trim().is_empty())
299}
300
301fn parse_env_usize(name: &str, default: usize) -> usize {
302    std::env::var(name)
303        .ok()
304        .and_then(|v| v.parse::<usize>().ok())
305        .filter(|v| *v > 0)
306        .unwrap_or(default)
307}
308
309fn parse_env_u64(name: &str, default: u64) -> u64 {
310    std::env::var(name)
311        .ok()
312        .and_then(|v| v.parse::<u64>().ok())
313        .filter(|v| *v > 0)
314        .unwrap_or(default)
315}
316
317fn env_flag(name: &str) -> bool {
318    std::env::var(name)
319        .ok()
320        .map(|v| {
321            let s = v.trim().to_ascii_lowercase();
322            s == "1" || s == "true" || s == "yes" || s == "on"
323        })
324        .unwrap_or(false)
325}
326
327fn mock_embedding(input: &str, dim: usize) -> Vec<f32> {
328    let mut seed = fnv1a64(input.as_bytes());
329    let mut out = Vec::with_capacity(dim);
330    for _ in 0..dim {
331        seed = xorshift64(seed);
332        let ratio = (seed as f64 / u64::MAX as f64) as f32;
333        out.push((ratio * 2.0) - 1.0);
334    }
335
336    let norm = out
337        .iter()
338        .map(|v| (*v as f64) * (*v as f64))
339        .sum::<f64>()
340        .sqrt() as f32;
341    if norm > f32::EPSILON {
342        for value in &mut out {
343            *value /= norm;
344        }
345    }
346    out
347}
348
349fn fnv1a64(bytes: &[u8]) -> u64 {
350    let mut hash = 14695981039346656037u64;
351    for byte in bytes {
352        hash ^= *byte as u64;
353        hash = hash.wrapping_mul(1099511628211u64);
354    }
355    hash
356}
357
358fn xorshift64(mut x: u64) -> u64 {
359    x ^= x << 13;
360    x ^= x >> 7;
361    x ^= x << 17;
362    x
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[tokio::test]
370    async fn mock_embeddings_are_deterministic() {
371        let client = EmbeddingClient::mock_for_tests();
372        let a = client.embed_text("alpha", 8).await.unwrap();
373        let b = client.embed_text("alpha", 8).await.unwrap();
374        let c = client.embed_text("beta", 8).await.unwrap();
375        assert_eq!(a, b);
376        assert_ne!(a, c);
377        assert_eq!(a.len(), 8);
378    }
379}