siumai 0.10.3

A unified LLM interface library for Rust
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
//! Error handling module
//!
//! Defines all error types used in the LLM library.

use thiserror::Error;

// Static assertions to ensure error types are Send + Sync
// This is important for async code and multi-threading
static_assertions::assert_impl_all!(LlmError: Send, Sync, Clone);

/// Error category for better error handling and recovery strategies.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ErrorCategory {
    /// Network-related errors (connection, timeout, etc.)
    Network,
    /// Authentication and authorization errors
    Authentication,
    /// Rate limiting and quota errors
    RateLimit,
    /// Client-side errors (4xx HTTP status codes)
    Client,
    /// Server-side errors (5xx HTTP status codes)
    Server,
    /// Data parsing and serialization errors
    Parsing,
    /// Input validation errors
    Validation,
    /// Configuration errors
    Configuration,
    /// Unsupported operations or models
    Unsupported,
    /// Streaming-related errors
    Stream,
    /// Provider-specific errors
    Provider,
    /// Unknown or uncategorized errors
    Unknown,
}

/// The primary error type for the LLM library.
#[derive(Error, Debug, Clone)]
pub enum LlmError {
    /// HTTP request error
    #[error("HTTP request failed: {0}")]
    HttpError(String),

    /// JSON serialization/deserialization error
    #[error("JSON error: {0}")]
    JsonError(String),

    /// Response parsing error
    #[error("Parse error: {0}")]
    ParseError(String),

    /// Invalid input error
    #[error("Invalid input: {0}")]
    InvalidInput(String),

    /// IO error
    #[error("IO error: {0}")]
    IoError(String),

    /// Not found error
    #[error("Not found: {0}")]
    NotFound(String),

    /// Missing API key
    #[error("Missing API key: {0}")]
    MissingApiKey(String),

    /// Invalid parameter
    #[error("Invalid parameter: {0}")]
    InvalidParameter(String),

    /// API error response
    #[error("API error: {code} - {message}")]
    ApiError {
        code: u16,
        message: String,
        details: Option<serde_json::Value>,
    },

    /// Authentication error
    #[error("Authentication failed: {0}")]
    AuthenticationError(String),

    /// Rate limit error
    #[error("Rate limit exceeded: {0}")]
    RateLimitError(String),

    /// Quota exceeded error
    #[error("Quota exceeded: {0}")]
    QuotaExceededError(String),

    /// Model not supported error
    #[error("Model not supported: {0}")]
    ModelNotSupported(String),

    /// Stream processing error
    #[error("Stream error: {0}")]
    StreamError(String),

    /// Timeout error
    #[error("Request timeout: {0}")]
    TimeoutError(String),

    /// Network connection error
    #[error("Connection error: {0}")]
    ConnectionError(String),

    /// Provider-specific error
    #[error("Provider error ({provider}): {message}")]
    ProviderError {
        provider: String,
        message: String,
        error_code: Option<String>,
    },

    /// Configuration error
    #[error("Configuration error: {0}")]
    ConfigurationError(String),

    /// Internal error
    #[error("Internal error: {0}")]
    InternalError(String),

    /// Unsupported operation
    #[error("Unsupported operation: {0}")]
    UnsupportedOperation(String),

    /// Processing error
    #[error("Processing error: {0}")]
    ProcessingError(String),

    /// Tool call error
    #[error("Tool call error: {0}")]
    ToolCallError(String),

    /// Tool validation error
    #[error("Tool validation error: {0}")]
    ToolValidationError(String),

    /// Unsupported tool type
    #[error("Unsupported tool type: {0}")]
    UnsupportedToolType(String),

    /// Context-aware error with additional metadata
    #[error("Error in {context}: {message}")]
    ContextualError {
        context: String,
        message: String,
        source_error: Option<Box<LlmError>>,
        metadata: std::collections::HashMap<String, String>,
    },

    /// Other errors
    #[error("Other error: {0}")]
    Other(String),
}

impl LlmError {
    /// Creates a new API error.
    pub fn api_error(code: u16, message: impl Into<String>) -> Self {
        Self::ApiError {
            code,
            message: message.into(),
            details: None,
        }
    }

    /// Creates a new API error with details.
    pub fn api_error_with_details(
        code: u16,
        message: impl Into<String>,
        details: serde_json::Value,
    ) -> Self {
        Self::ApiError {
            code,
            message: message.into(),
            details: Some(details),
        }
    }

