rig/providers/anthropic/
client.rs

1//! Anthropic client api implementation
2use bytes::Bytes;
3use http_client::{Method, Request};
4
5use super::completion::{ANTHROPIC_VERSION_LATEST, CompletionModel};
6use crate::{
7    client::{
8        ClientBuilderError, CompletionClient, ProviderClient, ProviderValue, VerifyClient,
9        VerifyError, impl_conversion_traits,
10    },
11    http_client::{self, HttpClientExt},
12};
13
14// ================================================================
15// Main Anthropic Client
16// ================================================================
17const ANTHROPIC_API_BASE_URL: &str = "https://api.anthropic.com";
18
19pub struct ClientBuilder<'a, T = reqwest::Client> {
20    api_key: &'a str,
21    base_url: &'a str,
22    anthropic_version: &'a str,
23    anthropic_betas: Option<Vec<&'a str>>,
24    http_client: T,
25}
26
27impl<'a> ClientBuilder<'a, reqwest::Client> {
28    pub fn new(api_key: &'a str) -> Self {
29        ClientBuilder {
30            api_key,
31            base_url: ANTHROPIC_API_BASE_URL,
32            anthropic_version: ANTHROPIC_VERSION_LATEST,
33            anthropic_betas: None,
34            http_client: Default::default(),
35        }
36    }
37}
38
39/// Create a new anthropic client using the builder
40///
41/// # Example
42/// ```
43/// use rig::providers::anthropic::{ClientBuilder, self};
44///
45/// // Initialize the Anthropic client
46/// let anthropic_client = ClientBuilder::new("your-claude-api-key")
47///    .anthropic_version(ANTHROPIC_VERSION_LATEST)
48///    .anthropic_beta("prompt-caching-2024-07-31")
49///    .build()
50/// ```
51impl<'a, T> ClientBuilder<'a, T>
52where
53    T: HttpClientExt,
54{
55    pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
56        Self {
57            api_key,
58            base_url: ANTHROPIC_API_BASE_URL,
59            anthropic_version: ANTHROPIC_VERSION_LATEST,
60            anthropic_betas: None,
61            http_client,
62        }
63    }
64
65    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
66        ClientBuilder {
67            api_key: self.api_key,
68            base_url: self.base_url,
69            anthropic_version: self.anthropic_version,
70            anthropic_betas: self.anthropic_betas,
71            http_client,
72        }
73    }
74
75    pub fn base_url(mut self, base_url: &'a str) -> Self {
76        self.base_url = base_url;
77        self
78    }
79
80    pub fn anthropic_version(mut self, anthropic_version: &'a str) -> Self {
81        self.anthropic_version = anthropic_version;
82        self
83    }
84
85    pub fn anthropic_beta(mut self, anthropic_beta: &'a str) -> Self {
86        if let Some(mut betas) = self.anthropic_betas {
87            betas.push(anthropic_beta);
88            self.anthropic_betas = Some(betas);
89        } else {
90            self.anthropic_betas = Some(vec![anthropic_beta]);
91        }
92        self
93    }
94
95    pub fn build(self) -> Result<Client<T>, ClientBuilderError> {
96        let mut default_headers = reqwest::header::HeaderMap::new();
97        default_headers.insert(
98            "anthropic-version",
99            self.anthropic_version
100                .parse()
101                .map_err(|_| ClientBuilderError::InvalidProperty("anthropic-version"))?,
102        );
103
104        if let Some(betas) = self.anthropic_betas {
105            default_headers.insert(
106                "anthropic-beta",
107                betas
108                    .join(",")
109                    .parse()
110                    .map_err(|_| ClientBuilderError::InvalidProperty("anthropic-beta"))?,
111            );
112        };
113
114        Ok(Client {
115            base_url: self.base_url.to_string(),
116            api_key: self.api_key.to_string(),
117            default_headers,
118            http_client: self.http_client,
119        })
120    }
121}
122
123#[derive(Clone)]
124pub struct Client<T = reqwest::Client> {
125    /// The base URL
126    base_url: String,
127    /// The API key
128    api_key: String,
129    /// The underlying HTTP client
130    http_client: T,
131    /// Default headers that will be automatically added to any given request with this client (API key, Anthropic Version and any betas that have been added)
132    default_headers: reqwest::header::HeaderMap,
133}
134
135impl<T> std::fmt::Debug for Client<T>
136where
137    T: HttpClientExt + std::fmt::Debug,
138{
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        f.debug_struct("Client")
141            .field("base_url", &self.base_url)
142            .field("http_client", &self.http_client)
143            .field("api_key", &"<REDACTED>")
144            .field("default_headers", &self.default_headers)
145            .finish()
146    }
147}
148
149impl<T> Client<T>
150where
151    T: HttpClientExt + Clone + Default,
152{
153    pub async fn send<U, V>(
154        &self,
155        req: http_client::Request<U>,
156    ) -> Result<http_client::Response<http_client::LazyBody<V>>, http_client::Error>
157    where
158        U: Into<Bytes> + Send,
159        V: From<Bytes> + Send + 'static,
160    {
161        self.http_client.send(req).await
162    }
163
164    pub async fn send_streaming<U>(
165        &self,
166        req: Request<U>,
167    ) -> Result<http_client::StreamingResponse, http_client::Error>
168    where
169        U: Into<Bytes>,
170    {
171        self.http_client.send_streaming(req).await
172    }
173
174    pub(crate) fn post(&self, path: &str) -> http_client::Builder {
175        let uri = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
176
177        let mut headers = self.default_headers.clone();
178
179        headers.insert(
180            "X-Api-Key",
181            http_client::HeaderValue::from_str(&self.api_key).unwrap(),
182        );
183
184        let mut req = http_client::Request::builder()
185            .method(Method::POST)
186            .uri(uri);
187
188        if let Some(hs) = req.headers_mut() {
189            *hs = headers;
190        }
191
192        req
193    }
194
195    pub(crate) fn get(&self, path: &str) -> http_client::Builder {
196        let uri = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
197
198        let mut headers = self.default_headers.clone();
199        headers.insert(
200            "X-Api-Key",
201            http_client::HeaderValue::from_str(&self.api_key).unwrap(),
202        );
203
204        let mut req = http_client::Request::builder().method(Method::GET).uri(uri);
205
206        if let Some(hs) = req.headers_mut() {
207            *hs = headers;
208        }
209
210        req
211    }
212}
213
214impl Client<reqwest::Client> {
215    /// Create a new Anthropic client. For more control, use the `builder` method.
216    ///
217    /// # Panics
218    /// - If the API key or version cannot be parsed as a Json value from a String.
219    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
220    pub fn new(api_key: &str) -> Self {
221        ClientBuilder::new(api_key)
222            .build()
223            .expect("Anthropic client should build")
224    }
225}
226
227impl ProviderClient for Client<reqwest::Client> {
228    /// Create a new Anthropic client from the `ANTHROPIC_API_KEY` environment variable.
229    /// Panics if the environment variable is not set.
230    fn from_env() -> Self {
231        let api_key = std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set");
232
233        Client::new(&api_key)
234    }
235
236    fn from_val(input: crate::client::ProviderValue) -> Self {
237        let ProviderValue::Simple(api_key) = input else {
238            panic!("Incorrect provider value type")
239        };
240
241        Client::new(&api_key)
242    }
243}
244
245impl CompletionClient for Client<reqwest::Client> {
246    type CompletionModel = CompletionModel<reqwest::Client>;
247
248    fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
249        CompletionModel::new(self.clone(), model)
250    }
251}
252
253impl VerifyClient for Client<reqwest::Client> {
254    #[cfg_attr(feature = "worker", worker::send)]
255    async fn verify(&self) -> Result<(), VerifyError> {
256        let req = self
257            .get("/v1/models")
258            .body(http_client::NoBody)
259            .map_err(http_client::Error::from)?;
260
261        let response = HttpClientExt::send(&self.http_client, req).await?;
262
263        match response.status() {
264            http::StatusCode::OK => Ok(()),
265            http::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
266                Err(VerifyError::InvalidAuthentication)
267            }
268            http::StatusCode::INTERNAL_SERVER_ERROR => {
269                let text = http_client::text(response).await?;
270                Err(VerifyError::ProviderError(text))
271            }
272            status if status.as_u16() == 529 => {
273                let text = http_client::text(response).await?;
274                Err(VerifyError::ProviderError(text))
275            }
276            _ => {
277                let status = response.status();
278
279                if status.is_success() {
280                    Ok(())
281                } else {
282                    let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
283                    Err(VerifyError::HttpError(http_client::Error::Instance(
284                        format!("Failed with '{status}': {text}").into(),
285                    )))
286                }
287            }
288        }
289    }
290}
291
292impl_conversion_traits!(
293    AsTranscription,
294    AsEmbeddings,
295    AsImageGeneration,
296    AsAudioGeneration
297    for Client<T>
298);