Skip to main content

converge_provider/
chat.rs

1// Copyright 2024-2026 Reflective Labs
2// SPDX-License-Identifier: MIT
3
4//! Canonical chat capability contracts for provider consumers and adapters.
5
6use serde::{Deserialize, Serialize};
7use std::future::Future;
8use std::pin::Pin;
9use std::time::Duration;
10
11/// Boxed future type for dyn-safe capability traits.
12pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
13
14/// Request for chat completion.
15#[derive(Debug, Clone)]
16pub struct ChatRequest {
17    pub messages: Vec<ChatMessage>,
18    pub system: Option<String>,
19    pub tools: Vec<ToolDefinition>,
20    pub response_format: ResponseFormat,
21    pub max_tokens: Option<u32>,
22    pub temperature: Option<f32>,
23    pub stop_sequences: Vec<String>,
24    pub model: Option<String>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ChatMessage {
29    pub role: ChatRole,
30    pub content: String,
31    #[serde(default, skip_serializing_if = "Vec::is_empty")]
32    pub tool_calls: Vec<ToolCall>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub tool_call_id: Option<String>,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(rename_all = "snake_case")]
39pub enum ChatRole {
40    System,
41    User,
42    Assistant,
43    Tool,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ToolDefinition {
48    pub name: String,
49    pub description: String,
50    pub parameters: serde_json::Value,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ToolCall {
55    pub id: String,
56    pub name: String,
57    pub arguments: String,
58}
59
60/// Requested output format for a chat completion.
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
62#[serde(rename_all = "snake_case")]
63pub enum ResponseFormat {
64    #[default]
65    Text,
66    Markdown,
67    Json,
68    Yaml,
69    Toml,
70}
71
72impl ResponseFormat {
73    #[must_use]
74    pub fn default_structured() -> Self {
75        Self::Yaml
76    }
77
78    #[must_use]
79    pub fn fallback(self) -> Option<Self> {
80        match self {
81            Self::Json | Self::Text => None,
82            Self::Yaml | Self::Toml | Self::Markdown => Some(Self::Json),
83        }
84    }
85
86    #[must_use]
87    pub fn system_instruction(self) -> Option<&'static str> {
88        match self {
89            Self::Text => None,
90            Self::Markdown => Some(
91                "You MUST respond with valid Markdown only. Use headings, lists, and tables to structure the data. Do NOT wrap output in code fences or return serialized JSON/YAML. Present data as readable Markdown.",
92            ),
93            Self::Json => Some("You MUST respond with valid JSON only. No other text."),
94            Self::Yaml => Some(
95                "You MUST respond with valid YAML only. No anchors, no aliases, no custom tags. No other text or code fences.",
96            ),
97            Self::Toml => Some(
98                "You MUST respond with valid TOML only. Use sections and key-value pairs. No inline tables for complex data. No other text or code fences.",
99            ),
100        }
101    }
102}
103
104#[derive(Debug, Clone)]
105pub struct ChatResponse {
106    pub content: String,
107    pub tool_calls: Vec<ToolCall>,
108    pub usage: Option<TokenUsage>,
109    pub model: Option<String>,
110    pub finish_reason: Option<FinishReason>,
111    pub metadata: std::collections::HashMap<String, String>,
112}
113
114#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
115pub struct TokenUsage {
116    pub prompt_tokens: u32,
117    pub completion_tokens: u32,
118    pub total_tokens: u32,
119}
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
122#[serde(rename_all = "snake_case")]
123pub enum FinishReason {
124    Stop,
125    Length,
126    ContentFilter,
127    StopSequence,
128    ToolCalls,
129}
130
131/// Error type for chat operations.
132#[derive(Debug, Clone)]
133pub enum LlmError {
134    RateLimited {
135        retry_after: Duration,
136        message: Option<String>,
137    },
138    Timeout {
139        elapsed: Duration,
140        deadline: Duration,
141    },
142    AuthDenied {
143        message: String,
144    },
145    InvalidRequest {
146        message: String,
147    },
148    ModelNotFound {
149        model: String,
150    },
151    ContextLengthExceeded {
152        max_tokens: u32,
153        request_tokens: u32,
154    },
155    ContentFiltered {
156        reason: String,
157    },
158    ResponseFormatMismatch {
159        expected: ResponseFormat,
160        message: String,
161    },
162    ProviderError {
163        message: String,
164        code: Option<String>,
165    },
166    NetworkError {
167        message: String,
168    },
169}
170
171impl std::fmt::Display for LlmError {
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        match self {
174            Self::RateLimited {
175                retry_after,
176                message,
177            } => {
178                write!(f, "rate limited (retry after {:?})", retry_after)?;
179                if let Some(message) = message {
180                    write!(f, ": {message}")?;
181                }
182                Ok(())
183            }
184            Self::Timeout { elapsed, deadline } => {
185                write!(f, "timeout after {:?} (deadline: {:?})", elapsed, deadline)
186            }
187            Self::AuthDenied { message } => write!(f, "authentication denied: {message}"),
188            Self::InvalidRequest { message } => write!(f, "invalid request: {message}"),
189            Self::ModelNotFound { model } => write!(f, "model not found: {model}"),
190            Self::ContextLengthExceeded {
191                max_tokens,
192                request_tokens,
193            } => {
194                write!(
195                    f,
196                    "context length exceeded: {request_tokens} tokens (max: {max_tokens})"
197                )
198            }
199            Self::ContentFiltered { reason } => write!(f, "content filtered: {reason}"),
200            Self::ResponseFormatMismatch { expected, message } => {
201                write!(f, "response format mismatch for {:?}: {message}", expected)
202            }
203            Self::ProviderError { message, code } => {
204                write!(f, "provider error: {message}")?;
205                if let Some(code) = code {
206                    write!(f, " (code: {code})")?;
207                }
208                Ok(())
209            }
210            Self::NetworkError { message } => write!(f, "network error: {message}"),
211        }
212    }
213}
214
215impl std::error::Error for LlmError {}
216
217/// Chat completion capability.
218pub trait ChatBackend: Send + Sync {
219    type ChatFut<'a>: Future<Output = Result<ChatResponse, LlmError>> + Send + 'a
220    where
221        Self: 'a;
222
223    fn chat<'a>(&'a self, req: ChatRequest) -> Self::ChatFut<'a>;
224}
225
226/// Dyn-safe chat backend for runtime polymorphism.
227pub trait DynChatBackend: Send + Sync {
228    fn chat(&self, req: ChatRequest) -> BoxFuture<'_, Result<ChatResponse, LlmError>>;
229}
230
231impl<T: ChatBackend> DynChatBackend for T {
232    fn chat(&self, req: ChatRequest) -> BoxFuture<'_, Result<ChatResponse, LlmError>> {
233        Box::pin(ChatBackend::chat(self, req))
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn response_format_default_structured_is_yaml() {
243        assert_eq!(ResponseFormat::default_structured(), ResponseFormat::Yaml);
244    }
245
246    #[test]
247    fn response_format_fallback() {
248        assert_eq!(ResponseFormat::Text.fallback(), None);
249        assert_eq!(ResponseFormat::Json.fallback(), None);
250        assert_eq!(ResponseFormat::Yaml.fallback(), Some(ResponseFormat::Json));
251        assert_eq!(ResponseFormat::Toml.fallback(), Some(ResponseFormat::Json));
252        assert_eq!(
253            ResponseFormat::Markdown.fallback(),
254            Some(ResponseFormat::Json)
255        );
256    }
257
258    #[test]
259    fn response_format_system_instruction_text_is_none() {
260        assert!(ResponseFormat::Text.system_instruction().is_none());
261    }
262
263    #[test]
264    fn response_format_system_instruction_json() {
265        let instr = ResponseFormat::Json.system_instruction().unwrap();
266        assert!(instr.contains("JSON"));
267    }
268
269    #[test]
270    fn response_format_system_instruction_yaml() {
271        let instr = ResponseFormat::Yaml.system_instruction().unwrap();
272        assert!(instr.contains("YAML"));
273    }
274
275    #[test]
276    fn response_format_system_instruction_toml() {
277        let instr = ResponseFormat::Toml.system_instruction().unwrap();
278        assert!(instr.contains("TOML"));
279    }
280
281    #[test]
282    fn response_format_system_instruction_markdown() {
283        let instr = ResponseFormat::Markdown.system_instruction().unwrap();
284        assert!(instr.contains("Markdown"));
285    }
286
287    #[test]
288    fn response_format_default_is_text() {
289        assert_eq!(ResponseFormat::default(), ResponseFormat::Text);
290    }
291
292    #[test]
293    fn chat_role_variants_exist() {
294        let _system = ChatRole::System;
295        let _user = ChatRole::User;
296        let _assistant = ChatRole::Assistant;
297        let _tool = ChatRole::Tool;
298    }
299
300    #[test]
301    fn llm_error_display_rate_limited() {
302        let err = LlmError::RateLimited {
303            retry_after: Duration::from_secs(30),
304            message: Some("too many requests".into()),
305        };
306        let s = err.to_string();
307        assert!(s.contains("rate limited"));
308        assert!(s.contains("too many requests"));
309    }
310
311    #[test]
312    fn llm_error_display_rate_limited_no_message() {
313        let err = LlmError::RateLimited {
314            retry_after: Duration::from_secs(5),
315            message: None,
316        };
317        let s = err.to_string();
318        assert!(s.contains("rate limited"));
319        assert!(!s.contains(":"));
320    }
321
322    #[test]
323    fn llm_error_display_timeout() {
324        let err = LlmError::Timeout {
325            elapsed: Duration::from_secs(10),
326            deadline: Duration::from_secs(5),
327        };
328        let s = err.to_string();
329        assert!(s.contains("timeout"));
330        assert!(s.contains("deadline"));
331    }
332
333    #[test]
334    fn llm_error_display_auth_denied() {
335        let err = LlmError::AuthDenied {
336            message: "bad key".into(),
337        };
338        assert!(err.to_string().contains("authentication denied"));
339    }
340
341    #[test]
342    fn llm_error_display_invalid_request() {
343        let err = LlmError::InvalidRequest {
344            message: "missing model".into(),
345        };
346        assert!(err.to_string().contains("invalid request"));
347    }
348
349    #[test]
350    fn llm_error_display_model_not_found() {
351        let err = LlmError::ModelNotFound {
352            model: "gpt-5".into(),
353        };
354        assert!(err.to_string().contains("gpt-5"));
355    }
356
357    #[test]
358    fn llm_error_display_context_length() {
359        let err = LlmError::ContextLengthExceeded {
360            max_tokens: 4096,
361            request_tokens: 8000,
362        };
363        let s = err.to_string();
364        assert!(s.contains("8000"));
365        assert!(s.contains("4096"));
366    }
367
368    #[test]
369    fn llm_error_display_content_filtered() {
370        let err = LlmError::ContentFiltered {
371            reason: "safety".into(),
372        };
373        assert!(err.to_string().contains("safety"));
374    }
375
376    #[test]
377    fn llm_error_display_response_format_mismatch() {
378        let err = LlmError::ResponseFormatMismatch {
379            expected: ResponseFormat::Json,
380            message: "got yaml".into(),
381        };
382        let s = err.to_string();
383        assert!(s.contains("format mismatch"));
384        assert!(s.contains("got yaml"));
385    }
386
387    #[test]
388    fn llm_error_display_provider_error_with_code() {
389        let err = LlmError::ProviderError {
390            message: "internal".into(),
391            code: Some("500".into()),
392        };
393        let s = err.to_string();
394        assert!(s.contains("provider error"));
395        assert!(s.contains("500"));
396    }
397
398    #[test]
399    fn llm_error_display_provider_error_no_code() {
400        let err = LlmError::ProviderError {
401            message: "oops".into(),
402            code: None,
403        };
404        let s = err.to_string();
405        assert!(s.contains("oops"));
406        assert!(!s.contains("code"));
407    }
408
409    #[test]
410    fn llm_error_display_network() {
411        let err = LlmError::NetworkError {
412            message: "dns failed".into(),
413        };
414        assert!(err.to_string().contains("dns failed"));
415    }
416}