Skip to main content

codemem_embeddings/
openai.rs

1//! OpenAI-compatible embedding provider for Codemem.
2//!
3//! Works with OpenAI, Azure OpenAI, Together.ai, and any OpenAI-compatible API.
4//! Default model: text-embedding-3-small.
5
6use codemem_core::CodememError;
7
8/// Default OpenAI API base URL.
9pub const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
10
11/// Default model.
12pub const DEFAULT_MODEL: &str = "text-embedding-3-small";
13
14/// OpenAI embedding provider.
15pub struct OpenAIProvider {
16    api_key: String,
17    model: String,
18    dimensions: usize,
19    base_url: String,
20    client: reqwest::blocking::Client,
21}
22
23impl OpenAIProvider {
24    /// Create a new OpenAI provider.
25    pub fn new(api_key: &str, model: &str, dimensions: usize, base_url: Option<&str>) -> Self {
26        Self {
27            api_key: api_key.to_string(),
28            model: model.to_string(),
29            dimensions,
30            base_url: base_url.unwrap_or(DEFAULT_BASE_URL).to_string(),
31            client: reqwest::blocking::Client::new(),
32        }
33    }
34
35    /// Create with default model (text-embedding-3-small, 768 dims).
36    pub fn with_api_key(api_key: &str) -> Self {
37        Self::new(api_key, DEFAULT_MODEL, 768, None)
38    }
39}
40
41impl super::EmbeddingProvider for OpenAIProvider {
42    fn dimensions(&self) -> usize {
43        self.dimensions
44    }
45
46    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
47        let url = format!("{}/embeddings", self.base_url);
48
49        let mut body = serde_json::json!({
50            "model": self.model,
51            "input": text,
52        });
53
54        // text-embedding-3-* supports custom dimensions
55        if self.model.starts_with("text-embedding-3") {
56            body["dimensions"] = serde_json::json!(self.dimensions);
57        }
58
59        let response = self
60            .client
61            .post(&url)
62            .header("Authorization", format!("Bearer {}", self.api_key))
63            .header("Content-Type", "application/json")
64            .json(&body)
65            .send()
66            .map_err(|e| CodememError::Embedding(format!("OpenAI request failed: {e}")))?;
67
68        if !response.status().is_success() {
69            let status = response.status();
70            let body = response.text().unwrap_or_default();
71            return Err(CodememError::Embedding(format!(
72                "OpenAI returned status {}: {}",
73                status, body
74            )));
75        }
76
77        let json: serde_json::Value = response
78            .json()
79            .map_err(|e| CodememError::Embedding(format!("OpenAI response parse error: {e}")))?;
80
81        let embedding = json
82            .get("data")
83            .and_then(|v| v.as_array())
84            .and_then(|arr| arr.first())
85            .and_then(|item| item.get("embedding"))
86            .and_then(|v| v.as_array())
87            .ok_or_else(|| CodememError::Embedding("Missing embedding in OpenAI response".into()))?
88            .iter()
89            .map(|v| v.as_f64().unwrap_or(0.0) as f32)
90            .collect();
91
92        Ok(embedding)
93    }
94
95    fn name(&self) -> &str {
96        "openai"
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn openai_provider_construction() {
106        let provider = OpenAIProvider::with_api_key("test-key");
107        assert_eq!(provider.model, DEFAULT_MODEL);
108        assert_eq!(provider.dimensions, 768);
109        assert_eq!(provider.base_url, DEFAULT_BASE_URL);
110    }
111
112    #[test]
113    fn openai_provider_custom_base_url() {
114        let provider = OpenAIProvider::new(
115            "test-key",
116            "custom-model",
117            1536,
118            Some("https://api.together.xyz/v1"),
119        );
120        assert_eq!(provider.base_url, "https://api.together.xyz/v1");
121        assert_eq!(provider.model, "custom-model");
122        assert_eq!(provider.dimensions, 1536);
123    }
124
125    #[test]
126    fn openai_name_returns_openai() {
127        use crate::EmbeddingProvider;
128        let provider = OpenAIProvider::with_api_key("test-key");
129        assert_eq!(provider.name(), "openai");
130    }
131
132    #[test]
133    fn openai_embed_success_mock() {
134        use crate::EmbeddingProvider;
135        let mut server = mockito::Server::new();
136        let mock = server
137            .mock("POST", "/embeddings")
138            .match_header("Authorization", "Bearer test-key")
139            .with_status(200)
140            .with_header("content-type", "application/json")
141            .with_body(r#"{"data": [{"embedding": [0.1, 0.2, 0.3]}]}"#)
142            .create();
143
144        let server_url = server.url();
145        let provider = OpenAIProvider::new("test-key", "custom-model", 3, Some(&server_url));
146        let result = provider.embed("test");
147        mock.assert();
148
149        let embedding = result.unwrap();
150        assert_eq!(embedding.len(), 3);
151        assert!((embedding[0] - 0.1).abs() < 1e-6);
152        assert!((embedding[1] - 0.2).abs() < 1e-6);
153        assert!((embedding[2] - 0.3).abs() < 1e-6);
154    }
155
156    #[test]
157    fn openai_embed_unauthorized_mock() {
158        use crate::EmbeddingProvider;
159        let mut server = mockito::Server::new();
160        let mock = server
161            .mock("POST", "/embeddings")
162            .with_status(401)
163            .with_body("Unauthorized")
164            .create();
165
166        let server_url = server.url();
167        let provider = OpenAIProvider::new("bad-key", "custom-model", 768, Some(&server_url));
168        let result = provider.embed("test");
169        mock.assert();
170
171        assert!(result.is_err());
172        let err = result.unwrap_err().to_string();
173        assert!(err.contains("401"), "Error should contain '401': {err}");
174    }
175
176    #[test]
177    fn openai_embed_text_embedding_3_includes_dimensions() {
178        use crate::EmbeddingProvider;
179        let mut server = mockito::Server::new();
180        let mock = server
181            .mock("POST", "/embeddings")
182            .match_body(mockito::Matcher::PartialJsonString(
183                r#"{"dimensions": 3}"#.to_string(),
184            ))
185            .with_status(200)
186            .with_header("content-type", "application/json")
187            .with_body(r#"{"data": [{"embedding": [0.1, 0.2, 0.3]}]}"#)
188            .create();
189
190        let server_url = server.url();
191        let provider =
192            OpenAIProvider::new("test-key", "text-embedding-3-small", 3, Some(&server_url));
193        let result = provider.embed("test");
194        mock.assert();
195
196        assert!(result.is_ok());
197    }
198}