codemem_embeddings/
openai.rs1use codemem_core::CodememError;
7
8pub const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
10
11pub const DEFAULT_MODEL: &str = "text-embedding-3-small";
13
14pub struct OpenAIProvider {
16 api_key: String,
17 model: String,
18 dimensions: usize,
19 base_url: String,
20 client: reqwest::blocking::Client,
21}
22
23impl OpenAIProvider {
24 pub fn new(api_key: &str, model: &str, dimensions: usize, base_url: Option<&str>) -> Self {
26 Self {
27 api_key: api_key.to_string(),
28 model: model.to_string(),
29 dimensions,
30 base_url: base_url.unwrap_or(DEFAULT_BASE_URL).to_string(),
31 client: reqwest::blocking::Client::new(),
32 }
33 }
34
35 pub fn with_api_key(api_key: &str) -> Self {
37 Self::new(api_key, DEFAULT_MODEL, 768, None)
38 }
39}
40
41impl super::EmbeddingProvider for OpenAIProvider {
42 fn dimensions(&self) -> usize {
43 self.dimensions
44 }
45
46 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
47 let url = format!("{}/embeddings", self.base_url);
48
49 let mut body = serde_json::json!({
50 "model": self.model,
51 "input": text,
52 });
53
54 if self.model.starts_with("text-embedding-3") {
56 body["dimensions"] = serde_json::json!(self.dimensions);
57 }
58
59 let response = self
60 .client
61 .post(&url)
62 .header("Authorization", format!("Bearer {}", self.api_key))
63 .header("Content-Type", "application/json")
64 .json(&body)
65 .send()
66 .map_err(|e| CodememError::Embedding(format!("OpenAI request failed: {e}")))?;
67
68 if !response.status().is_success() {
69 let status = response.status();
70 let body = response.text().unwrap_or_default();
71 return Err(CodememError::Embedding(format!(
72 "OpenAI returned status {}: {}",
73 status, body
74 )));
75 }
76
77 let json: serde_json::Value = response
78 .json()
79 .map_err(|e| CodememError::Embedding(format!("OpenAI response parse error: {e}")))?;
80
81 let embedding = json
82 .get("data")
83 .and_then(|v| v.as_array())
84 .and_then(|arr| arr.first())
85 .and_then(|item| item.get("embedding"))
86 .and_then(|v| v.as_array())
87 .ok_or_else(|| CodememError::Embedding("Missing embedding in OpenAI response".into()))?
88 .iter()
89 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
90 .collect();
91
92 Ok(embedding)
93 }
94
95 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
96 if texts.is_empty() {
97 return Ok(vec![]);
98 }
99
100 let url = format!("{}/embeddings", self.base_url);
101
102 let mut body = serde_json::json!({
103 "model": self.model,
104 "input": texts,
105 });
106
107 if self.model.starts_with("text-embedding-3") {
108 body["dimensions"] = serde_json::json!(self.dimensions);
109 }
110
111 let response = self
112 .client
113 .post(&url)
114 .header("Authorization", format!("Bearer {}", self.api_key))
115 .header("Content-Type", "application/json")
116 .json(&body)
117 .send()
118 .map_err(|e| CodememError::Embedding(format!("OpenAI batch request failed: {e}")))?;
119
120 if !response.status().is_success() {
121 let status = response.status();
122 let body = response.text().unwrap_or_default();
123 return Err(CodememError::Embedding(format!(
124 "OpenAI returned status {}: {}",
125 status, body
126 )));
127 }
128
129 let json: serde_json::Value = response
130 .json()
131 .map_err(|e| CodememError::Embedding(format!("OpenAI response parse error: {e}")))?;
132
133 let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
134 CodememError::Embedding("Missing 'data' array in OpenAI response".into())
135 })?;
136
137 let mut indexed: Vec<(usize, Vec<f32>)> = data
139 .iter()
140 .filter_map(|item| {
141 let index = item.get("index")?.as_u64()? as usize;
142 let embedding: Vec<f32> = item
143 .get("embedding")?
144 .as_array()?
145 .iter()
146 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
147 .collect();
148 Some((index, embedding))
149 })
150 .collect();
151 indexed.sort_by_key(|(i, _)| *i);
152
153 if indexed.len() != texts.len() {
154 return Err(CodememError::Embedding(format!(
155 "OpenAI returned {} embeddings, expected {}",
156 indexed.len(),
157 texts.len()
158 )));
159 }
160
161 Ok(indexed.into_iter().map(|(_, e)| e).collect())
162 }
163
164 fn name(&self) -> &str {
165 "openai"
166 }
167}
168
169#[cfg(test)]
170#[path = "tests/openai_tests.rs"]
171mod tests;