Skip to main content

oxi_ai/
error.rs

1//! Error types for oxi-ai
2
3use thiserror::Error;
4
5/// Provider-specific errors
6#[derive(Error, Debug)]
7pub enum ProviderError {
8    /// API key is missing.
9    #[error("Missing API key")]
10    MissingApiKey,
11
12    /// Unknown provider.
13    #[error("Unknown provider: {0}")]
14    UnknownProvider(String),
15
16    /// Provider not yet implemented.
17    #[error("Provider not implemented: {0}")]
18    NotImplemented(String),
19
20    /// HTTP error (status code + message).
21    #[error("HTTP error {0}: {1}")]
22    HttpError(u16, String),
23
24    /// HTTP request failed.
25    #[error("Request failed: {0}")]
26    RequestFailed(#[from] reqwest::Error),
27
28    /// I/O error.
29    #[error("IO error: {0}")]
30    IoError(#[from] std::io::Error),
31
32    /// Invalid response from provider.
33    #[error("Invalid response: {0}")]
34    InvalidResponse(String),
35
36    /// Invalid API key format.
37    #[error("Invalid API key format")]
38    InvalidApiKey,
39
40    /// JSON parsing error.
41    #[error("JSON parse error: {0}")]
42    JsonParse(#[from] serde_json::Error),
43
44    /// Streaming error.
45    #[error("Stream error: {0}")]
46    StreamError(String),
47
48    /// Network error.
49    #[error("Network error: {0}")]
50    NetworkError(String),
51
52    /// Request timed out.
53    #[error("Request timed out")]
54    Timeout,
55
56    /// Rate limit exceeded.
57    #[error("Rate limited")]
58    RateLimited {
59        /// Wait time suggested by the server.
60        retry_after: Option<std::time::Duration>,
61    },
62}
63
64impl ProviderError {
65    /// Returns whether this error is retryable.
66    pub fn is_retryable(&self) -> bool {
67        match self {
68            Self::HttpError(status, _) => *status == 429 || *status >= 500,
69            Self::NetworkError(_) => true,
70            Self::Timeout => true,
71            Self::RateLimited { .. } => true,
72            _ => false,
73        }
74    }
75
76    /// Returns the retry wait time suggested by the server.
77    pub fn retry_after(&self) -> Option<std::time::Duration> {
78        match self {
79            Self::RateLimited { retry_after } => *retry_after,
80            Self::HttpError(429, _) => Some(std::time::Duration::from_secs(5)),
81            _ => None,
82        }
83    }
84}
85
86/// Validation errors
87#[derive(Error, Debug)]
88pub enum ValidationError {
89    #[error("Invalid JSON: {0}")]
90    InvalidJson(#[from] serde_json::Error),
91
92    #[error("Schema validation failed: {0}")]
93    SchemaValidation(String),
94
95    #[error("Missing required field: {0}")]
96    MissingRequiredField(String),
97}
98
99/// Unified error type for oxi-ai
100#[derive(Error, Debug)]
101pub enum Error {
102    /// Wraps a provider error.
103    #[error("Provider error: {0}")]
104    Provider(#[from] ProviderError),
105
106    /// Wraps a validation error.
107    #[error("Validation error: {0}")]
108    Validation(#[from] ValidationError),
109
110    /// Wraps an I/O error.
111    #[error("IO error: {0}")]
112    Io(#[from] std::io::Error),
113}
114
115/// Result type alias
116pub type Result<T> = std::result::Result<T, Error>;
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn provider_error_display() {
124        assert_eq!(ProviderError::MissingApiKey.to_string(), "Missing API key");
125        assert_eq!(
126            ProviderError::UnknownProvider("foo".to_string()).to_string(),
127            "Unknown provider: foo"
128        );
129        assert_eq!(
130            ProviderError::HttpError(429, "rate limited".to_string()).to_string(),
131            "HTTP error 429: rate limited"
132        );
133        assert_eq!(
134            ProviderError::InvalidResponse("bad json".to_string()).to_string(),
135            "Invalid response: bad json"
136        );
137        assert_eq!(
138            ProviderError::StreamError("disconnected".to_string()).to_string(),
139            "Stream error: disconnected"
140        );
141        assert_eq!(
142            ProviderError::NotImplemented("x".to_string()).to_string(),
143            "Provider not implemented: x"
144        );
145    }
146
147    #[test]
148    fn error_chain_from_provider_error() {
149        let inner = ProviderError::MissingApiKey;
150        let outer: Error = inner.into();
151        assert!(matches!(
152            outer,
153            Error::Provider(ProviderError::MissingApiKey)
154        ));
155        assert!(outer.to_string().contains("Missing API key"));
156    }
157
158    #[test]
159    fn validation_error_display() {
160        let err = ValidationError::MissingRequiredField("model".to_string());
161        assert_eq!(err.to_string(), "Missing required field: model");
162    }
163
164    #[test]
165    fn error_chain_from_io() {
166        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file missing");
167        let outer: Error = io_err.into();
168        assert!(matches!(outer, Error::Io(_)));
169    }
170}