1use crate::config::EmbeddingConfig;
4use crate::engine::EmbeddingEngine;
5use crate::error::EmbeddingError;
6
7#[derive(serde::Deserialize)]
9struct EmbeddingResponse {
10 data: Vec<EmbeddingData>,
11}
12
13#[derive(serde::Deserialize)]
14struct EmbeddingData {
15 embedding: Vec<f32>,
16}
17
18pub struct RemoteEmbedding {
20 client: reqwest::Client,
21 base_url: String,
22 api_key: String,
23 config: EmbeddingConfig,
24}
25
26impl RemoteEmbedding {
27 pub fn new(
29 base_url: impl Into<String>,
30 api_key: impl Into<String>,
31 config: EmbeddingConfig,
32 ) -> Self {
33 Self {
34 client: reqwest::Client::new(),
35 base_url: base_url.into(),
36 api_key: api_key.into(),
37 config,
38 }
39 }
40
41 #[must_use]
43 pub fn with_client(mut self, client: reqwest::Client) -> Self {
44 self.client = client;
45 self
46 }
47
48 async fn post_embeddings(
49 &self,
50 input: serde_json::Value,
51 ) -> Result<EmbeddingResponse, EmbeddingError> {
52 let url = format!("{}/v1/embeddings", self.base_url);
53 let body = serde_json::json!({
54 "model": self.config.model_id,
55 "input": input,
56 });
57
58 let response = self
59 .client
60 .post(&url)
61 .header("Authorization", format!("Bearer {}", self.api_key))
62 .header("Content-Type", "application/json")
63 .json(&body)
64 .send()
65 .await
66 .map_err(|e| EmbeddingError::Inference(format!("request failed: {e}")))?;
67
68 let status = response.status();
69 if !status.is_success() {
70 let body_text = response.text().await.unwrap_or_else(|_| "unknown".into());
71 return Err(EmbeddingError::Inference(format!(
72 "API error {status}: {body_text}"
73 )));
74 }
75
76 response
77 .json::<EmbeddingResponse>()
78 .await
79 .map_err(|e| EmbeddingError::Inference(format!("failed to parse response: {e}")))
80 }
81}
82
83#[async_trait::async_trait]
84impl EmbeddingEngine for RemoteEmbedding {
85 fn name(&self) -> &'static str {
86 "remote"
87 }
88
89 fn dimensions(&self) -> usize {
90 self.config.dimensions
91 }
92
93 async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
94 if text.is_empty() {
95 return Err(EmbeddingError::InvalidInput(
96 "text must not be empty".into(),
97 ));
98 }
99 let resp = self
100 .post_embeddings(serde_json::Value::String(text.to_owned()))
101 .await?;
102 resp.data
103 .into_iter()
104 .next()
105 .map(|d| d.embedding)
106 .ok_or_else(|| EmbeddingError::Inference("empty response data".into()))
107 }
108
109 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
110 if texts.iter().any(|t| t.is_empty()) {
111 return Err(EmbeddingError::InvalidInput(
112 "text must not be empty".into(),
113 ));
114 }
115 let input: Vec<serde_json::Value> = texts
116 .iter()
117 .map(|t| serde_json::Value::String((*t).to_owned()))
118 .collect();
119 let resp = self
120 .post_embeddings(serde_json::Value::Array(input))
121 .await?;
122 Ok(resp.data.into_iter().map(|d| d.embedding).collect())
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use wiremock::matchers::{header, method, path};
130 use wiremock::{Mock, MockServer, ResponseTemplate};
131
132 fn no_proxy_client() -> reqwest::Client {
133 reqwest::Client::builder().no_proxy().build().unwrap()
134 }
135
136 fn test_config() -> EmbeddingConfig {
137 EmbeddingConfig::builder()
138 .model_id("text-embedding-test")
139 .dimensions(3)
140 .build()
141 }
142
143 fn mock_response(embeddings: Vec<Vec<f32>>) -> serde_json::Value {
144 let data: Vec<serde_json::Value> = embeddings
145 .into_iter()
146 .enumerate()
147 .map(|(i, emb)| {
148 serde_json::json!({
149 "embedding": emb,
150 "index": i,
151 "object": "embedding"
152 })
153 })
154 .collect();
155 serde_json::json!({
156 "data": data,
157 "model": "text-embedding-test",
158 "object": "list",
159 "usage": {"prompt_tokens": 5, "total_tokens": 5}
160 })
161 }
162
163 #[test]
164 fn remote_returns_name() {
165 let engine = RemoteEmbedding::new("http://localhost", "key", test_config());
166 assert_eq!(engine.name(), "remote");
167 }
168
169 #[test]
170 fn remote_returns_correct_dimensions() {
171 let engine = RemoteEmbedding::new("http://localhost", "key", test_config());
172 assert_eq!(engine.dimensions(), 3);
173 }
174
175 #[tokio::test]
176 async fn remote_rejects_empty_input() {
177 let engine = RemoteEmbedding::new("http://localhost", "key", test_config());
178 let result = engine.embed("").await;
179 assert!(matches!(
180 result.unwrap_err(),
181 EmbeddingError::InvalidInput(_)
182 ));
183 }
184
185 #[tokio::test]
186 async fn remote_sends_correct_request() {
187 let server = MockServer::start().await;
188 Mock::given(method("POST"))
189 .and(path("/v1/embeddings"))
190 .and(header("Authorization", "Bearer test-key"))
191 .and(header("Content-Type", "application/json"))
192 .respond_with(
193 ResponseTemplate::new(200).set_body_json(mock_response(vec![vec![0.1, 0.2, 0.3]])),
194 )
195 .expect(1)
196 .mount(&server)
197 .await;
198
199 let engine = RemoteEmbedding::new(server.uri(), "test-key", test_config())
200 .with_client(no_proxy_client());
201 let result = engine.embed("hello").await.unwrap();
202 assert_eq!(result, vec![0.1, 0.2, 0.3]);
203 }
204
205 #[tokio::test]
206 async fn remote_parses_openai_embedding_response() {
207 let server = MockServer::start().await;
208 Mock::given(method("POST"))
209 .and(path("/v1/embeddings"))
210 .respond_with(
211 ResponseTemplate::new(200).set_body_json(mock_response(vec![vec![1.0, 2.0, 3.0]])),
212 )
213 .mount(&server)
214 .await;
215
216 let engine =
217 RemoteEmbedding::new(server.uri(), "key", test_config()).with_client(no_proxy_client());
218 let result = engine.embed("test").await.unwrap();
219 assert_eq!(result.len(), 3);
220 assert!((result[0] - 1.0).abs() < f32::EPSILON);
221 }
222
223 #[tokio::test]
224 async fn remote_embed_batch_sends_array_input() {
225 let server = MockServer::start().await;
226 Mock::given(method("POST"))
227 .and(path("/v1/embeddings"))
228 .respond_with(ResponseTemplate::new(200).set_body_json(mock_response(vec![
229 vec![0.1, 0.2, 0.3],
230 vec![0.4, 0.5, 0.6],
231 ])))
232 .mount(&server)
233 .await;
234
235 let engine =
236 RemoteEmbedding::new(server.uri(), "key", test_config()).with_client(no_proxy_client());
237 let results = engine.embed_batch(&["hello", "world"]).await.unwrap();
238 assert_eq!(results.len(), 2);
239 assert_eq!(results[0], vec![0.1, 0.2, 0.3]);
240 assert_eq!(results[1], vec![0.4, 0.5, 0.6]);
241 }
242
243 #[tokio::test]
244 async fn remote_returns_error_on_401() {
245 let server = MockServer::start().await;
246 Mock::given(method("POST"))
247 .and(path("/v1/embeddings"))
248 .respond_with(ResponseTemplate::new(401).set_body_string("Unauthorized"))
249 .mount(&server)
250 .await;
251
252 let engine = RemoteEmbedding::new(server.uri(), "bad-key", test_config())
253 .with_client(no_proxy_client());
254 let result = engine.embed("test").await;
255 assert!(result.is_err());
256 let err = result.unwrap_err();
257 assert!(
258 err.to_string().contains("401"),
259 "Expected 401 in error: {err}"
260 );
261 }
262
263 #[tokio::test]
264 async fn remote_returns_error_on_500() {
265 let server = MockServer::start().await;
266 Mock::given(method("POST"))
267 .and(path("/v1/embeddings"))
268 .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error"))
269 .mount(&server)
270 .await;
271
272 let engine =
273 RemoteEmbedding::new(server.uri(), "key", test_config()).with_client(no_proxy_client());
274 let result = engine.embed("test").await;
275 assert!(result.is_err());
276 let err = result.unwrap_err();
277 assert!(
278 err.to_string().contains("500"),
279 "Expected 500 in error: {err}"
280 );
281 }
282}