collet 0.1.0

Relentless agentic coding orchestrator with zero-drop agent loops
Documentation
//! Per-model concurrency limiting with autonomous fallback.
//!
//! Each `provider/model` pair gets an independent semaphore whose capacity is
//! derived from:
//!   1. User config: `[[models]] concurrency_limit = N`
//!   2. Built-in defaults (see `default_concurrency`)
//!
//! When a task cannot acquire a permit (model at capacity), the limiter tries
//! fallback candidates in order:
//!   1. Agent's `providers` chain (`provider/model` pairs from .md frontmatter)
//!   2. Capability-based automatic match (same cost tier + tool support)
//!   3. CLI providers (independent rate limit pool)
//!   4. Queue wait (block until a permit becomes available)

use dashmap::DashMap;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Semaphore;

use crate::api::model_profile::{self, CostTier};
use crate::config::Config;

/// A `provider/model` pair used as the concurrency key.
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct ModelKey {
    pub provider: String,
    pub model: String,
}

impl ModelKey {
    pub fn new(provider: &str, model: &str) -> Self {
        Self {
            provider: provider.to_string(),
            model: model.to_string(),
        }
    }

    /// Parse from `"provider/model"` string.
    pub fn parse(s: &str) -> Option<Self> {
        let slash = s.find('/')?;
        let provider = s[..slash].trim();
        let model = s[slash + 1..].trim();
        if provider.is_empty() || model.is_empty() {
            return None;
        }
        Some(Self::new(provider, model))
    }

    /// Format as `"provider/model"`.
    pub fn as_str(&self) -> String {
        format!("{}/{}", self.provider, self.model)
    }
}

/// Outcome of a permit acquisition attempt.
pub enum AcquireResult {
    /// Permit acquired for the requested (or fallback) model.
    Acquired {
        key: ModelKey,
        permit: tokio::sync::OwnedSemaphorePermit,
    },
    /// All candidates exhausted; caller should queue-wait on the primary key.
    QueueWait { key: ModelKey },
}

/// Per-model concurrency limiter with fallback support.
pub struct ModelRateLimiter {
    /// Global concurrency cap (from `collaboration.max_agents`).
    global_semaphore: Arc<Semaphore>,
    /// Per-model semaphores keyed by `provider/model`.
    model_semaphores: DashMap<ModelKey, Arc<Semaphore>>,
    /// Configured concurrency limits from `[[models]]`.
    configured_limits: HashMap<String, usize>,
}

impl ModelRateLimiter {
    /// Build a new limiter from the current config.
    pub fn new(config: &Config) -> Self {
        let global_max = config.collaboration.max_agents.max(1);

        // Collect user-configured concurrency limits from [[models]].
        let configured_limits: HashMap<String, usize> = config
            .model_overrides
            .iter()
            .filter_map(|m| {
                m.concurrency_limit
                    .map(|limit| (m.name.clone(), limit as usize))
            })
            .collect();

        Self {
            global_semaphore: Arc::new(Semaphore::new(global_max)),
            model_semaphores: DashMap::new(),
            configured_limits,
        }
    }

    /// Get or create the semaphore for a given model key.
    fn get_semaphore(&self, key: &ModelKey) -> Arc<Semaphore> {
        self.model_semaphores
            .entry(key.clone())
            .or_insert_with(|| {
                let limit = self.resolve_limit(key);
                Arc::new(Semaphore::new(limit))
            })
            .clone()
    }

    /// Resolve concurrency limit for a model key.
    ///
    /// Priority: config `[[models]]` name match → built-in default.
    fn resolve_limit(&self, key: &ModelKey) -> usize {
        // Check user config by model name
        if let Some(&limit) = self.configured_limits.get(&key.model) {
            return limit;
        }
        // Check by "provider/model" composite key
        let composite = key.as_str();
        if let Some(&limit) = self.configured_limits.get(&composite) {
            return limit;
        }
        // Built-in default based on cost tier
        default_concurrency(&key.model)
    }

    /// Try to acquire a permit, falling back through candidates if primary is full.
    ///
    /// `fallback_chain`: ordered `provider/model` pairs from the agent's `.md` `providers` field.
    /// `config`: used for capability-based matching and CLI discovery.
    pub async fn acquire_with_fallback(
        &self,
        primary: &ModelKey,
        fallback_chain: &[ModelKey],
        config: &Config,
    ) -> AcquireResult {
        // Always acquire global permit first (blocks if global limit reached).
        let _global = self
            .global_semaphore
            .clone()
            .acquire_owned()
            .await
            .expect("global semaphore closed");

        // 1. Try primary model
        let sem = self.get_semaphore(primary);
        if let Ok(permit) = sem.clone().try_acquire_owned() {
            return AcquireResult::Acquired {
                key: primary.clone(),
                permit,
            };
        }
        tracing::debug!(
            model = %primary.as_str(),
            "Primary model at capacity, trying fallback chain"
        );

        // 2. Try agent's explicit fallback chain
        for candidate in fallback_chain {
            if candidate == primary {
                continue;
            }
            let sem = self.get_semaphore(candidate);
            if let Ok(permit) = sem.clone().try_acquire_owned() {
                tracing::info!(
                    primary = %primary.as_str(),
                    fallback = %candidate.as_str(),
                    "Fell back to provider chain candidate"
                );
                return AcquireResult::Acquired {
                    key: candidate.clone(),
                    permit,
                };
            }
        }

        // 3. Capability-based automatic match
        if let Some((key, permit)) = self.try_capability_match(primary, config) {
            tracing::info!(
                primary = %primary.as_str(),
                fallback = %key.as_str(),
                "Fell back via capability matching"
            );
            return AcquireResult::Acquired { key, permit };
        }

        // 4. Try CLI providers as last resort
        if let Some((key, permit)) = self.try_cli_fallback(primary, config) {
            tracing::info!(
                primary = %primary.as_str(),
                cli = %key.as_str(),
                "Fell back to CLI provider"
            );
            return AcquireResult::Acquired { key, permit };
        }

        // 5. All fallbacks exhausted — caller should queue-wait
        tracing::warn!(
            model = %primary.as_str(),
            "All fallbacks exhausted, queuing for primary model"
        );
        AcquireResult::QueueWait {
            key: primary.clone(),
        }
    }

