async_anthropic/
client.rs1use backoff::{Error as BackoffError, ExponentialBackoff, ExponentialBackoffBuilder};
2use derive_builder::Builder;
3use reqwest::StatusCode;
4use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
5use secrecy::ExposeSecret;
6use serde::{de::DeserializeOwned, Serialize};
7use std::{pin::Pin, time::Duration};
8use tokio_stream::{Stream, StreamExt as _};
9
10use crate::{
11 errors::{map_deserialization_error, AnthropicError, StreamError},
12 messages::Messages,
13 models::Models,
14};
15
16const BASE_URL: &str = "https://api.anthropic.com";
17
18#[derive(Clone, Debug, Builder)]
43#[builder(setter(into, strip_option))]
44pub struct Client {
45 #[builder(default)]
46 http_client: reqwest::Client,
47 #[builder(default)]
48 base_url: String,
49 #[builder(default = default_api_key())]
50 api_key: secrecy::SecretString,
51 #[builder(default)]
52 version: String,
53 #[builder(default)]
54 beta: Option<String>,
55 #[builder(default)]
56 backoff: ExponentialBackoff,
57}
58
59impl Default for Client {
60 fn default() -> Self {
61 let backoff = ExponentialBackoffBuilder::default()
63 .with_initial_interval(Duration::from_secs(15))
64 .with_multiplier(2.0)
65 .with_randomization_factor(0.05)
66 .with_max_elapsed_time(Some(Duration::from_secs(120)))
67 .build();
68
69 Self {
70 http_client: reqwest::Client::new(),
71 api_key: default_api_key(), version: "2023-06-01".to_string(),
73 beta: None,
74 base_url: BASE_URL.to_string(),
75 backoff,
76 }
77 }
78}
79
80fn default_api_key() -> secrecy::SecretString {
81 if cfg!(test) {
82 return "test".into();
83 }
84 std::env::var("ANTHROPIC_API_KEY")
85 .unwrap_or_else(|_| {
86 tracing::warn!("Default Anthropic client initialized without api key");
87 String::new()
88 })
89 .into()
90}
91
92impl Client {
93 pub fn from_api_key(api_key: impl Into<secrecy::SecretString>) -> Self {
95 Self {
96 api_key: api_key.into(),
97 ..Default::default()
98 }
99 }
100
101 pub fn builder() -> ClientBuilder {
103 ClientBuilder::default()
104 }
105
106 pub fn with_backoff(mut self, backoff: ExponentialBackoff) -> Self {
108 self.backoff = backoff;
109 self
110 }
111
112 pub fn messages(&self) -> Messages {
114 Messages::new(self)
115 }
116
117 pub fn models(&self) -> Models {
118 Models::new(self)
119 }
120
121 fn headers(&self) -> reqwest::header::HeaderMap {
122 let mut headers = reqwest::header::HeaderMap::new();
123 headers.insert("x-api-key", self.api_key.expose_secret().parse().unwrap());
124 headers.insert("anthropic-version", self.version.parse().unwrap());
125 if let Some(beta_value) = &self.beta {
126 headers.insert("anthropic-beta", beta_value.parse().unwrap());
127 }
128 headers
129 }
130
131 fn format_url(&self, path: &str) -> String {
132 format!(
133 "{}/{}",
134 &self.base_url.trim_end_matches('/'),
135 &path.trim_start_matches('/')
136 )
137 }
138
139 pub async fn get<O>(&self, path: &str) -> Result<O, AnthropicError>
140 where
141 O: DeserializeOwned,
142 {
143 backoff::future::retry(self.backoff.clone(), || async {
144 let response = self
145 .http_client
146 .get(self.format_url(path))
147 .headers(self.headers())
148 .send()
149 .await
150 .map_err(AnthropicError::NetworkError)
151 .map_err(backoff::Error::Permanent)?;
152
153 let status = response.status();
154
155 match status {
156 StatusCode::OK => {
157 let response = response
158 .json::<O>()
159 .await
160 .map_err(AnthropicError::NetworkError)
161 .map_err(backoff::Error::Permanent)?;
162 Ok(response)
163 }
164 StatusCode::BAD_REQUEST => {
165 let text = response
166 .text()
167 .await
168 .map_err(AnthropicError::NetworkError)
169 .map_err(backoff::Error::Permanent)?;
170 Err(BackoffError::Permanent(AnthropicError::BadRequest(text)))
171 }
172 StatusCode::UNAUTHORIZED => {
173 Err(BackoffError::Permanent(AnthropicError::Unauthorized))
174 }
175 _ => {
176 let text = response
177 .text()
178 .await
179 .map_err(AnthropicError::NetworkError)
180 .map_err(backoff::Error::Permanent)?;
181 Err(BackoffError::Permanent(AnthropicError::Unknown(text)))
182 }
183 }
184 })
185 .await
186 }
187
188 pub async fn post<I, O>(&self, path: &str, request: I) -> Result<O, AnthropicError>
192 where
193 I: Serialize,
194 O: DeserializeOwned,
195 {
196 backoff::future::retry(self.backoff.clone(), || async {
197 let mut request = self
198 .http_client
199 .post(self.format_url(path))
200 .headers(self.headers())
201 .json(&request);
202
203 if let Some(beta_value) = &self.beta {
204 request = request.header("anthropic-beta", beta_value);
205 }
206
207 let response = request
208 .send()
209 .await
210 .map_err(AnthropicError::NetworkError)
211 .map_err(backoff::Error::Permanent)?;
212 let status = response.status();
213
214 let overloaded_status = StatusCode::from_u16(529).expect("529 is a valid status code");
216
217 match status {
218 StatusCode::OK => {
219 let response = response
220 .json::<O>()
221 .await
222 .map_err(AnthropicError::NetworkError)
223 .map_err(backoff::Error::Permanent)?;
224 Ok(response)
225 }
226 StatusCode::BAD_REQUEST => {
227 let text = response
228 .text()
229 .await
230 .map_err(AnthropicError::NetworkError)
231 .map_err(backoff::Error::Permanent)?;
232 Err(BackoffError::Permanent(AnthropicError::BadRequest(text)))
233 }
234 StatusCode::UNAUTHORIZED => {
235 Err(BackoffError::Permanent(AnthropicError::Unauthorized))
236 }
237
238 _ if status == StatusCode::TOO_MANY_REQUESTS || status == overloaded_status => {
239 let text = response
240 .text()
241 .await
242 .map_err(AnthropicError::NetworkError)
243 .map_err(backoff::Error::Permanent)?;
244
245 tracing::warn!("Rate limited: {}", text);
247 Err(backoff::Error::Transient {
248 err: AnthropicError::ApiError(text),
249 retry_after: None,
250 })
251 }
252 _ => {
253 let text = response
254 .text()
255 .await
256 .map_err(AnthropicError::NetworkError)
257 .map_err(backoff::Error::Permanent)?;
258 Err(BackoffError::Permanent(AnthropicError::Unknown(text)))
259 }
260 }
261 })
262 .await
263 }
264
265 pub(crate) async fn post_stream<I, O, const N: usize>(
266 &self,
267 path: &str,
268 request: I,
269 event_types: [&'static str; N],
270 ) -> Pin<Box<dyn Stream<Item = Result<O, AnthropicError>> + Send>>
271 where
272 I: Serialize,
273 O: DeserializeOwned + Send + 'static,
274 {
275 let event_source = self
276 .http_client
277 .post(self.format_url(path))
278 .headers(self.headers())
279 .json(&request)
280 .eventsource()
281 .unwrap();
282
283 stream(event_source, event_types).await
284 }
285}
286
287async fn stream<O, const N: usize>(
288 mut event_source: EventSource,
289 event_types: [&'static str; N],
290) -> Pin<Box<dyn Stream<Item = Result<O, AnthropicError>> + Send>>
291where
292 O: DeserializeOwned + Send + 'static,
293{
294 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
295
296 tokio::spawn(async move {
297 while let Some(ev) = event_source.next().await {
298 tracing::trace!("Streaming event: {ev:?}");
299 match ev {
300 Ok(event) => match event {
301 Event::Open => continue,
302 Event::Message(message) => {
303 let event = message.event.as_str();
304 if event == "ping" {
305 continue;
306 }
307
308 let response = if event == "error" {
309 match serde_json::from_str::<StreamError>(&message.data) {
310 Ok(e) => Err(AnthropicError::StreamError(e)),
311 Err(e) => {
312 Err(map_deserialization_error(e, message.data.as_bytes()))
313 }
314 }
315 } else if event_types.contains(&event) {
316 match serde_json::from_str::<O>(&message.data) {
317 Ok(output) => Ok(output),
318 Err(e) => {
319 Err(map_deserialization_error(e, message.data.as_bytes()))
320 }
321 }
322 } else {
323 Err(AnthropicError::StreamError(StreamError {
324 error_type: "unknown_event_type".to_string(),
325 message: format!("Unknown event type: {event}"),
326 }))
327 };
328 let cancel = response.is_err();
329 if tx.send(response).is_err() || cancel {
330 break;
332 }
333 }
334 },
335 Err(e) => {
336 if let reqwest_eventsource::Error::StreamEnded = e {
337 break;
338 }
339 if tx
340 .send(Err(AnthropicError::StreamError(StreamError {
341 error_type: "sse_error".to_string(),
342 message: e.to_string(),
343 })))
344 .is_err()
345 {
346 break;
348 }
349 }
350 }
351 }
352
353 event_source.close();
354 });
355
356 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
357}