anthropic_rs/
client.rs

1use core::fmt;
2use futures_util::{stream, Stream, StreamExt};
3use reqwest::{
4    header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE},
5    Method, RequestBuilder, Url,
6};
7use serde::{Deserialize, Serialize};
8use std::str::FromStr;
9
10use crate::{
11    completion::{
12        message::{MessageRequest, MessageResponse},
13        stream::StreamEvent,
14    },
15    config::Config,
16    error::{AnthropicError, ApiErrorResponse},
17};
18
19const ANTHROPIC_API_KEY_HEADER: &str = "x-api-key";
20const ANTHROPIC_BETA_HEADERS: &str = "anthropic-beta";
21const ANTHROPIC_VERSION_HEADER: &str = "anthropic-version";
22
23pub struct Client {
24    api_key: String,
25    api_version: ApiVersion,
26    anthropic_version: AnthropicVersion,
27    base_url: Url,
28    beta_headers: Option<String>,
29    http_client: reqwest::Client,
30}
31
32impl Client {
33    pub fn new(config: Config) -> Result<Self, AnthropicError> {
34        let mut headers = HeaderMap::new();
35        headers.insert(
36            ANTHROPIC_API_KEY_HEADER,
37            HeaderValue::from_str(config.api_key.as_str())
38                .map_err(AnthropicError::InvalidHeaderValue)?,
39        );
40        headers.insert(
41            ANTHROPIC_VERSION_HEADER,
42            HeaderValue::from_str(&config.anthropic_version.to_string())
43                .map_err(AnthropicError::InvalidHeaderValue)?,
44        );
45        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
46
47        let http_client = reqwest::Client::builder()
48            .default_headers(headers)
49            .build()?;
50
51        let base_url = Url::parse(&config.base_url)
52            .map_err(|err| AnthropicError::UrlParse(err.to_string()))?
53            .join(format!("{}/", config.api_version).as_str())
54            .map_err(|err| AnthropicError::UrlParse(err.to_string()))?;
55
56        Ok(Self {
57            anthropic_version: config.anthropic_version,
58            api_key: config.api_key,
59            api_version: config.api_version,
60            base_url,
61            beta_headers: None,
62            http_client,
63        })
64    }
65
66    pub fn anthropic_version(&self) -> &AnthropicVersion {
67        &self.anthropic_version
68    }
69
70    pub fn api_key(&self) -> &str {
71        self.api_key.as_str()
72    }
73
74    pub fn api_version(&self) -> &ApiVersion {
75        &self.api_version
76    }
77
78    pub fn base_url(&self) -> &str {
79        self.base_url.as_str()
80    }
81
82    pub fn beta_headers(&self) -> Option<&str> {
83        self.beta_headers.as_deref()
84    }
85
86    pub fn with_beta_header(&mut self, new_header: &str) {
87        if let Some(existing_headers) = &self.beta_headers {
88            self.beta_headers = Some(format!("{},{}", existing_headers, new_header));
89        } else {
90            self.beta_headers = Some(new_header.to_string());
91        }
92    }
93
94    pub fn with_beta_headers(&mut self, headers: &[&str]) {
95        if let Some(existing_headers) = &self.beta_headers {
96            let new_headers = headers.join(",");
97            self.beta_headers = Some(format!("{},{}", existing_headers, new_headers));
98        } else {
99            self.beta_headers = Some(headers.join(","));
100        }
101    }
102
103    fn request(&self, method: Method, path: &str) -> Result<RequestBuilder, AnthropicError> {
104        let url = self
105            .base_url
106            .join(path)
107            .map_err(|err| AnthropicError::UrlParse(err.to_string()))?;
108        let req = self.http_client.request(method, url);
109        if let Some(beta_headers) = &self.beta_headers {
110            Ok(req.header(ANTHROPIC_BETA_HEADERS, beta_headers))
111        } else {
112            Ok(req)
113        }
114    }
115
116    pub async fn create_message(
117        &self,
118        payload: MessageRequest,
119    ) -> Result<MessageResponse, AnthropicError> {
120        let response = self
121            .request(Method::POST, "messages")?
122            .json(&payload)
123            .send()
124            .await?;
125
126        if !response.status().is_success() {
127            let error = response.text().await?;
128            match serde_json::from_str::<ApiErrorResponse>(&error) {
129                Ok(api_error) => return Err(AnthropicError::Api(api_error)),
130                Err(err) => return Err(AnthropicError::JsonDeserialize(err)),
131            }
132        }
133
134        response
135            .json::<MessageResponse>()
136            .await
137            .map_err(AnthropicError::from)
138    }
139
140    pub async fn stream_message(
141        &self,
142        request: MessageRequest,
143    ) -> Result<impl Stream<Item = Result<StreamEvent, AnthropicError>>, AnthropicError> {
144        let response = self
145            .request(Method::POST, "messages")?
146            .header(ACCEPT, "text/event-stream")
147            .json(&request)
148            .send()
149            .await?;
150
151        if !response.status().is_success() {
152            let error = response.text().await?;
153            match serde_json::from_str::<ApiErrorResponse>(&error) {
154                Ok(api_error) => return Err(AnthropicError::Api(api_error)),
155                Err(err) => return Err(AnthropicError::JsonDeserialize(err)),
156            }
157        }
158
159        Ok(response.bytes_stream().flat_map(move |chunk| match chunk {
160            Ok(bytes) => {
161                let events = Self::parse_stream_chunk(&bytes);
162                stream::iter(events)
163            }
164            Err(err) => stream::iter(vec![Err(AnthropicError::from(err))]),
165        }))
166    }
167
168    fn parse_stream_chunk(bytes: &[u8]) -> Vec<Result<StreamEvent, AnthropicError>> {
169        let chunk_str = match std::str::from_utf8(bytes).map_err(AnthropicError::Utf8Error) {
170            Ok(chunk_str) => chunk_str,
171            Err(err) => return vec![Err(err)],
172        };
173        chunk_str
174            .split("\n\n")
175            .filter(|event| !event.trim().is_empty())
176            .map(|event| {
177                event
178                    .lines()
179                    .find(|line| line.starts_with("data: "))
180                    .and_then(|line| line.strip_prefix("data: "))
181                    .ok_or(AnthropicError::InvalidStreamEvent)
182                    .and_then(|content| {
183                        StreamEvent::from_str(content)
184                            .map_err(|_| AnthropicError::InvalidStreamEvent)
185                    })
186            })
187            .collect()
188    }
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
192pub enum AnthropicVersion {
193    Latest,
194    Initial,
195}
196
197impl Default for AnthropicVersion {
198    fn default() -> Self {
199        Self::Latest
200    }
201}
202
203impl fmt::Display for AnthropicVersion {
204    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205        match self {
206            Self::Latest => write!(f, "2023-06-01"),
207            Self::Initial => write!(f, "2023-01-01"),
208        }
209    }
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
213pub enum ApiVersion {
214    V1,
215}
216
217impl Default for ApiVersion {
218    fn default() -> Self {
219        Self::V1
220    }
221}
222
223impl fmt::Display for ApiVersion {
224    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225        match self {
226            Self::V1 => write!(f, "v1"),
227        }
228    }
229}
230
231#[derive(Debug, PartialEq, Eq, thiserror::Error)]
232#[error("Invalid API version: {0}")]
233pub struct ApiVersionError(String);
234
235impl FromStr for ApiVersion {
236    type Err = ApiVersionError;
237
238    fn from_str(s: &str) -> Result<Self, Self::Err> {
239        match s {
240            "v1" => Ok(Self::V1),
241            _ => Err(ApiVersionError(s.to_string())),
242        }
243    }
244}