    /// Creates a new provider error.
    pub fn provider_error(provider: impl Into<String>, message: impl Into<String>) -> Self {
        Self::ProviderError {
            provider: provider.into(),
            message: message.into(),
            error_code: None,
        }
    }

    /// Creates a new provider error with an error code.
    pub fn provider_error_with_code(
        provider: impl Into<String>,
        message: impl Into<String>,
        error_code: impl Into<String>,
    ) -> Self {
        Self::ProviderError {
            provider: provider.into(),
            message: message.into(),
            error_code: Some(error_code.into()),
        }
    }

    /// Creates a new contextual error with metadata.
    pub fn contextual_error(context: impl Into<String>, message: impl Into<String>) -> Self {
        Self::ContextualError {
            context: context.into(),
            message: message.into(),
            source_error: None,
            metadata: std::collections::HashMap::new(),
        }
    }

    /// Creates a contextual error with a source error.
    pub fn contextual_error_with_source(
        context: impl Into<String>,
        message: impl Into<String>,
        source: LlmError,
    ) -> Self {
        Self::ContextualError {
            context: context.into(),
            message: message.into(),
            source_error: Some(Box::new(source)),
            metadata: std::collections::HashMap::new(),
        }
    }

    /// Creates a contextual error with metadata.
    pub fn contextual_error_with_metadata(
        context: impl Into<String>,
        message: impl Into<String>,
        metadata: std::collections::HashMap<String, String>,
    ) -> Self {
        Self::ContextualError {
            context: context.into(),
            message: message.into(),
            source_error: None,
            metadata,
        }
    }

    /// Checks if the error is retryable with more sophisticated logic.
    pub fn is_retryable(&self) -> bool {
        match self {
            Self::HttpError(e) => {
                // More comprehensive check for retryable HTTP errors
                let retryable_keywords = [
                    "timeout",
                    "connect",
                    "network",
                    "dns",
                    "socket",
                    "connection reset",
                    "connection refused",
                    "temporary failure",
                ];
                retryable_keywords
                    .iter()
                    .any(|keyword| e.to_lowercase().contains(keyword))
            }
            Self::ApiError { code, .. } => {
                // More nuanced retry logic based on HTTP status codes
                match *code {
                    // 4xx errors that are retryable
                    408 | 429 => true, // Request Timeout, Too Many Requests
                    // 5xx errors are generally retryable
                    500..=599 => true,
                    // Other 4xx errors are not retryable
                    400..=499 => false,
                    // Unexpected codes, be conservative
                    _ => false,
                }
            }
            Self::RateLimitError(_) => true,
            Self::TimeoutError(_) => true,
            Self::ConnectionError(_) => true,
            Self::QuotaExceededError(_) => false, // Don't retry quota errors
            Self::AuthenticationError(_) => false, // Don't retry auth errors
            Self::ModelNotSupported(_) => false,  // Don't retry unsupported models
            Self::InvalidInput(_) | Self::InvalidParameter(_) => false, // Don't retry validation errors
            Self::ContextualError {
                source_error: Some(source),
                ..
            } => source.is_retryable(),
            Self::ContextualError { .. } => false, // Conservative approach for unknown contextual errors
            _ => false,
        }
    }

    /// Checks if the error is an authentication-related error.
    pub const fn is_auth_error(&self) -> bool {
        match self {
            Self::AuthenticationError(_) => true,
            Self::ApiError { code, .. } => *code == 401 || *code == 403,
            _ => false,
        }
    }

    /// Checks if the error is a rate limit error.
    pub const fn is_rate_limit_error(&self) -> bool {
        match self {
            Self::RateLimitError(_) => true,
            Self::ApiError { code, .. } => *code == 429,
            _ => false,
        }
    }

    /// Gets the HTTP status code of the error, if available.
    pub const fn status_code(&self) -> Option<u16> {
        match self {
            Self::ApiError { code, .. } => Some(*code),
            Self::HttpError(_) => None, // Cannot get status code directly from the string form.
            _ => None,
        }
    }

