Skip to main content

pylon_plugin/builtin/
ai_proxy.rs

1use crate::Plugin;
2use std::io::{BufRead, BufReader, Read, Write};
3use std::net::TcpStream;
4use std::time::Duration;
5
6// ---------------------------------------------------------------------------
7// Data types
8// ---------------------------------------------------------------------------
9
10/// Supported AI providers.
11#[derive(Debug, Clone)]
12pub enum AiProvider {
13    Anthropic {
14        api_key: String,
15        model: String,
16    },
17    OpenAI {
18        api_key: String,
19        model: String,
20    },
21    Custom {
22        base_url: String,
23        api_key: String,
24        model: Option<String>,
25    },
26}
27
28/// A single message in a conversation.
29#[derive(Debug, Clone)]
30pub struct AiMessage {
31    pub role: String,
32    pub content: String,
33}
34
35impl AiMessage {
36    pub fn system(content: &str) -> Self {
37        Self {
38            role: "system".into(),
39            content: content.into(),
40        }
41    }
42
43    pub fn user(content: &str) -> Self {
44        Self {
45            role: "user".into(),
46            content: content.into(),
47        }
48    }
49
50    pub fn assistant(content: &str) -> Self {
51        Self {
52            role: "assistant".into(),
53            content: content.into(),
54        }
55    }
56}
57
58// ---------------------------------------------------------------------------
59// AI proxy plugin
60// ---------------------------------------------------------------------------
61
62/// Proxies requests to LLM providers with streaming support.
63///
64/// Supports Anthropic, OpenAI, and any OpenAI-compatible custom endpoint
65/// (Ollama, Together, Groq, etc.). HTTPS endpoints require a TLS-terminating
66/// reverse proxy; plain HTTP endpoints (e.g. local Ollama) work directly.
67pub struct AiProxyPlugin {
68    provider: AiProvider,
69}
70
71impl AiProxyPlugin {
72    pub fn anthropic(api_key: &str, model: &str) -> Self {
73        Self {
74            provider: AiProvider::Anthropic {
75                api_key: api_key.to_string(),
76                model: model.to_string(),
77            },
78        }
79    }
80
81    pub fn openai(api_key: &str, model: &str) -> Self {
82        Self {
83            provider: AiProvider::OpenAI {
84                api_key: api_key.to_string(),
85                model: model.to_string(),
86            },
87        }
88    }
89
90    pub fn custom(base_url: &str, api_key: &str) -> Self {
91        Self {
92            provider: AiProvider::Custom {
93                base_url: base_url.to_string(),
94                api_key: api_key.to_string(),
95                model: None,
96            },
97        }
98    }
99
100    /// Create a custom provider with an explicit model name included in
101    /// the request body sent to the upstream endpoint.
102    pub fn custom_with_model(base_url: &str, api_key: &str, model: &str) -> Self {
103        Self {
104            provider: AiProvider::Custom {
105                base_url: base_url.to_string(),
106                api_key: api_key.to_string(),
107                model: if model.is_empty() {
108                    None
109                } else {
110                    Some(model.to_string())
111                },
112            },
113        }
114    }
115
116    /// Returns a reference to the configured provider.
117    pub fn provider(&self) -> &AiProvider {
118        &self.provider
119    }
120
121    /// Stream a completion request to the configured provider.
122    ///
123    /// Calls `on_chunk` for each text token received from the provider.
124    /// Returns the full accumulated response text on success.
125    pub fn stream_completion(
126        &self,
127        messages: &[AiMessage],
128        on_chunk: &mut dyn FnMut(&str),
129    ) -> Result<String, String> {
130        match &self.provider {
131            AiProvider::Anthropic { api_key, model } => {
132                self.stream_anthropic(api_key, model, messages, on_chunk)
133            }
134            AiProvider::OpenAI { api_key, model } => {
135                self.stream_openai(api_key, model, messages, on_chunk)
136            }
137            AiProvider::Custom {
138                base_url,
139                api_key,
140                model,
141            } => self.stream_custom(base_url, api_key, model.as_deref(), messages, on_chunk),
142        }
143    }
144
145    /// Non-streaming convenience wrapper. Waits for the full response.
146    pub fn completion(&self, messages: &[AiMessage]) -> Result<String, String> {
147        let mut full = String::new();
148        self.stream_completion(messages, &mut |chunk| {
149            full.push_str(chunk);
150        })?;
151        Ok(full)
152    }
153
154    // -----------------------------------------------------------------------
155    // Provider-specific streaming
156    // -----------------------------------------------------------------------
157
158    fn stream_anthropic(
159        &self,
160        api_key: &str,
161        model: &str,
162        messages: &[AiMessage],
163        on_chunk: &mut dyn FnMut(&str),
164    ) -> Result<String, String> {
165        let msgs: Vec<serde_json::Value> = messages
166            .iter()
167            .map(|m| serde_json::json!({"role": m.role, "content": m.content}))
168            .collect();
169
170        let body = serde_json::json!({
171            "model": model,
172            "max_tokens": 4096,
173            "stream": true,
174            "messages": msgs,
175        })
176        .to_string();
177
178        self.stream_https_request(
179            "api.anthropic.com",
180            443,
181            "/v1/messages",
182            &[
183                ("x-api-key", api_key),
184                ("anthropic-version", "2023-06-01"),
185                ("content-type", "application/json"),
186            ],
187            &body,
188            on_chunk,
189            parse_anthropic_sse,
190        )
191    }
192
193    fn stream_openai(
194        &self,
195        api_key: &str,
196        model: &str,
197        messages: &[AiMessage],
198        on_chunk: &mut dyn FnMut(&str),
199    ) -> Result<String, String> {
200        let msgs: Vec<serde_json::Value> = messages
201            .iter()
202            .map(|m| serde_json::json!({"role": m.role, "content": m.content}))
203            .collect();
204
205        let body = serde_json::json!({
206            "model": model,
207            "stream": true,
208            "max_tokens": 4096,
209            "messages": msgs,
210        })
211        .to_string();
212
213        self.stream_https_request(
214            "api.openai.com",
215            443,
216            "/v1/chat/completions",
217            &[
218                ("Authorization", &format!("Bearer {api_key}")),
219                ("Content-Type", "application/json"),
220            ],
221            &body,
222            on_chunk,
223            parse_openai_sse,
224        )
225    }
226
227    fn stream_custom(
228        &self,
229        base_url: &str,
230        api_key: &str,
231        model: Option<&str>,
232        messages: &[AiMessage],
233        on_chunk: &mut dyn FnMut(&str),
234    ) -> Result<String, String> {
235        let is_https = base_url.starts_with("https://");
236        let url = base_url
237            .strip_prefix("https://")
238            .or_else(|| base_url.strip_prefix("http://"))
239            .unwrap_or(base_url);
240
241        let (host, path) = match url.find('/') {
242            Some(i) => (&url[..i], &url[i..]),
243            None => (url, "/v1/chat/completions"),
244        };
245
246        let msgs: Vec<serde_json::Value> = messages
247            .iter()
248            .map(|m| serde_json::json!({"role": m.role, "content": m.content}))
249            .collect();
250
251        let mut body_value = serde_json::json!({
252            "stream": true,
253            "messages": msgs,
254        });
255
256        // Include model in the request body when configured.
257        if let Some(m) = model {
258            body_value["model"] = serde_json::json!(m);
259        }
260
261        let body = body_value.to_string();
262
263        if is_https {
264            let port = 443;
265            return self.stream_https_request(
266                host,
267                port,
268                path,
269                &[
270                    ("Authorization", &format!("Bearer {api_key}")),
271                    ("Content-Type", "application/json"),
272                ],
273                &body,
274                on_chunk,
275                parse_openai_sse,
276            );
277        }
278
279        self.stream_http_request(host, 80, path, api_key, &body, on_chunk)
280    }
281
282    // -----------------------------------------------------------------------
283    // Transport
284    // -----------------------------------------------------------------------
285
286    /// HTTPS transport stub. Real HTTPS requires a TLS library (rustls or
287    /// native-tls) which we deliberately avoid to keep the dependency tree
288    /// minimal. Users who need HTTPS should either:
289    ///   - Use a local TLS-terminating proxy (nginx, caddy, stunnel).
290    ///   - Use a plain-HTTP custom endpoint (e.g. local Ollama on port 11434).
291    fn stream_https_request(
292        &self,
293        _host: &str,
294        _port: u16,
295        _path: &str,
296        _headers: &[(&str, &str)],
297        _body: &str,
298        _on_chunk: &mut dyn FnMut(&str),
299        _parse_chunk: fn(&str) -> Option<String>,
300    ) -> Result<String, String> {
301        Err(
302            "HTTPS streaming requires a TLS library. Configure a TLS-terminating \
303             reverse proxy or use a plain-HTTP custom endpoint (e.g. Ollama)."
304                .into(),
305        )
306    }
307
308    /// Plain HTTP streaming for local/custom endpoints (Ollama, vLLM, etc.).
309    fn stream_http_request(
310        &self,
311        host: &str,
312        port: u16,
313        path: &str,
314        api_key: &str,
315        body: &str,
316        on_chunk: &mut dyn FnMut(&str),
317    ) -> Result<String, String> {
318        let addr = format!("{host}:{port}");
319        let mut stream =
320            TcpStream::connect(&addr).map_err(|e| format!("Connection failed: {e}"))?;
321        stream.set_read_timeout(Some(Duration::from_secs(120))).ok();
322
323        // Build raw HTTP/1.1 request.
324        let mut req = format!(
325            "POST {path} HTTP/1.1\r\n\
326             Host: {host}\r\n\
327             Content-Type: application/json\r\n\
328             Content-Length: {}\r\n\
329             Connection: keep-alive\r\n",
330            body.len()
331        );
332        if !api_key.is_empty() {
333            req.push_str(&format!("Authorization: Bearer {api_key}\r\n"));
334        }
335        req.push_str("\r\n");
336        req.push_str(body);
337
338        stream
339            .write_all(req.as_bytes())
340            .map_err(|e| format!("Write failed: {e}"))?;
341
342        // Consume response headers.
343        let mut reader = BufReader::new(stream);
344        let mut header_line = String::new();
345        let mut status_code: u16 = 0;
346        let mut first_line = true;
347        loop {
348            header_line.clear();
349            reader
350                .read_line(&mut header_line)
351                .map_err(|e| format!("Read failed: {e}"))?;
352            if first_line {
353                // Parse "HTTP/1.1 200 OK"
354                status_code = header_line
355                    .split_whitespace()
356                    .nth(1)
357                    .and_then(|s| s.parse().ok())
358                    .unwrap_or(0);
359                first_line = false;
360            }
361            if header_line.trim().is_empty() {
362                break;
363            }
364        }
365
366        if status_code != 200 {
367            // Read the error body (up to 4 KB).
368            let mut err_body = vec![0u8; 4096];
369            let n = reader.read(&mut err_body).unwrap_or(0);
370            let err_text = String::from_utf8_lossy(&err_body[..n]);
371            return Err(format!("Provider returned HTTP {status_code}: {err_text}"));
372        }
373
374        // Read SSE data lines.
375        let mut full_response = String::new();
376        let mut line = String::new();
377        loop {
378            line.clear();
379            match reader.read_line(&mut line) {
380                Ok(0) => break,
381                Ok(_) => {
382                    let trimmed = line.trim();
383                    if trimmed.is_empty() {
384                        continue;
385                    }
386                    if let Some(text) = parse_openai_sse(trimmed) {
387                        full_response.push_str(&text);
388                        on_chunk(&text);
389                    }
390                    // Check for [DONE] sentinel.
391                    if trimmed == "data: [DONE]" {
392                        break;
393                    }
394                }
395                Err(_) => break,
396            }
397        }
398
399        Ok(full_response)
400    }
401}
402
403impl Plugin for AiProxyPlugin {
404    fn name(&self) -> &str {
405        "ai-proxy"
406    }
407}
408
409// ---------------------------------------------------------------------------
410// SSE parsers — free functions so they can be passed as fn pointers
411// ---------------------------------------------------------------------------
412
413/// Extract text from an Anthropic SSE data line.
414///
415/// Anthropic sends `content_block_delta` events with
416/// `{"delta":{"type":"text_delta","text":"..."}}`.
417fn parse_anthropic_sse(line: &str) -> Option<String> {
418    let data = line.strip_prefix("data: ")?;
419    let parsed: serde_json::Value = serde_json::from_str(data).ok()?;
420    if parsed.get("type").and_then(|t| t.as_str()) != Some("content_block_delta") {
421        return None;
422    }
423    let delta = parsed.get("delta")?;
424    // Only extract text from text_delta events; ignore tool_use or other delta types.
425    if delta.get("type").and_then(|t| t.as_str()) != Some("text_delta") {
426        return None;
427    }
428    delta
429        .get("text")
430        .and_then(|t| t.as_str())
431        .map(|s| s.to_string())
432}
433
434/// Extract text from an OpenAI-compatible SSE data line.
435///
436/// OpenAI (and compatible APIs like Ollama, Together, Groq) sends
437/// `{"choices":[{"delta":{"content":"..."}}]}`.
438fn parse_openai_sse(line: &str) -> Option<String> {
439    let data = line.strip_prefix("data: ")?;
440    if data == "[DONE]" {
441        return None;
442    }
443    let parsed: serde_json::Value = serde_json::from_str(data).ok()?;
444    parsed
445        .get("choices")
446        .and_then(|c| c.get(0))
447        .and_then(|c| c.get("delta"))
448        .and_then(|d| d.get("content"))
449        .and_then(|t| t.as_str())
450        .map(|s| s.to_string())
451}
452
453// ---------------------------------------------------------------------------
454// Tests
455// ---------------------------------------------------------------------------
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    #[test]
462    fn creates_anthropic_provider() {
463        let plugin = AiProxyPlugin::anthropic("sk-ant-test", "claude-sonnet-4-20250514");
464        match plugin.provider() {
465            AiProvider::Anthropic { api_key, model } => {
466                assert_eq!(api_key, "sk-ant-test");
467                assert_eq!(model, "claude-sonnet-4-20250514");
468            }
469            _ => panic!("Expected Anthropic provider"),
470        }
471    }
472
473    #[test]
474    fn creates_openai_provider() {
475        let plugin = AiProxyPlugin::openai("sk-test", "gpt-4");
476        match plugin.provider() {
477            AiProvider::OpenAI { api_key, model } => {
478                assert_eq!(api_key, "sk-test");
479                assert_eq!(model, "gpt-4");
480            }
481            _ => panic!("Expected OpenAI provider"),
482        }
483    }
484
485    #[test]
486    fn creates_custom_provider() {
487        let plugin = AiProxyPlugin::custom("http://localhost:11434/v1/chat/completions", "key");
488        match plugin.provider() {
489            AiProvider::Custom {
490                base_url,
491                api_key,
492                model,
493            } => {
494                assert_eq!(base_url, "http://localhost:11434/v1/chat/completions");
495                assert_eq!(api_key, "key");
496                assert!(model.is_none());
497            }
498            _ => panic!("Expected Custom provider"),
499        }
500    }
501
502    #[test]
503    fn creates_custom_provider_with_model() {
504        let plugin = AiProxyPlugin::custom_with_model("http://localhost:11434", "key", "llama3");
505        match plugin.provider() {
506            AiProvider::Custom {
507                base_url,
508                api_key,
509                model,
510            } => {
511                assert_eq!(base_url, "http://localhost:11434");
512                assert_eq!(api_key, "key");
513                assert_eq!(model.as_deref(), Some("llama3"));
514            }
515            _ => panic!("Expected Custom provider"),
516        }
517    }
518
519    #[test]
520    fn custom_with_empty_model_stores_none() {
521        let plugin = AiProxyPlugin::custom_with_model("http://localhost:11434", "key", "");
522        match plugin.provider() {
523            AiProvider::Custom { model, .. } => {
524                assert!(model.is_none());
525            }
526            _ => panic!("Expected Custom provider"),
527        }
528    }
529
530    #[test]
531    fn ai_message_constructors() {
532        let sys = AiMessage::system("You are helpful.");
533        assert_eq!(sys.role, "system");
534        assert_eq!(sys.content, "You are helpful.");
535
536        let user = AiMessage::user("Hello!");
537        assert_eq!(user.role, "user");
538        assert_eq!(user.content, "Hello!");
539
540        let asst = AiMessage::assistant("Hi there.");
541        assert_eq!(asst.role, "assistant");
542        assert_eq!(asst.content, "Hi there.");
543    }
544
545    #[test]
546    fn plugin_name() {
547        let plugin = AiProxyPlugin::openai("key", "model");
548        assert_eq!(plugin.name(), "ai-proxy");
549    }
550
551    #[test]
552    fn completion_without_server_returns_error() {
553        // Attempting to reach an unreachable host should return an error,
554        // not panic.
555        let plugin = AiProxyPlugin::custom("http://127.0.0.1:19999", "");
556        let msgs = vec![AiMessage::user("hi")];
557        let result = plugin.completion(&msgs);
558        assert!(result.is_err());
559    }
560
561    #[test]
562    fn parse_anthropic_sse_extracts_text() {
563        let line = r#"data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#;
564        assert_eq!(parse_anthropic_sse(line), Some("Hello".to_string()));
565    }
566
567    #[test]
568    fn parse_anthropic_sse_ignores_non_delta() {
569        let line = r#"data: {"type":"message_start","message":{}}"#;
570        assert_eq!(parse_anthropic_sse(line), None);
571    }
572
573    #[test]
574    fn parse_anthropic_sse_ignores_non_text_delta() {
575        // A content_block_delta with a non-text_delta type (e.g. tool_use)
576        // should be ignored.
577        let line = r#"data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\"x\":1}"}}"#;
578        assert_eq!(parse_anthropic_sse(line), None);
579    }
580
581    #[test]
582    fn parse_openai_sse_extracts_content() {
583        let line = r#"data: {"id":"x","choices":[{"index":0,"delta":{"content":" world"}}]}"#;
584        assert_eq!(parse_openai_sse(line), Some(" world".to_string()));
585    }
586
587    #[test]
588    fn parse_openai_sse_handles_done() {
589        assert_eq!(parse_openai_sse("data: [DONE]"), None);
590    }
591
592    #[test]
593    fn parse_openai_sse_ignores_non_data() {
594        assert_eq!(parse_openai_sse("event: message"), None);
595    }
596
597    #[test]
598    fn https_returns_informative_error() {
599        let plugin = AiProxyPlugin::anthropic("key", "model");
600        let msgs = vec![AiMessage::user("hi")];
601        let result = plugin.completion(&msgs);
602        assert!(result.is_err());
603        let err = result.unwrap_err();
604        assert!(err.contains("TLS"), "Error should mention TLS: {err}");
605    }
606}