rig/providers/anthropic/
client.rs

1//! Anthropic client api implementation
2use bytes::Bytes;
3use http_client::Method;
4
5use super::completion::{ANTHROPIC_VERSION_LATEST, CompletionModel};
6use crate::{
7    client::{
8        ClientBuilderError, CompletionClient, ProviderClient, ProviderValue, VerifyClient,
9        VerifyError, impl_conversion_traits,
10    },
11    http_client::{self, HttpClientExt},
12    wasm_compat::WasmCompatSend,
13};
14
15// ================================================================
16// Main Anthropic Client
17// ================================================================
18const ANTHROPIC_API_BASE_URL: &str = "https://api.anthropic.com";
19
20pub struct ClientBuilder<'a, T = reqwest::Client> {
21    api_key: &'a str,
22    base_url: &'a str,
23    anthropic_version: &'a str,
24    anthropic_betas: Option<Vec<&'a str>>,
25    http_client: T,
26}
27
28impl<'a, T> ClientBuilder<'a, T>
29where
30    T: Default,
31{
32    pub fn new(api_key: &'a str) -> Self {
33        ClientBuilder {
34            api_key,
35            base_url: ANTHROPIC_API_BASE_URL,
36            anthropic_version: ANTHROPIC_VERSION_LATEST,
37            anthropic_betas: None,
38            http_client: Default::default(),
39        }
40    }
41}
42
43/// Create a new anthropic client using the builder
44///
45/// # Example
46/// ```
47/// use rig::providers::anthropic::{ClientBuilder, self};
48///
49/// // Initialize the Anthropic client
50/// let anthropic_client = ClientBuilder::new("your-claude-api-key")
51///    .anthropic_version(ANTHROPIC_VERSION_LATEST)
52///    .anthropic_beta("prompt-caching-2024-07-31")
53///    .build()
54/// ```
55impl<'a, T> ClientBuilder<'a, T>
56where
57    T: HttpClientExt,
58{
59    pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
60        Self {
61            api_key,
62            base_url: ANTHROPIC_API_BASE_URL,
63            anthropic_version: ANTHROPIC_VERSION_LATEST,
64            anthropic_betas: None,
65            http_client,
66        }
67    }
68
69    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
70        ClientBuilder {
71            api_key: self.api_key,
72            base_url: self.base_url,
73            anthropic_version: self.anthropic_version,
74            anthropic_betas: self.anthropic_betas,
75            http_client,
76        }
77    }
78
79    pub fn base_url(mut self, base_url: &'a str) -> Self {
80        self.base_url = base_url;
81        self
82    }
83
84    pub fn anthropic_version(mut self, anthropic_version: &'a str) -> Self {
85        self.anthropic_version = anthropic_version;
86        self
87    }
88
89    pub fn anthropic_beta(mut self, anthropic_beta: &'a str) -> Self {
90        if let Some(mut betas) = self.anthropic_betas {
91            betas.push(anthropic_beta);
92            self.anthropic_betas = Some(betas);
93        } else {
94            self.anthropic_betas = Some(vec![anthropic_beta]);
95        }
96        self
97    }
98
99    pub fn build(self) -> Result<Client<T>, ClientBuilderError> {
100        let mut default_headers = reqwest::header::HeaderMap::new();
101        default_headers.insert(
102            "anthropic-version",
103            self.anthropic_version
104                .parse()
105                .map_err(|_| ClientBuilderError::InvalidProperty("anthropic-version"))?,
106        );
107
108        if let Some(betas) = self.anthropic_betas {
109            default_headers.insert(
110                "anthropic-beta",
111                betas
112                    .join(",")
113                    .parse()
114                    .map_err(|_| ClientBuilderError::InvalidProperty("anthropic-beta"))?,
115            );
116        };
117
118        Ok(Client {
119            base_url: self.base_url.to_string(),
120            api_key: self.api_key.to_string(),
121            default_headers,
122            http_client: self.http_client,
123        })
124    }
125}
126
127#[derive(Clone)]
128pub struct Client<T = reqwest::Client> {
129    /// The base URL
130    base_url: String,
131    /// The API key
132    api_key: String,
133    /// The underlying HTTP client
134    pub http_client: T,
135    /// Default headers that will be automatically added to any given request with this client (API key, Anthropic Version and any betas that have been added)
136    default_headers: reqwest::header::HeaderMap,
137}
138
139impl<T> std::fmt::Debug for Client<T>
140where
141    T: HttpClientExt + std::fmt::Debug,
142{
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        f.debug_struct("Client")
145            .field("base_url", &self.base_url)
146            .field("http_client", &self.http_client)
147            .field("api_key", &"<REDACTED>")
148            .field("default_headers", &self.default_headers)
149            .finish()
150    }
151}
152
153impl<T> Client<T>
154where
155    T: HttpClientExt + Clone + Default,
156{
157    pub async fn send<U, V>(
158        &self,
159        req: http_client::Request<U>,
160    ) -> Result<http_client::Response<http_client::LazyBody<V>>, http_client::Error>
161    where
162        U: Into<Bytes> + Send,
163        V: From<Bytes> + Send + 'static,
164    {
165        self.http_client.send(req).await
166    }
167
168    pub(crate) fn post(&self, path: &str) -> http_client::Builder {
169        let uri = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
170
171        let mut headers = self.default_headers.clone();
172
173        headers.insert(
174            "X-Api-Key",
175            http_client::HeaderValue::from_str(&self.api_key).unwrap(),
176        );
177
178        let mut req = http_client::Request::builder()
179            .method(Method::POST)
180            .uri(uri);
181
182        if let Some(hs) = req.headers_mut() {
183            *hs = headers;
184        }
185
186        req
187    }
188
189    pub(crate) fn get(&self, path: &str) -> http_client::Builder {
190        let uri = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
191
192        let mut headers = self.default_headers.clone();
193        headers.insert(
194            "X-Api-Key",
195            http_client::HeaderValue::from_str(&self.api_key).unwrap(),
196        );
197
198        let mut req = http_client::Request::builder().method(Method::GET).uri(uri);
199
200        if let Some(hs) = req.headers_mut() {
201            *hs = headers;
202        }
203
204        req
205    }
206}
207
208impl Client<reqwest::Client> {
209    pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
210        ClientBuilder::new(api_key)
211    }
212
213    pub fn from_env() -> Self {
214        <Self as ProviderClient>::from_env()
215    }
216    /// Create a new Anthropic client. For more control, use the `builder` method.
217    ///
218    /// # Panics
219    /// - If the API key or version cannot be parsed as a Json value from a String.
220    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
221    pub fn new(api_key: &str) -> Self {
222        ClientBuilder::new(api_key)
223            .build()
224            .expect("Anthropic client should build")
225    }
226}
227
228impl<T> ProviderClient for Client<T>
229where
230    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
231{
232    /// Create a new Anthropic client from the `ANTHROPIC_API_KEY` environment variable.
233    /// Panics if the environment variable is not set.
234    fn from_env() -> Self {
235        let api_key = std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set");
236
237        ClientBuilder::<T>::new(&api_key).build().unwrap()
238    }
239
240    fn from_val(input: crate::client::ProviderValue) -> Self {
241        let ProviderValue::Simple(api_key) = input else {
242            panic!("Incorrect provider value type")
243        };
244
245        ClientBuilder::<T>::new(&api_key).build().unwrap()
246    }
247}
248
249impl<T> CompletionClient for Client<T>
250where
251    T: HttpClientExt + Clone + std::fmt::Debug + Default + WasmCompatSend + 'static,
252{
253    type CompletionModel = CompletionModel<T>;
254
255    fn completion_model(&self, model: &str) -> CompletionModel<T> {
256        CompletionModel::new(self.clone(), model)
257    }
258}
259
260impl<T> VerifyClient for Client<T>
261where
262    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
263{
264    #[cfg_attr(feature = "worker", worker::send)]
265    async fn verify(&self) -> Result<(), VerifyError> {
266        let req = self
267            .get("/v1/models")
268            .body(http_client::NoBody)
269            .map_err(http_client::Error::from)?;
270
271        let response = HttpClientExt::send(&self.http_client, req).await?;
272
273        match response.status() {
274            http::StatusCode::OK => Ok(()),
275            http::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
276                Err(VerifyError::InvalidAuthentication)
277            }
278            http::StatusCode::INTERNAL_SERVER_ERROR => {
279                let text = http_client::text(response).await?;
280                Err(VerifyError::ProviderError(text))
281            }
282            status if status.as_u16() == 529 => {
283                let text = http_client::text(response).await?;
284                Err(VerifyError::ProviderError(text))
285            }
286            _ => {
287                let status = response.status();
288
289                if status.is_success() {
290                    Ok(())
291                } else {
292                    let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
293                    Err(VerifyError::HttpError(http_client::Error::Instance(
294                        format!("Failed with '{status}': {text}").into(),
295                    )))
296                }
297            }
298        }
299    }
300}
301
302impl_conversion_traits!(
303    AsTranscription,
304    AsEmbeddings,
305    AsImageGeneration,
306    AsAudioGeneration
307    for Client<T>
308);