    /// Gets the error category for better error handling.
    pub fn category(&self) -> ErrorCategory {
        match self {
            Self::HttpError(_) | Self::ConnectionError(_) | Self::TimeoutError(_) => {
                ErrorCategory::Network
            }
            Self::AuthenticationError(_) | Self::MissingApiKey(_) => ErrorCategory::Authentication,
            Self::RateLimitError(_) | Self::QuotaExceededError(_) => ErrorCategory::RateLimit,
            Self::ApiError { code, .. } => match *code {
                429 => ErrorCategory::RateLimit,
                400..=499 => ErrorCategory::Client,
                500..=599 => ErrorCategory::Server,
                _ => ErrorCategory::Unknown,
            },
            Self::JsonError(_) | Self::ParseError(_) => ErrorCategory::Parsing,
            Self::InvalidInput(_) | Self::InvalidParameter(_) | Self::ToolValidationError(_) => {
                ErrorCategory::Validation
            }
            Self::ConfigurationError(_) => ErrorCategory::Configuration,
            Self::ModelNotSupported(_)
            | Self::UnsupportedOperation(_)
            | Self::UnsupportedToolType(_) => ErrorCategory::Unsupported,
            Self::StreamError(_) => ErrorCategory::Stream,
            Self::ProviderError { .. } | Self::ToolCallError(_) => ErrorCategory::Provider,
            Self::ContextualError {
                source_error: Some(source),
                ..
            } => source.category(),
            Self::ContextualError { .. } => ErrorCategory::Unknown,
            _ => ErrorCategory::Unknown,
        }
    }

    /// Gets a user-friendly error message.
    pub fn user_message(&self) -> String {
        match self {
            Self::AuthenticationError(_) | Self::MissingApiKey(_) => {
                "Authentication failed. Please check your API key.".to_string()
            }
            Self::RateLimitError(_) => {
                "Rate limit exceeded. Please wait before making more requests.".to_string()
            }
            Self::QuotaExceededError(_) => {
                "API quota exceeded. Please check your usage limits.".to_string()
            }
            Self::ModelNotSupported(model) => {
                format!("The model '{model}' is not supported by this provider.")
            }
            Self::ConnectionError(_) | Self::TimeoutError(_) => {
                "Network connection failed. Please check your internet connection and try again."
                    .to_string()
            }
            Self::ApiError {
                code: 500..=599, ..
            } => "The service is temporarily unavailable. Please try again later.".to_string(),
            _ => self.to_string(),
        }
    }

    /// Gets suggested recovery actions for the error with more detailed guidance.
    pub fn recovery_suggestions(&self) -> Vec<String> {
        match self {
            Self::AuthenticationError(_) | Self::MissingApiKey(_) => {
                vec![
                    "Verify your API key is correct and properly formatted".to_string(),
                    "Check if your API key has the required permissions for this operation"
                        .to_string(),
                    "Ensure your API key is not expired or revoked".to_string(),
                    "Verify you're using the correct API endpoint".to_string(),
                ]
            }
            Self::RateLimitError(_) => {
                vec![
                    "Implement exponential backoff with jitter".to_string(),
                    "Reduce request frequency".to_string(),
                    "Consider upgrading your API plan for higher limits".to_string(),
                    "Use request batching if supported".to_string(),
                ]
            }
            Self::QuotaExceededError(_) => {
                vec![
                    "Check your usage dashboard for current consumption".to_string(),
                    "Upgrade your API plan for higher quotas".to_string(),
                    "Wait for quota reset (usually monthly)".to_string(),
                    "Optimize your requests to use fewer tokens".to_string(),
                ]
            }
            Self::ConnectionError(_) | Self::TimeoutError(_) => {
                vec![
                    "Check your internet connection stability".to_string(),
                    "Retry the request with exponential backoff".to_string(),
                    "Increase timeout settings for large requests".to_string(),
                    "Check if the service is experiencing outages".to_string(),
                ]
            }
            Self::ModelNotSupported(_) => {
                vec![
                    "Use a supported model from the provider's model list".to_string(),
                    "Check the provider's documentation for available models".to_string(),
                    "Verify the model name is spelled correctly".to_string(),
                ]
            }
            Self::ApiError { code: 400, .. } => {
                vec![
                    "Check your request parameters for validity".to_string(),
                    "Ensure all required fields are provided".to_string(),
                    "Verify parameter types and formats".to_string(),
                ]
            }
            Self::ApiError {
                code: 500..=599, ..
            } => {
                vec![
                    "Retry the request after a delay (server error)".to_string(),
                    "Check the service status page for outages".to_string(),
                    "Contact support if the issue persists".to_string(),
                ]
            }
            Self::StreamError(_) => {
                vec![
                    "Retry the streaming request".to_string(),
                    "Check network stability for streaming".to_string(),
                    "Consider using non-streaming mode as fallback".to_string(),
                ]
            }
            Self::InvalidInput(_) | Self::InvalidParameter(_) => {
                vec![
                    "Validate your input parameters".to_string(),
                    "Check parameter constraints and limits".to_string(),
                    "Refer to the API documentation for valid formats".to_string(),
                ]
            }
            _ => vec![
                "Check the error details and documentation".to_string(),
                "Contact support if the issue persists".to_string(),
            ],
        }
    }

