1use anyhow::{anyhow, Context, Result};
2use reqwest::{Client as ReqwestClient, Error as ReqwestError, RequestBuilder, StatusCode};
3use serde::Deserialize;
4use serde_json::{json, Value};
5use types::AnthropicChatCompletionChunk;
6mod types;
7use std::collections::HashMap;
8
9use crate::types::AnthropicErrorMessage;
10pub use types::ToolChoice;
11
12pub struct Client {
13 client: ReqwestClient,
14 secret_key: String,
15 model: String,
16 messages: Value,
17 tools: Value,
18 tool_choice: Option<types::ToolChoice>,
19 metadata: Value,
20 max_tokens: i32,
21 stream: bool,
22 verbose: bool,
23 temperature: f32,
24 system: String,
25 version: String,
26 stop_sequences: Vec<String>,
27 beta: Option<String>,
28 top_k: Option<i32>,
29 top_p: Option<f64>,
30}
31
32#[derive(Deserialize)]
33struct JsonResponse {
34 content: Vec<Content>,
35}
36
37#[derive(Deserialize)]
38struct Content {
39 #[serde(rename = "type")]
40 content_type: String,
41 text: String,
42}
43
44impl Client {
45 pub fn new() -> Self {
46 Self {
47 client: ReqwestClient::new(),
48 secret_key: String::new(),
49 model: String::new(),
50 messages: Value::Null,
51 tools: Value::Null,
52 tool_choice: None,
53 metadata: Value::Null,
54 max_tokens: 1024,
55 stream: false,
56 verbose: false,
57 temperature: 0.0,
58 system: String::new(),
59 version: "2023-06-01".to_string(),
60 stop_sequences: Vec::new(),
61 beta: None,
62 top_k: None,
63 top_p: None,
64 }
65 }
66
67 pub fn auth(mut self, secret_key: &str) -> Self {
68 self.secret_key = secret_key.to_owned();
69 self
70 }
71
72 pub fn model(mut self, model: &str) -> Self {
73 self.model = model.to_owned();
74 self
75 }
76
77 pub fn messages(mut self, messages: &Value) -> Self {
78 self.messages = messages.clone();
79 self
80 }
81
82 pub fn tools(mut self, tools: &Value) -> Self {
83 self.tools = tools.clone();
84 self
85 }
86
87 pub fn tool_choice(mut self, tool_choice: types::ToolChoice) -> Self {
88 self.tool_choice = Some(tool_choice);
89 self
90 }
91
92 pub fn metadata(mut self, metadata: &Value) -> Self {
93 self.metadata = metadata.clone();
94 self
95 }
96
97 pub fn max_tokens(mut self, max_tokens: i32) -> Self {
98 self.max_tokens = max_tokens;
99 self
100 }
101
102 pub fn temperature(mut self, temperature: f32) -> Self {
103 self.temperature = temperature.to_owned();
104 self
105 }
106
107 pub fn system(mut self, system: &str) -> Self {
108 self.system = system.to_owned();
109 self
110 }
111 pub fn version(mut self, version: &str) -> Self {
112 self.version = version.to_owned();
113 self
114 }
115
116 pub fn stream(mut self, stream: bool) -> Self {
117 self.stream = stream;
118 self
119 }
120
121 pub fn verbose(mut self, verbose: bool) -> Self {
122 self.verbose = verbose;
123 self
124 }
125
126 pub fn beta(mut self, beta: &str) -> Self {
127 self.beta = Some(beta.to_owned());
128 self
129 }
130
131 pub fn stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
132 self.stop_sequences = stop_sequences;
133 self
134 }
135
136 pub fn top_k(mut self, top_k: i32) -> Self {
137 self.top_k = Some(top_k);
138 self
139 }
140
141 pub fn top_p(mut self, top_p: f64) -> Self {
142 self.top_p = Some(top_p);
143 self
144 }
145
146 pub fn build(self) -> Result<Request, ReqwestError> {
147 let mut body_map: HashMap<&str, Value> = HashMap::new();
148 body_map.insert("model", json!(self.model));
149 body_map.insert("max_tokens", json!(self.max_tokens));
150 body_map.insert("messages", json!(self.messages));
151 body_map.insert("stream", json!(self.stream));
152 body_map.insert("temperature", json!(self.temperature));
153 body_map.insert("system", json!(self.system));
154
155 if self.tools != Value::Null {
156 body_map.insert("tools", self.tools.clone());
157 }
158 if let Some(tool_choice) = self.tool_choice {
159 body_map.insert("tool_choice", json!(tool_choice));
160 }
161
162 if self.metadata != Value::Null {
163 body_map.insert("metadata", self.metadata.clone());
164 }
165
166 if self.stop_sequences.len() > 0 {
167 body_map.insert("stop_sequences", json!(self.stop_sequences));
168 }
169
170 if let Some(top_k) = self.top_k {
171 body_map.insert("top_k", json!(top_k));
172 }
173
174 if let Some(top_p) = self.top_p {
175 body_map.insert("top_p", json!(top_p));
176 }
177
178 let mut request_builder = self
179 .client
180 .post("https://api.anthropic.com/v1/messages")
181 .header("x-api-key", self.secret_key)
182 .header("anthropic-version", self.version)
183 .header("content-type", "application/json")
184 .json(&body_map);
185
186 if let Some(beta_value) = self.beta {
187 request_builder = request_builder.header("anthropic-beta", beta_value);
188 }
189
190 Ok(Request {
191 request_builder,
192 stream: self.stream,
193 verbose: self.verbose,
194 tools: self.tools,
195 })
196 }
197
198 pub fn builder(self) -> Result<RequestBuilder, ReqwestError> {
199 let mut body_map: HashMap<&str, Value> = HashMap::new();
200 body_map.insert("model", json!(self.model));
201 body_map.insert("max_tokens", json!(self.max_tokens));
202 body_map.insert("messages", json!(self.messages));
203 body_map.insert("stream", json!(self.stream));
204 body_map.insert("temperature", json!(self.temperature));
205 body_map.insert("system", json!(self.system));
206
207 if self.tools != Value::Null {
208 body_map.insert("tools", self.tools.clone());
209 }
210
211 if self.metadata != Value::Null {
212 body_map.insert("metadata", self.metadata.clone());
213 }
214
215 if self.stop_sequences.len() > 0 {
216 body_map.insert("stop_sequences", json!(self.stop_sequences));
217 }
218
219 if let Some(top_k) = self.top_k {
220 body_map.insert("top_k", json!(top_k));
221 }
222
223 if let Some(top_p) = self.top_p {
224 body_map.insert("top_p", json!(top_p));
225 }
226
227 let mut request_builder = self
228 .client
229 .post("https://api.anthropic.com/v1/messages")
230 .header("x-api-key", self.secret_key)
231 .header("anthropic-version", self.version)
232 .header("content-type", "application/json")
233 .json(&body_map);
234
235 if let Some(beta_value) = self.beta {
236 request_builder = request_builder.header("anthropic-beta", beta_value);
237 }
238
239 Ok(request_builder)
240 }
241}
242
243pub struct Request {
244 request_builder: RequestBuilder,
245 stream: bool,
246 verbose: bool,
247 tools: Value,
248}
249
250impl Request {
251 pub async fn execute<F, Fut>(self, mut callback: F) -> Result<()>
252 where
253 F: FnMut(String) -> Fut,
254 Fut: std::future::Future<Output = ()> + Send,
255 {
256 let mut response = self
257 .request_builder
258 .send()
259 .await
260 .context("Failed to send request")?;
261
262 match response.status() {
263 StatusCode::OK => {
264 if self.stream {
265 let mut buffer = String::new();
266 while let Some(chunk) = response.chunk().await? {
267 let s = match std::str::from_utf8(&chunk) {
268 Ok(v) => v,
269 Err(e) => panic!("Invalid UTF-8 sequence: {}", e),
270 };
271 buffer.push_str(s);
272 loop {
273 if let Some(index) = buffer.find("\n\n") {
274 let chunk = buffer[..index].to_string();
275 buffer.drain(..=index + 1);
276
277 if self.verbose {
278 callback(chunk.clone()).await;
279 } else {
280 if chunk == "data: [DONE]" {
281 break;
282 }
283 let processed_chunk = chunk
284 .trim_start_matches("event: message_start")
285 .trim_start_matches("event: content_block_start")
286 .trim_start_matches("event: ping")
287 .trim_start_matches("event: content_block_delta")
288 .trim_start_matches("event: content_block_stop")
289 .trim_start_matches("event: message_delta")
290 .trim_start_matches("event: message_stop")
291 .to_string();
292 let cleaned_string = &processed_chunk
293 .trim_start()
294 .strip_prefix("data: ")
295 .unwrap_or(&processed_chunk);
296 match serde_json::from_str::<AnthropicChatCompletionChunk>(
297 &cleaned_string,
298 ) {
299 Ok(d) => {
300 if let Some(delta) = d.delta {
301 if let Some(content) = delta.text {
302 callback(content).await;
303 }
304 }
305 }
306 Err(_) => {
307 let processed_chunk = cleaned_string
308 .trim_start_matches("event: error")
309 .to_string();
310 let cleaned_string = &processed_chunk
311 .trim_start()
312 .strip_prefix("data: ")
313 .unwrap_or(&processed_chunk);
314 match serde_json::from_str::<AnthropicErrorMessage>(
315 &cleaned_string,
316 ) {
317 Ok(error_message) => {
318 return Err(anyhow!("{}: {}", error_message.error.error_type, error_message.error.message));
319 }
320 Err(_) => {
321 eprintln!(
322 "Couldn't parse AnthropicChatCompletionChunk or AnthropicErrorMessage: {}",
323 &cleaned_string
324 );
325 }
326 }
327 }
328 }
329 }
330 } else {
331 break;
332 }
333 }
334 }
335 } else {
336 let json_text = response
337 .text()
338 .await
339 .context("Failed to read response text")?;
340 if self.tools == Value::Null && !self.verbose {
341 match serde_json::from_str::<JsonResponse>(&json_text) {
342 Ok(parsed_json) => {
343 if let Some(content) = parsed_json
344 .content
345 .iter()
346 .find(|c| c.content_type == "text")
347 {
348 callback(content.text.clone()).await;
349 }
350 }
351 Err(_) => return Err(anyhow!("Unable to parse JSON")),
352 }
353 } else {
354 callback(json_text).await;
355 }
356 }
357 Ok(())
358 }
359 StatusCode::BAD_REQUEST => Err(anyhow!(
360 "Bad request. Check your request parameters. {}",
361 response.text().await?
362 )),
363 StatusCode::UNAUTHORIZED => Err(anyhow!("Unauthorized. Check your authorization key.")),
364 _ => {
365 let error_message = format!("Unexpected status code: {:?}", response.text().await?);
366 Err(anyhow!(error_message))
367 }
368 }
369 }
370}