use dashmap::DashMap;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Semaphore;
use crate::api::model_profile::{self, CostTier};
use crate::config::Config;
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct ModelKey {
pub provider: String,
pub model: String,
}
impl ModelKey {
pub fn new(provider: &str, model: &str) -> Self {
Self {
provider: provider.to_string(),
model: model.to_string(),
}
}
pub fn parse(s: &str) -> Option<Self> {
let slash = s.find('/')?;
let provider = s[..slash].trim();
let model = s[slash + 1..].trim();
if provider.is_empty() || model.is_empty() {
return None;
}
Some(Self::new(provider, model))
}
pub fn as_str(&self) -> String {
format!("{}/{}", self.provider, self.model)
}
}
pub enum AcquireResult {
Acquired {
key: ModelKey,
permit: tokio::sync::OwnedSemaphorePermit,
},
QueueWait { key: ModelKey },
}
pub struct ModelRateLimiter {
global_semaphore: Arc<Semaphore>,
model_semaphores: DashMap<ModelKey, Arc<Semaphore>>,
configured_limits: HashMap<String, usize>,
}
impl ModelRateLimiter {
pub fn new(config: &Config) -> Self {
let global_max = config.collaboration.max_agents.max(1);
let configured_limits: HashMap<String, usize> = config
.model_overrides
.iter()
.filter_map(|m| {
m.concurrency_limit
.map(|limit| (m.name.clone(), limit as usize))
})
.collect();
Self {
global_semaphore: Arc::new(Semaphore::new(global_max)),
model_semaphores: DashMap::new(),
configured_limits,
}
}
fn get_semaphore(&self, key: &ModelKey) -> Arc<Semaphore> {
self.model_semaphores
.entry(key.clone())
.or_insert_with(|| {
let limit = self.resolve_limit(key);
Arc::new(Semaphore::new(limit))
})
.clone()
}
fn resolve_limit(&self, key: &ModelKey) -> usize {
if let Some(&limit) = self.configured_limits.get(&key.model) {
return limit;
}
let composite = key.as_str();
if let Some(&limit) = self.configured_limits.get(&composite) {
return limit;
}
default_concurrency(&key.model)
}
pub async fn acquire_with_fallback(
&self,
primary: &ModelKey,
fallback_chain: &[ModelKey],
config: &Config,
) -> AcquireResult {
let _global = self
.global_semaphore
.clone()
.acquire_owned()
.await
.expect("global semaphore closed");
let sem = self.get_semaphore(primary);
if let Ok(permit) = sem.clone().try_acquire_owned() {
return AcquireResult::Acquired {
key: primary.clone(),
permit,
};
}
tracing::debug!(
model = %primary.as_str(),
"Primary model at capacity, trying fallback chain"
);
for candidate in fallback_chain {
if candidate == primary {
continue;
}
let sem = self.get_semaphore(candidate);
if let Ok(permit) = sem.clone().try_acquire_owned() {
tracing::info!(
primary = %primary.as_str(),
fallback = %candidate.as_str(),
"Fell back to provider chain candidate"
);
return AcquireResult::Acquired {
key: candidate.clone(),
permit,
};
}
}
if let Some((key, permit)) = self.try_capability_match(primary, config) {
tracing::info!(
primary = %primary.as_str(),
fallback = %key.as_str(),
"Fell back via capability matching"
);
return AcquireResult::Acquired { key, permit };
}
if let Some((key, permit)) = self.try_cli_fallback(primary, config) {
tracing::info!(
primary = %primary.as_str(),
cli = %key.as_str(),
"Fell back to CLI provider"
);
return AcquireResult::Acquired { key, permit };
}
tracing::warn!(
model = %primary.as_str(),
"All fallbacks exhausted, queuing for primary model"
);
AcquireResult::QueueWait {
key: primary.clone(),
}
}
pub async fn acquire_wait(&self, key: &ModelKey) -> tokio::sync::OwnedSemaphorePermit {
let sem = self.get_semaphore(key);
sem.acquire_owned().await.expect("model semaphore closed")
}
fn try_capability_match(
&self,
primary: &ModelKey,
config: &Config,
) -> Option<(ModelKey, tokio::sync::OwnedSemaphorePermit)> {
let primary_profile = model_profile::profile_for(&primary.model);
let primary_tier = primary_profile.cost_tier;
let primary_tools = primary_profile.supports_tool_use;
for provider in &config.providers {
if provider.is_cli() {
continue; }
for model_name in &provider.models {
let candidate_key = ModelKey::new(&provider.name, model_name);
if candidate_key == *primary {
continue;
}
let profile = model_profile::profile_for(model_name);
let tier_compatible = match (primary_tier, profile.cost_tier) {
(a, b) if a == b => true,
(CostTier::Premium, CostTier::Standard) => true,
(CostTier::Standard, CostTier::Cheap) => true,
_ => false,
};
if !tier_compatible {
continue;
}
if primary_tools && !profile.supports_tool_use {
continue;
}
let sem = self.get_semaphore(&candidate_key);
if let Ok(permit) = sem.clone().try_acquire_owned() {
return Some((candidate_key, permit));
}
}
}
None
}
fn try_cli_fallback(
&self,
primary: &ModelKey,
config: &Config,
) -> Option<(ModelKey, tokio::sync::OwnedSemaphorePermit)> {
let primary_profile = model_profile::profile_for(&primary.model);
for provider in &config.providers {
if !provider.is_cli() {
continue;
}
let cli_model = provider
.models
.first()
.cloned()
.unwrap_or_else(|| provider.name.clone());
let candidate_key = ModelKey::new(&provider.name, &cli_model);
let profile = model_profile::profile_for(&cli_model);
if primary_profile.supports_tool_use && !profile.supports_tool_use {
continue;
}
let sem = self.get_semaphore(&candidate_key);
if let Ok(permit) = sem.clone().try_acquire_owned() {
return Some((candidate_key, permit));
}
}
None
}
}
fn default_concurrency(model: &str) -> usize {
let profile = model_profile::profile_for(model);
match profile.cost_tier {
CostTier::Premium => 3,
CostTier::Standard => 5,
CostTier::Cheap => 8,
}
}
pub fn parse_providers_chain(value: &str) -> Vec<ModelKey> {
value
.split(',')
.filter_map(|s| ModelKey::parse(s.trim()))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_key_parse() {
let key = ModelKey::parse("zai-coding/glm-5").unwrap();
assert_eq!(key.provider, "zai-coding");
assert_eq!(key.model, "glm-5");
}
#[test]
fn test_model_key_parse_invalid() {
assert!(ModelKey::parse("no-slash").is_none());
assert!(ModelKey::parse("/no-provider").is_none());
assert!(ModelKey::parse("no-model/").is_none());
}
#[test]
fn test_parse_providers_chain() {
let chain = parse_providers_chain("zai-coding/glm-5, claude/opus, openai/gpt-4o");
assert_eq!(chain.len(), 3);
assert_eq!(chain[0].provider, "zai-coding");
assert_eq!(chain[0].model, "glm-5");
assert_eq!(chain[1].provider, "claude");
assert_eq!(chain[1].model, "opus");
assert_eq!(chain[2].provider, "openai");
assert_eq!(chain[2].model, "gpt-4o");
}
#[test]
fn test_default_concurrency() {
assert_eq!(default_concurrency("glm-5"), 3);
assert_eq!(default_concurrency("glm-4.7"), 5);
assert_eq!(default_concurrency("glm-4.7-flash"), 8);
}
}