llmkit-anthropic 0.1.0

Anthropic (Claude) Messages API provider adapter for llmkit-rs
Documentation
//! [`AnthropicProvider`] — implements [`LlmProvider`] against the Messages API.

use std::time::{Duration, Instant};

use async_trait::async_trait;
use llmkit_core::{
    pricing, ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse,
    LlmError, LlmProvider, LlmResult,
};

use crate::types::{ApiError, MessagesResponse};
use crate::{chat, stream};

const DEFAULT_BASE_URL: &str = "https://api.anthropic.com/v1";
const DEFAULT_MODEL: &str = "claude-opus-4-8";
const ANTHROPIC_VERSION: &str = "2023-06-01";

/// Anthropic (Claude) provider over the `/v1/messages` API.
#[derive(Clone)]
pub struct AnthropicProvider {
    http: reqwest::Client,
    api_key: String,
    base_url: String,
    model: String,
    version: String,
}

impl AnthropicProvider {
    /// Construct with an explicit API key.
    pub fn new(api_key: impl Into<String>) -> Self {
        Self {
            http: reqwest::Client::new(),
            api_key: api_key.into(),
            base_url: DEFAULT_BASE_URL.to_string(),
            model: DEFAULT_MODEL.to_string(),
            version: ANTHROPIC_VERSION.to_string(),
        }
    }

    /// Construct from the `ANTHROPIC_API_KEY` environment variable.
    pub fn from_env() -> LlmResult<Self> {
        let key = std::env::var("ANTHROPIC_API_KEY")
            .map_err(|_| LlmError::Auth("ANTHROPIC_API_KEY not set".into()))?;
        Ok(Self::new(key))
    }

    /// Set the default model.
    pub fn model(mut self, model: impl Into<String>) -> Self {
        self.model = model.into();
        self
    }

    /// Override the base URL.
    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
        self.base_url = base_url.into();
        self
    }

    /// Override the `anthropic-version` header.
    pub fn version(mut self, version: impl Into<String>) -> Self {
        self.version = version.into();
        self
    }

    /// Provide a custom [`reqwest::Client`].
    pub fn with_client(mut self, client: reqwest::Client) -> Self {
        self.http = client;
        self
    }

    fn resolved_model(&self, req: &ChatRequest) -> String {
        req.model.clone().unwrap_or_else(|| self.model.clone())
    }

    fn request(&self, body: &impl serde::Serialize) -> reqwest::RequestBuilder {
        self.http
            .post(format!("{}/messages", self.base_url))
            .header("x-api-key", &self.api_key)
            .header("anthropic-version", &self.version)
            .json(body)
    }
}

#[async_trait]
impl LlmProvider for AnthropicProvider {
    async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
        let model = self.resolved_model(&req);
        let body = chat::build_request(&req, model, false);

        let start = Instant::now();
        let resp = self.request(&body).send().await.map_err(map_reqwest_err)?;
        let resp = check_status(resp).await?;
        let parsed: MessagesResponse = resp.json().await.map_err(map_reqwest_err)?;

        let mut out = chat::map_response(parsed, start.elapsed().as_millis() as u64)?;
        out.cost = pricing::pricing_for(&out.model).map(|p| p.cost_for(out.usage));
        Ok(out)
    }

    async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
        let model = self.resolved_model(&req);
        let body = chat::build_request(&req, model, true);

        let resp = self.request(&body).send().await.map_err(map_reqwest_err)?;
        let resp = check_status(resp).await?;
        Ok(stream::parse(resp))
    }

    async fn embed(&self, _req: EmbedRequest) -> LlmResult<EmbedResponse> {
        Err(LlmError::Unsupported(
            "Anthropic does not provide an embeddings endpoint".into(),
        ))
    }

    fn name(&self) -> &'static str {
        "anthropic"
    }

    fn model(&self) -> &str {
        &self.model
    }

    fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
        let model = self.resolved_model(req);
        let pricing = pricing::pricing_for(&model)?;
        let prompt_chars: usize = req
            .messages
            .iter()
            .filter_map(|m| m.content.as_text())
            .map(|t| t.len())
            .sum::<usize>()
            + req.system.as_deref().map(str::len).unwrap_or(0);
        let prompt_tokens = (prompt_chars / 4) as u32;
        let completion_tokens = req.max_tokens.unwrap_or(256);
        Some(pricing.cost_for(llmkit_core::TokenUsage::new(prompt_tokens, completion_tokens)))
    }
}

fn map_reqwest_err(e: reqwest::Error) -> LlmError {
    if e.is_timeout() {
        LlmError::Timeout
    } else if e.is_decode() {
        LlmError::Serialization(e.to_string())
    } else {
        LlmError::Transport(e.to_string())
    }
}

async fn check_status(resp: reqwest::Response) -> LlmResult<reqwest::Response> {
    let status = resp.status();
    if status.is_success() {
        return Ok(resp);
    }

    let retry_after = resp
        .headers()
        .get(reqwest::header::RETRY_AFTER)
        .and_then(|v| v.to_str().ok())
        .and_then(|s| s.parse::<u64>().ok())
        .map(Duration::from_secs);

    let body = resp.text().await.unwrap_or_default();
    let message = serde_json::from_str::<ApiError>(&body)
        .map(|e| e.error.message)
        .unwrap_or(body);

    Err(match status.as_u16() {
        401 | 403 => LlmError::Auth(message),
        429 => LlmError::RateLimited { retry_after, message },
        400 | 404 | 422 => LlmError::InvalidRequest(message),
        code => LlmError::Provider { status: code, message },
    })
}