1use core::fmt;
2use futures_util::{stream, Stream, StreamExt};
3use reqwest::{
4 header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE},
5 Method, RequestBuilder, Url,
6};
7use serde::{Deserialize, Serialize};
8use std::str::FromStr;
9
10use crate::{
11 completion::{
12 message::{MessageRequest, MessageResponse},
13 stream::StreamEvent,
14 },
15 config::Config,
16 error::{AnthropicError, ApiErrorResponse},
17};
18
19const ANTHROPIC_API_KEY_HEADER: &str = "x-api-key";
20const ANTHROPIC_BETA_HEADERS: &str = "anthropic-beta";
21const ANTHROPIC_VERSION_HEADER: &str = "anthropic-version";
22
23pub struct Client {
24 api_key: String,
25 api_version: ApiVersion,
26 anthropic_version: AnthropicVersion,
27 base_url: Url,
28 beta_headers: Option<String>,
29 http_client: reqwest::Client,
30}
31
32impl Client {
33 pub fn new(config: Config) -> Result<Self, AnthropicError> {
34 let mut headers = HeaderMap::new();
35 headers.insert(
36 ANTHROPIC_API_KEY_HEADER,
37 HeaderValue::from_str(config.api_key.as_str())
38 .map_err(AnthropicError::InvalidHeaderValue)?,
39 );
40 headers.insert(
41 ANTHROPIC_VERSION_HEADER,
42 HeaderValue::from_str(&config.anthropic_version.to_string())
43 .map_err(AnthropicError::InvalidHeaderValue)?,
44 );
45 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
46
47 let http_client = reqwest::Client::builder()
48 .default_headers(headers)
49 .build()?;
50
51 let base_url = Url::parse(&config.base_url)
52 .map_err(|err| AnthropicError::UrlParse(err.to_string()))?
53 .join(format!("{}/", config.api_version).as_str())
54 .map_err(|err| AnthropicError::UrlParse(err.to_string()))?;
55
56 Ok(Self {
57 anthropic_version: config.anthropic_version,
58 api_key: config.api_key,
59 api_version: config.api_version,
60 base_url,
61 beta_headers: None,
62 http_client,
63 })
64 }
65
66 pub fn anthropic_version(&self) -> &AnthropicVersion {
67 &self.anthropic_version
68 }
69
70 pub fn api_key(&self) -> &str {
71 self.api_key.as_str()
72 }
73
74 pub fn api_version(&self) -> &ApiVersion {
75 &self.api_version
76 }
77
78 pub fn base_url(&self) -> &str {
79 self.base_url.as_str()
80 }
81
82 pub fn beta_headers(&self) -> Option<&str> {
83 self.beta_headers.as_deref()
84 }
85
86 pub fn with_beta_header(&mut self, new_header: &str) {
87 if let Some(existing_headers) = &self.beta_headers {
88 self.beta_headers = Some(format!("{},{}", existing_headers, new_header));
89 } else {
90 self.beta_headers = Some(new_header.to_string());
91 }
92 }
93
94 pub fn with_beta_headers(&mut self, headers: &[&str]) {
95 if let Some(existing_headers) = &self.beta_headers {
96 let new_headers = headers.join(",");
97 self.beta_headers = Some(format!("{},{}", existing_headers, new_headers));
98 } else {
99 self.beta_headers = Some(headers.join(","));
100 }
101 }
102
103 fn request(&self, method: Method, path: &str) -> Result<RequestBuilder, AnthropicError> {
104 let url = self
105 .base_url
106 .join(path)
107 .map_err(|err| AnthropicError::UrlParse(err.to_string()))?;
108 let req = self.http_client.request(method, url);
109 if let Some(beta_headers) = &self.beta_headers {
110 Ok(req.header(ANTHROPIC_BETA_HEADERS, beta_headers))
111 } else {
112 Ok(req)
113 }
114 }
115
116 pub async fn create_message(
117 &self,
118 payload: MessageRequest,
119 ) -> Result<MessageResponse, AnthropicError> {
120 let response = self
121 .request(Method::POST, "messages")?
122 .json(&payload)
123 .send()
124 .await?;
125
126 if !response.status().is_success() {
127 let error = response.text().await?;
128 match serde_json::from_str::<ApiErrorResponse>(&error) {
129 Ok(api_error) => return Err(AnthropicError::Api(api_error)),
130 Err(err) => return Err(AnthropicError::JsonDeserialize(err)),
131 }
132 }
133
134 response
135 .json::<MessageResponse>()
136 .await
137 .map_err(AnthropicError::from)
138 }
139
140 pub async fn stream_message(
141 &self,
142 request: MessageRequest,
143 ) -> Result<impl Stream<Item = Result<StreamEvent, AnthropicError>>, AnthropicError> {
144 let response = self
145 .request(Method::POST, "messages")?
146 .header(ACCEPT, "text/event-stream")
147 .json(&request)
148 .send()
149 .await?;
150
151 if !response.status().is_success() {
152 let error = response.text().await?;
153 match serde_json::from_str::<ApiErrorResponse>(&error) {
154 Ok(api_error) => return Err(AnthropicError::Api(api_error)),
155 Err(err) => return Err(AnthropicError::JsonDeserialize(err)),
156 }
157 }
158
159 Ok(response.bytes_stream().flat_map(move |chunk| match chunk {
160 Ok(bytes) => {
161 let events = Self::parse_stream_chunk(&bytes);
162 stream::iter(events)
163 }
164 Err(err) => stream::iter(vec![Err(AnthropicError::from(err))]),
165 }))
166 }
167
168 fn parse_stream_chunk(bytes: &[u8]) -> Vec<Result<StreamEvent, AnthropicError>> {
169 let chunk_str = match std::str::from_utf8(bytes).map_err(AnthropicError::Utf8Error) {
170 Ok(chunk_str) => chunk_str,
171 Err(err) => return vec![Err(err)],
172 };
173 chunk_str
174 .split("\n\n")
175 .filter(|event| !event.trim().is_empty())
176 .map(|event| {
177 event
178 .lines()
179 .find(|line| line.starts_with("data: "))
180 .and_then(|line| line.strip_prefix("data: "))
181 .ok_or(AnthropicError::InvalidStreamEvent)
182 .and_then(|content| {
183 StreamEvent::from_str(content)
184 .map_err(|_| AnthropicError::InvalidStreamEvent)
185 })
186 })
187 .collect()
188 }
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
192pub enum AnthropicVersion {
193 Latest,
194 Initial,
195}
196
197impl Default for AnthropicVersion {
198 fn default() -> Self {
199 Self::Latest
200 }
201}
202
203impl fmt::Display for AnthropicVersion {
204 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205 match self {
206 Self::Latest => write!(f, "2023-06-01"),
207 Self::Initial => write!(f, "2023-01-01"),
208 }
209 }
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
213pub enum ApiVersion {
214 V1,
215}
216
217impl Default for ApiVersion {
218 fn default() -> Self {
219 Self::V1
220 }
221}
222
223impl fmt::Display for ApiVersion {
224 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225 match self {
226 Self::V1 => write!(f, "v1"),
227 }
228 }
229}
230
231#[derive(Debug, PartialEq, Eq, thiserror::Error)]
232#[error("Invalid API version: {0}")]
233pub struct ApiVersionError(String);
234
235impl FromStr for ApiVersion {
236 type Err = ApiVersionError;
237
238 fn from_str(s: &str) -> Result<Self, Self::Err> {
239 match s {
240 "v1" => Ok(Self::V1),
241 _ => Err(ApiVersionError(s.to_string())),
242 }
243 }
244}