langchain_rust/llm/claude/
client.rs1use crate::{
2 language_models::{llm::LLM, options::CallOptions, GenerateResult, LLMError, TokenUsage},
3 llm::AnthropicError,
4 schemas::{Message, MessageType, StreamData},
5};
6use async_trait::async_trait;
7use futures::{Stream, StreamExt};
8use reqwest::Client;
9use serde_json::Value;
10use std::{collections::HashMap, pin::Pin};
11
12use super::models::{ApiResponse, ClaudeMessage, Payload};
13
14pub enum ClaudeModel {
15 Claude3pus20240229,
16 Claude3sonnet20240229,
17 Claude3haiku20240307,
18 Claude3_5sonnet20240620,
19}
20
21impl ToString for ClaudeModel {
22 fn to_string(&self) -> String {
23 match self {
24 ClaudeModel::Claude3pus20240229 => "claude-3-opus-20240229".to_string(),
25 ClaudeModel::Claude3sonnet20240229 => "claude-3-sonnet-20240229".to_string(),
26 ClaudeModel::Claude3haiku20240307 => "claude-3-haiku-20240307".to_string(),
27 ClaudeModel::Claude3_5sonnet20240620 => "claude-3-5-sonnet-20240620".to_string(),
28 }
29 }
30}
31
32#[derive(Clone)]
33pub struct Claude {
34 model: String,
35 options: CallOptions,
36 api_key: String,
37 anthropic_version: String,
38}
39
40impl Default for Claude {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl Claude {
47 pub fn new() -> Self {
48 Self {
49 model: ClaudeModel::Claude3pus20240229.to_string(),
50 options: CallOptions::default(),
51 api_key: std::env::var("CLAUDE_API_KEY").unwrap_or_default(),
52 anthropic_version: "2023-06-01".to_string(),
53 }
54 }
55
56 pub fn with_model<S: Into<String>>(mut self, model: S) -> Self {
57 self.model = model.into();
58 self
59 }
60
61 pub fn with_options(mut self, options: CallOptions) -> Self {
62 self.options = options;
63 self
64 }
65
66 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
67 self.api_key = api_key.into();
68 self
69 }
70
71 pub fn with_anthropic_version<S: Into<String>>(mut self, version: S) -> Self {
72 self.anthropic_version = version.into();
73 self
74 }
75
76 async fn generate(&self, messages: &[Message]) -> Result<GenerateResult, LLMError> {
77 let client = Client::new();
78 let is_stream = self.options.streaming_func.is_some();
79
80 let payload = self.build_payload(messages, is_stream);
81 let res = client
82 .post("https://api.anthropic.com/v1/messages")
83 .header("x-api-key", &self.api_key)
84 .header("anthropic-version", self.anthropic_version.clone())
85 .header("content-type", "application/json; charset=utf-8")
86 .json(&payload)
87 .send()
88 .await?;
89 let res = match res.status().as_u16() {
90 401 => Err(LLMError::AnthropicError(
91 AnthropicError::AuthenticationError("Invalid API Key".to_string()),
92 )),
93 403 => Err(LLMError::AnthropicError(AnthropicError::PermissionError(
94 "Permission Denied".to_string(),
95 ))),
96 404 => Err(LLMError::AnthropicError(AnthropicError::NotFoundError(
97 "Not Found".to_string(),
98 ))),
99 429 => Err(LLMError::AnthropicError(AnthropicError::RateLimitError(
100 "Rate Limit Exceeded".to_string(),
101 ))),
102 503 => Err(LLMError::AnthropicError(AnthropicError::OverloadedError(
103 "Service Unavailable".to_string(),
104 ))),
105 _ => Ok(res.json::<ApiResponse>().await?),
106 }?;
107
108 let generation = res
109 .content
110 .first()
111 .map(|c| c.text.clone())
112 .unwrap_or_default();
113
114 let tokens = Some(TokenUsage {
115 prompt_tokens: res.usage.input_tokens,
116 completion_tokens: res.usage.output_tokens,
117 total_tokens: res.usage.input_tokens + res.usage.output_tokens,
118 });
119
120 Ok(GenerateResult { tokens, generation })
121 }
122
123 fn build_payload(&self, messages: &[Message], stream: bool) -> Payload {
124 let (system_message, other_messages): (Vec<_>, Vec<_>) = messages
125 .into_iter()
126 .partition(|m| m.message_type == MessageType::SystemMessage);
127 let mut payload = Payload {
128 model: self.model.clone(),
129 system: system_message.get(0).map(|m| m.content.clone()),
130 messages: other_messages
131 .into_iter()
132 .map(ClaudeMessage::from_message)
133 .collect::<Vec<_>>(),
134 max_tokens: self.options.max_tokens.unwrap_or(1024),
135 stream: None,
136 stop_sequences: self.options.stop_words.clone(),
137 temperature: self.options.temperature,
138 top_p: self.options.top_p,
139 top_k: self.options.top_k,
140 };
141 if stream {
142 payload.stream = Some(true);
143 }
144 payload
145 }
146}
147
148#[async_trait]
149impl LLM for Claude {
150 async fn generate(&self, messages: &[Message]) -> Result<GenerateResult, LLMError> {
151 match &self.options.streaming_func {
152 Some(func) => {
153 let mut complete_response = String::new();
154 let mut stream = self.stream(messages).await?;
155 while let Some(data) = stream.next().await {
156 match data {
157 Ok(value) => {
158 let mut func = func.lock().await;
159 complete_response.push_str(&value.content);
160 let _ = func(value.content).await;
161 }
162 Err(e) => return Err(e),
163 }
164 }
165 let mut generate_result = GenerateResult::default();
166 generate_result.generation = complete_response;
167 Ok(generate_result)
168 }
169 None => self.generate(messages).await,
170 }
171 }
172 async fn stream(
173 &self,
174 messages: &[Message],
175 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, LLMError>> + Send>>, LLMError> {
176 let client = Client::new();
177 let payload = self.build_payload(messages, true);
178 let request = client
179 .post("https://api.anthropic.com/v1/messages")
180 .header("x-api-key", &self.api_key)
181 .header("anthropic-version", &self.anthropic_version)
182 .header("content-type", "application/json; charset=utf-8")
183 .json(&payload)
184 .build()?;
185
186 let stream = client.execute(request).await?;
188 let stream = stream.bytes_stream();
189 let processed_stream = stream.then(move |result| {
191 async move {
192 match result {
193 Ok(bytes) => {
194 let value: Value = parse_sse_to_json(&String::from_utf8_lossy(&bytes))?;
195 if value["type"].as_str().unwrap_or("") == "content_block_delta" {
196 let content = value["delta"]["text"].clone();
197 Ok(StreamData::new(value, None, content.as_str().unwrap_or("")))
200 } else {
201 Ok(StreamData::new(value, None, ""))
202 }
203 }
204 Err(e) => Err(LLMError::RequestError(e)),
205 }
206 }
207 });
208
209 Ok(Box::pin(processed_stream))
210 }
211
212 fn add_options(&mut self, options: CallOptions) {
213 self.options.merge_options(options)
214 }
215}
216
217fn parse_sse_to_json(sse_data: &str) -> Result<Value, LLMError> {
218 if let Ok(json) = serde_json::from_str::<Value>(sse_data) {
219 return parse_error(&json);
220 }
221
222 let lines: Vec<&str> = sse_data.trim().split('\n').collect();
223 let mut event_data: HashMap<&str, String> = HashMap::new();
224
225 for line in lines {
226 if let Some((key, value)) = line.split_once(": ") {
227 event_data.insert(key, value.to_string());
228 }
229 }
230
231 if let Some(data) = event_data.get("data") {
232 let data: Value = serde_json::from_str(data)?;
233 return match data["type"].as_str() {
234 Some("error") => parse_error(&data),
235 _ => Ok(data),
236 };
237 }
238 log::error!("No data field in the SSE event");
239 Err(LLMError::ContentNotFound("data".to_string()))
240}
241
242fn parse_error(json: &Value) -> Result<Value, LLMError> {
243 let error_type = json["error"]["type"].as_str().unwrap_or("");
244 let message = json["error"]["message"].as_str().unwrap_or("").to_string();
245 match error_type {
246 "invalid_request_error" => Err(AnthropicError::InvalidRequestError(message))?,
247 "authentication_error" => Err(AnthropicError::AuthenticationError(message))?,
248 "permission_error" => Err(AnthropicError::PermissionError(message))?,
249 "not_found_error" => Err(AnthropicError::NotFoundError(message))?,
250 "rate_limit_error" => Err(AnthropicError::RateLimitError(message))?,
251 "api_error" => Err(AnthropicError::ApiError(message))?,
252 "overloaded_error" => Err(AnthropicError::OverloadedError(message))?,
253 _ => Err(LLMError::OtherError("Unknown error".to_string())),
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use tokio::test;
261
262 #[test]
263 #[ignore]
264 async fn test_cloudia_generate() {
265 let cloudia = Claude::new();
266
267 let res = cloudia
268 .generate(&[Message::new_human_message("Hi, how are you doing")])
269 .await
270 .unwrap();
271
272 println!("{:?}", res)
273 }
274
275 #[test]
276 #[ignore]
277 async fn test_cloudia_stream() {
278 let cloudia = Claude::new();
279 let mut stream = cloudia
280 .stream(&[Message::new_human_message("Hi, how are you doing")])
281 .await
282 .unwrap();
283 while let Some(data) = stream.next().await {
284 match data {
285 Ok(value) => value.to_stdout().unwrap(),
286 Err(e) => panic!("Error invoking LLMChain: {:?}", e),
287 }
288 }
289 }
290}