use std::pin::Pin;
use reqwest::header::{HeaderMap, ACCEPT, CONTENT_TYPE};
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio_stream::{Stream, StreamExt};
use crate::config::AnthropicConfig;
use crate::error::{map_deserialization_error, AnthropicError, WrappedError};
use crate::types::{
CompleteRequest, CompleteResponse, CompleteResponseStream, MessagesRequest, MessagesResponse,
MessagesResponseStream, StreamError,
};
use crate::{
API_VERSION, API_VERSION_HEADER_KEY, AUTHORIZATION_HEADER_KEY, CLIENT_ID, CLIENT_ID_HEADER_KEY, DEFAULT_API_BASE,
DEFAULT_MODEL,
};
#[derive(Builder, Debug)]
pub struct Client {
pub api_key: String,
#[builder(default = "DEFAULT_API_BASE.to_string()")]
pub api_base: String,
#[builder(default = "DEFAULT_MODEL.to_string()")]
pub default_model: String,
#[builder(setter(skip))]
pub http_client: reqwest::Client,
#[builder(default = "Default::default()")]
pub backoff: backoff::ExponentialBackoff,
}
impl Client {
pub async fn complete(&self, request: CompleteRequest) -> Result<CompleteResponse, AnthropicError> {
if request.stream {
return Err(AnthropicError::InvalidArgument("When stream is true, use complete_stream() instead".into()));
}
self.post("/v1/complete", request).await
}
pub async fn complete_stream(&self, request: CompleteRequest) -> Result<CompleteResponseStream, AnthropicError> {
if !request.stream {
return Err(AnthropicError::InvalidArgument("When stream is false, use complete() instead".into()));
}
Ok(self.post_stream("/v1/complete", request, ["completion"]).await)
}
pub async fn messages(&self, request: MessagesRequest) -> Result<MessagesResponse, AnthropicError> {
if request.stream {
return Err(AnthropicError::InvalidArgument("When stream is true, use complete_stream() instead".into()));
}
self.post("/v1/messages", request).await
}
pub async fn messages_stream(&self, request: MessagesRequest) -> Result<MessagesResponseStream, AnthropicError> {
if !request.stream {
return Err(AnthropicError::InvalidArgument("When stream is false, use complete() instead".into()));
}
Ok(self
.post_stream(
"/v1/messages",
request,
[
"message_start",
"message_delta",
"message_stop",
"content_block_start",
"content_block_delta",
"content_block_stop",
],
)
.await)
}
pub fn api_key(&self) -> &str {
self.api_key.as_str()
}
pub fn api_base(&self) -> &str {
self.api_base.as_str()
}
pub fn headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION_HEADER_KEY, self.api_key().parse().unwrap());
headers.insert(CLIENT_ID_HEADER_KEY, CLIENT_ID.as_str().parse().unwrap());
headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
headers.insert(ACCEPT, "application/json".parse().unwrap());
headers.insert(API_VERSION_HEADER_KEY, API_VERSION.parse().unwrap());
headers
}
pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, AnthropicError>
where
I: Serialize,
O: DeserializeOwned,
{
let request = self
.http_client
.post(format!("{}{path}", self.api_base()))
.bearer_auth(self.api_key())
.headers(self.headers())
.json(&request)
.build()?;
self.execute(request).await
}
pub(crate) async fn post_stream<I, O, const N: usize>(
&self,
path: &str,
request: I,
event_types: [&'static str; N],
) -> Pin<Box<dyn Stream<Item = Result<O, AnthropicError>> + Send>>
where
I: Serialize,
O: DeserializeOwned + Send + 'static,
{
let event_source = self
.http_client
.post(format!("{}{path}", self.api_base()))
.bearer_auth(self.api_key())
.headers(self.headers())
.json(&request)
.eventsource()
.unwrap();
stream(event_source, event_types).await
}
async fn process_response<O>(&self, response: reqwest::Response) -> Result<O, AnthropicError>
where
O: DeserializeOwned,
{
let status = response.status();
let bytes = response.bytes().await?;
if !status.is_success() {
let wrapped_error: WrappedError =
serde_json::from_slice(bytes.as_ref()).map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
return Err(AnthropicError::ApiError(wrapped_error.error));
}
let response: O =
serde_json::from_slice(bytes.as_ref()).map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
Ok(response)
}
async fn execute<O>(&self, request: reqwest::Request) -> Result<O, AnthropicError>
where
O: DeserializeOwned,
{
let client = self.http_client.clone();
match request.try_clone() {
Some(request) => {
backoff::future::retry(self.backoff.clone(), || async {
let response = client
.execute(request.try_clone().unwrap())
.await
.map_err(AnthropicError::Reqwest)
.map_err(backoff::Error::Permanent)?;
let status = response.status();
let bytes =
response.bytes().await.map_err(AnthropicError::Reqwest).map_err(backoff::Error::Permanent)?;
if !status.is_success() {
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
.map_err(backoff::Error::Permanent)?;
if status.as_u16() == 429
&& wrapped_error.error.r#type != "insufficient_quota"
{
return Err(backoff::Error::Transient {
err: AnthropicError::ApiError(wrapped_error.error),
retry_after: None,
});
} else {
return Err(backoff::Error::Permanent(AnthropicError::ApiError(wrapped_error.error)));
}
}
let response: O = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
.map_err(backoff::Error::Permanent)?;
Ok(response)
})
.await
}
None => {
let response = client.execute(request).await?;
self.process_response(response).await
}
}
}
}
async fn stream<O, const N: usize>(
mut event_source: EventSource,
event_types: [&'static str; N],
) -> Pin<Box<dyn Stream<Item = Result<O, AnthropicError>> + Send>>
where
O: DeserializeOwned + Send + 'static,
{
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
while let Some(ev) = event_source.next().await {
match ev {
Ok(event) => match event {
Event::Open => continue,
Event::Message(message) => {
let event = message.event.as_str();
if event == "ping" {
continue;
}
let response = if event == "error" {
match serde_json::from_str::<StreamError>(&message.data) {
Ok(e) => Err(AnthropicError::StreamError(e)),
Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
}
} else if event_types.contains(&event) {
match serde_json::from_str::<O>(&message.data) {
Ok(output) => Ok(output),
Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
}
} else {
Err(AnthropicError::StreamError(StreamError {
error_type: "unknown_event_type".to_string(),
message: format!("Unknown event type: {event}"),
}))
};
let cancel = response.is_err();
if tx.send(response).is_err() || cancel {
break;
}
}
},
Err(e) => {
if let reqwest_eventsource::Error::StreamEnded = e {
break;
}
if tx
.send(Err(AnthropicError::StreamError(StreamError {
error_type: "sse_error".to_string(),
message: e.to_string(),
})))
.is_err()
{
break;
}
}
}
}
event_source.close();
});
Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
}
impl TryFrom<AnthropicConfig> for Client {
type Error = AnthropicError;
fn try_from(value: AnthropicConfig) -> Result<Self, Self::Error> {
Ok(Self {
api_key: value.api_key,
api_base: value.api_base.unwrap_or_else(|| DEFAULT_API_BASE.to_string()),
default_model: value.default_model.unwrap_or_else(|| DEFAULT_MODEL.to_string()),
http_client: reqwest::Client::new(),
backoff: Default::default(),
})
}
}
impl Default for Client {
fn default() -> Self {
Self::try_from(AnthropicConfig::default()).unwrap()
}
}