collet 0.1.0

Relentless agentic coding orchestrator with zero-drop agent loops
Documentation
//! LLM provider abstraction for multi-provider support.
//!
//! Defines a concrete provider type and a capability-query interface
//! that allows collet to work with different LLM providers without
//! requiring the `async_trait` crate.

use std::time::Duration;

use crate::api::models::{ChatRequest, ChatResponse};

/// Structured error returned by [`OpenAiCompatibleProvider::chat`] and
/// [`OpenAiCompatibleProvider::chat_stream`].
///
/// The retry layer in `agent::r#loop::iter` inspects this enum (rather than
/// matching error strings) to decide whether a failure is transient. Keeping
/// the HTTP status and `Retry-After` header in the type means the retry
/// policy can be unit-tested as a pure function and stays correct as
/// providers change their wording.
#[derive(Debug, thiserror::Error)]
pub enum ApiCallError {
    /// HTTP non-2xx response from the upstream provider.
    #[error("API status {status}: {body}")]
    Status {
        status: u16,
        retry_after: Option<Duration>,
        body: String,
    },
    /// reqwest-level transport failure (DNS, TCP, TLS, idle pool reset…).
    #[error("network error: {0}")]
    Network(#[from] reqwest::Error),
    /// Response body could not be deserialized into the expected shape.
    #[error("response decode error: {0}")]
    Decode(String),
}

impl ApiCallError {
    /// Whether this error is safe to retry without re-issuing user intent.
    ///
    /// We treat the standard transient HTTP statuses (408 Request Timeout,
    /// 429 Too Many Requests, 5xx server errors) as retryable, plus reqwest
    /// errors that look like genuine network failures (`is_connect`,
    /// `is_timeout`, `is_request`). 4xx other than 408/429 are permanent
    /// and surfaced immediately so the user can fix the request.
    pub fn is_retryable(&self) -> bool {
        match self {
            Self::Status { status, .. } => {
                matches!(*status, 408 | 425 | 429 | 500 | 502 | 503 | 504)
            }
            Self::Network(e) => e.is_connect() || e.is_timeout() || e.is_request() || e.is_body(),
            Self::Decode(_) => false,
        }
    }

    /// Server-suggested wait when the upstream sent a `Retry-After` header.
    pub fn retry_after(&self) -> Option<Duration> {
        match self {
            Self::Status { retry_after, .. } => *retry_after,
            _ => None,
        }
    }

    /// Whether this is a 429 rate-limit response (used for nicer status text).
    pub fn is_rate_limit(&self) -> bool {
        matches!(self, Self::Status { status: 429, .. })
    }
}

/// Parse a `Retry-After` HTTP header value (per RFC 9110).
///
/// Accepts the seconds form (`Retry-After: 30`) only — the HTTP-date form is
/// rare for API providers and intentionally unsupported here. Returns `None`
/// for absent or unparseable values so the caller falls back to its own
/// backoff schedule.
fn parse_retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
    let value = headers.get(reqwest::header::RETRY_AFTER)?.to_str().ok()?;
    let secs: u64 = value.trim().parse().ok()?;
    Some(Duration::from_secs(secs))
}

/// OpenAI-compatible provider implementation.
///
/// Works with Z.ai, OpenAI, DeepSeek, local Ollama, and any other
/// service that implements the OpenAI chat completions API.
///
/// This is the primary provider type. Additional provider types
/// (e.g., Anthropic) can be added as separate structs implementing
/// the same public interface.
#[derive(Clone)]
pub struct OpenAiCompatibleProvider {
    http: reqwest::Client,
    base_url: String,
    api_key: String,
    pub model: String,
    pub max_tokens: u32,
    context_window_size: usize,
    has_tool_support: bool,
    has_reasoning: bool,
    extra_headers: std::collections::HashMap<String, String>,
}

