1use std::fmt;
14use std::pin::Pin;
15
16use async_trait::async_trait;
17use serde::{Deserialize, Serialize};
18use tokio_stream::Stream;
19
20#[derive(Debug, Clone)]
26#[must_use]
27pub struct RunnerError {
28 pub kind: ErrorKind,
30 pub message: String,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum ErrorKind {
37 Internal,
39 ExternalService,
41 Timeout,
43 BinaryNotFound,
45 AuthFailure,
47 Config,
49 Guardrail,
51}
52
53impl ErrorKind {
54 #[must_use]
60 pub const fn is_transient(self) -> bool {
61 matches!(self, Self::Timeout | Self::ExternalService)
62 }
63}
64
65impl RunnerError {
66 pub fn internal(message: impl Into<String>) -> Self {
68 Self {
69 kind: ErrorKind::Internal,
70 message: message.into(),
71 }
72 }
73
74 pub fn external_service(service: impl Into<String>, message: impl Into<String>) -> Self {
76 Self {
77 kind: ErrorKind::ExternalService,
78 message: format!("{}: {}", service.into(), message.into()),
79 }
80 }
81
82 pub fn binary_not_found(binary: impl Into<String>) -> Self {
84 Self {
85 kind: ErrorKind::BinaryNotFound,
86 message: format!("Binary not found: {}", binary.into()),
87 }
88 }
89
90 pub fn auth_failure(message: impl Into<String>) -> Self {
92 Self {
93 kind: ErrorKind::AuthFailure,
94 message: message.into(),
95 }
96 }
97
98 pub fn config(message: impl Into<String>) -> Self {
100 Self {
101 kind: ErrorKind::Config,
102 message: message.into(),
103 }
104 }
105
106 pub fn timeout(message: impl Into<String>) -> Self {
108 Self {
109 kind: ErrorKind::Timeout,
110 message: message.into(),
111 }
112 }
113
114 pub fn guardrail(message: impl Into<String>) -> Self {
116 Self {
117 kind: ErrorKind::Guardrail,
118 message: message.into(),
119 }
120 }
121}
122
123impl fmt::Display for RunnerError {
124 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125 write!(f, "{:?}: {}", self.kind, self.message)
126 }
127}
128
129impl std::error::Error for RunnerError {}
130
131bitflags::bitflags! {
136 #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
141 pub struct LlmCapabilities: u16 {
142 const STREAMING = 0b0000_0000_0001;
144 const FUNCTION_CALLING = 0b0000_0000_0010;
146 const VISION = 0b0000_0000_0100;
148 const JSON_MODE = 0b0000_0000_1000;
150 const SYSTEM_MESSAGES = 0b0000_0001_0000;
152 const SDK_TOOL_CALLING = 0b0000_0010_0000;
154 const TEMPERATURE = 0b0000_0100_0000;
156 const MAX_TOKENS = 0b0000_1000_0000;
158 const TOP_P = 0b0001_0000_0000;
160 const STOP_SEQUENCES = 0b0010_0000_0000;
162 const RESPONSE_FORMAT = 0b0100_0000_0000;
164 }
165}
166
167impl LlmCapabilities {
168 #[must_use]
170 pub const fn text_only() -> Self {
171 Self::STREAMING.union(Self::SYSTEM_MESSAGES)
172 }
173
174 #[must_use]
176 pub const fn full_featured() -> Self {
177 Self::STREAMING
178 .union(Self::FUNCTION_CALLING)
179 .union(Self::VISION)
180 .union(Self::JSON_MODE)
181 .union(Self::SYSTEM_MESSAGES)
182 }
183
184 #[must_use]
186 pub const fn supports_streaming(&self) -> bool {
187 self.contains(Self::STREAMING)
188 }
189
190 #[must_use]
192 pub const fn supports_function_calling(&self) -> bool {
193 self.contains(Self::FUNCTION_CALLING)
194 }
195
196 #[must_use]
198 pub const fn supports_vision(&self) -> bool {
199 self.contains(Self::VISION)
200 }
201
202 #[must_use]
204 pub const fn supports_json_mode(&self) -> bool {
205 self.contains(Self::JSON_MODE)
206 }
207
208 #[must_use]
210 pub const fn supports_system_messages(&self) -> bool {
211 self.contains(Self::SYSTEM_MESSAGES)
212 }
213
214 #[must_use]
216 pub const fn supports_sdk_tool_calling(&self) -> bool {
217 self.contains(Self::SDK_TOOL_CALLING)
218 }
219
220 #[must_use]
222 pub const fn supports_temperature(&self) -> bool {
223 self.contains(Self::TEMPERATURE)
224 }
225
226 #[must_use]
228 pub const fn supports_max_tokens(&self) -> bool {
229 self.contains(Self::MAX_TOKENS)
230 }
231
232 #[must_use]
234 pub const fn supports_top_p(&self) -> bool {
235 self.contains(Self::TOP_P)
236 }
237
238 #[must_use]
240 pub const fn supports_stop_sequences(&self) -> bool {
241 self.contains(Self::STOP_SEQUENCES)
242 }
243
244 #[must_use]
246 pub const fn supports_response_format(&self) -> bool {
247 self.contains(Self::RESPONSE_FORMAT)
248 }
249}
250
251#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
257#[serde(rename_all = "lowercase")]
258pub enum MessageRole {
259 System,
261 User,
263 Assistant,
265 Tool,
267}
268
269impl MessageRole {
270 #[must_use]
272 pub const fn as_str(&self) -> &'static str {
273 match self {
274 Self::System => "system",
275 Self::User => "user",
276 Self::Assistant => "assistant",
277 Self::Tool => "tool",
278 }
279 }
280}
281
282const VALID_IMAGE_MIME_TYPES: &[&str] = &["image/png", "image/jpeg", "image/webp", "image/gif"];
284
285#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
287pub struct ImagePart {
288 pub data: String,
290 pub mime_type: String,
292}
293
294impl ImagePart {
295 pub fn new(data: impl Into<String>, mime_type: impl Into<String>) -> Result<Self, RunnerError> {
303 let mime_type = mime_type.into();
304 if !VALID_IMAGE_MIME_TYPES.contains(&mime_type.as_str()) {
305 return Err(RunnerError::config(format!(
306 "Unsupported image MIME type '{mime_type}'; expected one of: {}",
307 VALID_IMAGE_MIME_TYPES.join(", ")
308 )));
309 }
310 Ok(Self {
311 data: data.into(),
312 mime_type,
313 })
314 }
315}
316
317#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct ChatMessage {
320 pub role: MessageRole,
322 pub content: String,
324 #[serde(default, skip_serializing_if = "Option::is_none")]
326 pub images: Option<Vec<ImagePart>>,
327 #[serde(default, skip_serializing_if = "Option::is_none")]
329 pub tool_calls: Option<Vec<ToolCallRequest>>,
330 #[serde(default, skip_serializing_if = "Option::is_none")]
332 pub tool_call_id: Option<String>,
333 #[serde(default, skip_serializing_if = "Option::is_none")]
335 pub name: Option<String>,
336}
337
338impl ChatMessage {
339 #[must_use]
341 pub fn new(role: MessageRole, content: impl Into<String>) -> Self {
342 Self {
343 role,
344 content: content.into(),
345 images: None,
346 tool_calls: None,
347 tool_call_id: None,
348 name: None,
349 }
350 }
351
352 #[must_use]
354 pub fn system(content: impl Into<String>) -> Self {
355 Self::new(MessageRole::System, content)
356 }
357
358 #[must_use]
360 pub fn user(content: impl Into<String>) -> Self {
361 Self::new(MessageRole::User, content)
362 }
363
364 #[must_use]
366 pub fn user_with_images(content: impl Into<String>, images: Vec<ImagePart>) -> Self {
367 Self {
368 role: MessageRole::User,
369 content: content.into(),
370 images: Some(images),
371 tool_calls: None,
372 tool_call_id: None,
373 name: None,
374 }
375 }
376
377 #[must_use]
379 pub fn assistant(content: impl Into<String>) -> Self {
380 Self::new(MessageRole::Assistant, content)
381 }
382
383 #[must_use]
385 pub fn tool(
386 name: impl Into<String>,
387 tool_call_id: impl Into<String>,
388 content: impl Into<String>,
389 ) -> Self {
390 Self {
391 role: MessageRole::Tool,
392 content: content.into(),
393 images: None,
394 tool_calls: None,
395 tool_call_id: Some(tool_call_id.into()),
396 name: Some(name.into()),
397 }
398 }
399}
400
401#[derive(Debug, Clone, Serialize, Deserialize)]
407pub struct ToolCallRequest {
408 pub id: String,
410 pub function_name: String,
412 pub arguments: serde_json::Value,
414}
415
416#[derive(Debug, Clone, Serialize, Deserialize)]
418pub struct ToolDefinition {
419 pub name: String,
421 pub description: String,
423 #[serde(skip_serializing_if = "Option::is_none")]
425 pub parameters: Option<serde_json::Value>,
426}
427
428#[derive(Debug, Clone, Serialize, Deserialize)]
430pub enum ToolChoice {
431 Auto,
433 None,
435 Required,
437 Specific {
439 name: String,
441 },
442}
443
444#[derive(Debug, Clone, Serialize, Deserialize)]
446pub enum ResponseFormat {
447 Text,
449 JsonObject,
451 JsonSchema {
453 name: String,
455 schema: serde_json::Value,
457 },
458}
459
460#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct ChatRequest {
467 pub messages: Vec<ChatMessage>,
469 pub model: Option<String>,
471 pub temperature: Option<f32>,
477 pub max_tokens: Option<u32>,
483 pub stream: bool,
485 #[serde(default, skip_serializing_if = "Option::is_none")]
487 pub tools: Option<Vec<ToolDefinition>>,
488 #[serde(default, skip_serializing_if = "Option::is_none")]
490 pub tool_choice: Option<ToolChoice>,
491 #[serde(default, skip_serializing_if = "Option::is_none")]
493 pub top_p: Option<f32>,
494 #[serde(default, skip_serializing_if = "Option::is_none")]
496 pub stop: Option<Vec<String>>,
497 #[serde(default, skip_serializing_if = "Option::is_none")]
499 pub response_format: Option<ResponseFormat>,
500}
501
502impl ChatRequest {
503 #[must_use]
505 pub const fn new(messages: Vec<ChatMessage>) -> Self {
506 Self {
507 messages,
508 model: None,
509 temperature: None,
510 max_tokens: None,
511 stream: false,
512 tools: None,
513 tool_choice: None,
514 top_p: None,
515 stop: None,
516 response_format: None,
517 }
518 }
519
520 #[must_use]
522 pub fn with_model(mut self, model: impl Into<String>) -> Self {
523 self.model = Some(model.into());
524 self
525 }
526
527 #[must_use]
529 pub const fn with_temperature(mut self, temperature: f32) -> Self {
530 self.temperature = Some(temperature);
531 self
532 }
533
534 #[must_use]
536 pub const fn with_max_tokens(mut self, max_tokens: u32) -> Self {
537 self.max_tokens = Some(max_tokens);
538 self
539 }
540
541 #[must_use]
543 pub const fn with_streaming(mut self) -> Self {
544 self.stream = true;
545 self
546 }
547
548 #[must_use]
550 pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
551 self.tools = Some(tools);
552 self
553 }
554
555 #[must_use]
557 pub fn with_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
558 self.tool_choice = Some(tool_choice);
559 self
560 }
561
562 #[must_use]
564 pub const fn with_top_p(mut self, top_p: f32) -> Self {
565 self.top_p = Some(top_p);
566 self
567 }
568
569 #[must_use]
571 pub fn with_stop(mut self, stop: Vec<String>) -> Self {
572 self.stop = Some(stop);
573 self
574 }
575
576 #[must_use]
578 pub fn with_response_format(mut self, response_format: ResponseFormat) -> Self {
579 self.response_format = Some(response_format);
580 self
581 }
582
583 #[must_use]
585 pub fn has_images(&self) -> bool {
586 self.messages
587 .iter()
588 .any(|m| m.images.as_ref().is_some_and(|imgs| !imgs.is_empty()))
589 }
590}
591
592#[derive(Debug, Clone, Serialize, Deserialize)]
594pub struct ChatResponse {
595 pub content: String,
597 pub model: String,
599 pub usage: Option<TokenUsage>,
601 pub finish_reason: Option<String>,
603 #[serde(skip_serializing_if = "Option::is_none")]
605 pub warnings: Option<Vec<String>>,
606 #[serde(default, skip_serializing_if = "Option::is_none")]
608 pub tool_calls: Option<Vec<ToolCallRequest>>,
609}
610
611#[derive(Debug, Clone, Serialize, Deserialize)]
613pub struct TokenUsage {
614 pub prompt_tokens: u32,
616 pub completion_tokens: u32,
618 pub total_tokens: u32,
620}
621
622#[derive(Debug, Clone, Serialize, Deserialize)]
624pub struct StreamChunk {
625 pub delta: String,
627 pub is_final: bool,
629 pub finish_reason: Option<String>,
631}
632
633pub type ChatStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, RunnerError>> + Send>>;
635
636#[async_trait]
646pub trait LlmProvider: Send + Sync {
647 fn name(&self) -> &'static str;
649
650 fn display_name(&self) -> &str;
652
653 fn capabilities(&self) -> LlmCapabilities;
655
656 fn default_model(&self) -> &str;
658
659 fn available_models(&self) -> &[String];
661
662 async fn complete(&self, request: &ChatRequest) -> Result<ChatResponse, RunnerError>;
664
665 async fn complete_stream(&self, request: &ChatRequest) -> Result<ChatStream, RunnerError>;
670
671 async fn health_check(&self) -> Result<bool, RunnerError>;
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678 use serde_json::json;
679
680 #[test]
681 fn is_transient_classification() {
682 assert!(ErrorKind::Timeout.is_transient());
683 assert!(ErrorKind::ExternalService.is_transient());
684 assert!(!ErrorKind::Internal.is_transient());
685 assert!(!ErrorKind::BinaryNotFound.is_transient());
686 assert!(!ErrorKind::AuthFailure.is_transient());
687 assert!(!ErrorKind::Config.is_transient());
688 assert!(!ErrorKind::Guardrail.is_transient());
689 }
690
691 #[test]
692 fn tool_call_request_serde_round_trip() {
693 let tc = ToolCallRequest {
694 id: "call_1".to_owned(),
695 function_name: "get_weather".to_owned(),
696 arguments: json!({"city": "Paris"}),
697 };
698 let json = serde_json::to_string(&tc).unwrap();
699 let deserialized: ToolCallRequest = serde_json::from_str(&json).unwrap();
700 assert_eq!(deserialized.id, "call_1");
701 assert_eq!(deserialized.function_name, "get_weather");
702 assert_eq!(deserialized.arguments["city"], "Paris");
703 }
704
705 #[test]
706 fn tool_definition_serde_round_trip() {
707 let td = ToolDefinition {
708 name: "search".to_owned(),
709 description: "Search the web".to_owned(),
710 parameters: Some(json!({"type": "object", "properties": {"q": {"type": "string"}}})),
711 };
712 let json = serde_json::to_string(&td).unwrap();
713 let deserialized: ToolDefinition = serde_json::from_str(&json).unwrap();
714 assert_eq!(deserialized.name, "search");
715 assert!(deserialized.parameters.is_some());
716 }
717
718 #[test]
719 fn tool_definition_without_parameters() {
720 let td = ToolDefinition {
721 name: "ping".to_owned(),
722 description: "Check connectivity".to_owned(),
723 parameters: None,
724 };
725 let json = serde_json::to_string(&td).unwrap();
726 assert!(!json.contains("parameters"));
727 let deserialized: ToolDefinition = serde_json::from_str(&json).unwrap();
728 assert!(deserialized.parameters.is_none());
729 }
730
731 #[test]
732 fn tool_choice_serde_variants() {
733 let auto = ToolChoice::Auto;
734 let json = serde_json::to_string(&auto).unwrap();
735 let deserialized: ToolChoice = serde_json::from_str(&json).unwrap();
736 assert!(matches!(deserialized, ToolChoice::Auto));
737
738 let none = ToolChoice::None;
739 let json = serde_json::to_string(&none).unwrap();
740 let deserialized: ToolChoice = serde_json::from_str(&json).unwrap();
741 assert!(matches!(deserialized, ToolChoice::None));
742
743 let required = ToolChoice::Required;
744 let json = serde_json::to_string(&required).unwrap();
745 let deserialized: ToolChoice = serde_json::from_str(&json).unwrap();
746 assert!(matches!(deserialized, ToolChoice::Required));
747
748 let specific = ToolChoice::Specific {
749 name: "get_weather".to_owned(),
750 };
751 let json = serde_json::to_string(&specific).unwrap();
752 let deserialized: ToolChoice = serde_json::from_str(&json).unwrap();
753 assert!(matches!(deserialized, ToolChoice::Specific { name } if name == "get_weather"));
754 }
755
756 #[test]
757 fn response_format_serde_variants() {
758 let text = ResponseFormat::Text;
759 let json = serde_json::to_string(&text).unwrap();
760 let deserialized: ResponseFormat = serde_json::from_str(&json).unwrap();
761 assert!(matches!(deserialized, ResponseFormat::Text));
762
763 let json_obj = ResponseFormat::JsonObject;
764 let json = serde_json::to_string(&json_obj).unwrap();
765 let deserialized: ResponseFormat = serde_json::from_str(&json).unwrap();
766 assert!(matches!(deserialized, ResponseFormat::JsonObject));
767
768 let json_schema = ResponseFormat::JsonSchema {
769 name: "person".to_owned(),
770 schema: json!({"type": "object", "properties": {"name": {"type": "string"}}}),
771 };
772 let json = serde_json::to_string(&json_schema).unwrap();
773 let deserialized: ResponseFormat = serde_json::from_str(&json).unwrap();
774 assert!(
775 matches!(deserialized, ResponseFormat::JsonSchema { name, .. } if name == "person")
776 );
777 }
778
779 #[test]
780 fn chat_message_tool_constructor() {
781 let msg = ChatMessage::tool("get_weather", "call_1", r#"{"temp": 72}"#);
782 assert_eq!(msg.role, MessageRole::Tool);
783 assert_eq!(msg.content, r#"{"temp": 72}"#);
784 assert_eq!(msg.tool_call_id.as_deref(), Some("call_1"));
785 assert_eq!(msg.name.as_deref(), Some("get_weather"));
786 assert!(msg.tool_calls.is_none());
787 }
788
789 #[test]
790 fn chat_message_regular_constructors_have_none_tool_fields() {
791 let user = ChatMessage::user("hello");
792 assert!(user.tool_calls.is_none());
793 assert!(user.tool_call_id.is_none());
794 assert!(user.name.is_none());
795 assert!(user.images.is_none());
796 }
797
798 #[test]
799 fn image_part_valid_mime_types() {
800 for mime in &["image/png", "image/jpeg", "image/webp", "image/gif"] {
801 let part = ImagePart::new("base64data", *mime);
802 assert!(part.is_ok(), "Expected {mime} to be valid");
803 }
804 }
805
806 #[test]
807 fn image_part_invalid_mime_type() {
808 let err = ImagePart::new("data", "image/bmp").unwrap_err();
809 assert_eq!(err.kind, ErrorKind::Config);
810 assert!(err.message.contains("image/bmp"));
811 }
812
813 #[test]
814 fn user_with_images_constructor() {
815 let img = ImagePart::new("aGVsbG8=", "image/png").unwrap();
816 let msg = ChatMessage::user_with_images("describe this", vec![img]);
817 assert_eq!(msg.role, MessageRole::User);
818 assert_eq!(msg.content, "describe this");
819 let images = msg.images.as_ref().unwrap();
820 assert_eq!(images.len(), 1);
821 assert_eq!(images[0].mime_type, "image/png");
822 }
823
824 #[test]
825 fn chat_request_has_images() {
826 let img = ImagePart::new("data", "image/jpeg").unwrap();
827 let with = ChatRequest::new(vec![ChatMessage::user_with_images("x", vec![img])]);
828 assert!(with.has_images());
829
830 let without = ChatRequest::new(vec![ChatMessage::user("text only")]);
831 assert!(!without.has_images());
832 }
833
834 #[test]
835 fn chat_request_has_images_empty_vec() {
836 let msg = ChatMessage::user_with_images("x", vec![]);
837 let req = ChatRequest::new(vec![msg]);
838 assert!(!req.has_images());
839 }
840
841 #[test]
842 fn image_part_serde_round_trip() {
843 let img = ImagePart::new("aGVsbG8=", "image/png").unwrap();
844 let json = serde_json::to_string(&img).unwrap();
845 let deserialized: ImagePart = serde_json::from_str(&json).unwrap();
846 assert_eq!(deserialized, img);
847 }
848
849 #[test]
850 fn chat_message_with_images_serde_round_trip() {
851 let img = ImagePart::new("data", "image/jpeg").unwrap();
852 let msg = ChatMessage::user_with_images("describe", vec![img]);
853 let json = serde_json::to_string(&msg).unwrap();
854 let deserialized: ChatMessage = serde_json::from_str(&json).unwrap();
855 assert_eq!(deserialized.images.as_ref().unwrap().len(), 1);
856 assert_eq!(deserialized.images.unwrap()[0].mime_type, "image/jpeg");
857 }
858
859 #[test]
860 fn chat_message_without_images_backward_compat() {
861 let json = r#"{"role":"user","content":"hello"}"#;
862 let msg: ChatMessage = serde_json::from_str(json).unwrap();
863 assert!(msg.images.is_none());
864 assert_eq!(msg.content, "hello");
865 }
866
867 #[test]
868 fn chat_message_images_not_serialized_when_none() {
869 let msg = ChatMessage::user("hello");
870 let json = serde_json::to_string(&msg).unwrap();
871 assert!(!json.contains("images"));
872 }
873
874 #[test]
875 fn chat_request_builder_methods() {
876 let req = ChatRequest::new(vec![ChatMessage::user("hi")])
877 .with_tools(vec![ToolDefinition {
878 name: "test".to_owned(),
879 description: "test fn".to_owned(),
880 parameters: None,
881 }])
882 .with_tool_choice(ToolChoice::Required)
883 .with_top_p(0.9)
884 .with_stop(vec!["END".to_owned()])
885 .with_response_format(ResponseFormat::JsonObject);
886
887 assert!(req.tools.is_some());
888 assert!(matches!(req.tool_choice, Some(ToolChoice::Required)));
889 assert_eq!(req.top_p, Some(0.9));
890 assert_eq!(req.stop.as_ref().unwrap()[0], "END");
891 assert!(matches!(
892 req.response_format,
893 Some(ResponseFormat::JsonObject)
894 ));
895 }
896
897 #[test]
898 fn message_role_tool_as_str() {
899 assert_eq!(MessageRole::Tool.as_str(), "tool");
900 }
901
902 #[test]
903 fn capability_flags_new_fields() {
904 let caps = LlmCapabilities::TOP_P
905 | LlmCapabilities::STOP_SEQUENCES
906 | LlmCapabilities::RESPONSE_FORMAT;
907 assert!(caps.supports_top_p());
908 assert!(caps.supports_stop_sequences());
909 assert!(caps.supports_response_format());
910
911 let empty = LlmCapabilities::empty();
912 assert!(!empty.supports_top_p());
913 assert!(!empty.supports_stop_sequences());
914 assert!(!empty.supports_response_format());
915 }
916}