async_anthropic/
client.rs

1use backoff::{Error as BackoffError, ExponentialBackoff, ExponentialBackoffBuilder};
2use derive_builder::Builder;
3use reqwest::StatusCode;
4use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
5use secrecy::ExposeSecret;
6use serde::{de::DeserializeOwned, Serialize};
7use std::{pin::Pin, time::Duration};
8use tokio_stream::{Stream, StreamExt as _};
9
10use crate::{
11    errors::{map_deserialization_error, AnthropicError, StreamError},
12    messages::Messages,
13    models::Models,
14};
15
16const BASE_URL: &str = "https://api.anthropic.com";
17
18/// Main entry point for the Anthropic API
19///
20/// By default will use the `ANTHROPIC_API_KEY` environment variable
21///
22/// # Example
23///
24/// ```no_run
25/// # use async_anthropic::types::*;
26/// # async fn run() {
27/// let client = async_anthropic::Client::default();
28///
29/// let request = CreateMessagesRequestBuilder::default()
30///    .model("claude-3.5-sonnet")
31///    .messages(vec![MessageBuilder::default()
32///        .role(MessageRole::User)
33///        .content("Hello world!")
34///        .build()
35///        .unwrap()])
36///    .build()
37///    .unwrap();
38///
39/// client.messages().create(request).await.unwrap();
40/// # }
41/// ```
42#[derive(Clone, Debug, Builder)]
43#[builder(setter(into, strip_option))]
44pub struct Client {
45    #[builder(default)]
46    http_client: reqwest::Client,
47    #[builder(default)]
48    base_url: String,
49    #[builder(default = default_api_key())]
50    api_key: secrecy::SecretString,
51    #[builder(default)]
52    version: String,
53    #[builder(default)]
54    beta: Option<String>,
55    #[builder(default)]
56    backoff: ExponentialBackoff,
57}
58
59impl Default for Client {
60    fn default() -> Self {
61        // Load backoff settings from configuration
62        let backoff = ExponentialBackoffBuilder::default()
63            .with_initial_interval(Duration::from_secs(15))
64            .with_multiplier(2.0)
65            .with_randomization_factor(0.05)
66            .with_max_elapsed_time(Some(Duration::from_secs(120)))
67            .build();
68
69        Self {
70            http_client: reqwest::Client::new(),
71            api_key: default_api_key(), // Default env?
72            version: "2023-06-01".to_string(),
73            beta: None,
74            base_url: BASE_URL.to_string(),
75            backoff,
76        }
77    }
78}
79
80fn default_api_key() -> secrecy::SecretString {
81    if cfg!(test) {
82        return "test".into();
83    }
84    std::env::var("ANTHROPIC_API_KEY")
85        .unwrap_or_else(|_| {
86            tracing::warn!("Default Anthropic client initialized without api key");
87            String::new()
88        })
89        .into()
90}
91
92impl Client {
93    /// Build a new client from an API key
94    pub fn from_api_key(api_key: impl Into<secrecy::SecretString>) -> Self {
95        Self {
96            api_key: api_key.into(),
97            ..Default::default()
98        }
99    }
100
101    /// Create a new client builder
102    pub fn builder() -> ClientBuilder {
103        ClientBuilder::default()
104    }
105
106    /// Set a custom backoff strategy
107    pub fn with_backoff(mut self, backoff: ExponentialBackoff) -> Self {
108        self.backoff = backoff;
109        self
110    }
111
112    /// Call the messages api
113    pub fn messages(&self) -> Messages {
114        Messages::new(self)
115    }
116
117    pub fn models(&self) -> Models {
118        Models::new(self)
119    }
120
121    fn headers(&self) -> reqwest::header::HeaderMap {
122        let mut headers = reqwest::header::HeaderMap::new();
123        headers.insert("x-api-key", self.api_key.expose_secret().parse().unwrap());
124        headers.insert("anthropic-version", self.version.parse().unwrap());
125        if let Some(beta_value) = &self.beta {
126            headers.insert("anthropic-beta", beta_value.parse().unwrap());
127        }
128        headers
129    }
130
131    fn format_url(&self, path: &str) -> String {
132        format!(
133            "{}/{}",
134            &self.base_url.trim_end_matches('/'),
135            &path.trim_start_matches('/')
136        )
137    }
138
139    pub async fn get<O>(&self, path: &str) -> Result<O, AnthropicError>
140    where
141        O: DeserializeOwned,
142    {
143        backoff::future::retry(self.backoff.clone(), || async {
144            let response = self
145                .http_client
146                .get(self.format_url(path))
147                .headers(self.headers())
148                .send()
149                .await
150                .map_err(AnthropicError::NetworkError)
151                .map_err(backoff::Error::Permanent)?;
152
153            let status = response.status();
154
155            match status {
156                StatusCode::OK => {
157                    let response = response
158                        .json::<O>()
159                        .await
160                        .map_err(AnthropicError::NetworkError)
161                        .map_err(backoff::Error::Permanent)?;
162                    Ok(response)
163                }
164                StatusCode::BAD_REQUEST => {
165                    let text = response
166                        .text()
167                        .await
168                        .map_err(AnthropicError::NetworkError)
169                        .map_err(backoff::Error::Permanent)?;
170                    Err(BackoffError::Permanent(AnthropicError::BadRequest(text)))
171                }
172                StatusCode::UNAUTHORIZED => {
173                    Err(BackoffError::Permanent(AnthropicError::Unauthorized))
174                }
175                _ => {
176                    let text = response
177                        .text()
178                        .await
179                        .map_err(AnthropicError::NetworkError)
180                        .map_err(backoff::Error::Permanent)?;
181                    Err(BackoffError::Permanent(AnthropicError::Unknown(text)))
182                }
183            }
184        })
185        .await
186    }
187
188    /// Make post request to the API
189    ///
190    /// This includes all headers and error handling
191    pub async fn post<I, O>(&self, path: &str, request: I) -> Result<O, AnthropicError>
192    where
193        I: Serialize,
194        O: DeserializeOwned,
195    {
196        backoff::future::retry(self.backoff.clone(), || async {
197            let mut request = self
198                .http_client
199                .post(self.format_url(path))
200                .headers(self.headers())
201                .json(&request);
202
203            if let Some(beta_value) = &self.beta {
204                request = request.header("anthropic-beta", beta_value);
205            }
206
207            let response = request
208                .send()
209                .await
210                .map_err(AnthropicError::NetworkError)
211                .map_err(backoff::Error::Permanent)?;
212            let status = response.status();
213
214            // 529 is the status code for overloaded requests
215            let overloaded_status = StatusCode::from_u16(529).expect("529 is a valid status code");
216
217            match status {
218                StatusCode::OK => {
219                    let response = response
220                        .json::<O>()
221                        .await
222                        .map_err(AnthropicError::NetworkError)
223                        .map_err(backoff::Error::Permanent)?;
224                    Ok(response)
225                }
226                StatusCode::BAD_REQUEST => {
227                    let text = response
228                        .text()
229                        .await
230                        .map_err(AnthropicError::NetworkError)
231                        .map_err(backoff::Error::Permanent)?;
232                    Err(BackoffError::Permanent(AnthropicError::BadRequest(text)))
233                }
234                StatusCode::UNAUTHORIZED => {
235                    Err(BackoffError::Permanent(AnthropicError::Unauthorized))
236                }
237
238                _ if status == StatusCode::TOO_MANY_REQUESTS || status == overloaded_status => {
239                    let text = response
240                        .text()
241                        .await
242                        .map_err(AnthropicError::NetworkError)
243                        .map_err(backoff::Error::Permanent)?;
244
245                    // Rate limited retry...
246                    tracing::warn!("Rate limited: {}", text);
247                    Err(backoff::Error::Transient {
248                        err: AnthropicError::ApiError(text),
249                        retry_after: None,
250                    })
251                }
252                _ => {
253                    let text = response
254                        .text()
255                        .await
256                        .map_err(AnthropicError::NetworkError)
257                        .map_err(backoff::Error::Permanent)?;
258                    Err(BackoffError::Permanent(AnthropicError::Unknown(text)))
259                }
260            }
261        })
262        .await
263    }
264
265    pub(crate) async fn post_stream<I, O, const N: usize>(
266        &self,
267        path: &str,
268        request: I,
269        event_types: [&'static str; N],
270    ) -> Pin<Box<dyn Stream<Item = Result<O, AnthropicError>> + Send>>
271    where
272        I: Serialize,
273        O: DeserializeOwned + Send + 'static,
274    {
275        let event_source = self
276            .http_client
277            .post(self.format_url(path))
278            .headers(self.headers())
279            .json(&request)
280            .eventsource()
281            .unwrap();
282
283        stream(event_source, event_types).await
284    }
285}
286
287async fn stream<O, const N: usize>(
288    mut event_source: EventSource,
289    event_types: [&'static str; N],
290) -> Pin<Box<dyn Stream<Item = Result<O, AnthropicError>> + Send>>
291where
292    O: DeserializeOwned + Send + 'static,
293{
294    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
295
296    tokio::spawn(async move {
297        while let Some(ev) = event_source.next().await {
298            tracing::trace!("Streaming event: {ev:?}");
299            match ev {
300                Ok(event) => match event {
301                    Event::Open => continue,
302                    Event::Message(message) => {
303                        let event = message.event.as_str();
304                        if event == "ping" {
305                            continue;
306                        }
307
308                        let response = if event == "error" {
309                            match serde_json::from_str::<StreamError>(&message.data) {
310                                Ok(e) => Err(AnthropicError::StreamError(e)),
311                                Err(e) => {
312                                    Err(map_deserialization_error(e, message.data.as_bytes()))
313                                }
314                            }
315                        } else if event_types.contains(&event) {
316                            match serde_json::from_str::<O>(&message.data) {
317                                Ok(output) => Ok(output),
318                                Err(e) => {
319                                    Err(map_deserialization_error(e, message.data.as_bytes()))
320                                }
321                            }
322                        } else {
323                            Err(AnthropicError::StreamError(StreamError {
324                                error_type: "unknown_event_type".to_string(),
325                                message: format!("Unknown event type: {event}"),
326                            }))
327                        };
328                        let cancel = response.is_err();
329                        if tx.send(response).is_err() || cancel {
330                            // rx dropped or other error
331                            break;
332                        }
333                    }
334                },
335                Err(e) => {
336                    if let reqwest_eventsource::Error::StreamEnded = e {
337                        break;
338                    }
339                    if tx
340                        .send(Err(AnthropicError::StreamError(StreamError {
341                            error_type: "sse_error".to_string(),
342                            message: e.to_string(),
343                        })))
344                        .is_err()
345                    {
346                        // rx dropped
347                        break;
348                    }
349                }
350            }
351        }
352
353        event_source.close();
354    });
355
356    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
357}