impl OpenAiCompatibleProvider {
    pub fn new(
        base_url: String,
        api_key: String,
        model: String,
        context_window_size: usize,
    ) -> anyhow::Result<Self> {
        let profile = crate::api::model_profile::profile_for(&model);
        let http = reqwest::Client::builder()
            .timeout(std::time::Duration::from_secs(300))
            .pool_max_idle_per_host(10)
            .pool_idle_timeout(std::time::Duration::from_secs(60))
            .tcp_keepalive(std::time::Duration::from_secs(30))
            .build()?;

        Ok(Self {
            http,
            base_url,
            api_key,
            max_tokens: profile.max_output_tokens,
            model,
            context_window_size,
            has_tool_support: profile.supports_tool_use,
            has_reasoning: profile.supports_reasoning,
            extra_headers: std::collections::HashMap::new(),
        })
    }

    /// Create from the global [`Config`][crate::config::Config].
    pub fn from_config(config: &crate::config::Config) -> anyhow::Result<Self> {
        let profile = crate::api::model_profile::profile_for(&config.model);
        let http = reqwest::Client::builder()
            .timeout(std::time::Duration::from_secs(300))
            .pool_max_idle_per_host(10)
            .pool_idle_timeout(std::time::Duration::from_secs(60))
            .tcp_keepalive(std::time::Duration::from_secs(30))
            .build()?;

        Ok(Self {
            http,
            base_url: config.base_url.clone(),
            api_key: config.api_key.clone(),
            model: config.model.clone(),
            max_tokens: config.max_tokens,
            context_window_size: profile.context_window,
            has_tool_support: profile.supports_tool_use,
            has_reasoning: profile.supports_reasoning,
            extra_headers: config.proxy_headers.clone(),
        })
    }

    /// Create from a stored `ProviderEntry` using the given model name.
    /// Model capabilities are derived from the model profile; use `AgentDef`
    /// overrides at call sites for per-agent customization.
    pub fn from_entry(
        entry: &crate::config::ProviderEntry,
        api_key: &str,
        model: &str,
        extra_headers: std::collections::HashMap<String, String>,
    ) -> anyhow::Result<Self> {
        let profile = crate::api::model_profile::profile_for(model);
        // Delegate to `new` for the base configuration, then apply ProviderEntry overrides.
        let mut provider = Self::new(
            entry.base_url.clone(),
            api_key.to_string(),
            model.to_string(),
            profile.context_window,
        )?;
        provider.extra_headers = extra_headers;
        Ok(provider)
    }

    /// Switch to a different provider at runtime.
    pub fn switch_provider(
        &mut self,
        base_url: String,
        api_key: String,
        model: String,
        max_tokens: u32,
    ) {
        self.base_url = base_url;
        self.api_key = api_key;
        self.model = model;
        self.max_tokens = max_tokens;
    }

    /// Get the base URL.
    pub fn base_url(&self) -> &str {
        &self.base_url
    }

    /// Get the API key.
    pub fn api_key(&self) -> &str {
        &self.api_key
    }

    /// Non-streaming chat completion.
    pub async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, ApiCallError> {
        let url = format!("{}/chat/completions", self.base_url());
        tracing::trace!(base_url = self.base_url(), model = %self.model, "chat request");
        let mut req = self
            .http
            .post(&url)
            .header("Authorization", format!("Bearer {}", self.api_key))
            .header("Content-Type", "application/json");
        for (k, v) in &self.extra_headers {
            req = req.header(k.as_str(), v.as_str());
        }
        let resp = req.json(request).send().await?;

        let status = resp.status();
        if !status.is_success() {
            let retry_after = parse_retry_after(resp.headers());
            let body = resp.text().await.unwrap_or_default();
            return Err(ApiCallError::Status {
                status: status.as_u16(),
                retry_after,
                body,
            });
        }

        let response: ChatResponse = resp
            .json()
            .await
            .map_err(|e| ApiCallError::Decode(e.to_string()))?;
        tracing::trace!(
            response_id = %response.id,
            base_url = %self.base_url(),
            finish_reason = ?response.choices.first().and_then(|c| c.finish_reason.as_deref()),
            first_choice_index = ?response.choices.first().map(|c| c.index),
            total_tokens = ?response.usage.as_ref().map(|u| u.total_tokens),
            choices = response.choices.len(),
            "Non-streaming API response"
        );
        Ok(response)
    }

