codemem_embeddings/
openai.rs1use codemem_core::CodememError;
7
8pub const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
10
11pub const DEFAULT_MODEL: &str = "text-embedding-3-small";
13
14pub struct OpenAIProvider {
16 api_key: String,
17 model: String,
18 dimensions: usize,
19 base_url: String,
20 client: reqwest::blocking::Client,
21}
22
23impl OpenAIProvider {
24 pub fn new(api_key: &str, model: &str, dimensions: usize, base_url: Option<&str>) -> Self {
26 Self {
27 api_key: api_key.to_string(),
28 model: model.to_string(),
29 dimensions,
30 base_url: base_url.unwrap_or(DEFAULT_BASE_URL).to_string(),
31 client: reqwest::blocking::Client::new(),
32 }
33 }
34
35 #[cfg(test)]
37 pub fn with_api_key(api_key: &str) -> Self {
38 Self::new(api_key, DEFAULT_MODEL, 768, None)
39 }
40}
41
42impl super::EmbeddingProvider for OpenAIProvider {
43 fn dimensions(&self) -> usize {
44 self.dimensions
45 }
46
47 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
48 let url = format!("{}/embeddings", self.base_url);
49
50 let mut body = serde_json::json!({
51 "model": self.model,
52 "input": text,
53 });
54
55 if self.model.starts_with("text-embedding-3") {
57 body["dimensions"] = serde_json::json!(self.dimensions);
58 }
59
60 let response = self
61 .client
62 .post(&url)
63 .header("Authorization", format!("Bearer {}", self.api_key))
64 .header("Content-Type", "application/json")
65 .json(&body)
66 .send()
67 .map_err(|e| CodememError::Embedding(format!("OpenAI request failed: {e}")))?;
68
69 if !response.status().is_success() {
70 let status = response.status();
71 let body = response.text().unwrap_or_default();
72 return Err(CodememError::Embedding(format!(
73 "OpenAI returned status {}: {}",
74 status, body
75 )));
76 }
77
78 let json: serde_json::Value = response
79 .json()
80 .map_err(|e| CodememError::Embedding(format!("OpenAI response parse error: {e}")))?;
81
82 let embedding = json
83 .get("data")
84 .and_then(|v| v.as_array())
85 .and_then(|arr| arr.first())
86 .and_then(|item| item.get("embedding"))
87 .and_then(|v| v.as_array())
88 .ok_or_else(|| CodememError::Embedding("Missing embedding in OpenAI response".into()))?
89 .iter()
90 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
91 .collect();
92
93 Ok(embedding)
94 }
95
96 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
97 if texts.is_empty() {
98 return Ok(vec![]);
99 }
100
101 let url = format!("{}/embeddings", self.base_url);
102
103 let mut body = serde_json::json!({
104 "model": self.model,
105 "input": texts,
106 });
107
108 if self.model.starts_with("text-embedding-3") {
109 body["dimensions"] = serde_json::json!(self.dimensions);
110 }
111
112 let response = self
113 .client
114 .post(&url)
115 .header("Authorization", format!("Bearer {}", self.api_key))
116 .header("Content-Type", "application/json")
117 .json(&body)
118 .send()
119 .map_err(|e| CodememError::Embedding(format!("OpenAI batch request failed: {e}")))?;
120
121 if !response.status().is_success() {
122 let status = response.status();
123 let body = response.text().unwrap_or_default();
124 return Err(CodememError::Embedding(format!(
125 "OpenAI returned status {}: {}",
126 status, body
127 )));
128 }
129
130 let json: serde_json::Value = response
131 .json()
132 .map_err(|e| CodememError::Embedding(format!("OpenAI response parse error: {e}")))?;
133
134 let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
135 CodememError::Embedding("Missing 'data' array in OpenAI response".into())
136 })?;
137
138 let mut indexed: Vec<(usize, Vec<f32>)> = data
140 .iter()
141 .filter_map(|item| {
142 let index = item.get("index")?.as_u64()? as usize;
143 let embedding: Vec<f32> = item
144 .get("embedding")?
145 .as_array()?
146 .iter()
147 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
148 .collect();
149 Some((index, embedding))
150 })
151 .collect();
152 indexed.sort_by_key(|(i, _)| *i);
153
154 if indexed.len() != texts.len() {
155 return Err(CodememError::Embedding(format!(
156 "OpenAI returned {} embeddings, expected {}",
157 indexed.len(),
158 texts.len()
159 )));
160 }
161
162 Ok(indexed.into_iter().map(|(_, e)| e).collect())
163 }
164
165 fn name(&self) -> &str {
166 "openai"
167 }
168}
169
170#[cfg(test)]
171#[path = "tests/openai_tests.rs"]
172mod tests;