mnemo_core/embedding/
openai.rs1use crate::embedding::EmbeddingProvider;
2use crate::error::{Error, Result};
3use serde::{Deserialize, Serialize};
4
5pub struct OpenAiEmbedding {
6 api_key: String,
7 model: String,
8 dimensions: usize,
9 client: reqwest::Client,
10}
11
12#[derive(Serialize)]
13struct EmbeddingRequest {
14 model: String,
15 input: Vec<String>,
16 dimensions: usize,
17}
18
19#[derive(Deserialize)]
20struct EmbeddingResponse {
21 data: Vec<EmbeddingData>,
22}
23
24#[derive(Deserialize)]
25struct EmbeddingData {
26 embedding: Vec<f32>,
27}
28
29impl OpenAiEmbedding {
30 pub fn new(api_key: String, model: String, dimensions: usize) -> Self {
31 Self {
32 api_key,
33 model,
34 dimensions,
35 client: reqwest::Client::builder()
36 .timeout(std::time::Duration::from_secs(30))
37 .connect_timeout(std::time::Duration::from_secs(10))
38 .build()
39 .unwrap_or_else(|e| {
40 tracing::error!(error = %e, "failed to build HTTP client with timeouts, using default");
41 reqwest::Client::default()
42 }),
43 }
44 }
45}
46
47#[async_trait::async_trait]
48impl EmbeddingProvider for OpenAiEmbedding {
49 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
50 let results = self.embed_batch(&[text]).await?;
51 results
52 .into_iter()
53 .next()
54 .ok_or_else(|| Error::Embedding("empty response from OpenAI".to_string()))
55 }
56
57 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
58 let request = EmbeddingRequest {
59 model: self.model.clone(),
60 input: texts.iter().map(|s| s.to_string()).collect(),
61 dimensions: self.dimensions,
62 };
63
64 let response = self
65 .client
66 .post("https://api.openai.com/v1/embeddings")
67 .header("Authorization", format!("Bearer {}", self.api_key))
68 .json(&request)
69 .send()
70 .await?;
71
72 if !response.status().is_success() {
73 let status = response.status();
74 let body = response.text().await.unwrap_or_default();
75 return Err(Error::Embedding(format!(
76 "OpenAI API error {status}: {body}"
77 )));
78 }
79
80 let resp: EmbeddingResponse = response.json().await?;
81 Ok(resp.data.into_iter().map(|d| d.embedding).collect())
82 }
83
84 fn dimensions(&self) -> usize {
85 self.dimensions
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 use crate::embedding::NoopEmbedding;
93
94 #[tokio::test]
95 async fn test_noop_embedding() {
96 let provider = NoopEmbedding::new(1536);
97 let result = provider.embed("test").await.unwrap();
98 assert_eq!(result.len(), 1536);
99 assert!(result.iter().all(|&v| v == 0.0));
100 }
101
102 #[tokio::test]
103 async fn test_noop_batch() {
104 let provider = NoopEmbedding::new(768);
105 let result = provider.embed_batch(&["a", "b", "c"]).await.unwrap();
106 assert_eq!(result.len(), 3);
107 assert!(result.iter().all(|v| v.len() == 768));
108 }
109
110 #[tokio::test]
111 async fn test_noop_dimensions() {
112 let provider = NoopEmbedding::new(256);
113 assert_eq!(provider.dimensions(), 256);
114 }
115
116 #[tokio::test]
117 #[ignore] async fn test_openai_embedding() {
119 let api_key = std::env::var("OPENAI_API_KEY").unwrap();
120 let provider = OpenAiEmbedding::new(api_key, "text-embedding-3-small".to_string(), 1536);
121 let result = provider.embed("hello world").await.unwrap();
122 assert_eq!(result.len(), 1536);
123 }
124}