Skip to main content

embacle/
types.rs

1// ABOUTME: Core types for CLI LLM runners — standalone definitions independent of pierre-core
2// ABOUTME: Provides LlmProvider trait, ChatRequest/Response, error types, and capability flags
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright (c) 2026 dravr.ai
6
7//! # Core Types
8//!
9//! Self-contained type definitions for the CLI LLM runners library.
10//! These types mirror the LLM provider contract without requiring
11//! any external platform dependency.
12
13use std::fmt;
14use std::pin::Pin;
15
16use async_trait::async_trait;
17use serde::{Deserialize, Serialize};
18use tokio_stream::Stream;
19
20// ============================================================================
21// Error Type
22// ============================================================================
23
24/// Error type for CLI LLM runner operations
25#[derive(Debug, Clone)]
26#[must_use]
27pub struct RunnerError {
28    /// Error category
29    pub kind: ErrorKind,
30    /// Human-readable error message
31    pub message: String,
32}
33
34/// Categories of errors produced by CLI runners
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum ErrorKind {
37    /// Internal runner error (bug, unexpected state)
38    Internal,
39    /// External service error (CLI tool failure, bad response)
40    ExternalService,
41    /// CLI command exceeded its configured timeout
42    Timeout,
43    /// Binary not found or not executable
44    BinaryNotFound,
45    /// Authentication or authorization failure
46    AuthFailure,
47    /// Configuration error
48    Config,
49    /// Guardrail policy violation (request or response rejected)
50    Guardrail,
51}
52
53impl ErrorKind {
54    /// Whether this error category represents a transient failure worth retrying.
55    ///
56    /// Transient errors (timeouts, external service issues) may succeed on a
57    /// subsequent attempt. Permanent errors (config, auth, missing binary) will
58    /// not benefit from retries.
59    #[must_use]
60    pub const fn is_transient(self) -> bool {
61        matches!(self, Self::Timeout | Self::ExternalService)
62    }
63}
64
65impl RunnerError {
66    /// Create an internal error
67    pub fn internal(message: impl Into<String>) -> Self {
68        Self {
69            kind: ErrorKind::Internal,
70            message: message.into(),
71        }
72    }
73
74    /// Create an external service error
75    pub fn external_service(service: impl Into<String>, message: impl Into<String>) -> Self {
76        Self {
77            kind: ErrorKind::ExternalService,
78            message: format!("{}: {}", service.into(), message.into()),
79        }
80    }
81
82    /// Create a binary-not-found error
83    pub fn binary_not_found(binary: impl Into<String>) -> Self {
84        Self {
85            kind: ErrorKind::BinaryNotFound,
86            message: format!("Binary not found: {}", binary.into()),
87        }
88    }
89
90    /// Create an auth failure error
91    pub fn auth_failure(message: impl Into<String>) -> Self {
92        Self {
93            kind: ErrorKind::AuthFailure,
94            message: message.into(),
95        }
96    }
97
98    /// Create a config error
99    pub fn config(message: impl Into<String>) -> Self {
100        Self {
101            kind: ErrorKind::Config,
102            message: message.into(),
103        }
104    }
105
106    /// Create a timeout error
107    pub fn timeout(message: impl Into<String>) -> Self {
108        Self {
109            kind: ErrorKind::Timeout,
110            message: message.into(),
111        }
112    }
113
114    /// Create a guardrail violation error
115    pub fn guardrail(message: impl Into<String>) -> Self {
116        Self {
117            kind: ErrorKind::Guardrail,
118            message: message.into(),
119        }
120    }
121}
122
123impl fmt::Display for RunnerError {
124    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125        write!(f, "{:?}: {}", self.kind, self.message)
126    }
127}
128
129impl std::error::Error for RunnerError {}
130
131// ============================================================================
132// Capability Flags
133// ============================================================================
134
135bitflags::bitflags! {
136    /// LLM provider capability flags using bitflags for efficient storage
137    ///
138    /// Indicates which features a provider supports. Used by the system to
139    /// select appropriate providers and configure request handling.
140    #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
141    pub struct LlmCapabilities: u16 {
142        /// Provider supports streaming responses
143        const STREAMING         = 0b0000_0000_0001;
144        /// Provider supports function/tool calling
145        const FUNCTION_CALLING  = 0b0000_0000_0010;
146        /// Provider supports vision/image input
147        const VISION            = 0b0000_0000_0100;
148        /// Provider supports JSON mode output
149        const JSON_MODE         = 0b0000_0000_1000;
150        /// Provider supports system messages
151        const SYSTEM_MESSAGES   = 0b0000_0001_0000;
152        /// Provider supports SDK-managed tool calling (tool loop handled by SDK, not by caller)
153        const SDK_TOOL_CALLING  = 0b0000_0010_0000;
154        /// Provider supports temperature parameter
155        const TEMPERATURE       = 0b0000_0100_0000;
156        /// Provider supports max_tokens parameter
157        const MAX_TOKENS        = 0b0000_1000_0000;
158        /// Provider supports top_p (nucleus sampling) parameter
159        const TOP_P             = 0b0001_0000_0000;
160        /// Provider supports stop sequences parameter
161        const STOP_SEQUENCES    = 0b0010_0000_0000;
162        /// Provider supports response format control (JSON mode, JSON Schema)
163        const RESPONSE_FORMAT   = 0b0100_0000_0000;
164    }
165}
166
167impl LlmCapabilities {
168    /// Create capabilities for a basic text-only provider
169    #[must_use]
170    pub const fn text_only() -> Self {
171        Self::STREAMING.union(Self::SYSTEM_MESSAGES)
172    }
173
174    /// Create capabilities for a full-featured provider (like Gemini Pro)
175    #[must_use]
176    pub const fn full_featured() -> Self {
177        Self::STREAMING
178            .union(Self::FUNCTION_CALLING)
179            .union(Self::VISION)
180            .union(Self::JSON_MODE)
181            .union(Self::SYSTEM_MESSAGES)
182    }
183
184    /// Check if streaming is supported
185    #[must_use]
186    pub const fn supports_streaming(&self) -> bool {
187        self.contains(Self::STREAMING)
188    }
189
190    /// Check if function calling is supported
191    #[must_use]
192    pub const fn supports_function_calling(&self) -> bool {
193        self.contains(Self::FUNCTION_CALLING)
194    }
195
196    /// Check if vision is supported
197    #[must_use]
198    pub const fn supports_vision(&self) -> bool {
199        self.contains(Self::VISION)
200    }
201
202    /// Check if JSON mode is supported
203    #[must_use]
204    pub const fn supports_json_mode(&self) -> bool {
205        self.contains(Self::JSON_MODE)
206    }
207
208    /// Check if system messages are supported
209    #[must_use]
210    pub const fn supports_system_messages(&self) -> bool {
211        self.contains(Self::SYSTEM_MESSAGES)
212    }
213
214    /// Check if SDK-managed tool calling is supported
215    #[must_use]
216    pub const fn supports_sdk_tool_calling(&self) -> bool {
217        self.contains(Self::SDK_TOOL_CALLING)
218    }
219
220    /// Check if temperature parameter is supported
221    #[must_use]
222    pub const fn supports_temperature(&self) -> bool {
223        self.contains(Self::TEMPERATURE)
224    }
225
226    /// Check if `max_tokens` parameter is supported
227    #[must_use]
228    pub const fn supports_max_tokens(&self) -> bool {
229        self.contains(Self::MAX_TOKENS)
230    }
231
232    /// Check if `top_p` parameter is supported
233    #[must_use]
234    pub const fn supports_top_p(&self) -> bool {
235        self.contains(Self::TOP_P)
236    }
237
238    /// Check if stop sequences parameter is supported
239    #[must_use]
240    pub const fn supports_stop_sequences(&self) -> bool {
241        self.contains(Self::STOP_SEQUENCES)
242    }
243
244    /// Check if response format control is supported
245    #[must_use]
246    pub const fn supports_response_format(&self) -> bool {
247        self.contains(Self::RESPONSE_FORMAT)
248    }
249}
250
251// ============================================================================
252// Message Types
253// ============================================================================
254
255/// Role of a message in the conversation
256#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
257#[serde(rename_all = "lowercase")]
258pub enum MessageRole {
259    /// System instruction message
260    System,
261    /// User input message
262    User,
263    /// Assistant response message
264    Assistant,
265    /// Tool result message
266    Tool,
267}
268
269impl MessageRole {
270    /// Convert to string representation for API calls
271    #[must_use]
272    pub const fn as_str(&self) -> &'static str {
273        match self {
274            Self::System => "system",
275            Self::User => "user",
276            Self::Assistant => "assistant",
277            Self::Tool => "tool",
278        }
279    }
280}
281
282/// Supported MIME types for image content
283const VALID_IMAGE_MIME_TYPES: &[&str] = &["image/png", "image/jpeg", "image/webp", "image/gif"];
284
285/// An image attached to a chat message
286#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
287pub struct ImagePart {
288    /// Base64-encoded image data
289    pub data: String,
290    /// MIME type (e.g., "image/png", "image/jpeg")
291    pub mime_type: String,
292}
293
294impl ImagePart {
295    /// Create a new image part, validating the MIME type.
296    ///
297    /// Accepted MIME types: `image/png`, `image/jpeg`, `image/webp`, `image/gif`.
298    ///
299    /// # Errors
300    ///
301    /// Returns [`RunnerError`] if the MIME type is not supported.
302    pub fn new(data: impl Into<String>, mime_type: impl Into<String>) -> Result<Self, RunnerError> {
303        let mime_type = mime_type.into();
304        if !VALID_IMAGE_MIME_TYPES.contains(&mime_type.as_str()) {
305            return Err(RunnerError::config(format!(
306                "Unsupported image MIME type '{mime_type}'; expected one of: {}",
307                VALID_IMAGE_MIME_TYPES.join(", ")
308            )));
309        }
310        Ok(Self {
311            data: data.into(),
312            mime_type,
313        })
314    }
315}
316
317/// A single message in a chat conversation
318#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct ChatMessage {
320    /// Role of the message sender
321    pub role: MessageRole,
322    /// Content of the message
323    pub content: String,
324    /// Images attached to the message (only meaningful for `User` role)
325    #[serde(default, skip_serializing_if = "Option::is_none")]
326    pub images: Option<Vec<ImagePart>>,
327    /// Tool calls requested by the assistant (only for `Assistant` role)
328    #[serde(default, skip_serializing_if = "Option::is_none")]
329    pub tool_calls: Option<Vec<ToolCallRequest>>,
330    /// ID of the tool call this message responds to (only for `Tool` role)
331    #[serde(default, skip_serializing_if = "Option::is_none")]
332    pub tool_call_id: Option<String>,
333    /// Function name for tool result messages
334    #[serde(default, skip_serializing_if = "Option::is_none")]
335    pub name: Option<String>,
336}
337
338impl ChatMessage {
339    /// Create a new chat message
340    #[must_use]
341    pub fn new(role: MessageRole, content: impl Into<String>) -> Self {
342        Self {
343            role,
344            content: content.into(),
345            images: None,
346            tool_calls: None,
347            tool_call_id: None,
348            name: None,
349        }
350    }
351
352    /// Create a system message
353    #[must_use]
354    pub fn system(content: impl Into<String>) -> Self {
355        Self::new(MessageRole::System, content)
356    }
357
358    /// Create a user message
359    #[must_use]
360    pub fn user(content: impl Into<String>) -> Self {
361        Self::new(MessageRole::User, content)
362    }
363
364    /// Create a user message with attached images
365    #[must_use]
366    pub fn user_with_images(content: impl Into<String>, images: Vec<ImagePart>) -> Self {
367        Self {
368            role: MessageRole::User,
369            content: content.into(),
370            images: Some(images),
371            tool_calls: None,
372            tool_call_id: None,
373            name: None,
374        }
375    }
376
377    /// Create an assistant message
378    #[must_use]
379    pub fn assistant(content: impl Into<String>) -> Self {
380        Self::new(MessageRole::Assistant, content)
381    }
382
383    /// Create a tool result message
384    #[must_use]
385    pub fn tool(
386        name: impl Into<String>,
387        tool_call_id: impl Into<String>,
388        content: impl Into<String>,
389    ) -> Self {
390        Self {
391            role: MessageRole::Tool,
392            content: content.into(),
393            images: None,
394            tool_calls: None,
395            tool_call_id: Some(tool_call_id.into()),
396            name: Some(name.into()),
397        }
398    }
399}
400
401// ============================================================================
402// Tool Calling Types
403// ============================================================================
404
405/// A tool call requested by the assistant
406#[derive(Debug, Clone, Serialize, Deserialize)]
407pub struct ToolCallRequest {
408    /// Unique identifier for this tool call
409    pub id: String,
410    /// Name of the function to call
411    pub function_name: String,
412    /// JSON-encoded arguments for the function
413    pub arguments: serde_json::Value,
414}
415
416/// Definition of a tool that can be called by the model
417#[derive(Debug, Clone, Serialize, Deserialize)]
418pub struct ToolDefinition {
419    /// Name of the function
420    pub name: String,
421    /// Description of what the function does
422    pub description: String,
423    /// JSON Schema describing the function parameters
424    #[serde(skip_serializing_if = "Option::is_none")]
425    pub parameters: Option<serde_json::Value>,
426}
427
428/// Controls which tools the model may call
429#[derive(Debug, Clone, Serialize, Deserialize)]
430pub enum ToolChoice {
431    /// Model decides whether to call tools
432    Auto,
433    /// Model will not call any tools
434    None,
435    /// Model must call at least one tool
436    Required,
437    /// Model must call the specified function
438    Specific {
439        /// Name of the function to call
440        name: String,
441    },
442}
443
444/// Controls the response format from the model
445#[derive(Debug, Clone, Serialize, Deserialize)]
446pub enum ResponseFormat {
447    /// Default text response
448    Text,
449    /// Force JSON object output
450    JsonObject,
451    /// Force JSON output conforming to a specific schema
452    JsonSchema {
453        /// Schema name for identification
454        name: String,
455        /// JSON Schema the response must conform to
456        schema: serde_json::Value,
457    },
458}
459
460// ============================================================================
461// Request/Response Types
462// ============================================================================
463
464/// Configuration for a chat completion request
465#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct ChatRequest {
467    /// Conversation messages
468    pub messages: Vec<ChatMessage>,
469    /// Model identifier (provider-specific)
470    pub model: Option<String>,
471    /// Temperature for response randomness (0.0 - 2.0).
472    ///
473    /// Support depends on each provider's [`LlmCapabilities::TEMPERATURE`] flag.
474    /// Use [`validate_capabilities`](crate::validate_capabilities) to check
475    /// before dispatch.
476    pub temperature: Option<f32>,
477    /// Maximum tokens to generate.
478    ///
479    /// Support depends on each provider's [`LlmCapabilities::MAX_TOKENS`] flag.
480    /// Use [`validate_capabilities`](crate::validate_capabilities) to check
481    /// before dispatch.
482    pub max_tokens: Option<u32>,
483    /// Whether to stream the response
484    pub stream: bool,
485    /// Tool definitions available for the model to call
486    #[serde(default, skip_serializing_if = "Option::is_none")]
487    pub tools: Option<Vec<ToolDefinition>>,
488    /// Controls which tools the model may call
489    #[serde(default, skip_serializing_if = "Option::is_none")]
490    pub tool_choice: Option<ToolChoice>,
491    /// Nucleus sampling parameter (0.0 - 1.0)
492    #[serde(default, skip_serializing_if = "Option::is_none")]
493    pub top_p: Option<f32>,
494    /// Stop sequences that halt generation
495    #[serde(default, skip_serializing_if = "Option::is_none")]
496    pub stop: Option<Vec<String>>,
497    /// Control over the response format (text, JSON, or schema-validated JSON)
498    #[serde(default, skip_serializing_if = "Option::is_none")]
499    pub response_format: Option<ResponseFormat>,
500}
501
502impl ChatRequest {
503    /// Create a new chat request with messages
504    #[must_use]
505    pub const fn new(messages: Vec<ChatMessage>) -> Self {
506        Self {
507            messages,
508            model: None,
509            temperature: None,
510            max_tokens: None,
511            stream: false,
512            tools: None,
513            tool_choice: None,
514            top_p: None,
515            stop: None,
516            response_format: None,
517        }
518    }
519
520    /// Set the model to use
521    #[must_use]
522    pub fn with_model(mut self, model: impl Into<String>) -> Self {
523        self.model = Some(model.into());
524        self
525    }
526
527    /// Set the temperature
528    #[must_use]
529    pub const fn with_temperature(mut self, temperature: f32) -> Self {
530        self.temperature = Some(temperature);
531        self
532    }
533
534    /// Set the maximum tokens
535    #[must_use]
536    pub const fn with_max_tokens(mut self, max_tokens: u32) -> Self {
537        self.max_tokens = Some(max_tokens);
538        self
539    }
540
541    /// Enable streaming
542    #[must_use]
543    pub const fn with_streaming(mut self) -> Self {
544        self.stream = true;
545        self
546    }
547
548    /// Set the tool definitions
549    #[must_use]
550    pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
551        self.tools = Some(tools);
552        self
553    }
554
555    /// Set the tool choice
556    #[must_use]
557    pub fn with_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
558        self.tool_choice = Some(tool_choice);
559        self
560    }
561
562    /// Set the `top_p` (nucleus sampling) parameter
563    #[must_use]
564    pub const fn with_top_p(mut self, top_p: f32) -> Self {
565        self.top_p = Some(top_p);
566        self
567    }
568
569    /// Set stop sequences
570    #[must_use]
571    pub fn with_stop(mut self, stop: Vec<String>) -> Self {
572        self.stop = Some(stop);
573        self
574    }
575
576    /// Set the response format
577    #[must_use]
578    pub fn with_response_format(mut self, response_format: ResponseFormat) -> Self {
579        self.response_format = Some(response_format);
580        self
581    }
582
583    /// Check whether any message in this request contains images
584    #[must_use]
585    pub fn has_images(&self) -> bool {
586        self.messages
587            .iter()
588            .any(|m| m.images.as_ref().is_some_and(|imgs| !imgs.is_empty()))
589    }
590}
591
592/// Response from a chat completion
593#[derive(Debug, Clone, Serialize, Deserialize)]
594pub struct ChatResponse {
595    /// Generated message content
596    pub content: String,
597    /// Model used for generation
598    pub model: String,
599    /// Token usage statistics
600    pub usage: Option<TokenUsage>,
601    /// Finish reason (stop, length, etc.)
602    pub finish_reason: Option<String>,
603    /// Warnings about unsupported request parameters
604    #[serde(skip_serializing_if = "Option::is_none")]
605    pub warnings: Option<Vec<String>>,
606    /// Tool calls requested by the model (populated by providers with native function calling)
607    #[serde(default, skip_serializing_if = "Option::is_none")]
608    pub tool_calls: Option<Vec<ToolCallRequest>>,
609}
610
611/// Token usage statistics
612#[derive(Debug, Clone, Serialize, Deserialize)]
613pub struct TokenUsage {
614    /// Number of tokens in the prompt
615    pub prompt_tokens: u32,
616    /// Number of tokens in the completion
617    pub completion_tokens: u32,
618    /// Total tokens used
619    pub total_tokens: u32,
620}
621
622/// A chunk of a streaming response
623#[derive(Debug, Clone, Serialize, Deserialize)]
624pub struct StreamChunk {
625    /// Content delta for this chunk
626    pub delta: String,
627    /// Whether this is the final chunk
628    pub is_final: bool,
629    /// Finish reason if final
630    pub finish_reason: Option<String>,
631}
632
633/// Stream type for chat completion responses
634pub type ChatStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, RunnerError>> + Send>>;
635
636// ============================================================================
637// Provider Trait
638// ============================================================================
639
640/// LLM provider trait for chat completion
641///
642/// Implement this trait to add a new LLM runner. Each runner wraps
643/// a CLI tool and translates between the chat protocol and the
644/// tool's native interface.
645#[async_trait]
646pub trait LlmProvider: Send + Sync {
647    /// Unique provider identifier (e.g., `claude_code`, `copilot`)
648    fn name(&self) -> &'static str;
649
650    /// Human-readable display name for the provider
651    fn display_name(&self) -> &str;
652
653    /// Provider capabilities (streaming, function calling, etc.)
654    fn capabilities(&self) -> LlmCapabilities;
655
656    /// Default model to use if not specified in request
657    fn default_model(&self) -> &str;
658
659    /// Available models for this provider
660    fn available_models(&self) -> &[String];
661
662    /// Perform a chat completion (non-streaming)
663    async fn complete(&self, request: &ChatRequest) -> Result<ChatResponse, RunnerError>;
664
665    /// Perform a streaming chat completion
666    ///
667    /// Returns a stream of chunks that can be consumed incrementally.
668    /// Falls back to non-streaming if not supported.
669    async fn complete_stream(&self, request: &ChatRequest) -> Result<ChatStream, RunnerError>;
670
671    /// Check if the provider is healthy and ready to serve requests
672    async fn health_check(&self) -> Result<bool, RunnerError>;
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678    use serde_json::json;
679
680    #[test]
681    fn is_transient_classification() {
682        assert!(ErrorKind::Timeout.is_transient());
683        assert!(ErrorKind::ExternalService.is_transient());
684        assert!(!ErrorKind::Internal.is_transient());
685        assert!(!ErrorKind::BinaryNotFound.is_transient());
686        assert!(!ErrorKind::AuthFailure.is_transient());
687        assert!(!ErrorKind::Config.is_transient());
688        assert!(!ErrorKind::Guardrail.is_transient());
689    }
690
691    #[test]
692    fn tool_call_request_serde_round_trip() {
693        let tc = ToolCallRequest {
694            id: "call_1".to_owned(),
695            function_name: "get_weather".to_owned(),
696            arguments: json!({"city": "Paris"}),
697        };
698        let json = serde_json::to_string(&tc).unwrap();
699        let deserialized: ToolCallRequest = serde_json::from_str(&json).unwrap();
700        assert_eq!(deserialized.id, "call_1");
701        assert_eq!(deserialized.function_name, "get_weather");
702        assert_eq!(deserialized.arguments["city"], "Paris");
703    }
704
705    #[test]
706    fn tool_definition_serde_round_trip() {
707        let td = ToolDefinition {
708            name: "search".to_owned(),
709            description: "Search the web".to_owned(),
710            parameters: Some(json!({"type": "object", "properties": {"q": {"type": "string"}}})),
711        };
712        let json = serde_json::to_string(&td).unwrap();
713        let deserialized: ToolDefinition = serde_json::from_str(&json).unwrap();
714        assert_eq!(deserialized.name, "search");
715        assert!(deserialized.parameters.is_some());
716    }
717
718    #[test]
719    fn tool_definition_without_parameters() {
720        let td = ToolDefinition {
721            name: "ping".to_owned(),
722            description: "Check connectivity".to_owned(),
723            parameters: None,
724        };
725        let json = serde_json::to_string(&td).unwrap();
726        assert!(!json.contains("parameters"));
727        let deserialized: ToolDefinition = serde_json::from_str(&json).unwrap();
728        assert!(deserialized.parameters.is_none());
729    }
730
731    #[test]
732    fn tool_choice_serde_variants() {
733        let auto = ToolChoice::Auto;
734        let json = serde_json::to_string(&auto).unwrap();
735        let deserialized: ToolChoice = serde_json::from_str(&json).unwrap();
736        assert!(matches!(deserialized, ToolChoice::Auto));
737
738        let none = ToolChoice::None;
739        let json = serde_json::to_string(&none).unwrap();
740        let deserialized: ToolChoice = serde_json::from_str(&json).unwrap();
741        assert!(matches!(deserialized, ToolChoice::None));
742
743        let required = ToolChoice::Required;
744        let json = serde_json::to_string(&required).unwrap();
745        let deserialized: ToolChoice = serde_json::from_str(&json).unwrap();
746        assert!(matches!(deserialized, ToolChoice::Required));
747
748        let specific = ToolChoice::Specific {
749            name: "get_weather".to_owned(),
750        };
751        let json = serde_json::to_string(&specific).unwrap();
752        let deserialized: ToolChoice = serde_json::from_str(&json).unwrap();
753        assert!(matches!(deserialized, ToolChoice::Specific { name } if name == "get_weather"));
754    }
755
756    #[test]
757    fn response_format_serde_variants() {
758        let text = ResponseFormat::Text;
759        let json = serde_json::to_string(&text).unwrap();
760        let deserialized: ResponseFormat = serde_json::from_str(&json).unwrap();
761        assert!(matches!(deserialized, ResponseFormat::Text));
762
763        let json_obj = ResponseFormat::JsonObject;
764        let json = serde_json::to_string(&json_obj).unwrap();
765        let deserialized: ResponseFormat = serde_json::from_str(&json).unwrap();
766        assert!(matches!(deserialized, ResponseFormat::JsonObject));
767
768        let json_schema = ResponseFormat::JsonSchema {
769            name: "person".to_owned(),
770            schema: json!({"type": "object", "properties": {"name": {"type": "string"}}}),
771        };
772        let json = serde_json::to_string(&json_schema).unwrap();
773        let deserialized: ResponseFormat = serde_json::from_str(&json).unwrap();
774        assert!(
775            matches!(deserialized, ResponseFormat::JsonSchema { name, .. } if name == "person")
776        );
777    }
778
779    #[test]
780    fn chat_message_tool_constructor() {
781        let msg = ChatMessage::tool("get_weather", "call_1", r#"{"temp": 72}"#);
782        assert_eq!(msg.role, MessageRole::Tool);
783        assert_eq!(msg.content, r#"{"temp": 72}"#);
784        assert_eq!(msg.tool_call_id.as_deref(), Some("call_1"));
785        assert_eq!(msg.name.as_deref(), Some("get_weather"));
786        assert!(msg.tool_calls.is_none());
787    }
788
789    #[test]
790    fn chat_message_regular_constructors_have_none_tool_fields() {
791        let user = ChatMessage::user("hello");
792        assert!(user.tool_calls.is_none());
793        assert!(user.tool_call_id.is_none());
794        assert!(user.name.is_none());
795        assert!(user.images.is_none());
796    }
797
798    #[test]
799    fn image_part_valid_mime_types() {
800        for mime in &["image/png", "image/jpeg", "image/webp", "image/gif"] {
801            let part = ImagePart::new("base64data", *mime);
802            assert!(part.is_ok(), "Expected {mime} to be valid");
803        }
804    }
805
806    #[test]
807    fn image_part_invalid_mime_type() {
808        let err = ImagePart::new("data", "image/bmp").unwrap_err();
809        assert_eq!(err.kind, ErrorKind::Config);
810        assert!(err.message.contains("image/bmp"));
811    }
812
813    #[test]
814    fn user_with_images_constructor() {
815        let img = ImagePart::new("aGVsbG8=", "image/png").unwrap();
816        let msg = ChatMessage::user_with_images("describe this", vec![img]);
817        assert_eq!(msg.role, MessageRole::User);
818        assert_eq!(msg.content, "describe this");
819        let images = msg.images.as_ref().unwrap();
820        assert_eq!(images.len(), 1);
821        assert_eq!(images[0].mime_type, "image/png");
822    }
823
824    #[test]
825    fn chat_request_has_images() {
826        let img = ImagePart::new("data", "image/jpeg").unwrap();
827        let with = ChatRequest::new(vec![ChatMessage::user_with_images("x", vec![img])]);
828        assert!(with.has_images());
829
830        let without = ChatRequest::new(vec![ChatMessage::user("text only")]);
831        assert!(!without.has_images());
832    }
833
834    #[test]
835    fn chat_request_has_images_empty_vec() {
836        let msg = ChatMessage::user_with_images("x", vec![]);
837        let req = ChatRequest::new(vec![msg]);
838        assert!(!req.has_images());
839    }
840
841    #[test]
842    fn image_part_serde_round_trip() {
843        let img = ImagePart::new("aGVsbG8=", "image/png").unwrap();
844        let json = serde_json::to_string(&img).unwrap();
845        let deserialized: ImagePart = serde_json::from_str(&json).unwrap();
846        assert_eq!(deserialized, img);
847    }
848
849    #[test]
850    fn chat_message_with_images_serde_round_trip() {
851        let img = ImagePart::new("data", "image/jpeg").unwrap();
852        let msg = ChatMessage::user_with_images("describe", vec![img]);
853        let json = serde_json::to_string(&msg).unwrap();
854        let deserialized: ChatMessage = serde_json::from_str(&json).unwrap();
855        assert_eq!(deserialized.images.as_ref().unwrap().len(), 1);
856        assert_eq!(deserialized.images.unwrap()[0].mime_type, "image/jpeg");
857    }
858
859    #[test]
860    fn chat_message_without_images_backward_compat() {
861        let json = r#"{"role":"user","content":"hello"}"#;
862        let msg: ChatMessage = serde_json::from_str(json).unwrap();
863        assert!(msg.images.is_none());
864        assert_eq!(msg.content, "hello");
865    }
866
867    #[test]
868    fn chat_message_images_not_serialized_when_none() {
869        let msg = ChatMessage::user("hello");
870        let json = serde_json::to_string(&msg).unwrap();
871        assert!(!json.contains("images"));
872    }
873
874    #[test]
875    fn chat_request_builder_methods() {
876        let req = ChatRequest::new(vec![ChatMessage::user("hi")])
877            .with_tools(vec![ToolDefinition {
878                name: "test".to_owned(),
879                description: "test fn".to_owned(),
880                parameters: None,
881            }])
882            .with_tool_choice(ToolChoice::Required)
883            .with_top_p(0.9)
884            .with_stop(vec!["END".to_owned()])
885            .with_response_format(ResponseFormat::JsonObject);
886
887        assert!(req.tools.is_some());
888        assert!(matches!(req.tool_choice, Some(ToolChoice::Required)));
889        assert_eq!(req.top_p, Some(0.9));
890        assert_eq!(req.stop.as_ref().unwrap()[0], "END");
891        assert!(matches!(
892            req.response_format,
893            Some(ResponseFormat::JsonObject)
894        ));
895    }
896
897    #[test]
898    fn message_role_tool_as_str() {
899        assert_eq!(MessageRole::Tool.as_str(), "tool");
900    }
901
902    #[test]
903    fn capability_flags_new_fields() {
904        let caps = LlmCapabilities::TOP_P
905            | LlmCapabilities::STOP_SEQUENCES
906            | LlmCapabilities::RESPONSE_FORMAT;
907        assert!(caps.supports_top_p());
908        assert!(caps.supports_stop_sequences());
909        assert!(caps.supports_response_format());
910
911        let empty = LlmCapabilities::empty();
912        assert!(!empty.supports_top_p());
913        assert!(!empty.supports_stop_sequences());
914        assert!(!empty.supports_response_format());
915    }
916}