use anyllm_translate::anthropic::messages::MessageResponse;
use anyllm_translate::anthropic::streaming::StreamEvent;
use anyllm_translate::anthropic::MessageCreateRequest;
use anyllm_translate::openai::{ChatCompletionRequest, ChatCompletionResponse};
use anyllm_translate::{translate_request, translate_response, TranslationConfig};
use futures::Stream;
use crate::error::ClientError;
use crate::http::{build_http_client, HttpClientConfig};
use crate::rate_limit::RateLimitHeaders;
use crate::retry::{self, RetryableError};
use crate::streaming::SseTranslatingStream;
#[derive(Clone, Debug)]
pub enum Auth {
Bearer(String),
Header { name: String, value: String },
}
#[derive(Clone, Debug)]
pub struct ClientConfig {
pub chat_completions_url: String,
pub auth: Auth,
pub http: HttpClientConfig,
pub translation: TranslationConfig,
}
impl ClientConfig {
pub fn builder() -> ClientConfigBuilder {
ClientConfigBuilder::default()
}
}
#[derive(Default)]
pub struct ClientConfigBuilder {
backend_url: String,
auth: Option<Auth>,
http: Option<HttpClientConfig>,
translation: Option<TranslationConfig>,
}
impl ClientConfigBuilder {
pub fn backend_url(mut self, url: impl Into<String>) -> Self {
self.backend_url = url.into();
self
}
pub fn auth(mut self, auth: Auth) -> Self {
self.auth = Some(auth);
self
}
pub fn http(mut self, http: HttpClientConfig) -> Self {
self.http = Some(http);
self
}
pub fn translation(mut self, translation: TranslationConfig) -> Self {
self.translation = Some(translation);
self
}
pub fn build(self) -> ClientConfig {
ClientConfig {
chat_completions_url: self.backend_url,
auth: self.auth.unwrap_or_else(|| {
tracing::warn!(
"ClientConfig built without auth credentials; \
requests will be sent with an empty Bearer token"
);
Auth::Bearer(String::new())
}),
http: self.http.unwrap_or_default(),
translation: self.translation.unwrap_or_default(),
}
}
}
#[derive(Debug)]
enum InternalError {
Request(reqwest::Error),
ApiError { status: u16, body: String },
}
impl RetryableError for InternalError {
fn from_request(e: reqwest::Error) -> Self {
Self::Request(e)
}
fn from_api_response(status: u16, body: &str) -> Self {
Self::ApiError {
status,
body: body.to_string(),
}
}
}
impl From<InternalError> for ClientError {
fn from(e: InternalError) -> Self {
match e {
InternalError::Request(e) => ClientError::Transport(e),
InternalError::ApiError { status, body } => ClientError::ApiError {
status,
message: format!("Backend returned status {status}"),
body,
},
}
}
}
pub struct ClientBuilder {
base_url: Option<String>,
api_key: Option<String>,
connect_timeout: Option<std::time::Duration>,
request_timeout: Option<std::time::Duration>,
read_timeout: Option<std::time::Duration>,
max_retries: Option<u32>,
}
impl ClientBuilder {
pub fn new() -> Self {
Self {
base_url: None,
api_key: None,
connect_timeout: None,
request_timeout: None,
read_timeout: None,
max_retries: None,
}
}
pub fn base_url(mut self, url: &str) -> Self {
self.base_url = Some(url.to_string());
self
}
pub fn api_key(mut self, key: &str) -> Self {
self.api_key = Some(key.to_string());
self
}
pub fn connect_timeout(mut self, duration: std::time::Duration) -> Self {
self.connect_timeout = Some(duration);
self
}
pub fn timeout(mut self, duration: std::time::Duration) -> Self {
self.request_timeout = Some(duration);
self
}
pub fn read_timeout(mut self, duration: std::time::Duration) -> Self {
self.read_timeout = Some(duration);
self
}
pub fn max_retries(mut self, n: u32) -> Self {
self.max_retries = Some(n);
self
}
pub fn build(self) -> Result<Client, ClientError> {
let base_url = self.base_url.ok_or_else(|| ClientError::ApiError {
status: 0,
message: "ClientBuilder: base_url is required".to_string(),
body: String::new(),
})?;
if let Some(ref key) = self.api_key {
if key.is_empty() {
return Err(ClientError::ApiError {
status: 0,
message: "ClientBuilder: api_key is empty".to_string(),
body: String::new(),
});
}
}
let http_config = HttpClientConfig {
connect_timeout: self.connect_timeout,
request_timeout: self.request_timeout,
read_timeout: self.read_timeout,
..HttpClientConfig::new()
};
let max_retries = self.max_retries.unwrap_or(retry::MAX_RETRIES);
let config = ClientConfig {
chat_completions_url: base_url,
auth: Auth::Bearer(self.api_key.unwrap_or_default()),
http: http_config,
translation: TranslationConfig::default(),
};
let mut client = Client::new(config);
client.max_retries = max_retries;
Ok(client)
}
}
impl Default for ClientBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct Client {
http: reqwest::Client,
config: ClientConfig,
max_retries: u32,
}
impl Client {
pub fn new(config: ClientConfig) -> Self {
let http = build_http_client(&config.http);
Self {
http,
config,
max_retries: retry::MAX_RETRIES,
}
}
pub fn builder() -> ClientBuilder {
ClientBuilder::new()
}
pub fn with_http_client(http: reqwest::Client, config: ClientConfig) -> Self {
Self {
http,
config,
max_retries: retry::MAX_RETRIES,
}
}
fn auth(&self) -> retry::RequestAuth<'_> {
match &self.config.auth {
Auth::Bearer(token) => retry::RequestAuth::Bearer(token),
Auth::Header { name, value } => retry::RequestAuth::Header { name, value },
}
}
pub async fn messages(
&self,
req: &MessageCreateRequest,
) -> Result<MessageResponse, ClientError> {
let openai_req = translate_request(req, &self.config.translation)?;
let (resp, _status, _rate_limits) = self.chat_completion(&openai_req).await?;
let anthropic_resp = translate_response(&resp, &req.model);
Ok(anthropic_resp)
}
pub async fn messages_stream(
&self,
req: &MessageCreateRequest,
) -> Result<
(
impl Stream<Item = Result<StreamEvent, ClientError>>,
RateLimitHeaders,
),
ClientError,
> {
let mut openai_req = translate_request(req, &self.config.translation)?;
openai_req.stream = Some(true);
let (response, rate_limits) = self.chat_completion_stream_raw(&openai_req).await?;
let model = req.model.clone();
let stream = SseTranslatingStream::new(response, model);
Ok((stream, rate_limits))
}
pub async fn chat_completion(
&self,
req: &ChatCompletionRequest,
) -> Result<(ChatCompletionResponse, u16, RateLimitHeaders), ClientError> {
let response: reqwest::Response = retry::send_with_retry::<InternalError>(
&self.http,
&self.config.chat_completions_url,
&self.auth(),
req,
"backend",
self.max_retries,
)
.await
.map_err(ClientError::from)?;
let status = response.status().as_u16();
let rate_limits = RateLimitHeaders::from_openai_headers(response.headers());
let body = response
.json::<ChatCompletionResponse>()
.await
.map_err(|e| ClientError::Deserialization(e.to_string()))?;
Ok((body, status, rate_limits))
}
async fn chat_completion_stream_raw(
&self,
req: &ChatCompletionRequest,
) -> Result<(reqwest::Response, RateLimitHeaders), ClientError> {
let response: reqwest::Response = retry::send_with_retry::<InternalError>(
&self.http,
&self.config.chat_completions_url,
&self.auth(),
req,
"backend",
self.max_retries,
)
.await
.map_err(ClientError::from)?;
let rate_limits = RateLimitHeaders::from_openai_headers(response.headers());
Ok((response, rate_limits))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn client_config_builder_defaults() {
let config = ClientConfig::builder()
.backend_url("https://api.openai.com/v1/chat/completions")
.auth(Auth::Bearer("sk-test".into()))
.build();
assert_eq!(
config.chat_completions_url,
"https://api.openai.com/v1/chat/completions"
);
assert!(matches!(config.auth, Auth::Bearer(ref s) if s == "sk-test"));
}
#[test]
fn client_config_builder_with_translation() {
let translation = TranslationConfig::builder()
.model_map("haiku", "gpt-4o-mini")
.model_map("sonnet", "gpt-4o")
.build();
let config = ClientConfig::builder()
.backend_url("https://api.openai.com/v1/chat/completions")
.auth(Auth::Bearer("sk-test".into()))
.translation(translation)
.build();
assert!(config.translation.map_model("claude-3-haiku").is_ok());
}
#[test]
fn client_creates_without_panic() {
let config = ClientConfig::builder()
.backend_url("https://api.openai.com/v1/chat/completions")
.auth(Auth::Bearer("sk-test".into()))
.http(HttpClientConfig {
ssrf_protection: false,
..Default::default()
})
.build();
let _client = Client::new(config);
}
#[test]
fn client_builder_success() {
let client = ClientBuilder::new()
.base_url("https://api.openai.com/v1/chat/completions")
.api_key("sk-test")
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(120))
.read_timeout(std::time::Duration::from_secs(30))
.max_retries(2)
.build();
assert!(client.is_ok());
}
#[test]
fn client_builder_empty_api_key_rejected() {
let result = ClientBuilder::new()
.base_url("https://api.openai.com/v1/chat/completions")
.api_key("")
.build();
assert!(result.is_err());
}
#[test]
fn client_builder_max_retries_stored() {
let client = ClientBuilder::new()
.base_url("https://api.openai.com/v1/chat/completions")
.max_retries(7)
.build()
.unwrap();
assert_eq!(client.max_retries, 7);
}
#[test]
fn client_builder_missing_url() {
let result = ClientBuilder::new().api_key("sk-test").build();
assert!(result.is_err());
}
#[test]
fn client_builder_default_api_key() {
let client = ClientBuilder::new().base_url("https://example.com").build();
assert!(client.is_ok());
}
#[test]
fn client_builder_via_client() {
let client = Client::builder().base_url("https://example.com").build();
assert!(client.is_ok());
}
#[test]
fn client_builder_default_trait() {
let builder = ClientBuilder::default();
assert!(builder.base_url.is_none());
}
}