ai_lib/provider/strategies/
failover.rs

1use async_trait::async_trait;
2use futures::stream::Stream;
3use tracing::warn;
4
5use crate::{
6    api::{ChatCompletionChunk, ChatProvider, ModelInfo},
7    types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse},
8};
9
10pub struct FailoverProvider {
11    name: String,
12    providers: Vec<Box<dyn ChatProvider>>,
13}
14
15impl FailoverProvider {
16    pub fn new(providers: Vec<Box<dyn ChatProvider>>) -> Result<Self, AiLibError> {
17        if providers.is_empty() {
18            return Err(AiLibError::ConfigurationError(
19                "failover strategy requires at least one provider".to_string(),
20            ));
21        }
22
23        let composed_name = providers
24            .iter()
25            .map(|p| p.name().to_string())
26            .collect::<Vec<_>>()
27            .join("->");
28
29        Ok(Self {
30            name: format!("failover[{composed_name}]"),
31            providers,
32        })
33    }
34
35    fn should_retry(error: &AiLibError) -> bool {
36        error.is_retryable() || matches!(error, AiLibError::TimeoutError(_))
37    }
38
39    fn log_fail_event(provider: &dyn ChatProvider, error: &AiLibError) {
40        warn!(
41            target = "ai_lib.failover",
42            provider = provider.name(),
43            error_code = %error.error_code_with_severity(),
44            "failover candidate returned an error"
45        );
46    }
47}
48
49#[async_trait]
50impl ChatProvider for FailoverProvider {
51    fn name(&self) -> &str {
52        &self.name
53    }
54
55    async fn chat(
56        &self,
57        request: ChatCompletionRequest,
58    ) -> Result<ChatCompletionResponse, AiLibError> {
59        let fallback_template = request.clone();
60        let mut providers_iter = self.providers.iter();
61
62        let first = providers_iter
63            .next()
64            .expect("validated during construction");
65
66        let mut last_error = match first.chat(request).await {
67            Ok(resp) => return Ok(resp),
68            Err(err) => {
69                if !Self::should_retry(&err) {
70                    return Err(err);
71                }
72                Self::log_fail_event(first.as_ref(), &err);
73                err
74            }
75        };
76
77        for provider in providers_iter {
78            match provider.chat(fallback_template.clone()).await {
79                Ok(resp) => return Ok(resp),
80                Err(err) => {
81                    if !Self::should_retry(&err) {
82                        return Err(err);
83                    }
84                    Self::log_fail_event(provider.as_ref(), &err);
85                    last_error = err;
86                }
87            }
88        }
89
90        Err(last_error)
91    }
92
93    async fn stream(
94        &self,
95        request: ChatCompletionRequest,
96    ) -> Result<
97        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
98        AiLibError,
99    > {
100        let fallback_template = request.clone();
101        let mut providers_iter = self.providers.iter();
102
103        let first = providers_iter
104            .next()
105            .expect("validated during construction");
106
107        let mut last_error = match first.stream(request).await {
108            Ok(resp) => return Ok(resp),
109            Err(err) => {
110                if !Self::should_retry(&err) {
111                    return Err(err);
112                }
113                Self::log_fail_event(first.as_ref(), &err);
114                err
115            }
116        };
117
118        for provider in providers_iter {
119            match provider.stream(fallback_template.clone()).await {
120                Ok(resp) => return Ok(resp),
121                Err(err) => {
122                    if !Self::should_retry(&err) {
123                        return Err(err);
124                    }
125                    Self::log_fail_event(provider.as_ref(), &err);
126                    last_error = err;
127                }
128            }
129        }
130
131        Err(last_error)
132    }
133
134    async fn batch(
135        &self,
136        requests: Vec<ChatCompletionRequest>,
137        concurrency_limit: Option<usize>,
138    ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
139        let mut providers_iter = self.providers.iter();
140        let first = providers_iter
141            .next()
142            .expect("validated during construction");
143
144        let mut last_error = match first.batch(requests.clone(), concurrency_limit).await {
145            Ok(resp) => return Ok(resp),
146            Err(err) => {
147                if !Self::should_retry(&err) {
148                    return Err(err);
149                }
150                Self::log_fail_event(first.as_ref(), &err);
151                err
152            }
153        };
154
155        for provider in providers_iter {
156            match provider.batch(requests.clone(), concurrency_limit).await {
157                Ok(resp) => return Ok(resp),
158                Err(err) => {
159                    if !Self::should_retry(&err) {
160                        return Err(err);
161                    }
162                    Self::log_fail_event(provider.as_ref(), &err);
163                    last_error = err;
164                }
165            }
166        }
167
168        Err(last_error)
169    }
170
171    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
172        let mut last_error = None;
173        for provider in &self.providers {
174            match provider.list_models().await {
175                Ok(models) => return Ok(models),
176                Err(err) => {
177                    if !Self::should_retry(&err) {
178                        return Err(err);
179                    }
180                    Self::log_fail_event(provider.as_ref(), &err);
181                    last_error = Some(err);
182                }
183            }
184        }
185
186        Err(last_error.unwrap_or_else(|| {
187            AiLibError::ConfigurationError(
188                "failover strategy could not contact any provider".to_string(),
189            )
190        }))
191    }
192
193    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
194        for provider in &self.providers {
195            match provider.get_model_info(model_id).await {
196                Ok(info) => return Ok(info),
197                Err(err) => {
198                    if matches!(err, AiLibError::ModelNotFound(_)) {
199                        continue;
200                    }
201                    return Err(err);
202                }
203            }
204        }
205
206        Err(AiLibError::ModelNotFound(format!(
207            "model {model_id} not available in failover chain"
208        )))
209    }
210}