Skip to main content

car_inference/
remote.rs

1//! Remote inference backend — HTTP client for cloud API models.
2//!
3//! Supports three protocols:
4//! - **OpenAI-compatible** — `/v1/chat/completions` (OpenAI, Ollama, vLLM, etc.)
5//! - **Anthropic** — `/v1/messages`
6//! - **Google** — Gemini API
7//!
8//! The backend reads API keys from environment variables (never stored in config).
9//! All calls are async and return the same types as local inference.
10
11use reqwest::Client;
12use serde::{Deserialize, Serialize};
13use tracing::debug;
14
15use crate::schema::{ApiProtocol, ModelSchema, ModelSource};
16use crate::InferenceError;
17
18/// Remote inference client. Reuses a single HTTP client for connection pooling.
19pub struct RemoteBackend {
20    client: Client,
21}
22
23impl RemoteBackend {
24    pub fn new() -> Self {
25        let client = Client::builder()
26            .timeout(std::time::Duration::from_secs(120))
27            .build()
28            .unwrap_or_default();
29        Self { client }
30    }
31
32    /// Generate text via a remote API.
33    pub async fn generate(
34        &self,
35        schema: &ModelSchema,
36        prompt: &str,
37        context: Option<&str>,
38        temperature: f64,
39        max_tokens: usize,
40    ) -> Result<String, InferenceError> {
41        let (endpoint, api_key, protocol) = extract_remote_config(schema)?;
42
43        match protocol {
44            ApiProtocol::OpenAiCompat => {
45                self.generate_openai(&endpoint, &api_key, &schema.name, prompt, context, temperature, max_tokens).await
46            }
47            ApiProtocol::Anthropic => {
48                self.generate_anthropic(&endpoint, &api_key, &schema.name, prompt, context, temperature, max_tokens).await
49            }
50            ApiProtocol::Google => {
51                self.generate_google(&endpoint, &api_key, &schema.name, prompt, context, temperature, max_tokens).await
52            }
53        }
54    }
55
56    /// Generate embeddings via a remote API (OpenAI-compatible only for now).
57    pub async fn embed(
58        &self,
59        schema: &ModelSchema,
60        texts: &[String],
61    ) -> Result<Vec<Vec<f32>>, InferenceError> {
62        let (endpoint, api_key, protocol) = extract_remote_config(schema)?;
63
64        match protocol {
65            ApiProtocol::OpenAiCompat => {
66                self.embed_openai(&endpoint, &api_key, &schema.name, texts).await
67            }
68            _ => Err(InferenceError::InferenceFailed(format!(
69                "embedding not supported for {:?} protocol", protocol
70            ))),
71        }
72    }
73
74    // --- OpenAI-compatible ---
75
76    async fn generate_openai(
77        &self,
78        endpoint: &str,
79        api_key: &str,
80        model: &str,
81        prompt: &str,
82        context: Option<&str>,
83        temperature: f64,
84        max_tokens: usize,
85    ) -> Result<String, InferenceError> {
86        let url = format_endpoint(endpoint, "/v1/chat/completions");
87
88        let mut messages = Vec::new();
89        if let Some(ctx) = context {
90            messages.push(serde_json::json!({
91                "role": "system",
92                "content": ctx,
93            }));
94        }
95        messages.push(serde_json::json!({
96            "role": "user",
97            "content": prompt,
98        }));
99
100        let body = serde_json::json!({
101            "model": model,
102            "messages": messages,
103            "temperature": temperature,
104            "max_tokens": max_tokens,
105        });
106
107        debug!(url = %url, model = %model, "openai-compat generate request");
108
109        let resp = self.client
110            .post(&url)
111            .header("Authorization", format!("Bearer {api_key}"))
112            .header("Content-Type", "application/json")
113            .json(&body)
114            .send()
115            .await
116            .map_err(|e| InferenceError::InferenceFailed(format!("HTTP error: {e}")))?;
117
118        let status = resp.status();
119        let text = resp.text().await
120            .map_err(|e| InferenceError::InferenceFailed(format!("read body: {e}")))?;
121
122        if !status.is_success() {
123            return Err(InferenceError::InferenceFailed(format!(
124                "API returned {status}: {text}"
125            )));
126        }
127
128        let parsed: OpenAiResponse = serde_json::from_str(&text)
129            .map_err(|e| InferenceError::InferenceFailed(format!("parse response: {e}")))?;
130
131        parsed.choices.first()
132            .and_then(|c| c.message.content.clone())
133            .ok_or_else(|| InferenceError::InferenceFailed("empty response".into()))
134    }
135
136    async fn embed_openai(
137        &self,
138        endpoint: &str,
139        api_key: &str,
140        model: &str,
141        texts: &[String],
142    ) -> Result<Vec<Vec<f32>>, InferenceError> {
143        let url = format_endpoint(endpoint, "/v1/embeddings");
144
145        let body = serde_json::json!({
146            "model": model,
147            "input": texts,
148        });
149
150        let resp = self.client
151            .post(&url)
152            .header("Authorization", format!("Bearer {api_key}"))
153            .header("Content-Type", "application/json")
154            .json(&body)
155            .send()
156            .await
157            .map_err(|e| InferenceError::InferenceFailed(format!("HTTP error: {e}")))?;
158
159        let status = resp.status();
160        let text = resp.text().await
161            .map_err(|e| InferenceError::InferenceFailed(format!("read body: {e}")))?;
162
163        if !status.is_success() {
164            return Err(InferenceError::InferenceFailed(format!(
165                "API returned {status}: {text}"
166            )));
167        }
168
169        let parsed: OpenAiEmbedResponse = serde_json::from_str(&text)
170            .map_err(|e| InferenceError::InferenceFailed(format!("parse response: {e}")))?;
171
172        Ok(parsed.data.into_iter().map(|d| d.embedding).collect())
173    }
174
175    // --- Anthropic ---
176
177    async fn generate_anthropic(
178        &self,
179        endpoint: &str,
180        api_key: &str,
181        model: &str,
182        prompt: &str,
183        context: Option<&str>,
184        temperature: f64,
185        max_tokens: usize,
186    ) -> Result<String, InferenceError> {
187        let url = format_endpoint(endpoint, "/v1/messages");
188
189        let mut body = serde_json::json!({
190            "model": model,
191            "max_tokens": max_tokens,
192            "temperature": temperature,
193            "messages": [{
194                "role": "user",
195                "content": prompt,
196            }],
197        });
198
199        if let Some(ctx) = context {
200            body["system"] = serde_json::Value::String(ctx.to_string());
201        }
202
203        debug!(url = %url, model = %model, "anthropic generate request");
204
205        let resp = self.client
206            .post(&url)
207            .header("x-api-key", api_key)
208            .header("anthropic-version", "2023-06-01")
209            .header("Content-Type", "application/json")
210            .json(&body)
211            .send()
212            .await
213            .map_err(|e| InferenceError::InferenceFailed(format!("HTTP error: {e}")))?;
214
215        let status = resp.status();
216        let text = resp.text().await
217            .map_err(|e| InferenceError::InferenceFailed(format!("read body: {e}")))?;
218
219        if !status.is_success() {
220            return Err(InferenceError::InferenceFailed(format!(
221                "API returned {status}: {text}"
222            )));
223        }
224
225        let parsed: AnthropicResponse = serde_json::from_str(&text)
226            .map_err(|e| InferenceError::InferenceFailed(format!("parse response: {e}")))?;
227
228        parsed.content.into_iter()
229            .find(|c| c.content_type == "text")
230            .map(|c| c.text)
231            .ok_or_else(|| InferenceError::InferenceFailed("no text in response".into()))
232    }
233
234    // --- Google (Gemini) ---
235
236    async fn generate_google(
237        &self,
238        endpoint: &str,
239        api_key: &str,
240        model: &str,
241        prompt: &str,
242        context: Option<&str>,
243        _temperature: f64,
244        _max_tokens: usize,
245    ) -> Result<String, InferenceError> {
246        // Gemini uses ?key= query param, not header auth
247        let url = format!(
248            "{}/v1beta/models/{}:generateContent?key={}",
249            endpoint.trim_end_matches('/'),
250            model,
251            api_key,
252        );
253
254        let mut parts = vec![serde_json::json!({"text": prompt})];
255        if let Some(ctx) = context {
256            parts.insert(0, serde_json::json!({"text": ctx}));
257        }
258
259        let body = serde_json::json!({
260            "contents": [{
261                "parts": parts,
262            }],
263        });
264
265        debug!(model = %model, "google generate request");
266
267        let resp = self.client
268            .post(&url)
269            .header("Content-Type", "application/json")
270            .json(&body)
271            .send()
272            .await
273            .map_err(|e| InferenceError::InferenceFailed(format!("HTTP error: {e}")))?;
274
275        let status = resp.status();
276        let text = resp.text().await
277            .map_err(|e| InferenceError::InferenceFailed(format!("read body: {e}")))?;
278
279        if !status.is_success() {
280            return Err(InferenceError::InferenceFailed(format!(
281                "API returned {status}: {text}"
282            )));
283        }
284
285        let parsed: GoogleResponse = serde_json::from_str(&text)
286            .map_err(|e| InferenceError::InferenceFailed(format!("parse response: {e}")))?;
287
288        parsed.candidates.into_iter()
289            .next()
290            .and_then(|c| c.content.parts.into_iter().next())
291            .map(|p| p.text)
292            .ok_or_else(|| InferenceError::InferenceFailed("no text in response".into()))
293    }
294}
295
296impl Default for RemoteBackend {
297    fn default() -> Self {
298        Self::new()
299    }
300}
301
302// --- Helpers ---
303
304/// Extract endpoint, API key, and protocol from a model schema.
305fn extract_remote_config(schema: &ModelSchema) -> Result<(String, String, ApiProtocol), InferenceError> {
306    match &schema.source {
307        ModelSource::RemoteApi { endpoint, api_key_env, protocol, .. } => {
308            let api_key = std::env::var(api_key_env).map_err(|_| {
309                InferenceError::InferenceFailed(format!(
310                    "API key env var {} not set for model {}",
311                    api_key_env, schema.id
312                ))
313            })?;
314            Ok((endpoint.clone(), api_key, *protocol))
315        }
316        ModelSource::Ollama { model_tag, host } => {
317            // Ollama uses OpenAI-compatible API
318            Ok((host.clone(), String::new(), ApiProtocol::OpenAiCompat))
319        }
320        _ => Err(InferenceError::InferenceFailed(format!(
321            "model {} is not remote", schema.id
322        ))),
323    }
324}
325
326/// Normalize endpoint URL for a given path.
327fn format_endpoint(base: &str, path: &str) -> String {
328    let base = base.trim_end_matches('/');
329    // If the base already ends with the path, use it as-is
330    if base.ends_with(path.trim_start_matches('/')) {
331        base.to_string()
332    } else {
333        format!("{}{}", base, path)
334    }
335}
336
337// --- Response types ---
338
339#[derive(Debug, Deserialize)]
340struct OpenAiResponse {
341    choices: Vec<OpenAiChoice>,
342}
343
344#[derive(Debug, Deserialize)]
345struct OpenAiChoice {
346    message: OpenAiMessage,
347}
348
349#[derive(Debug, Deserialize)]
350struct OpenAiMessage {
351    content: Option<String>,
352}
353
354#[derive(Debug, Deserialize)]
355struct OpenAiEmbedResponse {
356    data: Vec<OpenAiEmbedData>,
357}
358
359#[derive(Debug, Deserialize)]
360struct OpenAiEmbedData {
361    embedding: Vec<f32>,
362}
363
364#[derive(Debug, Deserialize)]
365struct AnthropicResponse {
366    content: Vec<AnthropicContent>,
367}
368
369#[derive(Debug, Deserialize)]
370struct AnthropicContent {
371    #[serde(rename = "type")]
372    content_type: String,
373    #[serde(default)]
374    text: String,
375}
376
377#[derive(Debug, Deserialize)]
378struct GoogleResponse {
379    candidates: Vec<GoogleCandidate>,
380}
381
382#[derive(Debug, Deserialize)]
383struct GoogleCandidate {
384    content: GoogleContent,
385}
386
387#[derive(Debug, Deserialize)]
388struct GoogleContent {
389    parts: Vec<GooglePart>,
390}
391
392#[derive(Debug, Deserialize)]
393struct GooglePart {
394    text: String,
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    #[test]
402    fn format_endpoint_no_dup() {
403        assert_eq!(
404            format_endpoint("https://api.openai.com", "/v1/chat/completions"),
405            "https://api.openai.com/v1/chat/completions"
406        );
407        assert_eq!(
408            format_endpoint("https://api.openai.com/v1/chat/completions", "/v1/chat/completions"),
409            "https://api.openai.com/v1/chat/completions"
410        );
411        assert_eq!(
412            format_endpoint("https://api.openai.com/", "/v1/chat/completions"),
413            "https://api.openai.com/v1/chat/completions"
414        );
415    }
416
417    #[test]
418    fn extract_config_missing_env() {
419        let schema = ModelSchema {
420            id: "test/model:v1".into(),
421            name: "Test".into(),
422            provider: "test".into(),
423            family: "test".into(),
424            version: "1".into(),
425            capabilities: vec![],
426            context_length: 4096,
427            param_count: String::new(),
428            quantization: None,
429            performance: Default::default(),
430            cost: Default::default(),
431            source: ModelSource::RemoteApi {
432                endpoint: "https://api.test.com".into(),
433                api_key_env: "NONEXISTENT_TEST_KEY_12345".into(),
434                api_version: None,
435                protocol: ApiProtocol::OpenAiCompat,
436            },
437            tags: vec![],
438            available: false,
439        };
440        let result = extract_remote_config(&schema);
441        assert!(result.is_err());
442    }
443
444    #[test]
445    fn parse_openai_response() {
446        let json = r#"{"choices":[{"message":{"content":"Hello world"}}]}"#;
447        let resp: OpenAiResponse = serde_json::from_str(json).unwrap();
448        assert_eq!(resp.choices[0].message.content.as_deref(), Some("Hello world"));
449    }
450
451    #[test]
452    fn parse_anthropic_response() {
453        let json = r#"{"content":[{"type":"text","text":"Hello world"}]}"#;
454        let resp: AnthropicResponse = serde_json::from_str(json).unwrap();
455        assert_eq!(resp.content[0].text, "Hello world");
456    }
457
458    #[test]
459    fn parse_google_response() {
460        let json = r#"{"candidates":[{"content":{"parts":[{"text":"Hello world"}]}}]}"#;
461        let resp: GoogleResponse = serde_json::from_str(json).unwrap();
462        assert_eq!(resp.candidates[0].content.parts[0].text, "Hello world");
463    }
464}