    /// Streaming chat completion — returns the raw HTTP response for SSE processing.
    pub async fn chat_stream(
        &self,
        request: &ChatRequest,
    ) -> Result<reqwest::Response, ApiCallError> {
        let url = format!("{}/chat/completions", self.base_url);
        let mut req = self
            .http
            .post(&url)
            .header("Authorization", format!("Bearer {}", self.api_key))
            .header("Content-Type", "application/json");
        for (k, v) in &self.extra_headers {
            req = req.header(k.as_str(), v.as_str());
        }
        let resp = req.json(request).send().await?;

        let status = resp.status();
        if !status.is_success() {
            let retry_after = parse_retry_after(resp.headers());
            let body = resp.text().await.unwrap_or_default();
            return Err(ApiCallError::Status {
                status: status.as_u16(),
                retry_after,
                body,
            });
        }

        Ok(resp)
    }

    /// The current model identifier.
    pub fn model_name(&self) -> &str {
        &self.model
    }

    /// Context window size in tokens.
    pub fn context_window(&self) -> usize {
        self.context_window_size
    }

    /// Whether this provider/model supports tool use (function calling).
    pub fn supports_tools(&self) -> bool {
        self.has_tool_support
    }

    /// Whether this provider/model supports reasoning/thinking tokens.
    pub fn supports_reasoning(&self) -> bool {
        self.has_reasoning
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn status_err(status: u16) -> ApiCallError {
        ApiCallError::Status {
            status,
            retry_after: None,
            body: String::new(),
        }
    }

    #[test]
    fn test_is_retryable_5xx_and_429() {
        for code in [408u16, 425, 429, 500, 502, 503, 504] {
            assert!(
                status_err(code).is_retryable(),
                "status {code} must be retryable"
            );
        }
    }

    #[test]
    fn test_is_retryable_skips_4xx_permanent() {
        for code in [400u16, 401, 403, 404, 409, 422] {
            assert!(
                !status_err(code).is_retryable(),
                "status {code} must NOT be retryable"
            );
        }
    }

    #[test]
    fn test_is_rate_limit_only_429() {
        assert!(status_err(429).is_rate_limit());
        assert!(!status_err(503).is_rate_limit());
        assert!(!status_err(500).is_rate_limit());
    }

    #[test]
    fn test_retry_after_propagates() {
        let err = ApiCallError::Status {
            status: 429,
            retry_after: Some(Duration::from_secs(7)),
            body: String::new(),
        };
        assert_eq!(err.retry_after(), Some(Duration::from_secs(7)));
        assert_eq!(status_err(429).retry_after(), None);
    }

    #[test]
    fn test_decode_error_not_retryable() {
        assert!(!ApiCallError::Decode("bad json".into()).is_retryable());
    }

    #[test]
    fn test_parse_retry_after_seconds() {
        let mut h = reqwest::header::HeaderMap::new();
        h.insert(reqwest::header::RETRY_AFTER, "12".parse().unwrap());
        assert_eq!(parse_retry_after(&h), Some(Duration::from_secs(12)));
    }

    #[test]
    fn test_parse_retry_after_missing_or_garbage() {
        let empty = reqwest::header::HeaderMap::new();
        assert_eq!(parse_retry_after(&empty), None);

        let mut bad = reqwest::header::HeaderMap::new();
        bad.insert(
            reqwest::header::RETRY_AFTER,
            "Wed, 21 Oct 2026 07:28:00 GMT".parse().unwrap(),
        );
        // HTTP-date form is intentionally unsupported.
        assert_eq!(parse_retry_after(&bad), None);
    }

    #[test]
    fn test_openai_compatible_provider_creation() {
        let provider = OpenAiCompatibleProvider::new(
            "https://api.example.com/v1".to_string(),
            "test-key".to_string(),
            "glm-4.7".to_string(),
            128_000,
        );
        assert!(provider.is_ok());
        let p = provider.unwrap();
        assert_eq!(p.model_name(), "glm-4.7");
        assert_eq!(p.context_window(), 128_000);
        assert!(p.supports_tools());
    }
}