1use std::error::Error;
2use std::future::Future;
3
4use tracing::{debug, warn};
5
6use crate::model_config::{ModelConfigResolver, ResolvedModelConfig};
7
8use crate::error::AgentError;
9use crate::model::ModelError;
10
11pub fn classify_error_kind(error: &(dyn Error + 'static)) -> Option<&'static str> {
12 if let Some(agent_error) = error.downcast_ref::<AgentError>() {
13 return classify_agent_error(agent_error);
14 }
15 if let Some(model_error) = error.downcast_ref::<ModelError>() {
16 return classify_model_error(model_error);
17 }
18 None
19}
20
21fn classify_agent_error(error: &AgentError) -> Option<&'static str> {
22 match error {
23 AgentError::Model(model_error) => classify_model_error(model_error),
24 _ => None,
25 }
26}
27
28fn classify_model_error(error: &ModelError) -> Option<&'static str> {
29 match error {
30 ModelError::Timeout => Some("timeout"),
31 ModelError::Transport(_) => Some("connect_error"),
32 ModelError::HttpStatus { status } => match *status {
33 401 => Some("http_401"),
34 403 => Some("http_403"),
35 429 => Some("http_429"),
36 status if status >= 500 => Some("http_5xx"),
37 _ => None,
38 },
39 ModelError::Provider(_) | ModelError::Serialization(_) => Some("model_error"),
40 ModelError::Unsupported(_) => None,
41 }
42}
43
44#[derive(Clone, Debug, PartialEq)]
45pub struct FailoverResult<T> {
46 pub value: T,
47 pub model_used: String,
48 pub failed_over: bool,
49 pub primary_attempts: u32,
50}
51
52pub async fn run_with_failover<T, E, F, Fut>(
53 resolver: &dyn ModelConfigResolver,
54 agent_name: &str,
55 requested_model: Option<&str>,
56 environment: Option<&str>,
57 invoke: F,
58) -> Result<FailoverResult<T>, E>
59where
60 E: Error + Send + Sync + 'static,
61 F: FnMut(&str) -> Fut,
62 Fut: Future<Output = Result<T, E>>,
63{
64 run_with_failover_with_classifier(
65 resolver,
66 agent_name,
67 requested_model,
68 environment,
69 invoke,
70 |error| classify_error_kind(error),
71 )
72 .await
73}
74
75pub async fn run_with_failover_with_classifier<T, E, F, Fut, C>(
76 resolver: &dyn ModelConfigResolver,
77 agent_name: &str,
78 requested_model: Option<&str>,
79 environment: Option<&str>,
80 invoke: F,
81 classifier: C,
82) -> Result<FailoverResult<T>, E>
83where
84 E: Error + Send + Sync + 'static,
85 F: FnMut(&str) -> Fut,
86 Fut: Future<Output = Result<T, E>>,
87 C: Fn(&E) -> Option<&'static str>,
88{
89 let config = resolver.resolve_model_config(agent_name, requested_model, environment);
90 run_with_config_and_classifier(config, invoke, classifier).await
91}
92
93pub async fn run_with_utility_failover<T, E, F, Fut>(
94 resolver: &dyn ModelConfigResolver,
95 utility_name: &str,
96 environment: Option<&str>,
97 invoke: F,
98) -> Result<FailoverResult<T>, E>
99where
100 E: Error + Send + Sync + 'static,
101 F: FnMut(&str) -> Fut,
102 Fut: Future<Output = Result<T, E>>,
103{
104 run_with_utility_failover_with_classifier(
105 resolver,
106 utility_name,
107 environment,
108 invoke,
109 |error| classify_error_kind(error),
110 )
111 .await
112}
113
114pub async fn run_with_utility_failover_with_classifier<T, E, F, Fut, C>(
115 resolver: &dyn ModelConfigResolver,
116 utility_name: &str,
117 environment: Option<&str>,
118 invoke: F,
119 classifier: C,
120) -> Result<FailoverResult<T>, E>
121where
122 E: Error + Send + Sync + 'static,
123 F: FnMut(&str) -> Fut,
124 Fut: Future<Output = Result<T, E>>,
125 C: Fn(&E) -> Option<&'static str>,
126{
127 let config = resolver.resolve_utility_config(utility_name, environment);
128 run_with_config_and_classifier(config, invoke, classifier).await
129}
130
131pub async fn run_with_config<T, E, F, Fut>(
132 config: ResolvedModelConfig,
133 invoke: F,
134) -> Result<FailoverResult<T>, E>
135where
136 E: Error + Send + Sync + 'static,
137 F: FnMut(&str) -> Fut,
138 Fut: Future<Output = Result<T, E>>,
139{
140 run_with_config_and_classifier(config, invoke, |error| classify_error_kind(error)).await
141}
142
143pub async fn run_with_config_and_classifier<T, E, F, Fut, C>(
144 config: ResolvedModelConfig,
145 mut invoke: F,
146 classifier: C,
147) -> Result<FailoverResult<T>, E>
148where
149 E: Error + Send + Sync + 'static,
150 F: FnMut(&str) -> Fut,
151 Fut: Future<Output = Result<T, E>>,
152 C: Fn(&E) -> Option<&'static str>,
153{
154 let mut last_kind = None;
155 let mut last_error = None;
156
157 for attempt in 0..=config.retry_limit {
158 match invoke(&config.primary).await {
159 Ok(value) => {
160 return Ok(FailoverResult {
161 value,
162 model_used: config.primary.clone(),
163 failed_over: false,
164 primary_attempts: attempt + 1,
165 });
166 }
167 Err(error) => {
168 let kind = classifier(&error);
169 last_kind = kind;
170 if !kind.is_some_and(|kind| config.failover_on.contains(kind)) {
171 debug!(
172 model = config.primary.as_str(),
173 attempt = attempt + 1,
174 error_kind = kind.unwrap_or(""),
175 "primary request failed without failover"
176 );
177 return Err(error);
178 }
179 last_error = Some(error);
180 if attempt < config.retry_limit {
181 debug!(
182 model = config.primary.as_str(),
183 attempt = attempt + 1,
184 error_kind = kind.unwrap_or(""),
185 "primary request failed, retrying"
186 );
187 continue;
188 }
189 break;
190 }
191 }
192 }
193
194 let should_failover =
195 config.backup.is_some() && last_kind.is_some_and(|kind| config.failover_on.contains(kind));
196 if !should_failover && let Some(error) = last_error {
197 warn!(
198 model = config.primary.as_str(),
199 error_kind = last_kind.unwrap_or(""),
200 "primary request failed and no failover configured"
201 );
202 return Err(error);
203 }
204
205 let backup = config.backup.clone().unwrap_or_default();
206 warn!(
207 primary = config.primary.as_str(),
208 backup = backup.as_str(),
209 error_kind = last_kind.unwrap_or(""),
210 "failing over to backup model"
211 );
212 let result = invoke(&backup).await?;
213 Ok(FailoverResult {
214 value: result,
215 model_used: backup,
216 failed_over: true,
217 primary_attempts: config.retry_limit + 1,
218 })
219}