ai_lib/provider/strategies/
failover.rs1use async_trait::async_trait;
2use futures::stream::Stream;
3use tracing::warn;
4
5use crate::{
6 api::{ChatCompletionChunk, ChatProvider, ModelInfo},
7 types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse},
8};
9
10pub struct FailoverProvider {
11 name: String,
12 providers: Vec<Box<dyn ChatProvider>>,
13}
14
15impl FailoverProvider {
16 pub fn new(providers: Vec<Box<dyn ChatProvider>>) -> Result<Self, AiLibError> {
17 if providers.is_empty() {
18 return Err(AiLibError::ConfigurationError(
19 "failover strategy requires at least one provider".to_string(),
20 ));
21 }
22
23 let composed_name = providers
24 .iter()
25 .map(|p| p.name().to_string())
26 .collect::<Vec<_>>()
27 .join("->");
28
29 Ok(Self {
30 name: format!("failover[{composed_name}]"),
31 providers,
32 })
33 }
34
35 fn should_retry(error: &AiLibError) -> bool {
36 error.is_retryable() || matches!(error, AiLibError::TimeoutError(_))
37 }
38
39 fn log_fail_event(provider: &dyn ChatProvider, error: &AiLibError) {
40 warn!(
41 target = "ai_lib.failover",
42 provider = provider.name(),
43 error_code = %error.error_code_with_severity(),
44 "failover candidate returned an error"
45 );
46 }
47}
48
49#[async_trait]
50impl ChatProvider for FailoverProvider {
51 fn name(&self) -> &str {
52 &self.name
53 }
54
55 async fn chat(
56 &self,
57 request: ChatCompletionRequest,
58 ) -> Result<ChatCompletionResponse, AiLibError> {
59 let fallback_template = request.clone();
60 let mut providers_iter = self.providers.iter();
61
62 let first = providers_iter
63 .next()
64 .expect("validated during construction");
65
66 let mut last_error = match first.chat(request).await {
67 Ok(resp) => return Ok(resp),
68 Err(err) => {
69 if !Self::should_retry(&err) {
70 return Err(err);
71 }
72 Self::log_fail_event(first.as_ref(), &err);
73 err
74 }
75 };
76
77 for provider in providers_iter {
78 match provider.chat(fallback_template.clone()).await {
79 Ok(resp) => return Ok(resp),
80 Err(err) => {
81 if !Self::should_retry(&err) {
82 return Err(err);
83 }
84 Self::log_fail_event(provider.as_ref(), &err);
85 last_error = err;
86 }
87 }
88 }
89
90 Err(last_error)
91 }
92
93 async fn stream(
94 &self,
95 request: ChatCompletionRequest,
96 ) -> Result<
97 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
98 AiLibError,
99 > {
100 let fallback_template = request.clone();
101 let mut providers_iter = self.providers.iter();
102
103 let first = providers_iter
104 .next()
105 .expect("validated during construction");
106
107 let mut last_error = match first.stream(request).await {
108 Ok(resp) => return Ok(resp),
109 Err(err) => {
110 if !Self::should_retry(&err) {
111 return Err(err);
112 }
113 Self::log_fail_event(first.as_ref(), &err);
114 err
115 }
116 };
117
118 for provider in providers_iter {
119 match provider.stream(fallback_template.clone()).await {
120 Ok(resp) => return Ok(resp),
121 Err(err) => {
122 if !Self::should_retry(&err) {
123 return Err(err);
124 }
125 Self::log_fail_event(provider.as_ref(), &err);
126 last_error = err;
127 }
128 }
129 }
130
131 Err(last_error)
132 }
133
134 async fn batch(
135 &self,
136 requests: Vec<ChatCompletionRequest>,
137 concurrency_limit: Option<usize>,
138 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
139 let mut providers_iter = self.providers.iter();
140 let first = providers_iter
141 .next()
142 .expect("validated during construction");
143
144 let mut last_error = match first.batch(requests.clone(), concurrency_limit).await {
145 Ok(resp) => return Ok(resp),
146 Err(err) => {
147 if !Self::should_retry(&err) {
148 return Err(err);
149 }
150 Self::log_fail_event(first.as_ref(), &err);
151 err
152 }
153 };
154
155 for provider in providers_iter {
156 match provider.batch(requests.clone(), concurrency_limit).await {
157 Ok(resp) => return Ok(resp),
158 Err(err) => {
159 if !Self::should_retry(&err) {
160 return Err(err);
161 }
162 Self::log_fail_event(provider.as_ref(), &err);
163 last_error = err;
164 }
165 }
166 }
167
168 Err(last_error)
169 }
170
171 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
172 let mut last_error = None;
173 for provider in &self.providers {
174 match provider.list_models().await {
175 Ok(models) => return Ok(models),
176 Err(err) => {
177 if !Self::should_retry(&err) {
178 return Err(err);
179 }
180 Self::log_fail_event(provider.as_ref(), &err);
181 last_error = Some(err);
182 }
183 }
184 }
185
186 Err(last_error.unwrap_or_else(|| {
187 AiLibError::ConfigurationError(
188 "failover strategy could not contact any provider".to_string(),
189 )
190 }))
191 }
192
193 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
194 for provider in &self.providers {
195 match provider.get_model_info(model_id).await {
196 Ok(info) => return Ok(info),
197 Err(err) => {
198 if matches!(err, AiLibError::ModelNotFound(_)) {
199 continue;
200 }
201 return Err(err);
202 }
203 }
204 }
205
206 Err(AiLibError::ModelNotFound(format!(
207 "model {model_id} not available in failover chain"
208 )))
209 }
210}