use std::collections::HashMap;
use std::env;
use std::fmt;
use std::sync::Arc;
use defect_core::error::BoxError;
use defect_core::llm::{
Capabilities, CompletionRequest, FeatureSupport, LlmProvider, ModelCapabilityOverrides,
ModelInfo, ProtocolId, ProviderError, ProviderErrorKind, ProviderInfo, ProviderStream,
RateLimitScope, ThinkingEcho, TimeoutPhase,
};
use futures::FutureExt;
use futures::future::BoxFuture;
use http::{HeaderName, HeaderValue};
use toac::security::AuthFuture;
use toac::{ApiClient, AuthSelector, CallError, MakeRequest, Operation, Request as ToacRequest};
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use tower::Service;
use crate::protocol::anthropic_messages::{self, ThinkingWireFormat};
use crate::wire::anthropic::{
components as wire,
operations::v1::{messages, models},
};
use defect_http::{HttpStack, HttpStackConfig, HttpStackError, build_http_stack};
const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
const API_KEY_ENV: &str = "ANTHROPIC_API_KEY";
const BASE_URL_ENV: &str = "ANTHROPIC_BASE_URL";
const ANTHROPIC_VERSION: &str = "2023-06-01";
const DEFAULT_AUTH_HEADER: &str = "x-api-key";
const DEFAULT_VENDOR: &str = "anthropic";
const DEFAULT_DISPLAY_NAME: &str = "Anthropic Claude";
type Client = ApiClient<HttpStack>;
#[derive(Debug, Default, Clone)]
pub struct AnthropicConfig {
pub api_key: Option<String>,
pub api_key_env: Option<String>,
pub base_url: Option<String>,
pub vendor: Option<String>,
pub display_name: Option<String>,
pub auth_header: Option<String>,
pub headers: HashMap<HeaderName, HeaderValue>,
pub thinking_formats: HashMap<String, ThinkingWireFormat>,
pub http: HttpStackConfig,
}
impl AnthropicConfig {
pub fn from_env() -> Self {
Self {
api_key: env::var(API_KEY_ENV).ok(),
api_key_env: None,
base_url: env::var(BASE_URL_ENV).ok(),
vendor: None,
display_name: None,
auth_header: None,
headers: HashMap::new(),
thinking_formats: HashMap::new(),
http: HttpStackConfig::default(),
}
}
fn resolve_api_key(&self) -> Result<String, ProviderError> {
if let Some(api_key) = self.api_key.clone() {
return Ok(api_key);
}
let env_name = self.api_key_env.as_deref().unwrap_or(API_KEY_ENV);
env::var(env_name).map_err(|_| {
ProviderError::new(ProviderErrorKind::AuthMissing {
var_hint: Some(env_name.into()),
})
})
}
fn resolve_base_url(&self) -> String {
self.base_url
.clone()
.or_else(|| env::var(BASE_URL_ENV).ok())
.unwrap_or_else(|| DEFAULT_BASE_URL.to_owned())
}
}
pub struct AnthropicProvider {
client: Client,
info: ProviderInfo,
capabilities: Capabilities,
extra_headers: HashMap<HeaderName, HeaderValue>,
thinking_formats: HashMap<String, ThinkingWireFormat>,
models: Arc<RwLock<Option<Vec<ModelInfo>>>>,
}
impl fmt::Debug for AnthropicProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AnthropicProvider")
.field("info", &self.info)
.field("capabilities", &self.capabilities)
.finish_non_exhaustive()
}
}
impl AnthropicProvider {
pub fn new(config: AnthropicConfig) -> Result<Self, ProviderError> {
let token = config.resolve_api_key()?;
let base_url = config.resolve_base_url();
let header_name = config.auth_header.as_deref().unwrap_or(DEFAULT_AUTH_HEADER);
let auth = ConfigurableApiKeyAuth::new(header_name, token)?;
let http = build_http_stack(config.http)
.map_err(|e| ProviderError::new(ProviderErrorKind::Transport(BoxError::new(e))))?;
let client = ApiClient::new(http, base_url).with_auth(auth);
Ok(Self {
client,
info: ProviderInfo {
vendor: config.vendor.unwrap_or_else(|| DEFAULT_VENDOR.into()),
protocol: ProtocolId::AnthropicMessages,
display_name: config
.display_name
.unwrap_or_else(|| DEFAULT_DISPLAY_NAME.into()),
},
capabilities: Capabilities {
tool_calls: FeatureSupport::Supported,
parallel_tool_calls: FeatureSupport::Supported,
thinking: FeatureSupport::Supported,
vision: FeatureSupport::Supported,
prompt_cache: FeatureSupport::Supported,
thinking_echo: ThinkingEcho::Required,
},
extra_headers: config.headers,
thinking_formats: config.thinking_formats,
models: Arc::default(),
})
}
fn thinking_format_for(&self, model_id: &str) -> ThinkingWireFormat {
self.thinking_formats
.get(model_id)
.copied()
.unwrap_or_default()
}
}
impl LlmProvider for AnthropicProvider {
fn info(&self) -> ProviderInfo {
self.info.clone()
}
fn capabilities(&self) -> Capabilities {
self.capabilities
}
fn list_models(&self) -> BoxFuture<'_, Result<Vec<ModelInfo>, ProviderError>> {
async move {
if let Some(cached) = self.models.read().await.clone() {
return Ok(cached);
}
let request = self.with_anthropic_headers(models::get::Request {
before_id: None,
after_id: None,
limit: None,
});
let resp = self
.client
.clone()
.call(request)
.await
.map_err(call_error_to_provider)?;
let request_id = extract_request_id(&resp.headers);
let list = match resp.body {
models::get::ResponseBody::Status200(l) => l,
models::get::ResponseBody::Status400(e) => {
return Err(error_response(400, &e).with_request_id_opt(request_id));
}
models::get::ResponseBody::Status401(e) => {
return Err(error_response(401, &e).with_request_id_opt(request_id));
}
models::get::ResponseBody::Status429(e) => {
return Err(error_response(429, &e).with_request_id_opt(request_id));
}
models::get::ResponseBody::Status500(e) => {
return Err(error_response(500, &e).with_request_id_opt(request_id));
}
};
let mapped: Vec<ModelInfo> = list
.data
.into_iter()
.map(|m| ModelInfo {
id: m.id,
display_name: Some(m.display_name),
context_window: None,
max_output_tokens: None,
deprecated: false,
capabilities_overrides: ModelCapabilityOverrides::default(),
})
.collect();
*self.models.write().await = Some(mapped.clone());
Ok(mapped)
}
.boxed()
}
fn model_info(&self, model_id: &str) -> Option<ModelInfo> {
self.models
.try_read()
.ok()
.and_then(|g| g.as_ref()?.iter().find(|m| m.id == model_id).cloned())
}
fn complete(
&self,
req: CompletionRequest,
cancel: CancellationToken,
) -> BoxFuture<'_, Result<ProviderStream, ProviderError>> {
async move {
let thinking_format = self.thinking_format_for(&req.model);
let body = anthropic_messages::encode_request(&req, thinking_format);
let op = self
.with_anthropic_headers(messages::post::Request { body })
.with_accept(HeaderValue::from_static("text/event-stream"));
let mut client = self.client.clone();
let resp = tokio::select! {
biased;
_ = cancel.cancelled() => {
return Err(ProviderError::new(ProviderErrorKind::Canceled));
}
r = client.call(op) => r.map_err(call_error_to_provider)?,
};
let request_id = extract_request_id(&resp.headers);
let stream = match resp.body {
messages::post::ResponseBody::Status200Sse(s) => s,
messages::post::ResponseBody::Status200Json(_) => {
return Err(ProviderError::new(ProviderErrorKind::ProtocolViolation {
hint: "server returned application/json despite Accept: text/event-stream"
.into(),
})
.with_request_id_opt(request_id));
}
messages::post::ResponseBody::Status400(e) => {
return Err(error_response(400, &e).with_request_id_opt(request_id));
}
messages::post::ResponseBody::Status401(e) => {
return Err(error_response(401, &e).with_request_id_opt(request_id));
}
messages::post::ResponseBody::Status403(e) => {
return Err(error_response(403, &e).with_request_id_opt(request_id));
}
messages::post::ResponseBody::Status404(e) => {
return Err(error_response(404, &e).with_request_id_opt(request_id));
}
messages::post::ResponseBody::Status413(e) => {
return Err(error_response(413, &e).with_request_id_opt(request_id));
}
messages::post::ResponseBody::Status429(e) => {
return Err(error_response(429, &e).with_request_id_opt(request_id));
}
messages::post::ResponseBody::Status500(e) => {
return Err(error_response(500, &e).with_request_id_opt(request_id));
}
messages::post::ResponseBody::Status529(e) => {
return Err(error_response(529, &e).with_request_id_opt(request_id));
}
};
let decoded = anthropic_messages::decode_stream(stream, cancel);
Ok(Box::pin(decoded) as ProviderStream)
}
.boxed()
}
}
#[derive(Debug, Clone)]
struct ConfigurableApiKeyAuth {
name: HeaderName,
value: HeaderValue,
}
impl ConfigurableApiKeyAuth {
fn new(name: &str, value: String) -> Result<Self, ProviderError> {
let name = HeaderName::from_bytes(name.as_bytes()).map_err(|e| {
ProviderError::new(ProviderErrorKind::AuthMalformed {
hint: Some(format!("invalid auth header name `{name}`: {e}")),
})
})?;
let mut value = HeaderValue::from_str(&value).map_err(|e| {
ProviderError::new(ProviderErrorKind::AuthMalformed {
hint: Some(format!("api key is not a valid header value: {e}")),
})
})?;
value.set_sensitive(true);
Ok(Self { name, value })
}
}
impl AuthSelector for ConfigurableApiKeyAuth {
fn apply_for(
&self,
mut req: ToacRequest,
requirements: &'static [&'static [&'static str]],
) -> AuthFuture<'_> {
Box::pin(async move {
if !requirements.is_empty() {
req.headers_mut()
.insert(self.name.clone(), self.value.clone());
}
Ok(req)
})
}
}
#[derive(Debug, Clone)]
struct WithAnthropicHeaders<Op> {
op: Op,
extra: HashMap<HeaderName, HeaderValue>,
}
impl AnthropicProvider {
fn with_anthropic_headers<Op>(&self, op: Op) -> WithAnthropicHeaders<Op> {
WithAnthropicHeaders {
op,
extra: self.extra_headers.clone(),
}
}
}
impl<Op> MakeRequest for WithAnthropicHeaders<Op>
where
Op: MakeRequest + Send,
{
type Error = Op::Error;
#[allow(clippy::manual_async_fn)]
fn make_request(
self,
) -> impl std::future::Future<Output = Result<ToacRequest, Self::Error>> + Send {
async move {
let mut req = self.op.make_request().await?;
req.headers_mut().insert(
http::HeaderName::from_static("anthropic-version"),
HeaderValue::from_static(ANTHROPIC_VERSION),
);
req.headers_mut().extend(self.extra);
Ok(req)
}
}
}
impl<Op> Operation for WithAnthropicHeaders<Op>
where
Op: Operation + Send,
{
type Response = Op::Response;
}
impl<Op> WithAnthropicHeaders<Op> {
fn with_accept(self, accept: HeaderValue) -> toac::WithAccept<Self> {
toac::WithAccept::new(self, accept)
}
}
fn extract_request_id(headers: &http::HeaderMap) -> Option<String> {
headers
.get("request-id")
.or_else(|| headers.get("x-request-id"))
.and_then(|v| v.to_str().ok())
.map(str::to_owned)
}
trait WithRequestIdOpt {
fn with_request_id_opt(self, id: Option<String>) -> Self;
}
impl WithRequestIdOpt for ProviderError {
fn with_request_id_opt(self, id: Option<String>) -> Self {
match id {
Some(s) => self.with_request_id(s),
None => self,
}
}
}
fn call_error_to_provider(err: CallError<HttpStackError>) -> ProviderError {
match err {
CallError::Encode(e) => ProviderError::new(ProviderErrorKind::BadRequest {
hint: Some(e.to_string()),
}),
CallError::Auth(e) => ProviderError::new(ProviderErrorKind::AuthMalformed {
hint: Some(e.to_string()),
}),
CallError::Transport(HttpStackError::Timeout { phase }) => {
ProviderError::new(ProviderErrorKind::Timeout {
phase: map_timeout_phase(phase),
})
}
CallError::Transport(e) => {
ProviderError::new(ProviderErrorKind::Transport(BoxError::new(e)))
}
CallError::Decode(e) => ProviderError::new(ProviderErrorKind::Malformed(BoxError::new(e))),
}
}
fn map_timeout_phase(phase: defect_http::TimeoutPhase) -> TimeoutPhase {
match phase {
defect_http::TimeoutPhase::Connect => TimeoutPhase::Connect,
defect_http::TimeoutPhase::ReadHeaders => TimeoutPhase::ReadHeaders,
defect_http::TimeoutPhase::ReadBody => TimeoutPhase::ReadBody,
defect_http::TimeoutPhase::Idle => TimeoutPhase::Idle,
defect_http::TimeoutPhase::Total => TimeoutPhase::Total,
_ => TimeoutPhase::Total,
}
}
fn error_response(status: u16, e: &wire::ErrorResponse) -> ProviderError {
let message = e.error.message.clone();
let upstream_type = e.error.r#type.as_str();
let kind = match (status, upstream_type) {
(401, _) => ProviderErrorKind::AuthRejected {
hint: Some(message),
},
(400, t) if t == "invalid_request_error" && contains_max_tokens(&message) => {
ProviderErrorKind::MaxTokensInvalid {
requested: None,
limit: None,
}
}
(400, _) => ProviderErrorKind::BadRequest {
hint: Some(message),
},
(403, _) => ProviderErrorKind::AuthRejected {
hint: Some(message),
},
(404, "not_found_error") => ProviderErrorKind::ModelNotFound {
model: extract_model(&message).unwrap_or_else(|| "unknown".into()),
},
(404, _) => ProviderErrorKind::ServerError {
status: Some(404),
hint: Some(message),
},
(413, _) => ProviderErrorKind::BadRequest {
hint: Some("payload too large".into()),
},
(429, _) => ProviderErrorKind::RateLimit {
retry_after: None,
scope: RateLimitScope::Unspecified,
},
(529, _) => ProviderErrorKind::ServerError {
status: Some(529),
hint: Some("overloaded".into()),
},
(s, "overloaded_error") => ProviderErrorKind::ServerError {
status: Some(s),
hint: Some("overloaded".into()),
},
(s, _) => ProviderErrorKind::ServerError {
status: Some(s),
hint: Some(message),
},
};
ProviderError::new(kind)
}
fn contains_max_tokens(msg: &str) -> bool {
let lower = msg.to_ascii_lowercase();
lower.contains("max_tokens") || lower.contains("max tokens")
}
fn extract_model(msg: &str) -> Option<String> {
let lower = msg.to_ascii_lowercase();
let idx = lower.find("model")?;
let tail = &msg[idx + "model".len()..];
let trimmed = tail.trim_start_matches(|c: char| {
c.is_whitespace() || c == ':' || c == '=' || c == '"' || c == '\''
});
let end = trimmed
.find(|c: char| c.is_whitespace() || c == '"' || c == '\'' || c == ',')
.unwrap_or(trimmed.len());
let candidate = &trimmed[..end];
if candidate.is_empty() {
None
} else {
Some(candidate.to_owned())
}
}