use super::deployment::DeploymentId;
use super::error::{CooldownReason, RouterError};
use super::execution::{
build_execution_result, calculate_retry_delay, infer_cooldown_reason, is_retryable_error,
provider_error_to_router_error, router_error_to_provider_error,
};
use super::fallback::{ExecutionResult, FallbackType};
use super::unified::Router;
use crate::core::providers::unified_provider::ProviderError;
impl Router {
pub async fn execute_with_retry<T, F, Fut>(
&self,
model_name: &str,
operation: F,
) -> Result<(T, DeploymentId, u32, u64), (ProviderError, u32)>
where
F: Fn(DeploymentId) -> Fut + Clone,
Fut: std::future::Future<Output = Result<(T, u64), ProviderError>>,
{
let max_attempts = self.config.num_retries + 1;
let mut last_error = None;
for attempt in 1..=max_attempts {
let start = std::time::Instant::now();
let deployment_id = match self.select_deployment(model_name) {
Ok(id) => id,
Err(router_err) => {
let provider_err = router_error_to_provider_error(router_err);
if is_retryable_error(&provider_err) && attempt < max_attempts {
let delay = calculate_retry_delay(&self.config, attempt);
last_error = Some(provider_err);
tokio::time::sleep(delay).await;
continue;
} else {
return Err((provider_err, attempt));
}
}
};
let result = operation(deployment_id.clone()).await;
let latency_us = start.elapsed().as_micros() as u64;
match result {
Ok((value, tokens_used)) => {
self.release_deployment(&deployment_id);
self.record_success(&deployment_id, tokens_used, latency_us);
return Ok((value, deployment_id, attempt, latency_us));
}
Err(err) => {
self.release_deployment(&deployment_id);
if is_retryable_error(&err) && attempt < max_attempts {
self.record_failure_with_reason(
&deployment_id,
CooldownReason::ConsecutiveFailures,
);
let delay = calculate_retry_delay(&self.config, attempt);
last_error = Some(err);
tokio::time::sleep(delay).await;
continue;
} else {
let cooldown_reason = infer_cooldown_reason(&err);
self.record_failure_with_reason(&deployment_id, cooldown_reason);
return Err((err, attempt));
}
}
}
}
Err((
last_error.unwrap_or_else(|| ProviderError::Other {
provider: "router",
message: "Unknown error during retry".to_string(),
}),
max_attempts,
))
}
pub async fn execute<T, F, Fut>(
&self,
model_name: &str,
operation: F,
) -> Result<ExecutionResult<T>, RouterError>
where
F: Fn(DeploymentId) -> Fut + Clone,
Fut: std::future::Future<Output = Result<(T, u64), ProviderError>>,
{
let start = std::time::Instant::now();
let models_to_try = self.get_models_with_fallbacks(model_name, FallbackType::General);
let max_models = 1 + self.config.max_fallbacks as usize;
let mut seen = std::collections::HashSet::new();
let models_to_try: Vec<_> = models_to_try
.into_iter()
.filter(|m| seen.insert(m.clone()))
.take(max_models)
.collect();
let mut last_error = None;
let mut total_attempts = 0;
for (model_idx, model) in models_to_try.iter().enumerate() {
let is_fallback = model_idx > 0;
if is_fallback {
self.fallback_triggered_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tracing::info!(
original_model = %model_name,
fallback_model = %model,
fallback_index = model_idx,
error_type = %last_error.as_ref().map_or("unknown".to_string(), |e| format!("{e}")),
"fallback triggered, trying next model"
);
}
match self.execute_with_retry(model, operation.clone()).await {
Ok((result, deployment_id, attempts, _latency_us)) => {
total_attempts += attempts;
let total_latency_us = start.elapsed().as_micros() as u64;
let model_used = if let Some(deployment) = self.get_deployment(&deployment_id) {
deployment.model.clone()
} else {
model.clone()
};
return Ok(build_execution_result(
result,
deployment_id,
total_attempts,
model_used,
is_fallback,
total_latency_us,
));
}
Err((err, attempts)) => {
total_attempts += attempts;
last_error = Some(err);
}
}
}
if let Some(err) = last_error {
Err(provider_error_to_router_error(err, model_name))
} else {
Err(RouterError::NoAvailableDeployment(model_name.to_string()))
}
}
pub async fn execute_once<T, F, Fut>(
&self,
model_name: &str,
operation: F,
) -> Result<ExecutionResult<T>, RouterError>
where
F: FnOnce(DeploymentId) -> Fut,
Fut: std::future::Future<Output = Result<(T, u64), ProviderError>>,
{
let start = std::time::Instant::now();
let deployment_id = self.select_deployment(model_name)?;
let result = operation(deployment_id.clone()).await;
let latency_us = start.elapsed().as_micros() as u64;
self.release_deployment(&deployment_id);
match result {
Ok((value, tokens_used)) => {
self.record_success(&deployment_id, tokens_used, latency_us);
let model_used = if let Some(deployment) = self.get_deployment(&deployment_id) {
deployment.model.clone()
} else {
model_name.to_string()
};
Ok(build_execution_result(
value,
deployment_id,
1,
model_used,
false,
latency_us,
))
}
Err(err) => {
let cooldown_reason = infer_cooldown_reason(&err);
self.record_failure_with_reason(&deployment_id, cooldown_reason);
Err(provider_error_to_router_error(err, model_name))
}
}
}
}