1use serde::{Deserialize, Serialize};
7use std::future::Future;
8use std::pin::Pin;
9use std::time::Duration;
10
11pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
13
14#[derive(Debug, Clone)]
16pub struct ChatRequest {
17 pub messages: Vec<ChatMessage>,
18 pub system: Option<String>,
19 pub tools: Vec<ToolDefinition>,
20 pub response_format: ResponseFormat,
21 pub max_tokens: Option<u32>,
22 pub temperature: Option<f32>,
23 pub stop_sequences: Vec<String>,
24 pub model: Option<String>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ChatMessage {
29 pub role: ChatRole,
30 pub content: String,
31 #[serde(default, skip_serializing_if = "Vec::is_empty")]
32 pub tool_calls: Vec<ToolCall>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub tool_call_id: Option<String>,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(rename_all = "snake_case")]
39pub enum ChatRole {
40 System,
41 User,
42 Assistant,
43 Tool,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ToolDefinition {
48 pub name: String,
49 pub description: String,
50 pub parameters: serde_json::Value,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ToolCall {
55 pub id: String,
56 pub name: String,
57 pub arguments: String,
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
62#[serde(rename_all = "snake_case")]
63pub enum ResponseFormat {
64 #[default]
65 Text,
66 Markdown,
67 Json,
68 Yaml,
69 Toml,
70}
71
72impl ResponseFormat {
73 #[must_use]
74 pub fn default_structured() -> Self {
75 Self::Yaml
76 }
77
78 #[must_use]
79 pub fn fallback(self) -> Option<Self> {
80 match self {
81 Self::Json | Self::Text => None,
82 Self::Yaml | Self::Toml | Self::Markdown => Some(Self::Json),
83 }
84 }
85
86 #[must_use]
87 pub fn system_instruction(self) -> Option<&'static str> {
88 match self {
89 Self::Text => None,
90 Self::Markdown => Some(
91 "You MUST respond with valid Markdown only. Use headings, lists, and tables to structure the data. Do NOT wrap output in code fences or return serialized JSON/YAML. Present data as readable Markdown.",
92 ),
93 Self::Json => Some("You MUST respond with valid JSON only. No other text."),
94 Self::Yaml => Some(
95 "You MUST respond with valid YAML only. No anchors, no aliases, no custom tags. No other text or code fences.",
96 ),
97 Self::Toml => Some(
98 "You MUST respond with valid TOML only. Use sections and key-value pairs. No inline tables for complex data. No other text or code fences.",
99 ),
100 }
101 }
102}
103
104#[derive(Debug, Clone)]
105pub struct ChatResponse {
106 pub content: String,
107 pub tool_calls: Vec<ToolCall>,
108 pub usage: Option<TokenUsage>,
109 pub model: Option<String>,
110 pub finish_reason: Option<FinishReason>,
111 pub metadata: std::collections::HashMap<String, String>,
112}
113
114#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
115pub struct TokenUsage {
116 pub prompt_tokens: u32,
117 pub completion_tokens: u32,
118 pub total_tokens: u32,
119}
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
122#[serde(rename_all = "snake_case")]
123pub enum FinishReason {
124 Stop,
125 Length,
126 ContentFilter,
127 StopSequence,
128 ToolCalls,
129}
130
131#[derive(Debug, Clone)]
133pub enum LlmError {
134 RateLimited {
135 retry_after: Duration,
136 message: Option<String>,
137 },
138 Timeout {
139 elapsed: Duration,
140 deadline: Duration,
141 },
142 AuthDenied {
143 message: String,
144 },
145 InvalidRequest {
146 message: String,
147 },
148 ModelNotFound {
149 model: String,
150 },
151 ContextLengthExceeded {
152 max_tokens: u32,
153 request_tokens: u32,
154 },
155 ContentFiltered {
156 reason: String,
157 },
158 ResponseFormatMismatch {
159 expected: ResponseFormat,
160 message: String,
161 },
162 ProviderError {
163 message: String,
164 code: Option<String>,
165 },
166 NetworkError {
167 message: String,
168 },
169}
170
171impl std::fmt::Display for LlmError {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 match self {
174 Self::RateLimited {
175 retry_after,
176 message,
177 } => {
178 write!(f, "rate limited (retry after {:?})", retry_after)?;
179 if let Some(message) = message {
180 write!(f, ": {message}")?;
181 }
182 Ok(())
183 }
184 Self::Timeout { elapsed, deadline } => {
185 write!(f, "timeout after {:?} (deadline: {:?})", elapsed, deadline)
186 }
187 Self::AuthDenied { message } => write!(f, "authentication denied: {message}"),
188 Self::InvalidRequest { message } => write!(f, "invalid request: {message}"),
189 Self::ModelNotFound { model } => write!(f, "model not found: {model}"),
190 Self::ContextLengthExceeded {
191 max_tokens,
192 request_tokens,
193 } => {
194 write!(
195 f,
196 "context length exceeded: {request_tokens} tokens (max: {max_tokens})"
197 )
198 }
199 Self::ContentFiltered { reason } => write!(f, "content filtered: {reason}"),
200 Self::ResponseFormatMismatch { expected, message } => {
201 write!(f, "response format mismatch for {:?}: {message}", expected)
202 }
203 Self::ProviderError { message, code } => {
204 write!(f, "provider error: {message}")?;
205 if let Some(code) = code {
206 write!(f, " (code: {code})")?;
207 }
208 Ok(())
209 }
210 Self::NetworkError { message } => write!(f, "network error: {message}"),
211 }
212 }
213}
214
215impl std::error::Error for LlmError {}
216
217pub trait ChatBackend: Send + Sync {
219 type ChatFut<'a>: Future<Output = Result<ChatResponse, LlmError>> + Send + 'a
220 where
221 Self: 'a;
222
223 fn chat<'a>(&'a self, req: ChatRequest) -> Self::ChatFut<'a>;
224}
225
226pub trait DynChatBackend: Send + Sync {
228 fn chat(&self, req: ChatRequest) -> BoxFuture<'_, Result<ChatResponse, LlmError>>;
229}
230
231impl<T: ChatBackend> DynChatBackend for T {
232 fn chat(&self, req: ChatRequest) -> BoxFuture<'_, Result<ChatResponse, LlmError>> {
233 Box::pin(ChatBackend::chat(self, req))
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn response_format_default_structured_is_yaml() {
243 assert_eq!(ResponseFormat::default_structured(), ResponseFormat::Yaml);
244 }
245
246 #[test]
247 fn response_format_fallback() {
248 assert_eq!(ResponseFormat::Text.fallback(), None);
249 assert_eq!(ResponseFormat::Json.fallback(), None);
250 assert_eq!(ResponseFormat::Yaml.fallback(), Some(ResponseFormat::Json));
251 assert_eq!(ResponseFormat::Toml.fallback(), Some(ResponseFormat::Json));
252 assert_eq!(
253 ResponseFormat::Markdown.fallback(),
254 Some(ResponseFormat::Json)
255 );
256 }
257
258 #[test]
259 fn response_format_system_instruction_text_is_none() {
260 assert!(ResponseFormat::Text.system_instruction().is_none());
261 }
262
263 #[test]
264 fn response_format_system_instruction_json() {
265 let instr = ResponseFormat::Json.system_instruction().unwrap();
266 assert!(instr.contains("JSON"));
267 }
268
269 #[test]
270 fn response_format_system_instruction_yaml() {
271 let instr = ResponseFormat::Yaml.system_instruction().unwrap();
272 assert!(instr.contains("YAML"));
273 }
274
275 #[test]
276 fn response_format_system_instruction_toml() {
277 let instr = ResponseFormat::Toml.system_instruction().unwrap();
278 assert!(instr.contains("TOML"));
279 }
280
281 #[test]
282 fn response_format_system_instruction_markdown() {
283 let instr = ResponseFormat::Markdown.system_instruction().unwrap();
284 assert!(instr.contains("Markdown"));
285 }
286
287 #[test]
288 fn response_format_default_is_text() {
289 assert_eq!(ResponseFormat::default(), ResponseFormat::Text);
290 }
291
292 #[test]
293 fn chat_role_variants_exist() {
294 let _system = ChatRole::System;
295 let _user = ChatRole::User;
296 let _assistant = ChatRole::Assistant;
297 let _tool = ChatRole::Tool;
298 }
299
300 #[test]
301 fn llm_error_display_rate_limited() {
302 let err = LlmError::RateLimited {
303 retry_after: Duration::from_secs(30),
304 message: Some("too many requests".into()),
305 };
306 let s = err.to_string();
307 assert!(s.contains("rate limited"));
308 assert!(s.contains("too many requests"));
309 }
310
311 #[test]
312 fn llm_error_display_rate_limited_no_message() {
313 let err = LlmError::RateLimited {
314 retry_after: Duration::from_secs(5),
315 message: None,
316 };
317 let s = err.to_string();
318 assert!(s.contains("rate limited"));
319 assert!(!s.contains(":"));
320 }
321
322 #[test]
323 fn llm_error_display_timeout() {
324 let err = LlmError::Timeout {
325 elapsed: Duration::from_secs(10),
326 deadline: Duration::from_secs(5),
327 };
328 let s = err.to_string();
329 assert!(s.contains("timeout"));
330 assert!(s.contains("deadline"));
331 }
332
333 #[test]
334 fn llm_error_display_auth_denied() {
335 let err = LlmError::AuthDenied {
336 message: "bad key".into(),
337 };
338 assert!(err.to_string().contains("authentication denied"));
339 }
340
341 #[test]
342 fn llm_error_display_invalid_request() {
343 let err = LlmError::InvalidRequest {
344 message: "missing model".into(),
345 };
346 assert!(err.to_string().contains("invalid request"));
347 }
348
349 #[test]
350 fn llm_error_display_model_not_found() {
351 let err = LlmError::ModelNotFound {
352 model: "gpt-5".into(),
353 };
354 assert!(err.to_string().contains("gpt-5"));
355 }
356
357 #[test]
358 fn llm_error_display_context_length() {
359 let err = LlmError::ContextLengthExceeded {
360 max_tokens: 4096,
361 request_tokens: 8000,
362 };
363 let s = err.to_string();
364 assert!(s.contains("8000"));
365 assert!(s.contains("4096"));
366 }
367
368 #[test]
369 fn llm_error_display_content_filtered() {
370 let err = LlmError::ContentFiltered {
371 reason: "safety".into(),
372 };
373 assert!(err.to_string().contains("safety"));
374 }
375
376 #[test]
377 fn llm_error_display_response_format_mismatch() {
378 let err = LlmError::ResponseFormatMismatch {
379 expected: ResponseFormat::Json,
380 message: "got yaml".into(),
381 };
382 let s = err.to_string();
383 assert!(s.contains("format mismatch"));
384 assert!(s.contains("got yaml"));
385 }
386
387 #[test]
388 fn llm_error_display_provider_error_with_code() {
389 let err = LlmError::ProviderError {
390 message: "internal".into(),
391 code: Some("500".into()),
392 };
393 let s = err.to_string();
394 assert!(s.contains("provider error"));
395 assert!(s.contains("500"));
396 }
397
398 #[test]
399 fn llm_error_display_provider_error_no_code() {
400 let err = LlmError::ProviderError {
401 message: "oops".into(),
402 code: None,
403 };
404 let s = err.to_string();
405 assert!(s.contains("oops"));
406 assert!(!s.contains("code"));
407 }
408
409 #[test]
410 fn llm_error_display_network() {
411 let err = LlmError::NetworkError {
412 message: "dns failed".into(),
413 };
414 assert!(err.to_string().contains("dns failed"));
415 }
416}