Skip to main content

mnemo_core/embedding/
openai.rs

1use crate::embedding::EmbeddingProvider;
2use crate::error::{Error, Result};
3use serde::{Deserialize, Serialize};
4
5pub struct OpenAiEmbedding {
6    api_key: String,
7    model: String,
8    dimensions: usize,
9    client: reqwest::Client,
10}
11
12#[derive(Serialize)]
13struct EmbeddingRequest {
14    model: String,
15    input: Vec<String>,
16    dimensions: usize,
17}
18
19#[derive(Deserialize)]
20struct EmbeddingResponse {
21    data: Vec<EmbeddingData>,
22}
23
24#[derive(Deserialize)]
25struct EmbeddingData {
26    embedding: Vec<f32>,
27}
28
29impl OpenAiEmbedding {
30    pub fn new(api_key: String, model: String, dimensions: usize) -> Self {
31        Self {
32            api_key,
33            model,
34            dimensions,
35            client: reqwest::Client::builder()
36                .timeout(std::time::Duration::from_secs(30))
37                .connect_timeout(std::time::Duration::from_secs(10))
38                .build()
39                .unwrap_or_else(|e| {
40                    tracing::error!(error = %e, "failed to build HTTP client with timeouts, using default");
41                    reqwest::Client::default()
42                }),
43        }
44    }
45}
46
47#[async_trait::async_trait]
48impl EmbeddingProvider for OpenAiEmbedding {
49    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
50        let results = self.embed_batch(&[text]).await?;
51        results
52            .into_iter()
53            .next()
54            .ok_or_else(|| Error::Embedding("empty response from OpenAI".to_string()))
55    }
56
57    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
58        let request = EmbeddingRequest {
59            model: self.model.clone(),
60            input: texts.iter().map(|s| s.to_string()).collect(),
61            dimensions: self.dimensions,
62        };
63
64        let response = self
65            .client
66            .post("https://api.openai.com/v1/embeddings")
67            .header("Authorization", format!("Bearer {}", self.api_key))
68            .json(&request)
69            .send()
70            .await?;
71
72        if !response.status().is_success() {
73            let status = response.status();
74            let body = response.text().await.unwrap_or_default();
75            return Err(Error::Embedding(format!(
76                "OpenAI API error {status}: {body}"
77            )));
78        }
79
80        let resp: EmbeddingResponse = response.json().await?;
81        Ok(resp.data.into_iter().map(|d| d.embedding).collect())
82    }
83
84    fn dimensions(&self) -> usize {
85        self.dimensions
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use crate::embedding::NoopEmbedding;
93
94    #[tokio::test]
95    async fn test_noop_embedding() {
96        let provider = NoopEmbedding::new(1536);
97        let result = provider.embed("test").await.unwrap();
98        assert_eq!(result.len(), 1536);
99        assert!(result.iter().all(|&v| v == 0.0));
100    }
101
102    #[tokio::test]
103    async fn test_noop_batch() {
104        let provider = NoopEmbedding::new(768);
105        let result = provider.embed_batch(&["a", "b", "c"]).await.unwrap();
106        assert_eq!(result.len(), 3);
107        assert!(result.iter().all(|v| v.len() == 768));
108    }
109
110    #[tokio::test]
111    async fn test_noop_dimensions() {
112        let provider = NoopEmbedding::new(256);
113        assert_eq!(provider.dimensions(), 256);
114    }
115
116    #[tokio::test]
117    #[ignore] // Requires OPENAI_API_KEY
118    async fn test_openai_embedding() {
119        let api_key = std::env::var("OPENAI_API_KEY").unwrap();
120        let provider = OpenAiEmbedding::new(api_key, "text-embedding-3-small".to_string(), 1536);
121        let result = provider.embed("hello world").await.unwrap();
122        assert_eq!(result.len(), 1536);
123    }
124}