Skip to main content

codemem_embeddings/
openai.rs

1//! OpenAI-compatible embedding provider for Codemem.
2//!
3//! Works with OpenAI, Azure OpenAI, Together.ai, and any OpenAI-compatible API.
4//! Default model: text-embedding-3-small.
5
6use codemem_core::CodememError;
7
8/// Default OpenAI API base URL.
9pub const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
10
11/// Default model.
12pub const DEFAULT_MODEL: &str = "text-embedding-3-small";
13
14/// OpenAI embedding provider.
15pub struct OpenAIProvider {
16    api_key: String,
17    model: String,
18    dimensions: usize,
19    base_url: String,
20    client: reqwest::blocking::Client,
21}
22
23impl OpenAIProvider {
24    /// Create a new OpenAI provider.
25    pub fn new(api_key: &str, model: &str, dimensions: usize, base_url: Option<&str>) -> Self {
26        Self {
27            api_key: api_key.to_string(),
28            model: model.to_string(),
29            dimensions,
30            base_url: base_url.unwrap_or(DEFAULT_BASE_URL).to_string(),
31            client: reqwest::blocking::Client::new(),
32        }
33    }
34
35    /// Create with default model (text-embedding-3-small, 768 dims).
36    #[cfg(test)]
37    pub fn with_api_key(api_key: &str) -> Self {
38        Self::new(api_key, DEFAULT_MODEL, 768, None)
39    }
40}
41
42impl super::EmbeddingProvider for OpenAIProvider {
43    fn dimensions(&self) -> usize {
44        self.dimensions
45    }
46
47    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
48        let url = format!("{}/embeddings", self.base_url);
49
50        let mut body = serde_json::json!({
51            "model": self.model,
52            "input": text,
53        });
54
55        // text-embedding-3-* supports custom dimensions
56        if self.model.starts_with("text-embedding-3") {
57            body["dimensions"] = serde_json::json!(self.dimensions);
58        }
59
60        let response = self
61            .client
62            .post(&url)
63            .header("Authorization", format!("Bearer {}", self.api_key))
64            .header("Content-Type", "application/json")
65            .json(&body)
66            .send()
67            .map_err(|e| CodememError::Embedding(format!("OpenAI request failed: {e}")))?;
68
69        if !response.status().is_success() {
70            let status = response.status();
71            let body = response.text().unwrap_or_default();
72            return Err(CodememError::Embedding(format!(
73                "OpenAI returned status {}: {}",
74                status, body
75            )));
76        }
77
78        let json: serde_json::Value = response
79            .json()
80            .map_err(|e| CodememError::Embedding(format!("OpenAI response parse error: {e}")))?;
81
82        let embedding = json
83            .get("data")
84            .and_then(|v| v.as_array())
85            .and_then(|arr| arr.first())
86            .and_then(|item| item.get("embedding"))
87            .and_then(|v| v.as_array())
88            .ok_or_else(|| CodememError::Embedding("Missing embedding in OpenAI response".into()))?
89            .iter()
90            .map(|v| v.as_f64().unwrap_or(0.0) as f32)
91            .collect();
92
93        Ok(embedding)
94    }
95
96    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
97        if texts.is_empty() {
98            return Ok(vec![]);
99        }
100
101        let url = format!("{}/embeddings", self.base_url);
102
103        let mut body = serde_json::json!({
104            "model": self.model,
105            "input": texts,
106        });
107
108        if self.model.starts_with("text-embedding-3") {
109            body["dimensions"] = serde_json::json!(self.dimensions);
110        }
111
112        let response = self
113            .client
114            .post(&url)
115            .header("Authorization", format!("Bearer {}", self.api_key))
116            .header("Content-Type", "application/json")
117            .json(&body)
118            .send()
119            .map_err(|e| CodememError::Embedding(format!("OpenAI batch request failed: {e}")))?;
120
121        if !response.status().is_success() {
122            let status = response.status();
123            let body = response.text().unwrap_or_default();
124            return Err(CodememError::Embedding(format!(
125                "OpenAI returned status {}: {}",
126                status, body
127            )));
128        }
129
130        let json: serde_json::Value = response
131            .json()
132            .map_err(|e| CodememError::Embedding(format!("OpenAI response parse error: {e}")))?;
133
134        let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
135            CodememError::Embedding("Missing 'data' array in OpenAI response".into())
136        })?;
137
138        // OpenAI returns embeddings sorted by index, but we sort explicitly to be safe
139        let mut indexed: Vec<(usize, Vec<f32>)> = data
140            .iter()
141            .filter_map(|item| {
142                let index = item.get("index")?.as_u64()? as usize;
143                let embedding: Vec<f32> = item
144                    .get("embedding")?
145                    .as_array()?
146                    .iter()
147                    .map(|v| v.as_f64().unwrap_or(0.0) as f32)
148                    .collect();
149                Some((index, embedding))
150            })
151            .collect();
152        indexed.sort_by_key(|(i, _)| *i);
153
154        if indexed.len() != texts.len() {
155            return Err(CodememError::Embedding(format!(
156                "OpenAI returned {} embeddings, expected {}",
157                indexed.len(),
158                texts.len()
159            )));
160        }
161
162        Ok(indexed.into_iter().map(|(_, e)| e).collect())
163    }
164
165    fn name(&self) -> &str {
166        "openai"
167    }
168}
169
170#[cfg(test)]
171#[path = "tests/openai_tests.rs"]
172mod tests;