Skip to main content

anthropic_rust_sdk/
client.rs

1//! Anthropic API 客户端,对齐上游 `src/client.ts`。
2
3use crate::core::error::{ApiError, ConnectionError, ConnectionTimeoutError, Error};
4use crate::internal::backoff::{default_retry_timeout_ms, retry_after_ms};
5use crate::internal::headers::{default_headers, merge_headers};
6use crate::resources::beta::Beta;
7use crate::resources::completions::Completions;
8use crate::resources::messages::Messages;
9use crate::resources::models::Models;
10use http::HeaderMap;
11use reqwest::{Client as HttpClient, Method, Response};
12use serde::de::DeserializeOwned;
13use serde::Serialize;
14use std::collections::HashMap;
15use std::time::Duration;
16
17/// 遗留 Text Completions 提示常量。
18pub const HUMAN_PROMPT: &str = "\n\nHuman:";
19pub const AI_PROMPT: &str = "\n\nAssistant:";
20
21const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
22const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
23const DEFAULT_MAX_RETRIES: u32 = 2;
24
25/// 客户端配置。
26#[derive(Debug, Clone)]
27pub struct ClientOptions {
28    pub api_key: Option<String>,
29    pub auth_token: Option<String>,
30    pub base_url: Option<String>,
31    pub timeout: Option<Duration>,
32    pub max_retries: Option<u32>,
33    pub default_headers: HashMap<String, String>,
34    pub default_query: HashMap<String, String>,
35}
36
37impl Default for ClientOptions {
38    fn default() -> Self {
39        Self {
40            api_key: std::env::var("ANTHROPIC_API_KEY").ok(),
41            auth_token: std::env::var("ANTHROPIC_AUTH_TOKEN").ok(),
42            base_url: std::env::var("ANTHROPIC_BASE_URL").ok(),
43            timeout: None,
44            max_retries: None,
45            default_headers: HashMap::new(),
46            default_query: HashMap::new(),
47        }
48    }
49}
50
51/// Anthropic API 客户端。
52#[derive(Clone)]
53pub struct Anthropic {
54    http: HttpClient,
55    api_key: String,
56    auth_token: Option<String>,
57    base_url: String,
58    #[allow(dead_code)]
59    timeout: Duration,
60    max_retries: u32,
61    default_headers: HashMap<String, String>,
62    #[allow(dead_code)]
63    default_query: HashMap<String, String>,
64    #[allow(dead_code)]
65    middleware: Vec<std::sync::Arc<dyn crate::core::middleware::Middleware>>,
66}
67
68impl Anthropic {
69    pub fn new() -> Result<Self, Error> {
70        Self::with_options(ClientOptions::default())
71    }
72
73    pub fn with_options(options: ClientOptions) -> Result<Self, Error> {
74        let api_key = options
75            .api_key
76            .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
77            .ok_or_else(|| {
78                Error::Anthropic(crate::core::error::AnthropicError(
79                    "Missing API key: set ANTHROPIC_API_KEY or pass api_key in ClientOptions".into(),
80                ))
81            })?;
82
83        let base_url = options
84            .base_url
85            .or_else(|| std::env::var("ANTHROPIC_BASE_URL").ok())
86            .unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
87
88        let timeout = options.timeout.unwrap_or(DEFAULT_TIMEOUT);
89        let max_retries = options.max_retries.unwrap_or(DEFAULT_MAX_RETRIES);
90
91        let http = HttpClient::builder()
92            .timeout(timeout)
93            .build()
94            .map_err(|e| {
95                Error::Connection(ConnectionError {
96                    message: e.to_string(),
97                    source: Some(Box::new(e)),
98                })
99            })?;
100
101        Ok(Self {
102            http,
103            api_key,
104            auth_token: options.auth_token,
105            base_url: base_url.trim_end_matches('/').to_string(),
106            timeout,
107            max_retries,
108            default_headers: options.default_headers,
109            default_query: options.default_query,
110            middleware: Vec::new(),
111        })
112    }
113
114    pub fn with_api_key(api_key: impl Into<String>) -> Result<Self, Error> {
115        Self::with_options(ClientOptions {
116            api_key: Some(api_key.into()),
117            ..Default::default()
118        })
119    }
120
121    pub fn base_url(&self) -> &str {
122        &self.base_url
123    }
124
125    pub fn max_retries(&self) -> u32 {
126        self.max_retries
127    }
128
129    pub fn messages(&self) -> Messages<'_> {
130        Messages::new(self)
131    }
132
133    pub fn models(&self) -> Models<'_> {
134        Models::new(self)
135    }
136
137    pub fn completions(&self) -> Completions<'_> {
138        Completions::new(self)
139    }
140
141    pub fn beta(&self) -> Beta<'_> {
142        Beta::new(self)
143    }
144
145    pub(crate) async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T, Error> {
146        self.request(Method::GET, path, None::<&()>, false).await
147    }
148
149    pub(crate) async fn get_with_query<T: DeserializeOwned>(
150        &self,
151        path: &str,
152        query: Option<&[(&str, &str)]>,
153    ) -> Result<T, Error> {
154        let mut url = self.build_url(path);
155        if let Some(q) = query {
156            let qs: Vec<String> = q
157                .iter()
158                .map(|(k, v)| format!("{}={}", k, urlencoding_encode(v)))
159                .collect();
160            if !qs.is_empty() {
161                url.push('?');
162                url.push_str(&qs.join("&"));
163            }
164        }
165        self.request_url(Method::GET, &url, None::<&()>, false)
166            .await
167    }
168
169    pub(crate) async fn post<T, B>(&self, path: &str, body: &B) -> Result<T, Error>
170    where
171        T: DeserializeOwned,
172        B: Serialize + ?Sized,
173    {
174        self.request(Method::POST, path, Some(body), false).await
175    }
176
177    pub(crate) async fn post_streaming<B>(
178        &self,
179        path: &str,
180        body: &B,
181    ) -> Result<Response, Error>
182    where
183        B: Serialize + ?Sized,
184    {
185        self.request_raw(Method::POST, path, Some(body), true)
186            .await
187    }
188
189    pub(crate) async fn post_empty<T>(&self, path: &str) -> Result<T, Error>
190    where
191        T: DeserializeOwned,
192    {
193        self.request(Method::POST, path, None::<&()>, false).await
194    }
195
196    #[allow(dead_code)]
197    pub(crate) async fn delete<T>(&self, path: &str) -> Result<T, Error>
198    where
199        T: DeserializeOwned,
200    {
201        self.request(Method::DELETE, path, None::<&()>, false).await
202    }
203
204    #[allow(dead_code)]
205    pub(crate) async fn patch<T, B>(&self, path: &str, body: &B) -> Result<T, Error>
206    where
207        T: DeserializeOwned,
208        B: Serialize + ?Sized,
209    {
210        self.request(Method::PATCH, path, Some(body), false).await
211    }
212
213    pub(crate) async fn get_beta<T>(
214        &self,
215        path: &str,
216        beta_headers: &[String],
217        query: Option<&[(&str, &str)]>,
218    ) -> Result<T, Error>
219    where
220        T: DeserializeOwned,
221    {
222        self.request_beta(Method::GET, path, None::<&()>, beta_headers, query, false)
223            .await
224    }
225
226    pub(crate) async fn post_beta<T, B>(
227        &self,
228        path: &str,
229        body: &B,
230        beta_headers: &[String],
231    ) -> Result<T, Error>
232    where
233        T: DeserializeOwned,
234        B: Serialize + ?Sized,
235    {
236        self.request_beta(Method::POST, path, Some(body), beta_headers, None, false)
237            .await
238    }
239
240    pub(crate) async fn delete_beta<T>(
241        &self,
242        path: &str,
243        beta_headers: &[String],
244    ) -> Result<T, Error>
245    where
246        T: DeserializeOwned,
247    {
248        self.request_beta(
249            Method::DELETE,
250            path,
251            None::<&()>,
252            beta_headers,
253            None,
254            false,
255        )
256        .await
257    }
258
259    async fn request_beta<T, B>(
260        &self,
261        method: Method,
262        path: &str,
263        body: Option<&B>,
264        beta_headers: &[String],
265        query: Option<&[(&str, &str)]>,
266        stream: bool,
267    ) -> Result<T, Error>
268    where
269        T: DeserializeOwned,
270        B: Serialize + ?Sized,
271    {
272        let mut extra_headers = self.default_headers.clone();
273        if !beta_headers.is_empty() {
274            extra_headers.insert("anthropic-beta".to_string(), beta_headers.join(","));
275        }
276        let url = self.build_url(path);
277        let mut full_url = url;
278        if let Some(q) = query {
279            let qs: Vec<String> = q
280                .iter()
281                .map(|(k, v)| format!("{}={}", k, urlencoding_encode(v)))
282                .collect();
283            if !qs.is_empty() {
284                full_url.push('?');
285                full_url.push_str(&qs.join("&"));
286            }
287        }
288
289        let response = self
290            .make_request_with_retries_beta(method, &full_url, body, stream, self.max_retries, &extra_headers)
291            .await?;
292        let status = response.status().as_u16();
293        let headers = response.headers().clone();
294        let bytes = response.bytes().await.map_err(|e| {
295            Error::Connection(ConnectionError {
296                message: e.to_string(),
297                source: Some(Box::new(e)),
298            })
299        })?;
300
301        if !(200..300).contains(&status) {
302            let body_json = serde_json::from_slice(&bytes).unwrap_or(serde_json::Value::Null);
303            return Err(ApiError::generate(
304                Some(status),
305                body_json,
306                None,
307                header_map_from_reqwest(&headers),
308            ));
309        }
310
311        serde_json::from_slice(&bytes).map_err(|e| {
312            Error::Anthropic(crate::core::error::AnthropicError(format!(
313                "failed to parse response JSON: {e}"
314            )))
315        })
316    }
317
318    async fn make_request_with_retries_beta<B>(
319        &self,
320        method: Method,
321        url: &str,
322        body: Option<&B>,
323        stream: bool,
324        mut retries_remaining: u32,
325        extra_headers: &HashMap<String, String>,
326    ) -> Result<Response, Error>
327    where
328        B: Serialize + ?Sized,
329    {
330        loop {
331            let mut headers = self.build_headers(stream)?;
332            merge_headers(&mut headers, extra_headers);
333            let mut req = self.http.request(method.clone(), url).headers(headers);
334
335            if let Some(b) = body {
336                req = req.json(b);
337            }
338
339            let response = match req.send().await {
340                Ok(r) => r,
341                Err(e) => {
342                    if e.is_timeout() {
343                        return Err(Error::ConnectionTimeout(ConnectionTimeoutError(
344                            e.to_string(),
345                        )));
346                    }
347                    if retries_remaining == 0 {
348                        return Err(Error::Connection(ConnectionError {
349                            message: e.to_string(),
350                            source: Some(Box::new(e)),
351                        }));
352                    }
353                    retries_remaining -= 1;
354                    tokio::time::sleep(Duration::from_millis(default_retry_timeout_ms(
355                        retries_remaining,
356                        self.max_retries,
357                    )))
358                    .await;
359                    continue;
360                }
361            };
362
363            let status = response.status().as_u16();
364            if (200..300).contains(&status) {
365                return Ok(response);
366            }
367
368            if retries_remaining == 0 || !should_retry(status, response.headers()) {
369                return Ok(response);
370            }
371
372            let wait = retry_after_ms(response.headers()).unwrap_or_else(|| {
373                default_retry_timeout_ms(retries_remaining - 1, self.max_retries)
374            });
375            retries_remaining -= 1;
376            tokio::time::sleep(Duration::from_millis(wait)).await;
377        }
378    }
379
380    async fn request<T, B>(
381        &self,
382        method: Method,
383        path: &str,
384        body: Option<&B>,
385        stream: bool,
386    ) -> Result<T, Error>
387    where
388        T: DeserializeOwned,
389        B: Serialize + ?Sized,
390    {
391        let response = self
392            .request_raw(method.clone(), path, body, stream)
393            .await?;
394        let status = response.status().as_u16();
395        let headers = response.headers().clone();
396        let bytes = response.bytes().await.map_err(|e| {
397            Error::Connection(ConnectionError {
398                message: e.to_string(),
399                source: Some(Box::new(e)),
400            })
401        })?;
402
403        if !(200..300).contains(&status) {
404            let body_json = serde_json::from_slice(&bytes).unwrap_or(serde_json::Value::Null);
405            return Err(ApiError::generate(
406                Some(status),
407                body_json,
408                None,
409                header_map_from_reqwest(&headers),
410            ));
411        }
412
413        serde_json::from_slice(&bytes).map_err(|e| {
414            Error::Anthropic(crate::core::error::AnthropicError(format!(
415                "failed to parse response JSON: {e}"
416            )))
417        })
418    }
419
420    async fn request_url<T, B>(
421        &self,
422        method: Method,
423        url: &str,
424        body: Option<&B>,
425        stream: bool,
426    ) -> Result<T, Error>
427    where
428        T: DeserializeOwned,
429        B: Serialize + ?Sized,
430    {
431        let response = self
432            .make_request_with_retries(method, url, body, stream, self.max_retries)
433            .await?;
434        let status = response.status().as_u16();
435        let headers = response.headers().clone();
436        let bytes = response.bytes().await.map_err(|e| {
437            Error::Connection(ConnectionError {
438                message: e.to_string(),
439                source: Some(Box::new(e)),
440            })
441        })?;
442
443        if !(200..300).contains(&status) {
444            let body_json = serde_json::from_slice(&bytes).unwrap_or(serde_json::Value::Null);
445            return Err(ApiError::generate(
446                Some(status),
447                body_json,
448                None,
449                header_map_from_reqwest(&headers),
450            ));
451        }
452
453        serde_json::from_slice(&bytes).map_err(|e| {
454            Error::Anthropic(crate::core::error::AnthropicError(format!(
455                "failed to parse response JSON: {e}"
456            )))
457        })
458    }
459
460    async fn request_raw<B>(
461        &self,
462        method: Method,
463        path: &str,
464        body: Option<&B>,
465        stream: bool,
466    ) -> Result<Response, Error>
467    where
468        B: Serialize + ?Sized,
469    {
470        let url = self.build_url(path);
471        self.make_request_with_retries(method, &url, body, stream, self.max_retries)
472            .await
473    }
474
475    fn build_url(&self, path: &str) -> String {
476        format!("{}{}", self.base_url, path)
477    }
478
479    fn build_headers(&self, stream: bool) -> Result<HeaderMap, Error> {
480        let mut headers = default_headers(&self.api_key);
481        if let Some(token) = &self.auth_token {
482            headers.insert(
483                "authorization",
484                format!("Bearer {token}").parse().unwrap(),
485            );
486        }
487        if stream {
488            headers.insert("accept", "text/event-stream".parse().unwrap());
489        } else {
490            headers.insert("accept", "application/json".parse().unwrap());
491        }
492        headers.insert("content-type", "application/json".parse().unwrap());
493        merge_headers(&mut headers, &self.default_headers);
494        Ok(headers)
495    }
496
497    async fn make_request_with_retries<B>(
498        &self,
499        method: Method,
500        url: &str,
501        body: Option<&B>,
502        stream: bool,
503        mut retries_remaining: u32,
504    ) -> Result<Response, Error>
505    where
506        B: Serialize + ?Sized,
507    {
508        loop {
509            let headers = self.build_headers(stream)?;
510            let mut req = self.http.request(method.clone(), url).headers(headers);
511
512            if let Some(b) = body {
513                req = req.json(b);
514            }
515
516            let response = match req.send().await {
517                Ok(r) => r,
518                Err(e) => {
519                    if e.is_timeout() {
520                        return Err(Error::ConnectionTimeout(ConnectionTimeoutError(
521                            e.to_string(),
522                        )));
523                    }
524                    if retries_remaining == 0 {
525                        return Err(Error::Connection(ConnectionError {
526                            message: e.to_string(),
527                            source: Some(Box::new(e)),
528                        }));
529                    }
530                    retries_remaining -= 1;
531                    tokio::time::sleep(Duration::from_millis(default_retry_timeout_ms(
532                        retries_remaining,
533                        self.max_retries,
534                    )))
535                    .await;
536                    continue;
537                }
538            };
539
540            let status = response.status().as_u16();
541            if (200..300).contains(&status) {
542                return Ok(response);
543            }
544
545            if retries_remaining == 0 || !should_retry(status, response.headers()) {
546                return Ok(response);
547            }
548
549            let wait = retry_after_ms(response.headers()).unwrap_or_else(|| {
550                default_retry_timeout_ms(retries_remaining - 1, self.max_retries)
551            });
552            retries_remaining -= 1;
553            tokio::time::sleep(Duration::from_millis(wait)).await;
554        }
555    }
556}
557
558impl Default for Anthropic {
559    fn default() -> Self {
560        Self::new().expect("ANTHROPIC_API_KEY must be set for default client")
561    }
562}
563
564fn should_retry(status: u16, headers: &reqwest::header::HeaderMap) -> bool {
565    if let Some(v) = headers.get("x-should-retry") {
566        if let Ok(s) = v.to_str() {
567            if s == "true" {
568                return true;
569            }
570            if s == "false" {
571                return false;
572            }
573        }
574    }
575
576    matches!(status, 408 | 409 | 429) || (500..600).contains(&status)
577}
578
579fn header_map_from_reqwest(headers: &reqwest::header::HeaderMap) -> HeaderMap {
580    let mut map = HeaderMap::new();
581    for (k, v) in headers.iter() {
582        if let Ok(val) = http::HeaderValue::from_bytes(v.as_bytes()) {
583            map.insert(k.clone(), val);
584        }
585    }
586    map
587}
588
589fn urlencoding_encode(s: &str) -> String {
590    s.chars()
591        .map(|c| match c {
592            'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '_' | '.' | '~' => c.to_string(),
593            _ => format!("%{:02X}", c as u8),
594        })
595        .collect()
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601
602    #[test]
603    fn should_retry_rate_limit() {
604        let headers = reqwest::header::HeaderMap::new();
605        assert!(should_retry(429, &headers));
606    }
607
608    #[test]
609    fn should_not_retry_bad_request() {
610        let headers = reqwest::header::HeaderMap::new();
611        assert!(!should_retry(400, &headers));
612    }
613}