1use crate::types::{ProviderRequest, ProviderResponse};
8use std::future::Future;
9use thiserror::Error;
10
11#[non_exhaustive]
13#[derive(Debug, Error)]
14pub enum ProviderError {
15 #[error("request failed: {0}")]
17 RequestFailed(String),
18
19 #[error("rate limited")]
21 RateLimited,
22
23 #[error("auth failed: {0}")]
25 AuthFailed(String),
26
27 #[error("invalid response: {0}")]
29 InvalidResponse(String),
30
31 #[error("{0}")]
33 Other(#[from] Box<dyn std::error::Error + Send + Sync>),
34}
35
36impl ProviderError {
37 pub fn is_retryable(&self) -> bool {
39 matches!(
40 self,
41 ProviderError::RateLimited | ProviderError::RequestFailed(_)
42 )
43 }
44}
45
46pub trait Provider: Send + Sync {
56 fn complete(
58 &self,
59 request: ProviderRequest,
60 ) -> impl Future<Output = Result<ProviderResponse, ProviderError>> + Send;
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66
67 #[test]
68 fn provider_error_display() {
69 assert_eq!(
70 ProviderError::RequestFailed("timeout".into()).to_string(),
71 "request failed: timeout"
72 );
73 assert_eq!(ProviderError::RateLimited.to_string(), "rate limited");
74 assert_eq!(
75 ProviderError::AuthFailed("bad key".into()).to_string(),
76 "auth failed: bad key"
77 );
78 assert_eq!(
79 ProviderError::InvalidResponse("bad json".into()).to_string(),
80 "invalid response: bad json"
81 );
82 }
83
84 #[test]
85 fn provider_error_retryable() {
86 assert!(ProviderError::RateLimited.is_retryable());
87 assert!(ProviderError::RequestFailed("timeout".into()).is_retryable());
88 assert!(!ProviderError::AuthFailed("bad key".into()).is_retryable());
89 assert!(!ProviderError::InvalidResponse("x".into()).is_retryable());
90 }
91
92 #[test]
93 fn provider_error_from_boxed() {
94 let err: Box<dyn std::error::Error + Send + Sync> = "some error".into();
95 let provider_err = ProviderError::from(err);
96 assert!(matches!(provider_err, ProviderError::Other(_)));
97 assert!(!provider_err.is_retryable());
98 }
99
100 #[test]
101 fn provider_error_other_display() {
102 let err: Box<dyn std::error::Error + Send + Sync> = "custom error".into();
103 let provider_err = ProviderError::from(err);
104 assert_eq!(provider_err.to_string(), "custom error");
105 }
106}