1use crate::Plugin;
2use std::io::{BufRead, BufReader, Read, Write};
3use std::net::TcpStream;
4use std::time::Duration;
5
6#[derive(Debug, Clone)]
12pub enum AiProvider {
13 Anthropic {
14 api_key: String,
15 model: String,
16 },
17 OpenAI {
18 api_key: String,
19 model: String,
20 },
21 Custom {
22 base_url: String,
23 api_key: String,
24 model: Option<String>,
25 },
26}
27
28#[derive(Debug, Clone)]
30pub struct AiMessage {
31 pub role: String,
32 pub content: String,
33}
34
35impl AiMessage {
36 pub fn system(content: &str) -> Self {
37 Self {
38 role: "system".into(),
39 content: content.into(),
40 }
41 }
42
43 pub fn user(content: &str) -> Self {
44 Self {
45 role: "user".into(),
46 content: content.into(),
47 }
48 }
49
50 pub fn assistant(content: &str) -> Self {
51 Self {
52 role: "assistant".into(),
53 content: content.into(),
54 }
55 }
56}
57
58pub struct AiProxyPlugin {
68 provider: AiProvider,
69}
70
71impl AiProxyPlugin {
72 pub fn anthropic(api_key: &str, model: &str) -> Self {
73 Self {
74 provider: AiProvider::Anthropic {
75 api_key: api_key.to_string(),
76 model: model.to_string(),
77 },
78 }
79 }
80
81 pub fn openai(api_key: &str, model: &str) -> Self {
82 Self {
83 provider: AiProvider::OpenAI {
84 api_key: api_key.to_string(),
85 model: model.to_string(),
86 },
87 }
88 }
89
90 pub fn custom(base_url: &str, api_key: &str) -> Self {
91 Self {
92 provider: AiProvider::Custom {
93 base_url: base_url.to_string(),
94 api_key: api_key.to_string(),
95 model: None,
96 },
97 }
98 }
99
100 pub fn custom_with_model(base_url: &str, api_key: &str, model: &str) -> Self {
103 Self {
104 provider: AiProvider::Custom {
105 base_url: base_url.to_string(),
106 api_key: api_key.to_string(),
107 model: if model.is_empty() {
108 None
109 } else {
110 Some(model.to_string())
111 },
112 },
113 }
114 }
115
116 pub fn provider(&self) -> &AiProvider {
118 &self.provider
119 }
120
121 pub fn stream_completion(
126 &self,
127 messages: &[AiMessage],
128 on_chunk: &mut dyn FnMut(&str),
129 ) -> Result<String, String> {
130 match &self.provider {
131 AiProvider::Anthropic { api_key, model } => {
132 self.stream_anthropic(api_key, model, messages, on_chunk)
133 }
134 AiProvider::OpenAI { api_key, model } => {
135 self.stream_openai(api_key, model, messages, on_chunk)
136 }
137 AiProvider::Custom {
138 base_url,
139 api_key,
140 model,
141 } => self.stream_custom(base_url, api_key, model.as_deref(), messages, on_chunk),
142 }
143 }
144
145 pub fn completion(&self, messages: &[AiMessage]) -> Result<String, String> {
147 let mut full = String::new();
148 self.stream_completion(messages, &mut |chunk| {
149 full.push_str(chunk);
150 })?;
151 Ok(full)
152 }
153
154 fn stream_anthropic(
159 &self,
160 api_key: &str,
161 model: &str,
162 messages: &[AiMessage],
163 on_chunk: &mut dyn FnMut(&str),
164 ) -> Result<String, String> {
165 let msgs: Vec<serde_json::Value> = messages
166 .iter()
167 .map(|m| serde_json::json!({"role": m.role, "content": m.content}))
168 .collect();
169
170 let body = serde_json::json!({
171 "model": model,
172 "max_tokens": 4096,
173 "stream": true,
174 "messages": msgs,
175 })
176 .to_string();
177
178 self.stream_https_request(
179 "api.anthropic.com",
180 443,
181 "/v1/messages",
182 &[
183 ("x-api-key", api_key),
184 ("anthropic-version", "2023-06-01"),
185 ("content-type", "application/json"),
186 ],
187 &body,
188 on_chunk,
189 parse_anthropic_sse,
190 )
191 }
192
193 fn stream_openai(
194 &self,
195 api_key: &str,
196 model: &str,
197 messages: &[AiMessage],
198 on_chunk: &mut dyn FnMut(&str),
199 ) -> Result<String, String> {
200 let msgs: Vec<serde_json::Value> = messages
201 .iter()
202 .map(|m| serde_json::json!({"role": m.role, "content": m.content}))
203 .collect();
204
205 let body = serde_json::json!({
206 "model": model,
207 "stream": true,
208 "max_tokens": 4096,
209 "messages": msgs,
210 })
211 .to_string();
212
213 self.stream_https_request(
214 "api.openai.com",
215 443,
216 "/v1/chat/completions",
217 &[
218 ("Authorization", &format!("Bearer {api_key}")),
219 ("Content-Type", "application/json"),
220 ],
221 &body,
222 on_chunk,
223 parse_openai_sse,
224 )
225 }
226
227 fn stream_custom(
228 &self,
229 base_url: &str,
230 api_key: &str,
231 model: Option<&str>,
232 messages: &[AiMessage],
233 on_chunk: &mut dyn FnMut(&str),
234 ) -> Result<String, String> {
235 let is_https = base_url.starts_with("https://");
236 let url = base_url
237 .strip_prefix("https://")
238 .or_else(|| base_url.strip_prefix("http://"))
239 .unwrap_or(base_url);
240
241 let (host, path) = match url.find('/') {
242 Some(i) => (&url[..i], &url[i..]),
243 None => (url, "/v1/chat/completions"),
244 };
245
246 let msgs: Vec<serde_json::Value> = messages
247 .iter()
248 .map(|m| serde_json::json!({"role": m.role, "content": m.content}))
249 .collect();
250
251 let mut body_value = serde_json::json!({
252 "stream": true,
253 "messages": msgs,
254 });
255
256 if let Some(m) = model {
258 body_value["model"] = serde_json::json!(m);
259 }
260
261 let body = body_value.to_string();
262
263 if is_https {
264 let port = 443;
265 return self.stream_https_request(
266 host,
267 port,
268 path,
269 &[
270 ("Authorization", &format!("Bearer {api_key}")),
271 ("Content-Type", "application/json"),
272 ],
273 &body,
274 on_chunk,
275 parse_openai_sse,
276 );
277 }
278
279 self.stream_http_request(host, 80, path, api_key, &body, on_chunk)
280 }
281
282 fn stream_https_request(
292 &self,
293 _host: &str,
294 _port: u16,
295 _path: &str,
296 _headers: &[(&str, &str)],
297 _body: &str,
298 _on_chunk: &mut dyn FnMut(&str),
299 _parse_chunk: fn(&str) -> Option<String>,
300 ) -> Result<String, String> {
301 Err(
302 "HTTPS streaming requires a TLS library. Configure a TLS-terminating \
303 reverse proxy or use a plain-HTTP custom endpoint (e.g. Ollama)."
304 .into(),
305 )
306 }
307
308 fn stream_http_request(
310 &self,
311 host: &str,
312 port: u16,
313 path: &str,
314 api_key: &str,
315 body: &str,
316 on_chunk: &mut dyn FnMut(&str),
317 ) -> Result<String, String> {
318 let addr = format!("{host}:{port}");
319 let mut stream =
320 TcpStream::connect(&addr).map_err(|e| format!("Connection failed: {e}"))?;
321 stream.set_read_timeout(Some(Duration::from_secs(120))).ok();
322
323 let mut req = format!(
325 "POST {path} HTTP/1.1\r\n\
326 Host: {host}\r\n\
327 Content-Type: application/json\r\n\
328 Content-Length: {}\r\n\
329 Connection: keep-alive\r\n",
330 body.len()
331 );
332 if !api_key.is_empty() {
333 req.push_str(&format!("Authorization: Bearer {api_key}\r\n"));
334 }
335 req.push_str("\r\n");
336 req.push_str(body);
337
338 stream
339 .write_all(req.as_bytes())
340 .map_err(|e| format!("Write failed: {e}"))?;
341
342 let mut reader = BufReader::new(stream);
344 let mut header_line = String::new();
345 let mut status_code: u16 = 0;
346 let mut first_line = true;
347 loop {
348 header_line.clear();
349 reader
350 .read_line(&mut header_line)
351 .map_err(|e| format!("Read failed: {e}"))?;
352 if first_line {
353 status_code = header_line
355 .split_whitespace()
356 .nth(1)
357 .and_then(|s| s.parse().ok())
358 .unwrap_or(0);
359 first_line = false;
360 }
361 if header_line.trim().is_empty() {
362 break;
363 }
364 }
365
366 if status_code != 200 {
367 let mut err_body = vec![0u8; 4096];
369 let n = reader.read(&mut err_body).unwrap_or(0);
370 let err_text = String::from_utf8_lossy(&err_body[..n]);
371 return Err(format!("Provider returned HTTP {status_code}: {err_text}"));
372 }
373
374 let mut full_response = String::new();
376 let mut line = String::new();
377 loop {
378 line.clear();
379 match reader.read_line(&mut line) {
380 Ok(0) => break,
381 Ok(_) => {
382 let trimmed = line.trim();
383 if trimmed.is_empty() {
384 continue;
385 }
386 if let Some(text) = parse_openai_sse(trimmed) {
387 full_response.push_str(&text);
388 on_chunk(&text);
389 }
390 if trimmed == "data: [DONE]" {
392 break;
393 }
394 }
395 Err(_) => break,
396 }
397 }
398
399 Ok(full_response)
400 }
401}
402
403impl Plugin for AiProxyPlugin {
404 fn name(&self) -> &str {
405 "ai-proxy"
406 }
407}
408
409fn parse_anthropic_sse(line: &str) -> Option<String> {
418 let data = line.strip_prefix("data: ")?;
419 let parsed: serde_json::Value = serde_json::from_str(data).ok()?;
420 if parsed.get("type").and_then(|t| t.as_str()) != Some("content_block_delta") {
421 return None;
422 }
423 let delta = parsed.get("delta")?;
424 if delta.get("type").and_then(|t| t.as_str()) != Some("text_delta") {
426 return None;
427 }
428 delta
429 .get("text")
430 .and_then(|t| t.as_str())
431 .map(|s| s.to_string())
432}
433
434fn parse_openai_sse(line: &str) -> Option<String> {
439 let data = line.strip_prefix("data: ")?;
440 if data == "[DONE]" {
441 return None;
442 }
443 let parsed: serde_json::Value = serde_json::from_str(data).ok()?;
444 parsed
445 .get("choices")
446 .and_then(|c| c.get(0))
447 .and_then(|c| c.get("delta"))
448 .and_then(|d| d.get("content"))
449 .and_then(|t| t.as_str())
450 .map(|s| s.to_string())
451}
452
453#[cfg(test)]
458mod tests {
459 use super::*;
460
461 #[test]
462 fn creates_anthropic_provider() {
463 let plugin = AiProxyPlugin::anthropic("sk-ant-test", "claude-sonnet-4-20250514");
464 match plugin.provider() {
465 AiProvider::Anthropic { api_key, model } => {
466 assert_eq!(api_key, "sk-ant-test");
467 assert_eq!(model, "claude-sonnet-4-20250514");
468 }
469 _ => panic!("Expected Anthropic provider"),
470 }
471 }
472
473 #[test]
474 fn creates_openai_provider() {
475 let plugin = AiProxyPlugin::openai("sk-test", "gpt-4");
476 match plugin.provider() {
477 AiProvider::OpenAI { api_key, model } => {
478 assert_eq!(api_key, "sk-test");
479 assert_eq!(model, "gpt-4");
480 }
481 _ => panic!("Expected OpenAI provider"),
482 }
483 }
484
485 #[test]
486 fn creates_custom_provider() {
487 let plugin = AiProxyPlugin::custom("http://localhost:11434/v1/chat/completions", "key");
488 match plugin.provider() {
489 AiProvider::Custom {
490 base_url,
491 api_key,
492 model,
493 } => {
494 assert_eq!(base_url, "http://localhost:11434/v1/chat/completions");
495 assert_eq!(api_key, "key");
496 assert!(model.is_none());
497 }
498 _ => panic!("Expected Custom provider"),
499 }
500 }
501
502 #[test]
503 fn creates_custom_provider_with_model() {
504 let plugin = AiProxyPlugin::custom_with_model("http://localhost:11434", "key", "llama3");
505 match plugin.provider() {
506 AiProvider::Custom {
507 base_url,
508 api_key,
509 model,
510 } => {
511 assert_eq!(base_url, "http://localhost:11434");
512 assert_eq!(api_key, "key");
513 assert_eq!(model.as_deref(), Some("llama3"));
514 }
515 _ => panic!("Expected Custom provider"),
516 }
517 }
518
519 #[test]
520 fn custom_with_empty_model_stores_none() {
521 let plugin = AiProxyPlugin::custom_with_model("http://localhost:11434", "key", "");
522 match plugin.provider() {
523 AiProvider::Custom { model, .. } => {
524 assert!(model.is_none());
525 }
526 _ => panic!("Expected Custom provider"),
527 }
528 }
529
530 #[test]
531 fn ai_message_constructors() {
532 let sys = AiMessage::system("You are helpful.");
533 assert_eq!(sys.role, "system");
534 assert_eq!(sys.content, "You are helpful.");
535
536 let user = AiMessage::user("Hello!");
537 assert_eq!(user.role, "user");
538 assert_eq!(user.content, "Hello!");
539
540 let asst = AiMessage::assistant("Hi there.");
541 assert_eq!(asst.role, "assistant");
542 assert_eq!(asst.content, "Hi there.");
543 }
544
545 #[test]
546 fn plugin_name() {
547 let plugin = AiProxyPlugin::openai("key", "model");
548 assert_eq!(plugin.name(), "ai-proxy");
549 }
550
551 #[test]
552 fn completion_without_server_returns_error() {
553 let plugin = AiProxyPlugin::custom("http://127.0.0.1:19999", "");
556 let msgs = vec![AiMessage::user("hi")];
557 let result = plugin.completion(&msgs);
558 assert!(result.is_err());
559 }
560
561 #[test]
562 fn parse_anthropic_sse_extracts_text() {
563 let line = r#"data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#;
564 assert_eq!(parse_anthropic_sse(line), Some("Hello".to_string()));
565 }
566
567 #[test]
568 fn parse_anthropic_sse_ignores_non_delta() {
569 let line = r#"data: {"type":"message_start","message":{}}"#;
570 assert_eq!(parse_anthropic_sse(line), None);
571 }
572
573 #[test]
574 fn parse_anthropic_sse_ignores_non_text_delta() {
575 let line = r#"data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\"x\":1}"}}"#;
578 assert_eq!(parse_anthropic_sse(line), None);
579 }
580
581 #[test]
582 fn parse_openai_sse_extracts_content() {
583 let line = r#"data: {"id":"x","choices":[{"index":0,"delta":{"content":" world"}}]}"#;
584 assert_eq!(parse_openai_sse(line), Some(" world".to_string()));
585 }
586
587 #[test]
588 fn parse_openai_sse_handles_done() {
589 assert_eq!(parse_openai_sse("data: [DONE]"), None);
590 }
591
592 #[test]
593 fn parse_openai_sse_ignores_non_data() {
594 assert_eq!(parse_openai_sse("event: message"), None);
595 }
596
597 #[test]
598 fn https_returns_informative_error() {
599 let plugin = AiProxyPlugin::anthropic("key", "model");
600 let msgs = vec![AiMessage::user("hi")];
601 let result = plugin.completion(&msgs);
602 assert!(result.is_err());
603 let err = result.unwrap_err();
604 assert!(err.contains("TLS"), "Error should mention TLS: {err}");
605 }
606}