1use anyhow::{anyhow, bail, Result};
17use memvid_core::{EmbeddingConfig, EmbeddingProvider, VecEmbedder};
18use reqwest::blocking::Client;
19use serde::{Deserialize, Serialize};
20use std::sync::atomic::{AtomicBool, Ordering};
21use std::time::Duration;
22use tracing::{debug, info, warn};
23
24const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
26
27const MAX_BATCH_SIZE: usize = 100;
29
30const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
32
33const MAX_EMBEDDING_TEXT_LEN: usize = 20_000;
36
37fn truncate_for_embedding(text: &str) -> std::borrow::Cow<'_, str> {
39 if text.len() <= MAX_EMBEDDING_TEXT_LEN {
40 std::borrow::Cow::Borrowed(text)
41 } else {
42 let end = text[..MAX_EMBEDDING_TEXT_LEN]
44 .char_indices()
45 .rev()
46 .next()
47 .map(|(i, c)| i + c.len_utf8())
48 .unwrap_or(MAX_EMBEDDING_TEXT_LEN);
49 warn!(
50 "Truncating embedding text from {} to {} chars to avoid token limit",
51 text.len(),
52 end
53 );
54 std::borrow::Cow::Owned(text[..end].to_string())
55 }
56}
57
58#[derive(Debug, Serialize)]
60struct OpenAIEmbeddingRequest<'a> {
61 model: &'a str,
62 input: Vec<&'a str>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 dimensions: Option<usize>,
65}
66
67#[derive(Debug, Deserialize)]
69struct OpenAIEmbeddingResponse {
70 data: Vec<OpenAIEmbeddingData>,
71 model: String,
72 usage: OpenAIUsage,
73}
74
75#[derive(Debug, Deserialize)]
76struct OpenAIEmbeddingData {
77 embedding: Vec<f32>,
78 index: usize,
79}
80
81#[derive(Debug, Deserialize)]
82struct OpenAIUsage {
83 #[allow(dead_code)]
84 prompt_tokens: usize,
85 total_tokens: usize,
86}
87
88#[derive(Debug, Deserialize)]
90struct OpenAIErrorResponse {
91 error: OpenAIError,
92}
93
94#[derive(Debug, Deserialize)]
95struct OpenAIError {
96 message: String,
97 #[serde(rename = "type")]
98 error_type: String,
99}
100
101#[derive(Clone)]
105pub struct OpenAIEmbeddingProvider {
106 api_key: String,
107 config: EmbeddingConfig,
108 client: Client,
109 ready: std::sync::Arc<AtomicBool>,
110}
111
112impl std::fmt::Debug for OpenAIEmbeddingProvider {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct("OpenAIEmbeddingProvider")
115 .field("model", &self.config.model)
116 .field("dimension", &self.config.dimension)
117 .field("ready", &self.ready.load(Ordering::Relaxed))
118 .finish()
119 }
120}
121
122impl OpenAIEmbeddingProvider {
123 pub fn new(api_key: String, config: EmbeddingConfig) -> Result<Self> {
137 if api_key.is_empty() {
138 bail!("OpenAI API key cannot be empty");
139 }
140
141 let client = Client::builder()
142 .timeout(REQUEST_TIMEOUT)
143 .build()
144 .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
145
146 Ok(Self {
147 api_key,
148 config,
149 client,
150 ready: std::sync::Arc::new(AtomicBool::new(false)),
151 })
152 }
153
154 pub fn from_env() -> Result<Self> {
159 let api_key = std::env::var("OPENAI_API_KEY")
160 .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
161
162 let config = match std::env::var("OPENAI_EMBEDDING_MODEL") {
163 Ok(model) => match model.as_str() {
164 "text-embedding-3-small" => EmbeddingConfig::openai_small(),
165 "text-embedding-ada-002" => EmbeddingConfig::openai_ada(),
166 "text-embedding-3-large" | _ => EmbeddingConfig::openai_large(),
167 },
168 Err(_) => EmbeddingConfig::openai_large(),
169 };
170
171 Self::new(api_key, config)
172 }
173
174 pub fn large(api_key: String) -> Result<Self> {
176 Self::new(api_key, EmbeddingConfig::openai_large())
177 }
178
179 pub fn small(api_key: String) -> Result<Self> {
181 Self::new(api_key, EmbeddingConfig::openai_small())
182 }
183
184 pub fn ada(api_key: String) -> Result<Self> {
186 Self::new(api_key, EmbeddingConfig::openai_ada())
187 }
188
189 fn call_openai(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
191 if texts.is_empty() {
192 return Ok(Vec::new());
193 }
194
195 let request = OpenAIEmbeddingRequest {
196 model: &self.config.model,
197 input: texts.to_vec(),
198 dimensions: None, };
200
201 let response = self
202 .client
203 .post(OPENAI_EMBEDDINGS_URL)
204 .header("Authorization", format!("Bearer {}", self.api_key))
205 .header("Content-Type", "application/json")
206 .json(&request)
207 .send()
208 .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
209
210 let status = response.status();
211 let body = response
212 .text()
213 .map_err(|e| anyhow!("Failed to read response body: {}", e))?;
214
215 if !status.is_success() {
216 if let Ok(error_response) = serde_json::from_str::<OpenAIErrorResponse>(&body) {
218 bail!(
219 "OpenAI API error ({}): {}",
220 error_response.error.error_type,
221 error_response.error.message
222 );
223 }
224 bail!("OpenAI API request failed with status {}: {}", status, body);
225 }
226
227 let embedding_response: OpenAIEmbeddingResponse = serde_json::from_str(&body)
228 .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
229
230 debug!(
231 "OpenAI embeddings: {} texts, {} tokens, model={}",
232 texts.len(),
233 embedding_response.usage.total_tokens,
234 embedding_response.model
235 );
236
237 let mut data = embedding_response.data;
239 data.sort_by_key(|d| d.index);
240
241 let embeddings: Vec<Vec<f32>> = data.into_iter().map(|d| d.embedding).collect();
242
243 if let Some(first) = embeddings.first() {
245 if first.len() != self.config.dimension {
246 warn!(
247 "OpenAI returned dimension {} but expected {}",
248 first.len(),
249 self.config.dimension
250 );
251 }
252 }
253
254 Ok(embeddings)
255 }
256
257 fn embed_with_retry(&self, texts: &[&str], max_retries: usize) -> Result<Vec<Vec<f32>>> {
259 let mut last_error = None;
260
261 for attempt in 0..max_retries {
262 match self.call_openai(texts) {
263 Ok(embeddings) => return Ok(embeddings),
264 Err(e) => {
265 let error_str = e.to_string();
266 if error_str.contains("rate_limit") || error_str.contains("429") {
267 let backoff = Duration::from_millis(500 * (1 << attempt));
268 warn!(
269 "Rate limited by OpenAI, retrying in {:?} (attempt {}/{})",
270 backoff,
271 attempt + 1,
272 max_retries
273 );
274 std::thread::sleep(backoff);
275 last_error = Some(e);
276 continue;
277 }
278 return Err(e);
279 }
280 }
281 }
282
283 Err(last_error.unwrap_or_else(|| anyhow!("Failed to embed after {} retries", max_retries)))
284 }
285}
286
287impl EmbeddingProvider for OpenAIEmbeddingProvider {
288 fn kind(&self) -> &str {
289 "openai"
290 }
291
292 fn model(&self) -> &str {
293 &self.config.model
294 }
295
296 fn dimension(&self) -> usize {
297 self.config.dimension
298 }
299
300 fn embed_text(&self, text: &str) -> memvid_core::Result<Vec<f32>> {
301 let text = truncate_for_embedding(text);
302 self.embed_with_retry(&[&text], 3)
303 .map(|mut v| v.pop().unwrap_or_default())
304 .map_err(|e| memvid_core::MemvidError::EmbeddingFailed {
305 reason: e.to_string().into_boxed_str(),
306 })
307 }
308
309 fn embed_batch(&self, texts: &[&str]) -> memvid_core::Result<Vec<Vec<f32>>> {
310 if texts.is_empty() {
311 return Ok(Vec::new());
312 }
313
314 let truncated: Vec<std::borrow::Cow<'_, str>> =
316 texts.iter().map(|t| truncate_for_embedding(t)).collect();
317 let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
318
319 let batch_size = self
321 .config
322 .batch_size
323 .unwrap_or(MAX_BATCH_SIZE)
324 .min(MAX_BATCH_SIZE);
325 let mut all_embeddings = Vec::with_capacity(texts.len());
326
327 for chunk in truncated_refs.chunks(batch_size) {
328 let embeddings = self.embed_with_retry(chunk, 3).map_err(|e| {
329 memvid_core::MemvidError::EmbeddingFailed {
330 reason: e.to_string().into_boxed_str(),
331 }
332 })?;
333 all_embeddings.extend(embeddings);
334 }
335
336 Ok(all_embeddings)
337 }
338
339 fn is_ready(&self) -> bool {
340 self.ready.load(Ordering::Relaxed)
341 }
342
343 fn init(&mut self) -> memvid_core::Result<()> {
344 info!(
346 "Initializing OpenAI embedding provider with model: {}",
347 self.config.model
348 );
349
350 let test_embedding = self.embed_with_retry(&["test"], 1).map_err(|e| {
351 memvid_core::MemvidError::EmbeddingFailed {
352 reason: format!("Failed to initialize OpenAI provider: {}", e).into_boxed_str(),
353 }
354 })?;
355
356 if let Some(emb) = test_embedding.first() {
357 info!(
358 "OpenAI provider initialized: model={}, dimension={}",
359 self.config.model,
360 emb.len()
361 );
362 if emb.len() != self.config.dimension {
364 warn!(
365 "Updating dimension from {} to {}",
366 self.config.dimension,
367 emb.len()
368 );
369 }
370 }
371
372 self.ready.store(true, Ordering::Relaxed);
373 Ok(())
374 }
375}
376
377impl VecEmbedder for OpenAIEmbeddingProvider {
379 fn embed_query(&self, text: &str) -> memvid_core::Result<Vec<f32>> {
380 self.embed_text(text)
381 }
382
383 fn embed_chunks(&self, texts: &[&str]) -> memvid_core::Result<Vec<Vec<f32>>> {
384 self.embed_batch(texts)
385 }
386
387 fn embedding_dimension(&self) -> usize {
388 self.dimension()
389 }
390}
391
392pub fn try_openai_provider() -> Option<OpenAIEmbeddingProvider> {
394 match OpenAIEmbeddingProvider::from_env() {
395 Ok(provider) => {
396 info!("OpenAI embedding provider available");
397 Some(provider)
398 }
399 Err(e) => {
400 debug!("OpenAI provider not available: {}", e);
401 None
402 }
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn test_config_dimensions() {
412 assert_eq!(EmbeddingConfig::openai_large().dimension, 3072);
413 assert_eq!(EmbeddingConfig::openai_small().dimension, 1536);
414 assert_eq!(EmbeddingConfig::openai_ada().dimension, 1536);
415 }
416
417 #[test]
418 fn test_empty_api_key() {
419 let result = OpenAIEmbeddingProvider::new(String::new(), EmbeddingConfig::openai_large());
420 assert!(result.is_err());
421 }
422
423 #[test]
424 #[ignore] fn test_real_embedding() {
426 let provider = OpenAIEmbeddingProvider::from_env().expect("OPENAI_API_KEY must be set");
427 let embedding = provider.embed_text("Hello, world!").expect("embed");
428 assert!(!embedding.is_empty());
429 assert_eq!(embedding.len(), 3072); }
431}