    /// Blocking acquire on a specific model (used when QueueWait is returned).
    pub async fn acquire_wait(&self, key: &ModelKey) -> tokio::sync::OwnedSemaphorePermit {
        let sem = self.get_semaphore(key);
        sem.acquire_owned().await.expect("model semaphore closed")
    }

    /// Try to find a compatible model via capability matching.
    fn try_capability_match(
        &self,
        primary: &ModelKey,
        config: &Config,
    ) -> Option<(ModelKey, tokio::sync::OwnedSemaphorePermit)> {
        let primary_profile = model_profile::profile_for(&primary.model);
        let primary_tier = primary_profile.cost_tier;
        let primary_tools = primary_profile.supports_tool_use;

        // Collect all available provider/model combinations from config
        for provider in &config.providers {
            if provider.is_cli() {
                continue; // CLI handled separately
            }
            for model_name in &provider.models {
                let candidate_key = ModelKey::new(&provider.name, model_name);
                if candidate_key == *primary {
                    continue;
                }

                let profile = model_profile::profile_for(model_name);

                // Match criteria: same cost tier + tool support compatibility
                let tier_compatible = match (primary_tier, profile.cost_tier) {
                    (a, b) if a == b => true,
                    // Allow one-tier downgrade (Premium→Standard, Standard→Cheap)
                    (CostTier::Premium, CostTier::Standard) => true,
                    (CostTier::Standard, CostTier::Cheap) => true,
                    _ => false,
                };

                if !tier_compatible {
                    continue;
                }
                if primary_tools && !profile.supports_tool_use {
                    continue;
                }

                let sem = self.get_semaphore(&candidate_key);
                if let Ok(permit) = sem.clone().try_acquire_owned() {
                    return Some((candidate_key, permit));
                }
            }
        }
        None
    }

    /// Try CLI providers as fallback.
    fn try_cli_fallback(
        &self,
        primary: &ModelKey,
        config: &Config,
    ) -> Option<(ModelKey, tokio::sync::OwnedSemaphorePermit)> {
        let primary_profile = model_profile::profile_for(&primary.model);

        for provider in &config.providers {
            if !provider.is_cli() {
                continue;
            }
            // CLI providers typically have a single implicit model
            let cli_model = provider
                .models
                .first()
                .cloned()
                .unwrap_or_else(|| provider.name.clone());
            let candidate_key = ModelKey::new(&provider.name, &cli_model);

            let profile = model_profile::profile_for(&cli_model);

            // CLI fallback is more lenient — just check tool support
            if primary_profile.supports_tool_use && !profile.supports_tool_use {
                continue;
            }

            let sem = self.get_semaphore(&candidate_key);
            if let Ok(permit) = sem.clone().try_acquire_owned() {
                return Some((candidate_key, permit));
            }
        }
        None
    }
}

/// Default concurrency limit based on model cost tier.
fn default_concurrency(model: &str) -> usize {
    let profile = model_profile::profile_for(model);
    match profile.cost_tier {
        CostTier::Premium => 3,
        CostTier::Standard => 5,
        CostTier::Cheap => 8,
    }
}

/// Parse a comma-separated `providers` value from agent .md frontmatter
/// into a list of `ModelKey` pairs.
///
/// Input format: `"zai-coding/glm-5,claude/opus,openai/gpt-4o"`
pub fn parse_providers_chain(value: &str) -> Vec<ModelKey> {
    value
        .split(',')
        .filter_map(|s| ModelKey::parse(s.trim()))
        .collect()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_model_key_parse() {
        let key = ModelKey::parse("zai-coding/glm-5").unwrap();
        assert_eq!(key.provider, "zai-coding");
        assert_eq!(key.model, "glm-5");
    }

    #[test]
    fn test_model_key_parse_invalid() {
        assert!(ModelKey::parse("no-slash").is_none());
        assert!(ModelKey::parse("/no-provider").is_none());
        assert!(ModelKey::parse("no-model/").is_none());
    }

    #[test]
    fn test_parse_providers_chain() {
        let chain = parse_providers_chain("zai-coding/glm-5, claude/opus, openai/gpt-4o");
        assert_eq!(chain.len(), 3);
        assert_eq!(chain[0].provider, "zai-coding");
        assert_eq!(chain[0].model, "glm-5");
        assert_eq!(chain[1].provider, "claude");
        assert_eq!(chain[1].model, "opus");
        assert_eq!(chain[2].provider, "openai");
        assert_eq!(chain[2].model, "gpt-4o");
    }

    #[test]
    fn test_default_concurrency() {
        // glm-5 is Premium → 3
        assert_eq!(default_concurrency("glm-5"), 3);
        // glm-4.7 is Standard → 5
        assert_eq!(default_concurrency("glm-4.7"), 5);
        // glm-4.7-flash is Cheap → 8
        assert_eq!(default_concurrency("glm-4.7-flash"), 8);
    }
}