Skip to main content

erio_embedding/
remote.rs

1//! Remote embedding engine calling an OpenAI-compatible embedding API.
2
3use crate::config::EmbeddingConfig;
4use crate::engine::EmbeddingEngine;
5use crate::error::EmbeddingError;
6
7/// Response format from an OpenAI-compatible embedding API.
8#[derive(serde::Deserialize)]
9struct EmbeddingResponse {
10    data: Vec<EmbeddingData>,
11}
12
13#[derive(serde::Deserialize)]
14struct EmbeddingData {
15    embedding: Vec<f32>,
16}
17
18/// Remote embedding engine that calls an OpenAI-compatible embedding API.
19pub struct RemoteEmbedding {
20    client: reqwest::Client,
21    base_url: String,
22    api_key: String,
23    config: EmbeddingConfig,
24}
25
26impl RemoteEmbedding {
27    /// Creates a new `RemoteEmbedding` with the given base URL and API key.
28    pub fn new(
29        base_url: impl Into<String>,
30        api_key: impl Into<String>,
31        config: EmbeddingConfig,
32    ) -> Self {
33        Self {
34            client: reqwest::Client::new(),
35            base_url: base_url.into(),
36            api_key: api_key.into(),
37            config,
38        }
39    }
40
41    /// Sets a custom reqwest client (e.g. for testing with `no_proxy()`).
42    #[must_use]
43    pub fn with_client(mut self, client: reqwest::Client) -> Self {
44        self.client = client;
45        self
46    }
47
48    async fn post_embeddings(
49        &self,
50        input: serde_json::Value,
51    ) -> Result<EmbeddingResponse, EmbeddingError> {
52        let url = format!("{}/v1/embeddings", self.base_url);
53        let body = serde_json::json!({
54            "model": self.config.model_id,
55            "input": input,
56        });
57
58        let response = self
59            .client
60            .post(&url)
61            .header("Authorization", format!("Bearer {}", self.api_key))
62            .header("Content-Type", "application/json")
63            .json(&body)
64            .send()
65            .await
66            .map_err(|e| EmbeddingError::Inference(format!("request failed: {e}")))?;
67
68        let status = response.status();
69        if !status.is_success() {
70            let body_text = response.text().await.unwrap_or_else(|_| "unknown".into());
71            return Err(EmbeddingError::Inference(format!(
72                "API error {status}: {body_text}"
73            )));
74        }
75
76        response
77            .json::<EmbeddingResponse>()
78            .await
79            .map_err(|e| EmbeddingError::Inference(format!("failed to parse response: {e}")))
80    }
81}
82
83#[async_trait::async_trait]
84impl EmbeddingEngine for RemoteEmbedding {
85    fn name(&self) -> &'static str {
86        "remote"
87    }
88
89    fn dimensions(&self) -> usize {
90        self.config.dimensions
91    }
92
93    async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
94        if text.is_empty() {
95            return Err(EmbeddingError::InvalidInput(
96                "text must not be empty".into(),
97            ));
98        }
99        let resp = self
100            .post_embeddings(serde_json::Value::String(text.to_owned()))
101            .await?;
102        resp.data
103            .into_iter()
104            .next()
105            .map(|d| d.embedding)
106            .ok_or_else(|| EmbeddingError::Inference("empty response data".into()))
107    }
108
109    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
110        if texts.iter().any(|t| t.is_empty()) {
111            return Err(EmbeddingError::InvalidInput(
112                "text must not be empty".into(),
113            ));
114        }
115        let input: Vec<serde_json::Value> = texts
116            .iter()
117            .map(|t| serde_json::Value::String((*t).to_owned()))
118            .collect();
119        let resp = self
120            .post_embeddings(serde_json::Value::Array(input))
121            .await?;
122        Ok(resp.data.into_iter().map(|d| d.embedding).collect())
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use wiremock::matchers::{header, method, path};
130    use wiremock::{Mock, MockServer, ResponseTemplate};
131
132    fn no_proxy_client() -> reqwest::Client {
133        reqwest::Client::builder().no_proxy().build().unwrap()
134    }
135
136    fn test_config() -> EmbeddingConfig {
137        EmbeddingConfig::builder()
138            .model_id("text-embedding-test")
139            .dimensions(3)
140            .build()
141    }
142
143    fn mock_response(embeddings: Vec<Vec<f32>>) -> serde_json::Value {
144        let data: Vec<serde_json::Value> = embeddings
145            .into_iter()
146            .enumerate()
147            .map(|(i, emb)| {
148                serde_json::json!({
149                    "embedding": emb,
150                    "index": i,
151                    "object": "embedding"
152                })
153            })
154            .collect();
155        serde_json::json!({
156            "data": data,
157            "model": "text-embedding-test",
158            "object": "list",
159            "usage": {"prompt_tokens": 5, "total_tokens": 5}
160        })
161    }
162
163    #[test]
164    fn remote_returns_name() {
165        let engine = RemoteEmbedding::new("http://localhost", "key", test_config());
166        assert_eq!(engine.name(), "remote");
167    }
168
169    #[test]
170    fn remote_returns_correct_dimensions() {
171        let engine = RemoteEmbedding::new("http://localhost", "key", test_config());
172        assert_eq!(engine.dimensions(), 3);
173    }
174
175    #[tokio::test]
176    async fn remote_rejects_empty_input() {
177        let engine = RemoteEmbedding::new("http://localhost", "key", test_config());
178        let result = engine.embed("").await;
179        assert!(matches!(
180            result.unwrap_err(),
181            EmbeddingError::InvalidInput(_)
182        ));
183    }
184
185    #[tokio::test]
186    async fn remote_sends_correct_request() {
187        let server = MockServer::start().await;
188        Mock::given(method("POST"))
189            .and(path("/v1/embeddings"))
190            .and(header("Authorization", "Bearer test-key"))
191            .and(header("Content-Type", "application/json"))
192            .respond_with(
193                ResponseTemplate::new(200).set_body_json(mock_response(vec![vec![0.1, 0.2, 0.3]])),
194            )
195            .expect(1)
196            .mount(&server)
197            .await;
198
199        let engine = RemoteEmbedding::new(server.uri(), "test-key", test_config())
200            .with_client(no_proxy_client());
201        let result = engine.embed("hello").await.unwrap();
202        assert_eq!(result, vec![0.1, 0.2, 0.3]);
203    }
204
205    #[tokio::test]
206    async fn remote_parses_openai_embedding_response() {
207        let server = MockServer::start().await;
208        Mock::given(method("POST"))
209            .and(path("/v1/embeddings"))
210            .respond_with(
211                ResponseTemplate::new(200).set_body_json(mock_response(vec![vec![1.0, 2.0, 3.0]])),
212            )
213            .mount(&server)
214            .await;
215
216        let engine =
217            RemoteEmbedding::new(server.uri(), "key", test_config()).with_client(no_proxy_client());
218        let result = engine.embed("test").await.unwrap();
219        assert_eq!(result.len(), 3);
220        assert!((result[0] - 1.0).abs() < f32::EPSILON);
221    }
222
223    #[tokio::test]
224    async fn remote_embed_batch_sends_array_input() {
225        let server = MockServer::start().await;
226        Mock::given(method("POST"))
227            .and(path("/v1/embeddings"))
228            .respond_with(ResponseTemplate::new(200).set_body_json(mock_response(vec![
229                vec![0.1, 0.2, 0.3],
230                vec![0.4, 0.5, 0.6],
231            ])))
232            .mount(&server)
233            .await;
234
235        let engine =
236            RemoteEmbedding::new(server.uri(), "key", test_config()).with_client(no_proxy_client());
237        let results = engine.embed_batch(&["hello", "world"]).await.unwrap();
238        assert_eq!(results.len(), 2);
239        assert_eq!(results[0], vec![0.1, 0.2, 0.3]);
240        assert_eq!(results[1], vec![0.4, 0.5, 0.6]);
241    }
242
243    #[tokio::test]
244    async fn remote_returns_error_on_401() {
245        let server = MockServer::start().await;
246        Mock::given(method("POST"))
247            .and(path("/v1/embeddings"))
248            .respond_with(ResponseTemplate::new(401).set_body_string("Unauthorized"))
249            .mount(&server)
250            .await;
251
252        let engine = RemoteEmbedding::new(server.uri(), "bad-key", test_config())
253            .with_client(no_proxy_client());
254        let result = engine.embed("test").await;
255        assert!(result.is_err());
256        let err = result.unwrap_err();
257        assert!(
258            err.to_string().contains("401"),
259            "Expected 401 in error: {err}"
260        );
261    }
262
263    #[tokio::test]
264    async fn remote_returns_error_on_500() {
265        let server = MockServer::start().await;
266        Mock::given(method("POST"))
267            .and(path("/v1/embeddings"))
268            .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error"))
269            .mount(&server)
270            .await;
271
272        let engine =
273            RemoteEmbedding::new(server.uri(), "key", test_config()).with_client(no_proxy_client());
274        let result = engine.embed("test").await;
275        assert!(result.is_err());
276        let err = result.unwrap_err();
277        assert!(
278            err.to_string().contains("500"),
279            "Expected 500 in error: {err}"
280        );
281    }
282}