langchain_rust/llm/claude/
client.rs

1use crate::{
2    language_models::{llm::LLM, options::CallOptions, GenerateResult, LLMError, TokenUsage},
3    llm::AnthropicError,
4    schemas::{Message, MessageType, StreamData},
5};
6use async_trait::async_trait;
7use futures::{Stream, StreamExt};
8use reqwest::Client;
9use serde_json::Value;
10use std::{collections::HashMap, pin::Pin};
11
12use super::models::{ApiResponse, ClaudeMessage, Payload};
13
14pub enum ClaudeModel {
15    Claude3pus20240229,
16    Claude3sonnet20240229,
17    Claude3haiku20240307,
18    Claude3_5sonnet20240620,
19}
20
21impl ToString for ClaudeModel {
22    fn to_string(&self) -> String {
23        match self {
24            ClaudeModel::Claude3pus20240229 => "claude-3-opus-20240229".to_string(),
25            ClaudeModel::Claude3sonnet20240229 => "claude-3-sonnet-20240229".to_string(),
26            ClaudeModel::Claude3haiku20240307 => "claude-3-haiku-20240307".to_string(),
27            ClaudeModel::Claude3_5sonnet20240620 => "claude-3-5-sonnet-20240620".to_string(),
28        }
29    }
30}
31
32#[derive(Clone)]
33pub struct Claude {
34    model: String,
35    options: CallOptions,
36    api_key: String,
37    anthropic_version: String,
38}
39
40impl Default for Claude {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl Claude {
47    pub fn new() -> Self {
48        Self {
49            model: ClaudeModel::Claude3pus20240229.to_string(),
50            options: CallOptions::default(),
51            api_key: std::env::var("CLAUDE_API_KEY").unwrap_or_default(),
52            anthropic_version: "2023-06-01".to_string(),
53        }
54    }
55
56    pub fn with_model<S: Into<String>>(mut self, model: S) -> Self {
57        self.model = model.into();
58        self
59    }
60
61    pub fn with_options(mut self, options: CallOptions) -> Self {
62        self.options = options;
63        self
64    }
65
66    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
67        self.api_key = api_key.into();
68        self
69    }
70
71    pub fn with_anthropic_version<S: Into<String>>(mut self, version: S) -> Self {
72        self.anthropic_version = version.into();
73        self
74    }
75
76    async fn generate(&self, messages: &[Message]) -> Result<GenerateResult, LLMError> {
77        let client = Client::new();
78        let is_stream = self.options.streaming_func.is_some();
79
80        let payload = self.build_payload(messages, is_stream);
81        let res = client
82            .post("https://api.anthropic.com/v1/messages")
83            .header("x-api-key", &self.api_key)
84            .header("anthropic-version", self.anthropic_version.clone())
85            .header("content-type", "application/json; charset=utf-8")
86            .json(&payload)
87            .send()
88            .await?;
89        let res = match res.status().as_u16() {
90            401 => Err(LLMError::AnthropicError(
91                AnthropicError::AuthenticationError("Invalid API Key".to_string()),
92            )),
93            403 => Err(LLMError::AnthropicError(AnthropicError::PermissionError(
94                "Permission Denied".to_string(),
95            ))),
96            404 => Err(LLMError::AnthropicError(AnthropicError::NotFoundError(
97                "Not Found".to_string(),
98            ))),
99            429 => Err(LLMError::AnthropicError(AnthropicError::RateLimitError(
100                "Rate Limit Exceeded".to_string(),
101            ))),
102            503 => Err(LLMError::AnthropicError(AnthropicError::OverloadedError(
103                "Service Unavailable".to_string(),
104            ))),
105            _ => Ok(res.json::<ApiResponse>().await?),
106        }?;
107
108        let generation = res
109            .content
110            .first()
111            .map(|c| c.text.clone())
112            .unwrap_or_default();
113
114        let tokens = Some(TokenUsage {
115            prompt_tokens: res.usage.input_tokens,
116            completion_tokens: res.usage.output_tokens,
117            total_tokens: res.usage.input_tokens + res.usage.output_tokens,
118        });
119
120        Ok(GenerateResult { tokens, generation })
121    }
122
123    fn build_payload(&self, messages: &[Message], stream: bool) -> Payload {
124        let (system_message, other_messages): (Vec<_>, Vec<_>) = messages
125            .into_iter()
126            .partition(|m| m.message_type == MessageType::SystemMessage);
127        let mut payload = Payload {
128            model: self.model.clone(),
129            system: system_message.get(0).map(|m| m.content.clone()),
130            messages: other_messages
131                .into_iter()
132                .map(ClaudeMessage::from_message)
133                .collect::<Vec<_>>(),
134            max_tokens: self.options.max_tokens.unwrap_or(1024),
135            stream: None,
136            stop_sequences: self.options.stop_words.clone(),
137            temperature: self.options.temperature,
138            top_p: self.options.top_p,
139            top_k: self.options.top_k,
140        };
141        if stream {
142            payload.stream = Some(true);
143        }
144        payload
145    }
146}
147
148#[async_trait]
149impl LLM for Claude {
150    async fn generate(&self, messages: &[Message]) -> Result<GenerateResult, LLMError> {
151        match &self.options.streaming_func {
152            Some(func) => {
153                let mut complete_response = String::new();
154                let mut stream = self.stream(messages).await?;
155                while let Some(data) = stream.next().await {
156                    match data {
157                        Ok(value) => {
158                            let mut func = func.lock().await;
159                            complete_response.push_str(&value.content);
160                            let _ = func(value.content).await;
161                        }
162                        Err(e) => return Err(e),
163                    }
164                }
165                let mut generate_result = GenerateResult::default();
166                generate_result.generation = complete_response;
167                Ok(generate_result)
168            }
169            None => self.generate(messages).await,
170        }
171    }
172    async fn stream(
173        &self,
174        messages: &[Message],
175    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, LLMError>> + Send>>, LLMError> {
176        let client = Client::new();
177        let payload = self.build_payload(messages, true);
178        let request = client
179            .post("https://api.anthropic.com/v1/messages")
180            .header("x-api-key", &self.api_key)
181            .header("anthropic-version", &self.anthropic_version)
182            .header("content-type", "application/json; charset=utf-8")
183            .json(&payload)
184            .build()?;
185
186        // Instead of sending the request directly, return a stream wrapper
187        let stream = client.execute(request).await?;
188        let stream = stream.bytes_stream();
189        // Process each chunk as it arrives
190        let processed_stream = stream.then(move |result| {
191            async move {
192                match result {
193                    Ok(bytes) => {
194                        let value: Value = parse_sse_to_json(&String::from_utf8_lossy(&bytes))?;
195                        if value["type"].as_str().unwrap_or("") == "content_block_delta" {
196                            let content = value["delta"]["text"].clone();
197                            // Return StreamData based on the parsed content
198                            // TODO get tokens from the response
199                            Ok(StreamData::new(value, None, content.as_str().unwrap_or("")))
200                        } else {
201                            Ok(StreamData::new(value, None, ""))
202                        }
203                    }
204                    Err(e) => Err(LLMError::RequestError(e)),
205                }
206            }
207        });
208
209        Ok(Box::pin(processed_stream))
210    }
211
212    fn add_options(&mut self, options: CallOptions) {
213        self.options.merge_options(options)
214    }
215}
216
217fn parse_sse_to_json(sse_data: &str) -> Result<Value, LLMError> {
218    if let Ok(json) = serde_json::from_str::<Value>(sse_data) {
219        return parse_error(&json);
220    }
221
222    let lines: Vec<&str> = sse_data.trim().split('\n').collect();
223    let mut event_data: HashMap<&str, String> = HashMap::new();
224
225    for line in lines {
226        if let Some((key, value)) = line.split_once(": ") {
227            event_data.insert(key, value.to_string());
228        }
229    }
230
231    if let Some(data) = event_data.get("data") {
232        let data: Value = serde_json::from_str(data)?;
233        return match data["type"].as_str() {
234            Some("error") => parse_error(&data),
235            _ => Ok(data),
236        };
237    }
238    log::error!("No data field in the SSE event");
239    Err(LLMError::ContentNotFound("data".to_string()))
240}
241
242fn parse_error(json: &Value) -> Result<Value, LLMError> {
243    let error_type = json["error"]["type"].as_str().unwrap_or("");
244    let message = json["error"]["message"].as_str().unwrap_or("").to_string();
245    match error_type {
246        "invalid_request_error" => Err(AnthropicError::InvalidRequestError(message))?,
247        "authentication_error" => Err(AnthropicError::AuthenticationError(message))?,
248        "permission_error" => Err(AnthropicError::PermissionError(message))?,
249        "not_found_error" => Err(AnthropicError::NotFoundError(message))?,
250        "rate_limit_error" => Err(AnthropicError::RateLimitError(message))?,
251        "api_error" => Err(AnthropicError::ApiError(message))?,
252        "overloaded_error" => Err(AnthropicError::OverloadedError(message))?,
253        _ => Err(LLMError::OtherError("Unknown error".to_string())),
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use tokio::test;
261
262    #[test]
263    #[ignore]
264    async fn test_cloudia_generate() {
265        let cloudia = Claude::new();
266
267        let res = cloudia
268            .generate(&[Message::new_human_message("Hi, how are you doing")])
269            .await
270            .unwrap();
271
272        println!("{:?}", res)
273    }
274
275    #[test]
276    #[ignore]
277    async fn test_cloudia_stream() {
278        let cloudia = Claude::new();
279        let mut stream = cloudia
280            .stream(&[Message::new_human_message("Hi, how are you doing")])
281            .await
282            .unwrap();
283        while let Some(data) = stream.next().await {
284            match data {
285                Ok(value) => value.to_stdout().unwrap(),
286                Err(e) => panic!("Error invoking LLMChain: {:?}", e),
287            }
288        }
289    }
290}