use crate::routing::{CredentialRouter, RoutingStrategy};
use crate::versions::VersionStore;
use async_trait::async_trait;
use byokey_auth::AuthManager;
use byokey_config::KeyRoutingStrategy;
use byokey_types::{
ChatRequest, ProviderId, RateLimitStore,
traits::{ProviderExecutor, ProviderResponse, Result},
};
use rquest::Client;
use std::collections::HashMap;
use std::{sync::Arc, time::Duration};
const COOLDOWN_DURATION: Duration = Duration::from_secs(30);
pub struct RetryExecutor {
provider: ProviderId,
router: Arc<CredentialRouter>,
base_urls: HashMap<String, Option<String>>,
auth: Arc<AuthManager>,
http: Client,
models: Vec<String>,
ratelimit: Option<Arc<RateLimitStore>>,
versions: VersionStore,
}
impl RetryExecutor {
#[allow(clippy::too_many_arguments)]
pub fn new(
provider: ProviderId,
credentials: Vec<(String, Option<String>)>,
strategy: KeyRoutingStrategy,
auth: Arc<AuthManager>,
http: Client,
models: Vec<String>,
ratelimit: Option<Arc<RateLimitStore>>,
versions: VersionStore,
) -> Self {
let keys: Vec<String> = credentials.iter().map(|(k, _)| k.clone()).collect();
let base_urls: HashMap<String, Option<String>> = credentials.into_iter().collect();
let routing_strategy = match strategy {
KeyRoutingStrategy::RoundRobin => RoutingStrategy::RoundRobin,
KeyRoutingStrategy::Priority => RoutingStrategy::FillFirst,
};
Self {
provider,
router: Arc::new(
CredentialRouter::new(keys, COOLDOWN_DURATION).with_strategy(routing_strategy),
),
base_urls,
auth,
http,
models,
ratelimit,
versions,
}
}
}
#[async_trait]
impl ProviderExecutor for RetryExecutor {
async fn chat_completion(&self, request: ChatRequest) -> Result<ProviderResponse> {
let max_attempts = self
.router
.max_retry()
.unwrap_or_else(|| self.router.len().min(3));
let mut last_err = None;
for _ in 0..max_attempts {
let key = match self.router.next_key() {
Some(k) => k.to_string(),
None => break, };
let base_url = self.base_urls.get(&key).cloned().flatten();
let executor = crate::factory::make_executor(
&self.provider,
Some(key.clone()),
base_url,
Arc::clone(&self.auth),
self.http.clone(),
self.ratelimit.clone(),
&self.versions,
);
let Some(executor) = executor else {
break;
};
match executor.chat_completion(request.clone()).await {
Ok(resp) => return Ok(resp),
Err(e) if e.is_retryable() => {
tracing::warn!(
provider = %self.provider,
error = %e,
"retryable error, rotating key"
);
if let Some(delay) = e.retry_after() {
self.router.mark_error_with_delay(&key, delay);
} else {
self.router.mark_error(&key);
}
last_err = Some(e);
}
Err(e) => return Err(e),
}
}
Err(last_err.unwrap_or_else(|| {
byokey_types::ByokError::Http(format!(
"{}: all API keys exhausted or in cooldown",
self.provider
))
}))
}
fn supported_models(&self) -> Vec<String> {
self.models.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use byokey_store::InMemoryTokenStore;
fn make_auth() -> Arc<AuthManager> {
Arc::new(AuthManager::new(
Arc::new(InMemoryTokenStore::new()),
rquest::Client::new(),
))
}
#[test]
fn test_retry_executor_models() {
let exec = RetryExecutor::new(
ProviderId::Claude,
vec![("key-1".into(), None)],
KeyRoutingStrategy::default(),
make_auth(),
Client::new(),
vec!["claude-opus-4-5".into()],
None,
VersionStore::empty(),
);
assert_eq!(exec.supported_models(), vec!["claude-opus-4-5"]);
}
#[test]
fn test_retry_executor_multiple_keys() {
let exec = RetryExecutor::new(
ProviderId::Claude,
vec![
("key-1".into(), None),
("key-2".into(), None),
("key-3".into(), None),
],
KeyRoutingStrategy::default(),
make_auth(),
Client::new(),
vec!["claude-opus-4-5".into()],
None,
VersionStore::empty(),
);
assert_eq!(exec.supported_models().len(), 1);
}
}