anthropic_sdk/
lib.rs

1use anyhow::{anyhow, Context, Result};
2use reqwest::{Client as ReqwestClient, Error as ReqwestError, RequestBuilder, StatusCode};
3use serde::Deserialize;
4use serde_json::{json, Value};
5use types::AnthropicChatCompletionChunk;
6mod types;
7use std::collections::HashMap;
8
9use crate::types::AnthropicErrorMessage;
10pub use types::ToolChoice;
11
12pub struct Client {
13    client: ReqwestClient,
14    secret_key: String,
15    model: String,
16    messages: Value,
17    tools: Value,
18    tool_choice: Option<types::ToolChoice>,
19    metadata: Value,
20    max_tokens: i32,
21    stream: bool,
22    verbose: bool,
23    temperature: f32,
24    system: String,
25    version: String,
26    stop_sequences: Vec<String>,
27    beta: Option<String>,
28    top_k: Option<i32>,
29    top_p: Option<f64>,
30}
31
32#[derive(Deserialize)]
33struct JsonResponse {
34    content: Vec<Content>,
35}
36
37#[derive(Deserialize)]
38struct Content {
39    #[serde(rename = "type")]
40    content_type: String,
41    text: String,
42}
43
44impl Client {
45    pub fn new() -> Self {
46        Self {
47            client: ReqwestClient::new(),
48            secret_key: String::new(),
49            model: String::new(),
50            messages: Value::Null,
51            tools: Value::Null,
52            tool_choice: None,
53            metadata: Value::Null,
54            max_tokens: 1024,
55            stream: false,
56            verbose: false,
57            temperature: 0.0,
58            system: String::new(),
59            version: "2023-06-01".to_string(),
60            stop_sequences: Vec::new(),
61            beta: None,
62            top_k: None,
63            top_p: None,
64        }
65    }
66
67    pub fn auth(mut self, secret_key: &str) -> Self {
68        self.secret_key = secret_key.to_owned();
69        self
70    }
71
72    pub fn model(mut self, model: &str) -> Self {
73        self.model = model.to_owned();
74        self
75    }
76
77    pub fn messages(mut self, messages: &Value) -> Self {
78        self.messages = messages.clone();
79        self
80    }
81
82    pub fn tools(mut self, tools: &Value) -> Self {
83        self.tools = tools.clone();
84        self
85    }
86
87    pub fn tool_choice(mut self, tool_choice: types::ToolChoice) -> Self {
88        self.tool_choice = Some(tool_choice);
89        self
90    }
91
92    pub fn metadata(mut self, metadata: &Value) -> Self {
93        self.metadata = metadata.clone();
94        self
95    }
96
97    pub fn max_tokens(mut self, max_tokens: i32) -> Self {
98        self.max_tokens = max_tokens;
99        self
100    }
101
102    pub fn temperature(mut self, temperature: f32) -> Self {
103        self.temperature = temperature.to_owned();
104        self
105    }
106
107    pub fn system(mut self, system: &str) -> Self {
108        self.system = system.to_owned();
109        self
110    }
111    pub fn version(mut self, version: &str) -> Self {
112        self.version = version.to_owned();
113        self
114    }
115
116    pub fn stream(mut self, stream: bool) -> Self {
117        self.stream = stream;
118        self
119    }
120
121    pub fn verbose(mut self, verbose: bool) -> Self {
122        self.verbose = verbose;
123        self
124    }
125
126    pub fn beta(mut self, beta: &str) -> Self {
127        self.beta = Some(beta.to_owned());
128        self
129    }
130
131    pub fn stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
132        self.stop_sequences = stop_sequences;
133        self
134    }
135
136    pub fn top_k(mut self, top_k: i32) -> Self {
137        self.top_k = Some(top_k);
138        self
139    }
140
141    pub fn top_p(mut self, top_p: f64) -> Self {
142        self.top_p = Some(top_p);
143        self
144    }
145
146    pub fn build(self) -> Result<Request, ReqwestError> {
147        let mut body_map: HashMap<&str, Value> = HashMap::new();
148        body_map.insert("model", json!(self.model));
149        body_map.insert("max_tokens", json!(self.max_tokens));
150        body_map.insert("messages", json!(self.messages));
151        body_map.insert("stream", json!(self.stream));
152        body_map.insert("temperature", json!(self.temperature));
153        body_map.insert("system", json!(self.system));
154
155        if self.tools != Value::Null {
156            body_map.insert("tools", self.tools.clone());
157        }
158        if let Some(tool_choice) = self.tool_choice {
159            body_map.insert("tool_choice", json!(tool_choice));
160        }
161
162        if self.metadata != Value::Null {
163            body_map.insert("metadata", self.metadata.clone());
164        }
165
166        if self.stop_sequences.len() > 0 {
167            body_map.insert("stop_sequences", json!(self.stop_sequences));
168        }
169
170        if let Some(top_k) = self.top_k {
171            body_map.insert("top_k", json!(top_k));
172        }
173
174        if let Some(top_p) = self.top_p {
175            body_map.insert("top_p", json!(top_p));
176        }
177
178        let mut request_builder = self
179            .client
180            .post("https://api.anthropic.com/v1/messages")
181            .header("x-api-key", self.secret_key)
182            .header("anthropic-version", self.version)
183            .header("content-type", "application/json")
184            .json(&body_map);
185
186        if let Some(beta_value) = self.beta {
187            request_builder = request_builder.header("anthropic-beta", beta_value);
188        }
189
190        Ok(Request {
191            request_builder,
192            stream: self.stream,
193            verbose: self.verbose,
194            tools: self.tools,
195        })
196    }
197
198    pub fn builder(self) -> Result<RequestBuilder, ReqwestError> {
199        let mut body_map: HashMap<&str, Value> = HashMap::new();
200        body_map.insert("model", json!(self.model));
201        body_map.insert("max_tokens", json!(self.max_tokens));
202        body_map.insert("messages", json!(self.messages));
203        body_map.insert("stream", json!(self.stream));
204        body_map.insert("temperature", json!(self.temperature));
205        body_map.insert("system", json!(self.system));
206
207        if self.tools != Value::Null {
208            body_map.insert("tools", self.tools.clone());
209        }
210
211        if self.metadata != Value::Null {
212            body_map.insert("metadata", self.metadata.clone());
213        }
214
215        if self.stop_sequences.len() > 0 {
216            body_map.insert("stop_sequences", json!(self.stop_sequences));
217        }
218
219        if let Some(top_k) = self.top_k {
220            body_map.insert("top_k", json!(top_k));
221        }
222
223        if let Some(top_p) = self.top_p {
224            body_map.insert("top_p", json!(top_p));
225        }
226
227        let mut request_builder = self
228            .client
229            .post("https://api.anthropic.com/v1/messages")
230            .header("x-api-key", self.secret_key)
231            .header("anthropic-version", self.version)
232            .header("content-type", "application/json")
233            .json(&body_map);
234
235        if let Some(beta_value) = self.beta {
236            request_builder = request_builder.header("anthropic-beta", beta_value);
237        }
238
239        Ok(request_builder)
240    }
241}
242
243pub struct Request {
244    request_builder: RequestBuilder,
245    stream: bool,
246    verbose: bool,
247    tools: Value,
248}
249
250impl Request {
251    pub async fn execute<F, Fut>(self, mut callback: F) -> Result<()>
252    where
253        F: FnMut(String) -> Fut,
254        Fut: std::future::Future<Output = ()> + Send,
255    {
256        let mut response = self
257            .request_builder
258            .send()
259            .await
260            .context("Failed to send request")?;
261
262        match response.status() {
263            StatusCode::OK => {
264                if self.stream {
265                    let mut buffer = String::new();
266                    while let Some(chunk) = response.chunk().await? {
267                        let s = match std::str::from_utf8(&chunk) {
268                            Ok(v) => v,
269                            Err(e) => panic!("Invalid UTF-8 sequence: {}", e),
270                        };
271                        buffer.push_str(s);
272                        loop {
273                            if let Some(index) = buffer.find("\n\n") {
274                                let chunk = buffer[..index].to_string();
275                                buffer.drain(..=index + 1);
276
277                                if self.verbose {
278                                    callback(chunk.clone()).await;
279                                } else {
280                                    if chunk == "data: [DONE]" {
281                                        break;
282                                    }
283                                    let processed_chunk = chunk
284                                        .trim_start_matches("event: message_start")
285                                        .trim_start_matches("event: content_block_start")
286                                        .trim_start_matches("event: ping")
287                                        .trim_start_matches("event: content_block_delta")
288                                        .trim_start_matches("event: content_block_stop")
289                                        .trim_start_matches("event: message_delta")
290                                        .trim_start_matches("event: message_stop")
291                                        .to_string();
292                                    let cleaned_string = &processed_chunk
293                                        .trim_start()
294                                        .strip_prefix("data: ")
295                                        .unwrap_or(&processed_chunk);
296                                    match serde_json::from_str::<AnthropicChatCompletionChunk>(
297                                        &cleaned_string,
298                                    ) {
299                                        Ok(d) => {
300                                            if let Some(delta) = d.delta {
301                                                if let Some(content) = delta.text {
302                                                    callback(content).await;
303                                                }
304                                            }
305                                        }
306                                        Err(_) => {
307                                            let processed_chunk = cleaned_string
308                                                .trim_start_matches("event: error")
309                                                .to_string();
310                                            let cleaned_string = &processed_chunk
311                                                .trim_start()
312                                                .strip_prefix("data: ")
313                                                .unwrap_or(&processed_chunk);
314                                            match serde_json::from_str::<AnthropicErrorMessage>(
315                                                &cleaned_string,
316                                            ) {
317                                                Ok(error_message) => {
318                                                    return Err(anyhow!("{}: {}", error_message.error.error_type, error_message.error.message));
319                                                }
320                                                Err(_) => {
321                                                    eprintln!(
322                                                        "Couldn't parse AnthropicChatCompletionChunk or AnthropicErrorMessage: {}",
323                                                        &cleaned_string
324                                                    );
325                                                }
326                                            }
327                                        }
328                                    }
329                                }
330                            } else {
331                                break;
332                            }
333                        }
334                    }
335                } else {
336                    let json_text = response
337                        .text()
338                        .await
339                        .context("Failed to read response text")?;
340                    if self.tools == Value::Null && !self.verbose {
341                        match serde_json::from_str::<JsonResponse>(&json_text) {
342                            Ok(parsed_json) => {
343                                if let Some(content) = parsed_json
344                                    .content
345                                    .iter()
346                                    .find(|c| c.content_type == "text")
347                                {
348                                    callback(content.text.clone()).await;
349                                }
350                            }
351                            Err(_) => return Err(anyhow!("Unable to parse JSON")),
352                        }
353                    } else {
354                        callback(json_text).await;
355                    }
356                }
357                Ok(())
358            }
359            StatusCode::BAD_REQUEST => Err(anyhow!(
360                "Bad request. Check your request parameters. {}",
361                response.text().await?
362            )),
363            StatusCode::UNAUTHORIZED => Err(anyhow!("Unauthorized. Check your authorization key.")),
364            _ => {
365                let error_message = format!("Unexpected status code: {:?}", response.text().await?);
366                Err(anyhow!(error_message))
367            }
368        }
369    }
370}