1use reqwest::Client;
12use serde::{Deserialize, Serialize};
13use tracing::debug;
14
15use crate::schema::{ApiProtocol, ModelSchema, ModelSource};
16use crate::InferenceError;
17
18pub struct RemoteBackend {
20 client: Client,
21}
22
23impl RemoteBackend {
24 pub fn new() -> Self {
25 let client = Client::builder()
26 .timeout(std::time::Duration::from_secs(120))
27 .build()
28 .unwrap_or_default();
29 Self { client }
30 }
31
32 pub async fn generate(
34 &self,
35 schema: &ModelSchema,
36 prompt: &str,
37 context: Option<&str>,
38 temperature: f64,
39 max_tokens: usize,
40 ) -> Result<String, InferenceError> {
41 let (endpoint, api_key, protocol) = extract_remote_config(schema)?;
42
43 match protocol {
44 ApiProtocol::OpenAiCompat => {
45 self.generate_openai(&endpoint, &api_key, &schema.name, prompt, context, temperature, max_tokens).await
46 }
47 ApiProtocol::Anthropic => {
48 self.generate_anthropic(&endpoint, &api_key, &schema.name, prompt, context, temperature, max_tokens).await
49 }
50 ApiProtocol::Google => {
51 self.generate_google(&endpoint, &api_key, &schema.name, prompt, context, temperature, max_tokens).await
52 }
53 }
54 }
55
56 pub async fn embed(
58 &self,
59 schema: &ModelSchema,
60 texts: &[String],
61 ) -> Result<Vec<Vec<f32>>, InferenceError> {
62 let (endpoint, api_key, protocol) = extract_remote_config(schema)?;
63
64 match protocol {
65 ApiProtocol::OpenAiCompat => {
66 self.embed_openai(&endpoint, &api_key, &schema.name, texts).await
67 }
68 _ => Err(InferenceError::InferenceFailed(format!(
69 "embedding not supported for {:?} protocol", protocol
70 ))),
71 }
72 }
73
74 async fn generate_openai(
77 &self,
78 endpoint: &str,
79 api_key: &str,
80 model: &str,
81 prompt: &str,
82 context: Option<&str>,
83 temperature: f64,
84 max_tokens: usize,
85 ) -> Result<String, InferenceError> {
86 let url = format_endpoint(endpoint, "/v1/chat/completions");
87
88 let mut messages = Vec::new();
89 if let Some(ctx) = context {
90 messages.push(serde_json::json!({
91 "role": "system",
92 "content": ctx,
93 }));
94 }
95 messages.push(serde_json::json!({
96 "role": "user",
97 "content": prompt,
98 }));
99
100 let body = serde_json::json!({
101 "model": model,
102 "messages": messages,
103 "temperature": temperature,
104 "max_tokens": max_tokens,
105 });
106
107 debug!(url = %url, model = %model, "openai-compat generate request");
108
109 let resp = self.client
110 .post(&url)
111 .header("Authorization", format!("Bearer {api_key}"))
112 .header("Content-Type", "application/json")
113 .json(&body)
114 .send()
115 .await
116 .map_err(|e| InferenceError::InferenceFailed(format!("HTTP error: {e}")))?;
117
118 let status = resp.status();
119 let text = resp.text().await
120 .map_err(|e| InferenceError::InferenceFailed(format!("read body: {e}")))?;
121
122 if !status.is_success() {
123 return Err(InferenceError::InferenceFailed(format!(
124 "API returned {status}: {text}"
125 )));
126 }
127
128 let parsed: OpenAiResponse = serde_json::from_str(&text)
129 .map_err(|e| InferenceError::InferenceFailed(format!("parse response: {e}")))?;
130
131 parsed.choices.first()
132 .and_then(|c| c.message.content.clone())
133 .ok_or_else(|| InferenceError::InferenceFailed("empty response".into()))
134 }
135
136 async fn embed_openai(
137 &self,
138 endpoint: &str,
139 api_key: &str,
140 model: &str,
141 texts: &[String],
142 ) -> Result<Vec<Vec<f32>>, InferenceError> {
143 let url = format_endpoint(endpoint, "/v1/embeddings");
144
145 let body = serde_json::json!({
146 "model": model,
147 "input": texts,
148 });
149
150 let resp = self.client
151 .post(&url)
152 .header("Authorization", format!("Bearer {api_key}"))
153 .header("Content-Type", "application/json")
154 .json(&body)
155 .send()
156 .await
157 .map_err(|e| InferenceError::InferenceFailed(format!("HTTP error: {e}")))?;
158
159 let status = resp.status();
160 let text = resp.text().await
161 .map_err(|e| InferenceError::InferenceFailed(format!("read body: {e}")))?;
162
163 if !status.is_success() {
164 return Err(InferenceError::InferenceFailed(format!(
165 "API returned {status}: {text}"
166 )));
167 }
168
169 let parsed: OpenAiEmbedResponse = serde_json::from_str(&text)
170 .map_err(|e| InferenceError::InferenceFailed(format!("parse response: {e}")))?;
171
172 Ok(parsed.data.into_iter().map(|d| d.embedding).collect())
173 }
174
175 async fn generate_anthropic(
178 &self,
179 endpoint: &str,
180 api_key: &str,
181 model: &str,
182 prompt: &str,
183 context: Option<&str>,
184 temperature: f64,
185 max_tokens: usize,
186 ) -> Result<String, InferenceError> {
187 let url = format_endpoint(endpoint, "/v1/messages");
188
189 let mut body = serde_json::json!({
190 "model": model,
191 "max_tokens": max_tokens,
192 "temperature": temperature,
193 "messages": [{
194 "role": "user",
195 "content": prompt,
196 }],
197 });
198
199 if let Some(ctx) = context {
200 body["system"] = serde_json::Value::String(ctx.to_string());
201 }
202
203 debug!(url = %url, model = %model, "anthropic generate request");
204
205 let resp = self.client
206 .post(&url)
207 .header("x-api-key", api_key)
208 .header("anthropic-version", "2023-06-01")
209 .header("Content-Type", "application/json")
210 .json(&body)
211 .send()
212 .await
213 .map_err(|e| InferenceError::InferenceFailed(format!("HTTP error: {e}")))?;
214
215 let status = resp.status();
216 let text = resp.text().await
217 .map_err(|e| InferenceError::InferenceFailed(format!("read body: {e}")))?;
218
219 if !status.is_success() {
220 return Err(InferenceError::InferenceFailed(format!(
221 "API returned {status}: {text}"
222 )));
223 }
224
225 let parsed: AnthropicResponse = serde_json::from_str(&text)
226 .map_err(|e| InferenceError::InferenceFailed(format!("parse response: {e}")))?;
227
228 parsed.content.into_iter()
229 .find(|c| c.content_type == "text")
230 .map(|c| c.text)
231 .ok_or_else(|| InferenceError::InferenceFailed("no text in response".into()))
232 }
233
234 async fn generate_google(
237 &self,
238 endpoint: &str,
239 api_key: &str,
240 model: &str,
241 prompt: &str,
242 context: Option<&str>,
243 _temperature: f64,
244 _max_tokens: usize,
245 ) -> Result<String, InferenceError> {
246 let url = format!(
248 "{}/v1beta/models/{}:generateContent?key={}",
249 endpoint.trim_end_matches('/'),
250 model,
251 api_key,
252 );
253
254 let mut parts = vec![serde_json::json!({"text": prompt})];
255 if let Some(ctx) = context {
256 parts.insert(0, serde_json::json!({"text": ctx}));
257 }
258
259 let body = serde_json::json!({
260 "contents": [{
261 "parts": parts,
262 }],
263 });
264
265 debug!(model = %model, "google generate request");
266
267 let resp = self.client
268 .post(&url)
269 .header("Content-Type", "application/json")
270 .json(&body)
271 .send()
272 .await
273 .map_err(|e| InferenceError::InferenceFailed(format!("HTTP error: {e}")))?;
274
275 let status = resp.status();
276 let text = resp.text().await
277 .map_err(|e| InferenceError::InferenceFailed(format!("read body: {e}")))?;
278
279 if !status.is_success() {
280 return Err(InferenceError::InferenceFailed(format!(
281 "API returned {status}: {text}"
282 )));
283 }
284
285 let parsed: GoogleResponse = serde_json::from_str(&text)
286 .map_err(|e| InferenceError::InferenceFailed(format!("parse response: {e}")))?;
287
288 parsed.candidates.into_iter()
289 .next()
290 .and_then(|c| c.content.parts.into_iter().next())
291 .map(|p| p.text)
292 .ok_or_else(|| InferenceError::InferenceFailed("no text in response".into()))
293 }
294}
295
296impl Default for RemoteBackend {
297 fn default() -> Self {
298 Self::new()
299 }
300}
301
302fn extract_remote_config(schema: &ModelSchema) -> Result<(String, String, ApiProtocol), InferenceError> {
306 match &schema.source {
307 ModelSource::RemoteApi { endpoint, api_key_env, protocol, .. } => {
308 let api_key = std::env::var(api_key_env).map_err(|_| {
309 InferenceError::InferenceFailed(format!(
310 "API key env var {} not set for model {}",
311 api_key_env, schema.id
312 ))
313 })?;
314 Ok((endpoint.clone(), api_key, *protocol))
315 }
316 ModelSource::Ollama { model_tag, host } => {
317 Ok((host.clone(), String::new(), ApiProtocol::OpenAiCompat))
319 }
320 _ => Err(InferenceError::InferenceFailed(format!(
321 "model {} is not remote", schema.id
322 ))),
323 }
324}
325
326fn format_endpoint(base: &str, path: &str) -> String {
328 let base = base.trim_end_matches('/');
329 if base.ends_with(path.trim_start_matches('/')) {
331 base.to_string()
332 } else {
333 format!("{}{}", base, path)
334 }
335}
336
337#[derive(Debug, Deserialize)]
340struct OpenAiResponse {
341 choices: Vec<OpenAiChoice>,
342}
343
344#[derive(Debug, Deserialize)]
345struct OpenAiChoice {
346 message: OpenAiMessage,
347}
348
349#[derive(Debug, Deserialize)]
350struct OpenAiMessage {
351 content: Option<String>,
352}
353
354#[derive(Debug, Deserialize)]
355struct OpenAiEmbedResponse {
356 data: Vec<OpenAiEmbedData>,
357}
358
359#[derive(Debug, Deserialize)]
360struct OpenAiEmbedData {
361 embedding: Vec<f32>,
362}
363
364#[derive(Debug, Deserialize)]
365struct AnthropicResponse {
366 content: Vec<AnthropicContent>,
367}
368
369#[derive(Debug, Deserialize)]
370struct AnthropicContent {
371 #[serde(rename = "type")]
372 content_type: String,
373 #[serde(default)]
374 text: String,
375}
376
377#[derive(Debug, Deserialize)]
378struct GoogleResponse {
379 candidates: Vec<GoogleCandidate>,
380}
381
382#[derive(Debug, Deserialize)]
383struct GoogleCandidate {
384 content: GoogleContent,
385}
386
387#[derive(Debug, Deserialize)]
388struct GoogleContent {
389 parts: Vec<GooglePart>,
390}
391
392#[derive(Debug, Deserialize)]
393struct GooglePart {
394 text: String,
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 #[test]
402 fn format_endpoint_no_dup() {
403 assert_eq!(
404 format_endpoint("https://api.openai.com", "/v1/chat/completions"),
405 "https://api.openai.com/v1/chat/completions"
406 );
407 assert_eq!(
408 format_endpoint("https://api.openai.com/v1/chat/completions", "/v1/chat/completions"),
409 "https://api.openai.com/v1/chat/completions"
410 );
411 assert_eq!(
412 format_endpoint("https://api.openai.com/", "/v1/chat/completions"),
413 "https://api.openai.com/v1/chat/completions"
414 );
415 }
416
417 #[test]
418 fn extract_config_missing_env() {
419 let schema = ModelSchema {
420 id: "test/model:v1".into(),
421 name: "Test".into(),
422 provider: "test".into(),
423 family: "test".into(),
424 version: "1".into(),
425 capabilities: vec![],
426 context_length: 4096,
427 param_count: String::new(),
428 quantization: None,
429 performance: Default::default(),
430 cost: Default::default(),
431 source: ModelSource::RemoteApi {
432 endpoint: "https://api.test.com".into(),
433 api_key_env: "NONEXISTENT_TEST_KEY_12345".into(),
434 api_version: None,
435 protocol: ApiProtocol::OpenAiCompat,
436 },
437 tags: vec![],
438 available: false,
439 };
440 let result = extract_remote_config(&schema);
441 assert!(result.is_err());
442 }
443
444 #[test]
445 fn parse_openai_response() {
446 let json = r#"{"choices":[{"message":{"content":"Hello world"}}]}"#;
447 let resp: OpenAiResponse = serde_json::from_str(json).unwrap();
448 assert_eq!(resp.choices[0].message.content.as_deref(), Some("Hello world"));
449 }
450
451 #[test]
452 fn parse_anthropic_response() {
453 let json = r#"{"content":[{"type":"text","text":"Hello world"}]}"#;
454 let resp: AnthropicResponse = serde_json::from_str(json).unwrap();
455 assert_eq!(resp.content[0].text, "Hello world");
456 }
457
458 #[test]
459 fn parse_google_response() {
460 let json = r#"{"candidates":[{"content":{"parts":[{"text":"Hello world"}]}}]}"#;
461 let resp: GoogleResponse = serde_json::from_str(json).unwrap();
462 assert_eq!(resp.candidates[0].content.parts[0].text, "Hello world");
463 }
464}