codemem_embeddings/
openai.rs1use codemem_core::CodememError;
7
8pub const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
10
11pub const DEFAULT_MODEL: &str = "text-embedding-3-small";
13
14pub struct OpenAIProvider {
16 api_key: String,
17 model: String,
18 dimensions: usize,
19 base_url: String,
20 client: reqwest::blocking::Client,
21}
22
23impl OpenAIProvider {
24 pub fn new(api_key: &str, model: &str, dimensions: usize, base_url: Option<&str>) -> Self {
26 Self {
27 api_key: api_key.to_string(),
28 model: model.to_string(),
29 dimensions,
30 base_url: base_url.unwrap_or(DEFAULT_BASE_URL).to_string(),
31 client: reqwest::blocking::Client::new(),
32 }
33 }
34
35 pub fn with_api_key(api_key: &str) -> Self {
37 Self::new(api_key, DEFAULT_MODEL, 768, None)
38 }
39}
40
41impl super::EmbeddingProvider for OpenAIProvider {
42 fn dimensions(&self) -> usize {
43 self.dimensions
44 }
45
46 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
47 let url = format!("{}/embeddings", self.base_url);
48
49 let mut body = serde_json::json!({
50 "model": self.model,
51 "input": text,
52 });
53
54 if self.model.starts_with("text-embedding-3") {
56 body["dimensions"] = serde_json::json!(self.dimensions);
57 }
58
59 let response = self
60 .client
61 .post(&url)
62 .header("Authorization", format!("Bearer {}", self.api_key))
63 .header("Content-Type", "application/json")
64 .json(&body)
65 .send()
66 .map_err(|e| CodememError::Embedding(format!("OpenAI request failed: {e}")))?;
67
68 if !response.status().is_success() {
69 let status = response.status();
70 let body = response.text().unwrap_or_default();
71 return Err(CodememError::Embedding(format!(
72 "OpenAI returned status {}: {}",
73 status, body
74 )));
75 }
76
77 let json: serde_json::Value = response
78 .json()
79 .map_err(|e| CodememError::Embedding(format!("OpenAI response parse error: {e}")))?;
80
81 let embedding = json
82 .get("data")
83 .and_then(|v| v.as_array())
84 .and_then(|arr| arr.first())
85 .and_then(|item| item.get("embedding"))
86 .and_then(|v| v.as_array())
87 .ok_or_else(|| CodememError::Embedding("Missing embedding in OpenAI response".into()))?
88 .iter()
89 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
90 .collect();
91
92 Ok(embedding)
93 }
94
95 fn name(&self) -> &str {
96 "openai"
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 fn openai_provider_construction() {
106 let provider = OpenAIProvider::with_api_key("test-key");
107 assert_eq!(provider.model, DEFAULT_MODEL);
108 assert_eq!(provider.dimensions, 768);
109 assert_eq!(provider.base_url, DEFAULT_BASE_URL);
110 }
111
112 #[test]
113 fn openai_provider_custom_base_url() {
114 let provider = OpenAIProvider::new(
115 "test-key",
116 "custom-model",
117 1536,
118 Some("https://api.together.xyz/v1"),
119 );
120 assert_eq!(provider.base_url, "https://api.together.xyz/v1");
121 assert_eq!(provider.model, "custom-model");
122 assert_eq!(provider.dimensions, 1536);
123 }
124
125 #[test]
126 fn openai_name_returns_openai() {
127 use crate::EmbeddingProvider;
128 let provider = OpenAIProvider::with_api_key("test-key");
129 assert_eq!(provider.name(), "openai");
130 }
131
132 #[test]
133 fn openai_embed_success_mock() {
134 use crate::EmbeddingProvider;
135 let mut server = mockito::Server::new();
136 let mock = server
137 .mock("POST", "/embeddings")
138 .match_header("Authorization", "Bearer test-key")
139 .with_status(200)
140 .with_header("content-type", "application/json")
141 .with_body(r#"{"data": [{"embedding": [0.1, 0.2, 0.3]}]}"#)
142 .create();
143
144 let server_url = server.url();
145 let provider = OpenAIProvider::new("test-key", "custom-model", 3, Some(&server_url));
146 let result = provider.embed("test");
147 mock.assert();
148
149 let embedding = result.unwrap();
150 assert_eq!(embedding.len(), 3);
151 assert!((embedding[0] - 0.1).abs() < 1e-6);
152 assert!((embedding[1] - 0.2).abs() < 1e-6);
153 assert!((embedding[2] - 0.3).abs() < 1e-6);
154 }
155
156 #[test]
157 fn openai_embed_unauthorized_mock() {
158 use crate::EmbeddingProvider;
159 let mut server = mockito::Server::new();
160 let mock = server
161 .mock("POST", "/embeddings")
162 .with_status(401)
163 .with_body("Unauthorized")
164 .create();
165
166 let server_url = server.url();
167 let provider = OpenAIProvider::new("bad-key", "custom-model", 768, Some(&server_url));
168 let result = provider.embed("test");
169 mock.assert();
170
171 assert!(result.is_err());
172 let err = result.unwrap_err().to_string();
173 assert!(err.contains("401"), "Error should contain '401': {err}");
174 }
175
176 #[test]
177 fn openai_embed_text_embedding_3_includes_dimensions() {
178 use crate::EmbeddingProvider;
179 let mut server = mockito::Server::new();
180 let mock = server
181 .mock("POST", "/embeddings")
182 .match_body(mockito::Matcher::PartialJsonString(
183 r#"{"dimensions": 3}"#.to_string(),
184 ))
185 .with_status(200)
186 .with_header("content-type", "application/json")
187 .with_body(r#"{"data": [{"embedding": [0.1, 0.2, 0.3]}]}"#)
188 .create();
189
190 let server_url = server.url();
191 let provider =
192 OpenAIProvider::new("test-key", "text-embedding-3-small", 3, Some(&server_url));
193 let result = provider.embed("test");
194 mock.assert();
195
196 assert!(result.is_ok());
197 }
198}