Skip to main content

oxi_ai/
error.rs

1//! Error types for oxi-ai
2
3use thiserror::Error;
4
5/// Provider-specific errors
6#[derive(Error, Debug)]
7pub enum ProviderError {
8    /// API key is missing.
9    #[error("Missing API key")]
10    MissingApiKey,
11
12    /// Unknown provider.
13    #[error("Unknown provider: {0}")]
14    UnknownProvider(String),
15
16    /// Provider not yet implemented.
17    #[error("Provider not implemented: {0}")]
18    NotImplemented(String),
19
20    /// HTTP error (status code + message).
21    #[error("HTTP error {0}: {1}")]
22    HttpError(u16, String),
23
24    /// HTTP request failed.
25    #[error("Request failed: {0}")]
26    RequestFailed(#[from] reqwest::Error),
27
28    /// I/O error.
29    #[error("IO error: {0}")]
30    IoError(#[from] std::io::Error),
31
32    /// Invalid response from provider.
33    #[error("Invalid response: {0}")]
34    InvalidResponse(String),
35
36    /// Invalid API key format.
37    #[error("Invalid API key format")]
38    InvalidApiKey,
39
40    /// JSON parsing error.
41    #[error("JSON parse error: {0}")]
42    JsonParse(#[from] serde_json::Error),
43
44    /// Streaming error.
45    #[error("Stream error: {0}")]
46    StreamError(String),
47
48    /// Network error.
49    #[error("Network error: {0}")]
50    NetworkError(String),
51
52    /// Context window overflow.
53    #[error("Context overflow")]
54    ContextOverflow,
55
56    /// Request timed out.
57    #[error("Request timed out")]
58    Timeout,
59
60    /// Rate limit exceeded.
61    #[error("Rate limited")]
62    RateLimited {
63        /// Wait time suggested by the server.
64        retry_after: Option<std::time::Duration>,
65    },
66}
67
68impl ProviderError {
69    /// Returns whether this error is retryable.
70    pub fn is_retryable(&self) -> bool {
71        match self {
72            Self::HttpError(status, _) => *status == 429 || *status >= 500,
73            Self::NetworkError(_) => true,
74            Self::Timeout => true,
75            Self::RateLimited { .. } => true,
76            _ => false,
77        }
78    }
79
80    /// Returns the retry wait time suggested by the server.
81    pub fn retry_after(&self) -> Option<std::time::Duration> {
82        match self {
83            Self::RateLimited { retry_after } => *retry_after,
84            Self::HttpError(429, _) => Some(std::time::Duration::from_secs(5)),
85            _ => None,
86        }
87    }
88}
89
90/// Validation errors
91#[derive(Error, Debug)]
92pub enum ValidationError {
93    #[error("Invalid JSON: {0}")]
94    InvalidJson(#[from] serde_json::Error),
95
96    #[error("Schema validation failed: {0}")]
97    SchemaValidation(String),
98
99    #[error("Missing required field: {0}")]
100    MissingRequiredField(String),
101}
102
103/// Unified error type for oxi-ai
104#[derive(Error, Debug)]
105pub enum Error {
106    /// Wraps a provider error.
107    #[error("Provider error: {0}")]
108    Provider(#[from] ProviderError),
109
110    /// Wraps a validation error.
111    #[error("Validation error: {0}")]
112    Validation(#[from] ValidationError),
113
114    /// Wraps an I/O error.
115    #[error("IO error: {0}")]
116    Io(#[from] std::io::Error),
117}
118
119/// Result type alias
120pub type Result<T> = std::result::Result<T, Error>;
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn provider_error_display() {
128        assert_eq!(ProviderError::MissingApiKey.to_string(), "Missing API key");
129        assert_eq!(
130            ProviderError::UnknownProvider("foo".to_string()).to_string(),
131            "Unknown provider: foo"
132        );
133        assert_eq!(
134            ProviderError::HttpError(429, "rate limited".to_string()).to_string(),
135            "HTTP error 429: rate limited"
136        );
137        assert_eq!(
138            ProviderError::InvalidResponse("bad json".to_string()).to_string(),
139            "Invalid response: bad json"
140        );
141        assert_eq!(
142            ProviderError::StreamError("disconnected".to_string()).to_string(),
143            "Stream error: disconnected"
144        );
145        assert_eq!(
146            ProviderError::NotImplemented("x".to_string()).to_string(),
147            "Provider not implemented: x"
148        );
149    }
150
151    #[test]
152    fn error_chain_from_provider_error() {
153        let inner = ProviderError::MissingApiKey;
154        let outer: Error = inner.into();
155        assert!(matches!(
156            outer,
157            Error::Provider(ProviderError::MissingApiKey)
158        ));
159        assert!(outer.to_string().contains("Missing API key"));
160    }
161
162    #[test]
163    fn validation_error_display() {
164        let err = ValidationError::MissingRequiredField("model".to_string());
165        assert_eq!(err.to_string(), "Missing required field: model");
166    }
167
168    #[test]
169    fn error_chain_from_io() {
170        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file missing");
171        let outer: Error = io_err.into();
172        assert!(matches!(outer, Error::Io(_)));
173    }
174}