Skip to main content

embacle/
types.rs

1// ABOUTME: Core types for CLI LLM runners — standalone definitions independent of pierre-core
2// ABOUTME: Provides LlmProvider trait, ChatRequest/Response, error types, and capability flags
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright (c) 2026 dravr.ai
6
7//! # Core Types
8//!
9//! Self-contained type definitions for the CLI LLM runners library.
10//! These types mirror the LLM provider contract without requiring
11//! any external platform dependency.
12
13use std::any::Any;
14use std::fmt;
15use std::pin::Pin;
16
17use async_trait::async_trait;
18use serde::{Deserialize, Serialize};
19use tokio_stream::Stream;
20
21// ============================================================================
22// Error Type
23// ============================================================================
24
25/// Error type for CLI LLM runner operations
26#[derive(Debug, Clone)]
27pub struct RunnerError {
28    /// Error category
29    pub kind: ErrorKind,
30    /// Human-readable error message
31    pub message: String,
32}
33
34/// Categories of errors produced by CLI runners
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum ErrorKind {
37    /// Internal runner error (bug, unexpected state)
38    Internal,
39    /// External service error (CLI tool failure, bad response)
40    ExternalService,
41    /// CLI command exceeded its configured timeout
42    Timeout,
43    /// Binary not found or not executable
44    BinaryNotFound,
45    /// Authentication or authorization failure
46    AuthFailure,
47    /// Configuration error
48    Config,
49}
50
51impl RunnerError {
52    /// Create an internal error
53    pub fn internal(message: impl Into<String>) -> Self {
54        Self {
55            kind: ErrorKind::Internal,
56            message: message.into(),
57        }
58    }
59
60    /// Create an external service error
61    pub fn external_service(service: impl Into<String>, message: impl Into<String>) -> Self {
62        Self {
63            kind: ErrorKind::ExternalService,
64            message: format!("{}: {}", service.into(), message.into()),
65        }
66    }
67
68    /// Create a binary-not-found error
69    pub fn binary_not_found(binary: impl Into<String>) -> Self {
70        Self {
71            kind: ErrorKind::BinaryNotFound,
72            message: format!("Binary not found: {}", binary.into()),
73        }
74    }
75
76    /// Create an auth failure error
77    pub fn auth_failure(message: impl Into<String>) -> Self {
78        Self {
79            kind: ErrorKind::AuthFailure,
80            message: message.into(),
81        }
82    }
83
84    /// Create a config error
85    pub fn config(message: impl Into<String>) -> Self {
86        Self {
87            kind: ErrorKind::Config,
88            message: message.into(),
89        }
90    }
91
92    /// Create a timeout error
93    pub fn timeout(message: impl Into<String>) -> Self {
94        Self {
95            kind: ErrorKind::Timeout,
96            message: message.into(),
97        }
98    }
99}
100
101impl fmt::Display for RunnerError {
102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103        write!(f, "{:?}: {}", self.kind, self.message)
104    }
105}
106
107impl std::error::Error for RunnerError {}
108
109// ============================================================================
110// Capability Flags
111// ============================================================================
112
113bitflags::bitflags! {
114    /// LLM provider capability flags using bitflags for efficient storage
115    ///
116    /// Indicates which features a provider supports. Used by the system to
117    /// select appropriate providers and configure request handling.
118    #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
119    pub struct LlmCapabilities: u8 {
120        /// Provider supports streaming responses
121        const STREAMING = 0b0000_0001;
122        /// Provider supports function/tool calling
123        const FUNCTION_CALLING = 0b0000_0010;
124        /// Provider supports vision/image input
125        const VISION = 0b0000_0100;
126        /// Provider supports JSON mode output
127        const JSON_MODE = 0b0000_1000;
128        /// Provider supports system messages
129        const SYSTEM_MESSAGES = 0b0001_0000;
130        /// Provider supports SDK-managed tool calling (tool loop handled by SDK, not by caller)
131        const SDK_TOOL_CALLING = 0b0010_0000;
132    }
133}
134
135impl LlmCapabilities {
136    /// Create capabilities for a basic text-only provider
137    #[must_use]
138    pub const fn text_only() -> Self {
139        Self::STREAMING.union(Self::SYSTEM_MESSAGES)
140    }
141
142    /// Create capabilities for a full-featured provider (like Gemini Pro)
143    #[must_use]
144    pub const fn full_featured() -> Self {
145        Self::STREAMING
146            .union(Self::FUNCTION_CALLING)
147            .union(Self::VISION)
148            .union(Self::JSON_MODE)
149            .union(Self::SYSTEM_MESSAGES)
150    }
151
152    /// Check if streaming is supported
153    #[must_use]
154    pub const fn supports_streaming(&self) -> bool {
155        self.contains(Self::STREAMING)
156    }
157
158    /// Check if function calling is supported
159    #[must_use]
160    pub const fn supports_function_calling(&self) -> bool {
161        self.contains(Self::FUNCTION_CALLING)
162    }
163
164    /// Check if vision is supported
165    #[must_use]
166    pub const fn supports_vision(&self) -> bool {
167        self.contains(Self::VISION)
168    }
169
170    /// Check if JSON mode is supported
171    #[must_use]
172    pub const fn supports_json_mode(&self) -> bool {
173        self.contains(Self::JSON_MODE)
174    }
175
176    /// Check if system messages are supported
177    #[must_use]
178    pub const fn supports_system_messages(&self) -> bool {
179        self.contains(Self::SYSTEM_MESSAGES)
180    }
181
182    /// Check if SDK-managed tool calling is supported
183    #[must_use]
184    pub const fn supports_sdk_tool_calling(&self) -> bool {
185        self.contains(Self::SDK_TOOL_CALLING)
186    }
187}
188
189// ============================================================================
190// Message Types
191// ============================================================================
192
193/// Role of a message in the conversation
194#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
195#[serde(rename_all = "lowercase")]
196pub enum MessageRole {
197    /// System instruction message
198    System,
199    /// User input message
200    User,
201    /// Assistant response message
202    Assistant,
203}
204
205impl MessageRole {
206    /// Convert to string representation for API calls
207    #[must_use]
208    pub const fn as_str(&self) -> &'static str {
209        match self {
210            Self::System => "system",
211            Self::User => "user",
212            Self::Assistant => "assistant",
213        }
214    }
215}
216
217/// A single message in a chat conversation
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct ChatMessage {
220    /// Role of the message sender
221    pub role: MessageRole,
222    /// Content of the message
223    pub content: String,
224}
225
226impl ChatMessage {
227    /// Create a new chat message
228    #[must_use]
229    pub fn new(role: MessageRole, content: impl Into<String>) -> Self {
230        Self {
231            role,
232            content: content.into(),
233        }
234    }
235
236    /// Create a system message
237    #[must_use]
238    pub fn system(content: impl Into<String>) -> Self {
239        Self::new(MessageRole::System, content)
240    }
241
242    /// Create a user message
243    #[must_use]
244    pub fn user(content: impl Into<String>) -> Self {
245        Self::new(MessageRole::User, content)
246    }
247
248    /// Create an assistant message
249    #[must_use]
250    pub fn assistant(content: impl Into<String>) -> Self {
251        Self::new(MessageRole::Assistant, content)
252    }
253}
254
255// ============================================================================
256// Request/Response Types
257// ============================================================================
258
259/// Configuration for a chat completion request
260#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct ChatRequest {
262    /// Conversation messages
263    pub messages: Vec<ChatMessage>,
264    /// Model identifier (provider-specific)
265    pub model: Option<String>,
266    /// Temperature for response randomness (0.0 - 2.0).
267    ///
268    /// Not all CLI runners support this. Currently no runner propagates
269    /// temperature; it is accepted for API compatibility and ignored by
270    /// runners that lack a corresponding CLI flag.
271    pub temperature: Option<f32>,
272    /// Maximum tokens to generate.
273    ///
274    /// Runner support varies:
275    /// - **Claude Code**: propagated via `CLAUDE_CODE_MAX_OUTPUT_TOKENS` env var.
276    /// - **Copilot / Cursor Agent / `OpenCode`**: not supported by the CLI;
277    ///   the field is accepted but ignored.
278    pub max_tokens: Option<u32>,
279    /// Whether to stream the response
280    pub stream: bool,
281}
282
283impl ChatRequest {
284    /// Create a new chat request with messages
285    #[must_use]
286    pub const fn new(messages: Vec<ChatMessage>) -> Self {
287        Self {
288            messages,
289            model: None,
290            temperature: None,
291            max_tokens: None,
292            stream: false,
293        }
294    }
295
296    /// Set the model to use
297    #[must_use]
298    pub fn with_model(mut self, model: impl Into<String>) -> Self {
299        self.model = Some(model.into());
300        self
301    }
302
303    /// Set the temperature
304    #[must_use]
305    pub const fn with_temperature(mut self, temperature: f32) -> Self {
306        self.temperature = Some(temperature);
307        self
308    }
309
310    /// Set the maximum tokens
311    #[must_use]
312    pub const fn with_max_tokens(mut self, max_tokens: u32) -> Self {
313        self.max_tokens = Some(max_tokens);
314        self
315    }
316
317    /// Enable streaming
318    #[must_use]
319    pub const fn with_streaming(mut self) -> Self {
320        self.stream = true;
321        self
322    }
323}
324
325/// Response from a chat completion
326#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct ChatResponse {
328    /// Generated message content
329    pub content: String,
330    /// Model used for generation
331    pub model: String,
332    /// Token usage statistics
333    pub usage: Option<TokenUsage>,
334    /// Finish reason (stop, length, etc.)
335    pub finish_reason: Option<String>,
336}
337
338/// Token usage statistics
339#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct TokenUsage {
341    /// Number of tokens in the prompt
342    pub prompt_tokens: u32,
343    /// Number of tokens in the completion
344    pub completion_tokens: u32,
345    /// Total tokens used
346    pub total_tokens: u32,
347}
348
349/// A chunk of a streaming response
350#[derive(Debug, Clone, Serialize, Deserialize)]
351pub struct StreamChunk {
352    /// Content delta for this chunk
353    pub delta: String,
354    /// Whether this is the final chunk
355    pub is_final: bool,
356    /// Finish reason if final
357    pub finish_reason: Option<String>,
358}
359
360/// Stream type for chat completion responses
361pub type ChatStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, RunnerError>> + Send>>;
362
363// ============================================================================
364// Provider Trait
365// ============================================================================
366
367/// LLM provider trait for chat completion
368///
369/// Implement this trait to add a new LLM runner. Each runner wraps
370/// a CLI tool and translates between the chat protocol and the
371/// tool's native interface.
372#[async_trait]
373pub trait LlmProvider: Send + Sync {
374    /// Unique provider identifier (e.g., `claude_code`, `copilot`)
375    fn name(&self) -> &'static str;
376
377    /// Human-readable display name for the provider
378    fn display_name(&self) -> &'static str;
379
380    /// Provider capabilities (streaming, function calling, etc.)
381    fn capabilities(&self) -> LlmCapabilities;
382
383    /// Default model to use if not specified in request
384    fn default_model(&self) -> &str;
385
386    /// Available models for this provider
387    fn available_models(&self) -> &[String];
388
389    /// Perform a chat completion (non-streaming)
390    async fn complete(&self, request: &ChatRequest) -> Result<ChatResponse, RunnerError>;
391
392    /// Perform a streaming chat completion
393    ///
394    /// Returns a stream of chunks that can be consumed incrementally.
395    /// Falls back to non-streaming if not supported.
396    async fn complete_stream(&self, request: &ChatRequest) -> Result<ChatStream, RunnerError>;
397
398    /// Check if the provider is healthy and ready to serve requests
399    async fn health_check(&self) -> Result<bool, RunnerError>;
400
401    /// Downcast to a concrete type for provider-specific operations
402    fn as_any(&self) -> &dyn Any;
403}