difflore_core/context/embedding/
openai.rs1use async_trait::async_trait;
2
3use crate::errors::CoreError;
4
5use super::{Embedder, embedding_http_client};
6
7pub struct OpenAICompatEmbedder {
12 pub base_url: String,
13 pub api_key: String,
14 pub model: String,
15 pub dim: usize,
16 client: reqwest::Client,
17}
18
19impl OpenAICompatEmbedder {
20 pub fn new(base_url: String, api_key: String, model: String, dim: usize) -> Self {
21 Self {
22 base_url,
23 api_key,
24 model,
25 dim,
26 client: embedding_http_client(),
27 }
28 }
29
30 pub(crate) fn endpoint(&self) -> String {
31 let trimmed = self.base_url.trim_end_matches('/');
32 if trimmed.ends_with("/embeddings") {
33 trimmed.to_owned()
34 } else {
35 format!("{trimmed}/embeddings")
36 }
37 }
38
39 fn authed_post(&self, url: &str) -> reqwest::RequestBuilder {
44 let request = self.client.post(url);
45 if self.api_key.is_empty() {
46 request
47 } else {
48 request.bearer_auth(&self.api_key)
49 }
50 }
51}
52
53fn provider_status_error(status: reqwest::StatusCode) -> CoreError {
54 CoreError::Internal(format!(
55 "embedding provider returned {status}; check provider URL, model, dimensions, and API key"
56 ))
57}
58
59#[async_trait]
60impl Embedder for OpenAICompatEmbedder {
61 async fn embed(&self, text: &str) -> Result<Vec<f32>, CoreError> {
62 let url = self.endpoint();
63 let body = serde_json::json!({
70 "model": self.model,
71 "input": text,
72 });
73
74 let resp = self
75 .authed_post(&url)
76 .json(&body)
77 .send()
78 .await
79 .map_err(|e| CoreError::Internal(format!("embedding request failed: {e}")))?;
80
81 if !resp.status().is_success() {
82 let status = resp.status();
83 return Err(provider_status_error(status));
84 }
85
86 let json: serde_json::Value = resp
87 .json()
88 .await
89 .map_err(|e| CoreError::Internal(format!("embedding response parse error: {e}")))?;
90
91 let vec = json
92 .get("data")
93 .and_then(|d| d.get(0))
94 .and_then(|d| d.get("embedding"))
95 .and_then(|e| e.as_array())
96 .ok_or_else(|| {
97 CoreError::Internal("embedding response missing data[0].embedding".into())
98 })?
99 .iter()
100 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
101 .collect::<Vec<f32>>();
102
103 if vec.len() != self.dim {
107 return Err(CoreError::Internal(format!(
108 "embedding provider returned {} dimensions but {} are configured; \
109 re-run `difflore embeddings setup --dim {}` to match your provider/model",
110 vec.len(),
111 self.dim,
112 vec.len()
113 )));
114 }
115
116 Ok(vec)
117 }
118
119 async fn embed_batch(
120 &self,
121 texts: &[String],
122 _rule_ids: Option<&[String]>,
123 ) -> Result<Vec<Vec<f32>>, CoreError> {
124 if texts.is_empty() {
125 return Ok(Vec::new());
126 }
127 let body = serde_json::json!({
130 "model": self.model,
131 "input": texts,
132 });
133 let resp = self
134 .authed_post(&self.endpoint())
135 .json(&body)
136 .send()
137 .await
138 .map_err(|e| CoreError::Internal(format!("embedding request failed: {e}")))?;
139 if !resp.status().is_success() {
140 let status = resp.status();
141 return Err(provider_status_error(status));
142 }
143 let json: serde_json::Value = resp
144 .json()
145 .await
146 .map_err(|e| CoreError::Internal(format!("embedding response parse error: {e}")))?;
147 let data = json
148 .get("data")
149 .and_then(|d| d.as_array())
150 .ok_or_else(|| CoreError::Internal("embedding response missing data array".into()))?;
151 if data.len() != texts.len() {
152 return Err(CoreError::Internal(format!(
153 "embedding response length mismatch: expected {}, got {}",
154 texts.len(),
155 data.len()
156 )));
157 }
158 let mut indexed: Vec<(usize, Vec<f32>)> = Vec::with_capacity(data.len());
161 for item in data {
162 let index = item
163 .get("index")
164 .and_then(serde_json::Value::as_u64)
165 .map_or(indexed.len(), |i| i as usize);
166 let vector = item
167 .get("embedding")
168 .and_then(|e| e.as_array())
169 .ok_or_else(|| {
170 CoreError::Internal("embedding response item missing embedding array".into())
171 })?
172 .iter()
173 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
174 .collect::<Vec<f32>>();
175 if vector.len() != self.dim {
176 return Err(CoreError::Internal(format!(
177 "embedding provider returned {} dimensions but {} are configured; \
178 re-run `difflore embeddings setup --dim {}` to match your provider/model",
179 vector.len(),
180 self.dim,
181 vector.len()
182 )));
183 }
184 indexed.push((index, vector));
185 }
186 indexed.sort_by_key(|(index, _)| *index);
187 Ok(indexed.into_iter().map(|(_, vector)| vector).collect())
188 }
189
190 fn dim(&self) -> usize {
191 self.dim
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::provider_status_error;
198
199 #[test]
200 fn provider_status_error_does_not_echo_response_body() {
201 let err = provider_status_error(reqwest::StatusCode::UNAUTHORIZED).to_string();
202
203 assert!(err.contains("401"));
204 assert!(err.contains("check provider URL"));
205 assert!(!err.contains("Authorization"));
206 assert!(!err.contains("sk-"));
207 }
208}