1use clap::ValueEnum;
17use serde::{Deserialize, Serialize};
18
19pub trait ProviderPlugin: Send + Sync {
29 fn name(&self) -> &str;
31 fn default_model(&self) -> &str;
33 fn api_url(&self) -> &str;
39 fn build_request(&self, prompt: &str, system: Option<&str>, model: &str) -> serde_json::Value;
41}
42
43pub struct OpenAiPlugin;
45pub struct AnthropicPlugin;
47
48impl ProviderPlugin for OpenAiPlugin {
49 fn name(&self) -> &str {
50 "openai"
51 }
52 fn default_model(&self) -> &str {
53 "gpt-3.5-turbo"
54 }
55 fn api_url(&self) -> &str {
56 "https://api.openai.com/v1/chat/completions"
57 }
58 fn build_request(&self, prompt: &str, system: Option<&str>, model: &str) -> serde_json::Value {
59 let mut messages = Vec::new();
60 if let Some(sys) = system {
61 messages.push(serde_json::json!({ "role": "system", "content": sys }));
62 }
63 messages.push(serde_json::json!({ "role": "user", "content": prompt }));
64 serde_json::json!({
65 "model": model,
66 "messages": messages,
67 "stream": true,
68 "temperature": 0.7,
69 "logprobs": true,
70 "top_logprobs": 5,
71 })
72 }
73}
74
75pub const ANTHROPIC_API_VERSION: &str = "2023-06-01";
78
79impl ProviderPlugin for AnthropicPlugin {
80 fn name(&self) -> &str {
81 "anthropic"
82 }
83 fn default_model(&self) -> &str {
84 "claude-sonnet-4-6"
85 }
86 fn api_url(&self) -> &str {
87 "https://api.anthropic.com/v1/messages"
88 }
89 fn build_request(&self, prompt: &str, system: Option<&str>, model: &str) -> serde_json::Value {
90 let mut req = serde_json::json!({
91 "model": model,
92 "messages": [{ "role": "user", "content": prompt }],
93 "max_tokens": 1024,
94 "stream": true,
95 "temperature": 0.7,
96 });
97 if let Some(sys) = system {
98 req["system"] = serde_json::Value::String(sys.to_string());
99 }
100 req
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct OpenAITopLogprob {
109 pub token: String,
110 pub logprob: f32,
111}
112
113#[derive(Debug, Clone, Deserialize)]
115pub struct OpenAILogprobContent {
116 pub token: String,
117 pub logprob: f32,
118 #[serde(default)]
119 pub top_logprobs: Vec<OpenAITopLogprob>,
120}
121
122#[derive(Debug, Clone, Deserialize)]
124pub struct OpenAIChunkLogprobs {
125 #[serde(default)]
126 pub content: Vec<OpenAILogprobContent>,
127}
128
129#[derive(Debug, Clone, ValueEnum, PartialEq)]
134pub enum Provider {
135 Openai,
137 Anthropic,
139 Mock,
141}
142
143impl std::fmt::Display for Provider {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 match self {
146 Provider::Openai => write!(f, "openai"),
147 Provider::Anthropic => write!(f, "anthropic"),
148 Provider::Mock => write!(f, "mock"),
149 }
150 }
151}
152
153impl std::str::FromStr for Provider {
154 type Err = String;
155
156 fn from_str(s: &str) -> Result<Self, Self::Err> {
157 match s.to_lowercase().as_str() {
158 "openai" => Ok(Provider::Openai),
159 "anthropic" => Ok(Provider::Anthropic),
160 "mock" => Ok(Provider::Mock),
161 other => Err(format!(
162 "unknown provider: '{}' (expected openai, anthropic, or mock)",
163 other
164 )),
165 }
166 }
167}
168
169#[derive(Debug, Serialize)]
173pub struct OpenAIChatMessage {
174 pub role: String,
176 pub content: String,
178}
179
180#[derive(Debug, Serialize)]
182pub struct OpenAIChatRequest {
183 pub model: String,
185 pub messages: Vec<OpenAIChatMessage>,
187 pub stream: bool,
189 pub temperature: f32,
191 pub logprobs: bool,
193 pub top_logprobs: u8,
195}
196
197#[derive(Debug, Deserialize)]
199pub struct OpenAIDelta {
200 pub content: Option<String>,
202}
203
204#[derive(Debug, Deserialize)]
206pub struct OpenAIChoice {
207 pub delta: OpenAIDelta,
209 #[allow(dead_code)]
211 pub finish_reason: Option<String>,
212 #[serde(default)]
214 pub logprobs: Option<OpenAIChunkLogprobs>,
215}
216
217#[derive(Debug, Deserialize)]
219pub struct OpenAIChunk {
220 pub choices: Vec<OpenAIChoice>,
222}
223
224#[derive(Debug, Serialize)]
228pub struct AnthropicMessage {
229 pub role: String,
231 pub content: String,
233}
234
235#[derive(Debug, Serialize)]
237pub struct AnthropicRequest {
238 pub model: String,
240 pub messages: Vec<AnthropicMessage>,
242 pub max_tokens: u32,
244 pub stream: bool,
246 pub temperature: f32,
248 #[serde(skip_serializing_if = "Option::is_none")]
250 pub system: Option<String>,
251}
252
253#[derive(Debug, Deserialize)]
255pub struct AnthropicContentDelta {
256 #[serde(default)]
258 pub text: Option<String>,
259}
260
261#[derive(Debug, Deserialize)]
263pub struct AnthropicStreamEvent {
264 #[serde(rename = "type")]
266 pub event_type: String,
267 #[serde(default)]
269 pub delta: Option<AnthropicContentDelta>,
270}
271
272#[derive(Debug, Serialize)]
276pub struct McpInferRequest {
277 pub jsonrpc: String,
279 pub method: String,
281 pub id: u64,
283 pub params: McpInferParams,
285}
286
287#[derive(Debug, Serialize)]
289pub struct McpInferParams {
290 pub name: String,
292 pub arguments: McpInferArguments,
294}
295
296#[derive(Debug, Serialize)]
298pub struct McpInferArguments {
299 pub prompt: String,
301 pub worker: String,
303}
304
305#[derive(Debug, Deserialize)]
307pub struct McpInferResponse {
308 #[allow(dead_code)]
310 pub jsonrpc: Option<String>,
311 pub result: Option<McpInferResult>,
313 pub error: Option<McpError>,
315}
316
317#[derive(Debug, Deserialize)]
319pub struct McpInferResult {
320 pub content: Vec<McpContent>,
322}
323
324#[derive(Debug, Deserialize)]
326pub struct McpContent {
327 pub text: Option<String>,
329}
330
331#[derive(Debug, Deserialize)]
333pub struct McpError {
334 pub message: String,
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn test_provider_display() {
344 assert_eq!(Provider::Openai.to_string(), "openai");
345 assert_eq!(Provider::Anthropic.to_string(), "anthropic");
346 assert_eq!(Provider::Mock.to_string(), "mock");
347 }
348
349 #[test]
350 fn test_provider_equality() {
351 assert_eq!(Provider::Openai, Provider::Openai);
352 assert_eq!(Provider::Anthropic, Provider::Anthropic);
353 assert_eq!(Provider::Mock, Provider::Mock);
354 assert_ne!(Provider::Openai, Provider::Anthropic);
355 assert_ne!(Provider::Openai, Provider::Mock);
356 assert_ne!(Provider::Anthropic, Provider::Mock);
357 }
358
359 #[test]
360 fn test_provider_openai_display_lowercase() {
361 let s = format!("{}", Provider::Openai);
362 assert_eq!(s, "openai");
363 assert!(s.chars().all(|c| c.is_lowercase() || c.is_alphanumeric()));
364 }
365
366 #[test]
367 fn test_provider_anthropic_display_lowercase() {
368 assert_eq!(format!("{}", Provider::Anthropic), "anthropic");
369 }
370
371 #[test]
372 fn test_provider_clone() {
373 let p = Provider::Openai;
374 let p2 = p.clone();
375 assert_eq!(p, p2);
376 }
377
378 #[test]
379 fn test_mcp_request_serializes() {
380 let req = McpInferRequest {
381 jsonrpc: "2.0".to_string(),
382 method: "tools/call".to_string(),
383 id: 1,
384 params: McpInferParams {
385 name: "infer".to_string(),
386 arguments: McpInferArguments {
387 prompt: "hello".to_string(),
388 worker: "llama_cpp".to_string(),
389 },
390 },
391 };
392 let json = serde_json::to_string(&req).expect("serialization failed");
393 assert!(json.contains("\"jsonrpc\":\"2.0\""));
394 assert!(json.contains("\"worker\":\"llama_cpp\""));
395 assert!(json.contains("\"prompt\":\"hello\""));
396 }
397
398 #[test]
399 fn test_mcp_response_deserializes_success() {
400 let json = r#"{"jsonrpc":"2.0","result":{"content":[{"text":"enriched prompt"}]}}"#;
401 let resp: McpInferResponse = serde_json::from_str(json).expect("deser failed");
402 assert!(resp.error.is_none());
403 let text = resp
404 .result
405 .as_ref()
406 .and_then(|r| r.content.first())
407 .and_then(|c| c.text.as_ref());
408 assert_eq!(text, Some(&"enriched prompt".to_string()));
409 }
410
411 #[test]
412 fn test_mcp_response_deserializes_error() {
413 let json = r#"{"jsonrpc":"2.0","error":{"message":"pipeline down"}}"#;
414 let resp: McpInferResponse = serde_json::from_str(json).expect("deser failed");
415 assert!(resp.result.is_none());
416 assert_eq!(
417 resp.error.as_ref().map(|e| &e.message[..]),
418 Some("pipeline down")
419 );
420 }
421
422 #[test]
423 fn test_anthropic_content_block_delta_deserializes() {
424 let json = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#;
425 let event: AnthropicStreamEvent = serde_json::from_str(json).expect("deser failed");
426 assert_eq!(event.event_type, "content_block_delta");
427 assert_eq!(
428 event
429 .delta
430 .as_ref()
431 .and_then(|d| d.text.as_ref())
432 .map(|s| s.as_str()),
433 Some("Hello")
434 );
435 }
436
437 #[test]
438 fn test_anthropic_message_start_deserializes() {
439 let json = r#"{"type":"message_start","message":{"id":"msg_123"}}"#;
440 let event: AnthropicStreamEvent = serde_json::from_str(json).expect("deser failed");
441 assert_eq!(event.event_type, "message_start");
442 assert!(event.delta.is_none());
443 }
444
445 #[test]
446 fn test_openai_chunk_deserializes() {
447 let json = r#"{"id":"chatcmpl-abc","choices":[{"index":0,"delta":{"content":"Hi"},"finish_reason":null}]}"#;
448 let chunk: OpenAIChunk = serde_json::from_str(json).expect("deser failed");
449 assert_eq!(chunk.choices.len(), 1);
450 assert_eq!(
451 chunk.choices[0].delta.content.as_ref().map(|s| s.as_str()),
452 Some("Hi")
453 );
454 }
455
456 #[test]
457 fn test_openai_chunk_empty_delta() {
458 let json =
459 r#"{"id":"chatcmpl-abc","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}"#;
460 let chunk: OpenAIChunk = serde_json::from_str(json).expect("deser failed");
461 assert!(chunk.choices[0].delta.content.is_none());
462 }
463
464 #[test]
465 fn test_mcp_infer_request_contains_all_fields() {
466 let req = McpInferRequest {
467 jsonrpc: "2.0".to_string(),
468 method: "tools/call".to_string(),
469 id: 42,
470 params: McpInferParams {
471 name: "infer".to_string(),
472 arguments: McpInferArguments {
473 prompt: "test prompt".to_string(),
474 worker: "llama_cpp".to_string(),
475 },
476 },
477 };
478 let json = serde_json::to_string(&req).expect("serialize");
479 let parsed: serde_json::Value = serde_json::from_str(&json).expect("parse");
480 assert_eq!(parsed["jsonrpc"], "2.0");
481 assert_eq!(parsed["method"], "tools/call");
482 assert_eq!(parsed["id"], 42);
483 assert_eq!(parsed["params"]["arguments"]["prompt"], "test prompt");
484 }
485
486 #[test]
487 fn test_mcp_response_empty_content() {
488 let json = r#"{"jsonrpc":"2.0","result":{"content":[]}}"#;
489 let resp: McpInferResponse = serde_json::from_str(json).expect("deser");
490 assert!(resp.result.as_ref().expect("result").content.is_empty());
491 }
492
493 #[test]
494 fn test_mcp_response_null_text() {
495 let json = r#"{"jsonrpc":"2.0","result":{"content":[{"text":null}]}}"#;
496 let resp: McpInferResponse = serde_json::from_str(json).expect("deser");
497 assert!(resp.result.as_ref().expect("result").content[0]
498 .text
499 .is_none());
500 }
501
502 #[test]
503 fn test_openai_chunk_multiple_choices() {
504 let json = r#"{"id":"chatcmpl-x","choices":[{"index":0,"delta":{"content":"A"},"finish_reason":null},{"index":1,"delta":{"content":"B"},"finish_reason":null}]}"#;
505 let chunk: OpenAIChunk = serde_json::from_str(json).expect("deser");
506 assert_eq!(chunk.choices.len(), 2);
507 assert_eq!(chunk.choices[0].delta.content.as_deref(), Some("A"));
508 assert_eq!(chunk.choices[1].delta.content.as_deref(), Some("B"));
509 }
510
511 #[test]
512 fn test_openai_chunk_no_choices() {
513 let json = r#"{"id":"chatcmpl-x","choices":[]}"#;
514 let chunk: OpenAIChunk = serde_json::from_str(json).expect("deser");
515 assert!(chunk.choices.is_empty());
516 }
517
518 #[test]
519 fn test_anthropic_event_message_delta() {
520 let json = r#"{"type":"message_delta","delta":{"stop_reason":"end_turn"}}"#;
521 let event: AnthropicStreamEvent = serde_json::from_str(json).expect("deser");
522 assert_eq!(event.event_type, "message_delta");
523 assert!(event.delta.as_ref().and_then(|d| d.text.as_ref()).is_none());
524 }
525
526 #[test]
527 fn test_anthropic_event_ping() {
528 let json = r#"{"type":"ping"}"#;
529 let event: AnthropicStreamEvent = serde_json::from_str(json).expect("deser");
530 assert_eq!(event.event_type, "ping");
531 assert!(event.delta.is_none());
532 }
533
534 #[test]
537 fn test_openai_chat_request_has_logprobs_fields() {
538 let req = OpenAIChatRequest {
539 model: "gpt-4".to_string(),
540 messages: vec![OpenAIChatMessage {
541 role: "user".to_string(),
542 content: "hi".to_string(),
543 }],
544 stream: true,
545 temperature: 0.7,
546 logprobs: true,
547 top_logprobs: 5,
548 };
549 let json = serde_json::to_string(&req).expect("serialize");
550 assert!(json.contains("\"logprobs\":true"));
551 assert!(json.contains("\"top_logprobs\":5"));
552 }
553
554 #[test]
555 fn test_openai_top_logprob_deserializes() {
556 let json = r#"{"token":"hello","logprob":-0.5}"#;
557 let tlp: OpenAITopLogprob = serde_json::from_str(json).expect("deser");
558 assert_eq!(tlp.token, "hello");
559 assert!((tlp.logprob - (-0.5)).abs() < 1e-5);
560 }
561
562 #[test]
563 fn test_openai_logprob_content_deserializes() {
564 let json = r#"{"token":"world","logprob":-1.2,"top_logprobs":[{"token":"world","logprob":-1.2},{"token":"earth","logprob":-2.5}]}"#;
565 let lc: OpenAILogprobContent = serde_json::from_str(json).expect("deser");
566 assert_eq!(lc.token, "world");
567 assert_eq!(lc.top_logprobs.len(), 2);
568 assert_eq!(lc.top_logprobs[1].token, "earth");
569 }
570
571 #[test]
572 fn test_openai_chunk_logprobs_empty_content() {
573 let json = r#"{"content":[]}"#;
574 let cl: OpenAIChunkLogprobs = serde_json::from_str(json).expect("deser");
575 assert!(cl.content.is_empty());
576 }
577
578 #[test]
579 fn test_openai_choice_with_logprobs_deserializes() {
580 let json = r#"{"delta":{"content":"Hi"},"finish_reason":null,"logprobs":{"content":[{"token":"Hi","logprob":-0.1,"top_logprobs":[{"token":"Hi","logprob":-0.1},{"token":"Hey","logprob":-2.3}]}]}}"#;
581 let choice: OpenAIChoice = serde_json::from_str(json).expect("deser");
582 assert_eq!(choice.delta.content.as_deref(), Some("Hi"));
583 let lp = choice.logprobs.as_ref().expect("logprobs present");
584 assert_eq!(lp.content.len(), 1);
585 assert!((lp.content[0].logprob - (-0.1)).abs() < 1e-5);
586 assert_eq!(lp.content[0].top_logprobs[1].token, "Hey");
587 }
588
589 #[test]
590 fn test_openai_choice_without_logprobs_is_none() {
591 let json = r#"{"delta":{"content":"Hi"},"finish_reason":null}"#;
592 let choice: OpenAIChoice = serde_json::from_str(json).expect("deser");
593 assert!(choice.logprobs.is_none());
594 }
595
596 #[test]
597 fn test_anthropic_request_with_system_serializes() {
598 let req = AnthropicRequest {
599 model: "claude-sonnet-4-20250514".to_string(),
600 messages: vec![AnthropicMessage {
601 role: "user".to_string(),
602 content: "hi".to_string(),
603 }],
604 max_tokens: 1024,
605 stream: true,
606 temperature: 0.7,
607 system: Some("You are a helpful assistant.".to_string()),
608 };
609 let json = serde_json::to_string(&req).expect("serialize");
610 assert!(json.contains("\"system\":\"You are a helpful assistant.\""));
611 }
612
613 #[test]
614 fn test_anthropic_request_without_system_omits_field() {
615 let req = AnthropicRequest {
616 model: "claude-sonnet-4-20250514".to_string(),
617 messages: vec![AnthropicMessage {
618 role: "user".to_string(),
619 content: "hi".to_string(),
620 }],
621 max_tokens: 1024,
622 stream: true,
623 temperature: 0.7,
624 system: None,
625 };
626 let json = serde_json::to_string(&req).expect("serialize");
627 assert!(!json.contains("system"));
628 }
629
630 #[test]
631 fn test_openai_top_logprob_clone() {
632 let t = OpenAITopLogprob {
633 token: "foo".to_string(),
634 logprob: -1.0,
635 };
636 let t2 = t.clone();
637 assert_eq!(t2.token, t.token);
638 assert!((t2.logprob - t.logprob).abs() < 1e-6);
639 }
640
641 #[test]
642 fn test_openai_logprob_content_no_top_logprobs() {
643 let json = r#"{"token":"test","logprob":-0.8}"#;
644 let lc: OpenAILogprobContent = serde_json::from_str(json).expect("deser");
645 assert!(lc.top_logprobs.is_empty());
646 }
647
648 #[test]
649 fn test_openai_top_logprob_serializes() {
650 let t = OpenAITopLogprob {
651 token: "bar".to_string(),
652 logprob: -2.0,
653 };
654 let json = serde_json::to_string(&t).expect("serialize");
655 assert!(json.contains("\"token\":\"bar\""));
656 assert!(json.contains("\"logprob\":-2.0"));
657 }
658
659 #[test]
662 fn test_openai_plugin_api_url_https() {
663 assert!(OpenAiPlugin.api_url().starts_with("https://"));
664 }
665
666 #[test]
667 fn test_anthropic_plugin_api_url_https() {
668 assert!(AnthropicPlugin.api_url().starts_with("https://"));
669 }
670
671 #[test]
672 fn test_openai_plugin_api_url_contains_openai() {
673 assert!(OpenAiPlugin.api_url().contains("openai.com"));
674 }
675
676 #[test]
677 fn test_anthropic_plugin_api_url_contains_anthropic() {
678 assert!(AnthropicPlugin.api_url().contains("anthropic.com"));
679 }
680
681 #[test]
682 fn test_openai_plugin_name_matches_display() {
683 assert_eq!(OpenAiPlugin.name(), Provider::Openai.to_string());
684 }
685
686 #[test]
687 fn test_anthropic_plugin_name_matches_display() {
688 assert_eq!(AnthropicPlugin.name(), Provider::Anthropic.to_string());
689 }
690
691 #[test]
692 fn test_openai_plugin_default_model_nonempty() {
693 assert!(!OpenAiPlugin.default_model().is_empty());
694 }
695
696 #[test]
697 fn test_anthropic_plugin_default_model_nonempty() {
698 assert!(!AnthropicPlugin.default_model().is_empty());
699 }
700
701 #[test]
704 fn test_anthropic_api_version_nonempty() {
705 assert!(!super::ANTHROPIC_API_VERSION.is_empty());
706 }
707
708 #[test]
709 fn test_anthropic_api_version_format() {
710 let v = super::ANTHROPIC_API_VERSION;
712 assert_eq!(v.len(), 10, "version should be YYYY-MM-DD");
713 assert!(v.chars().nth(4) == Some('-'), "4th char should be -");
714 assert!(v.chars().nth(7) == Some('-'), "7th char should be -");
715 }
716}