    /// Gets the recommended retry delay in seconds based on error type.
    pub const fn recommended_retry_delay(&self) -> Option<u64> {
        match self {
            Self::RateLimitError(_) => Some(60), // Wait 1 minute for rate limits
            Self::ApiError { code: 429, .. } => Some(30), // Wait 30 seconds for 429
            Self::ApiError {
                code: 500..=599, ..
            } => Some(5), // Wait 5 seconds for server errors
            Self::TimeoutError(_) => Some(10),   // Wait 10 seconds for timeouts
            Self::ConnectionError(_) => Some(5), // Wait 5 seconds for connection errors
            _ => None,                           // No recommended delay for non-retryable errors
        }
    }

    /// Gets the maximum number of retry attempts recommended for this error.
    pub const fn max_retry_attempts(&self) -> u32 {
        match self {
            Self::RateLimitError(_) => 3,
            Self::ApiError { code: 429, .. } => 3,
            Self::ApiError {
                code: 500..=599, ..
            } => 5,
            Self::TimeoutError(_) => 3,
            Self::ConnectionError(_) => 3,
            _ => 0, // No retries for non-retryable errors
        }
    }
}

/// Result type alias.
pub type Result<T> = std::result::Result<T, LlmError>;

// From implementations
impl From<reqwest::Error> for LlmError {
    fn from(err: reqwest::Error) -> Self {
        Self::HttpError(err.to_string())
    }
}

impl From<serde_json::Error> for LlmError {
    fn from(err: serde_json::Error) -> Self {
        Self::JsonError(err.to_string())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_error_creation() {
        let error = LlmError::api_error(404, "Not found");
        assert!(matches!(error, LlmError::ApiError { code: 404, .. }));
    }

    #[test]
    fn test_retryable_errors() {
        let rate_limit = LlmError::RateLimitError("Too many requests".to_string());
        assert!(rate_limit.is_retryable());

        let server_error = LlmError::api_error(500, "Internal server error");
        assert!(server_error.is_retryable());

        let auth_error = LlmError::AuthenticationError("Invalid key".to_string());
        assert!(!auth_error.is_retryable());
    }

    #[test]
    fn test_auth_errors() {
        let auth_error = LlmError::AuthenticationError("Invalid key".to_string());
        assert!(auth_error.is_auth_error());

        let api_401 = LlmError::api_error(401, "Unauthorized");
        assert!(api_401.is_auth_error());
    }

    #[test]
    fn test_error_categories() {
        let auth_error = LlmError::AuthenticationError("Invalid key".to_string());
        assert_eq!(auth_error.category(), ErrorCategory::Authentication);

        let rate_limit = LlmError::RateLimitError("Too many requests".to_string());
        assert_eq!(rate_limit.category(), ErrorCategory::RateLimit);

        let server_error = LlmError::api_error(500, "Internal server error");
        assert_eq!(server_error.category(), ErrorCategory::Server);

        let client_error = LlmError::api_error(400, "Bad request");
        assert_eq!(client_error.category(), ErrorCategory::Client);

        let parse_error = LlmError::JsonError("Invalid JSON".to_string());
        assert_eq!(parse_error.category(), ErrorCategory::Parsing);
    }

    #[test]
    fn test_user_messages() {
        let auth_error = LlmError::AuthenticationError("Invalid key".to_string());
        let user_msg = auth_error.user_message();
        assert!(user_msg.contains("Authentication failed"));

        let rate_limit = LlmError::RateLimitError("Too many requests".to_string());
        let user_msg = rate_limit.user_message();
        assert!(user_msg.contains("Rate limit exceeded"));
    }

    #[test]
    fn test_recovery_suggestions() {
        let auth_error = LlmError::AuthenticationError("Invalid key".to_string());
        let suggestions = auth_error.recovery_suggestions();
        assert!(!suggestions.is_empty());
        assert!(suggestions.iter().any(|s| s.contains("API key")));

        let rate_limit = LlmError::RateLimitError("Too many requests".to_string());
        let suggestions = rate_limit.recovery_suggestions();
        assert!(suggestions.iter().any(|s| s.contains("backoff")));
    }
}