use super::{AnthropicProvider, ProviderResponse, error::ProviderError};
use crate::models::{AnthropicRequest, CountTokensRequest, CountTokensResponse};
use crate::auth::{TokenStore, OAuthClient, OAuthConfig};
use async_trait::async_trait;
use reqwest::Client;
use std::pin::Pin;
use futures::stream::Stream;
use bytes::Bytes;
pub struct AnthropicCompatibleProvider {
name: String,
api_key: String,
base_url: String,
client: Client,
models: Vec<String>,
custom_headers: Vec<(String, String)>,
oauth_provider: Option<String>,
token_store: Option<TokenStore>,
}
impl AnthropicCompatibleProvider {
pub fn new(
name: String,
api_key: String,
base_url: String,
models: Vec<String>,
oauth_provider: Option<String>,
token_store: Option<TokenStore>,
) -> Self {
Self {
name,
api_key,
base_url,
client: Client::new(),
models,
custom_headers: Vec::new(),
oauth_provider,
token_store,
}
}
pub fn with_headers(
name: String,
api_key: String,
base_url: String,
models: Vec<String>,
custom_headers: Vec<(String, String)>,
oauth_provider: Option<String>,
token_store: Option<TokenStore>,
) -> Self {
Self {
name,
api_key,
base_url,
client: Client::new(),
models,
custom_headers,
oauth_provider,
token_store,
}
}
async fn get_auth_header(&self) -> Result<String, ProviderError> {
if let Some(ref oauth_provider_id) = self.oauth_provider {
if let Some(ref token_store) = self.token_store {
if let Some(token) = token_store.get(oauth_provider_id) {
if token.needs_refresh() {
tracing::info!("🔄 Token for '{}' needs refresh, refreshing...", oauth_provider_id);
let config = OAuthConfig::anthropic();
let oauth_client = OAuthClient::new(config, token_store.clone());
match oauth_client.refresh_token(oauth_provider_id).await {
Ok(new_token) => {
tracing::info!("✅ Token refreshed successfully");
return Ok(new_token.access_token);
}
Err(e) => {
tracing::error!("❌ Failed to refresh token: {}", e);
return Err(ProviderError::AuthError(format!(
"Failed to refresh OAuth token: {}", e
)));
}
}
} else {
return Ok(token.access_token);
}
} else {
return Err(ProviderError::AuthError(format!(
"OAuth provider '{}' configured but no token found in store",
oauth_provider_id
)));
}
} else {
return Err(ProviderError::AuthError(
"OAuth provider configured but TokenStore not available".to_string()
));
}
}
Ok(self.api_key.clone())
}
fn is_oauth(&self) -> bool {
self.oauth_provider.is_some() && self.token_store.is_some()
}
pub fn anthropic(api_key: String, models: Vec<String>) -> Self {
Self::new(
"anthropic".to_string(),
api_key,
"https://api.anthropic.com".to_string(),
models,
None,
None,
)
}
pub fn openrouter(api_key: String, models: Vec<String>) -> Self {
Self::with_headers(
"openrouter".to_string(),
api_key,
"https://openrouter.ai/api".to_string(),
models,
vec![
("HTTP-Referer".to_string(), "https://github.com/bahkchanhee/claude-code-mux".to_string()),
("X-Title".to_string(), "Claude Code Mux".to_string()),
],
None,
None,
)
}
pub fn zai(api_key: String, models: Vec<String>, token_store: Option<TokenStore>) -> Self {
Self::new(
"z.ai".to_string(),
api_key,
"https://api.z.ai/api/anthropic".to_string(),
models,
None,
token_store,
)
}
pub fn minimax(api_key: String, models: Vec<String>, token_store: Option<TokenStore>) -> Self {
Self::new(
"minimax".to_string(),
api_key,
"https://api.minimax.io/anthropic".to_string(),
models,
None,
token_store,
)
}
pub fn zenmux(api_key: String, models: Vec<String>, token_store: Option<TokenStore>) -> Self {
Self::new(
"zenmux".to_string(),
api_key,
"https://zenmux.ai/api/anthropic".to_string(),
models,
None,
token_store,
)
}
pub fn kimi_coding(api_key: String, models: Vec<String>, token_store: Option<TokenStore>) -> Self {
Self::new(
"kimi-coding".to_string(),
api_key,
"https://api.kimi.com/coding".to_string(),
models,
None,
token_store,
)
}
}
#[async_trait]
impl AnthropicProvider for AnthropicCompatibleProvider {
async fn send_message(&self, request: AnthropicRequest) -> Result<ProviderResponse, ProviderError> {
let url = format!("{}/v1/messages", self.base_url);
let auth_value = self.get_auth_header().await?;
let mut req_builder = self.client
.post(&url)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json");
if self.is_oauth() {
req_builder = req_builder
.header("Authorization", format!("Bearer {}", auth_value))
.header("anthropic-beta", "oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14");
tracing::debug!("🔐 Using OAuth Bearer token for {}", self.name);
} else {
req_builder = req_builder.header("x-api-key", auth_value);
}
for (key, value) in &self.custom_headers {
req_builder = req_builder.header(key, value);
}
let response = req_builder
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
if status == 401 && self.is_oauth() {
tracing::warn!("🔄 Received 401, OAuth token may be invalid or expired");
}
return Err(ProviderError::ApiError {
status,
message: format!("{} API error: {}", self.name, error_text),
});
}
let response_text = response.text().await?;
tracing::debug!("{} provider response body: {}", self.name, response_text);
let provider_response: ProviderResponse = serde_json::from_str(&response_text)
.map_err(|e| {
tracing::error!("Failed to parse {} response: {}", self.name, e);
tracing::error!("Response body was: {}", response_text);
e
})?;
Ok(provider_response)
}
async fn count_tokens(&self, request: CountTokensRequest) -> Result<CountTokensResponse, ProviderError> {
if self.name == "anthropic" {
let url = format!("{}/v1/messages/count_tokens", self.base_url);
let auth_value = self.get_auth_header().await?;
let mut req_builder = self.client
.post(&url)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json");
if self.is_oauth() {
req_builder = req_builder
.header("Authorization", format!("Bearer {}", auth_value))
.header("anthropic-beta", "oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14");
} else {
req_builder = req_builder.header("x-api-key", auth_value);
}
let response = req_builder
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(ProviderError::ApiError {
status,
message: error_text,
});
}
let count_response: CountTokensResponse = response.json().await?;
return Ok(count_response);
}
let mut total_chars = 0;
if let Some(ref system) = request.system {
let system_text = match system {
crate::models::SystemPrompt::Text(text) => text.clone(),
crate::models::SystemPrompt::Blocks(blocks) => {
blocks.iter().map(|b| b.text.clone()).collect::<Vec<_>>().join("\n")
}
};
total_chars += system_text.len();
}
for msg in &request.messages {
use crate::models::MessageContent;
let content = match &msg.content {
MessageContent::Text(text) => text.clone(),
MessageContent::Blocks(blocks) => {
blocks.iter()
.filter_map(|block| {
match block {
crate::models::ContentBlock::Text { text } => Some(text.clone()),
crate::models::ContentBlock::ToolResult { content, .. } => {
Some(content.to_string())
}
crate::models::ContentBlock::Thinking { thinking, .. } => {
Some(thinking.clone())
}
_ => None,
}
})
.collect::<Vec<_>>()
.join("\n")
}
};
total_chars += content.len();
}
let estimated_tokens = (total_chars / 4) as u32;
Ok(CountTokensResponse {
input_tokens: estimated_tokens,
})
}
async fn send_message_stream(
&self,
request: AnthropicRequest,
) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes, ProviderError>> + Send>>, ProviderError> {
use futures::stream::TryStreamExt;
let url = format!("{}/v1/messages", self.base_url);
let auth_value = self.get_auth_header().await?;
let mut req_builder = self.client
.post(&url)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json");
if self.is_oauth() {
req_builder = req_builder
.header("Authorization", format!("Bearer {}", auth_value))
.header("anthropic-beta", "oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14");
tracing::debug!("🔐 Using OAuth Bearer token for streaming on {}", self.name);
} else {
req_builder = req_builder.header("x-api-key", auth_value);
}
for (key, value) in &self.custom_headers {
req_builder = req_builder.header(key, value);
}
let response = req_builder
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
if status == 401 && self.is_oauth() {
tracing::warn!("🔄 Received 401 on streaming, OAuth token may be invalid or expired");
}
return Err(ProviderError::ApiError {
status,
message: format!("{} API error: {}", self.name, error_text),
});
}
let stream = response.bytes_stream().map_err(|e| ProviderError::HttpError(e));
Ok(Box::pin(stream))
}
fn supports_model(&self, model: &str) -> bool {
self.models.iter().any(|m| m == model)
}
}