kalosm_language_model/claude/
chat.rs1use super::{AnthropicCompatibleClient, NoAnthropicAPIKeyError};
2use crate::{
3 ChatMessage, ChatModel, ChatSession, CreateChatSession, GenerationParameters, ModelBuilder,
4};
5use futures_util::StreamExt;
6use kalosm_model_types::ModelLoadingProgress;
7use reqwest_eventsource::{Event, RequestBuilderExt};
8use serde::{Deserialize, Serialize};
9use std::{future::Future, sync::Arc};
10use thiserror::Error;
11
12#[derive(Debug)]
13struct AnthropicCompatibleChatModelInner {
14 model: String,
15 max_tokens: u32,
16 client: AnthropicCompatibleClient,
17}
18
19#[derive(Debug, Clone)]
21pub struct AnthropicCompatibleChatModel {
22 inner: Arc<AnthropicCompatibleChatModelInner>,
23}
24
25impl AnthropicCompatibleChatModel {
26 pub fn builder() -> AnthropicCompatibleChatModelBuilder<false> {
28 AnthropicCompatibleChatModelBuilder::new()
29 }
30}
31
32#[derive(Debug, Default)]
34pub struct AnthropicCompatibleChatModelBuilder<const WITH_NAME: bool> {
35 model: Option<String>,
36 max_tokens: u32,
37 client: AnthropicCompatibleClient,
38}
39
40impl AnthropicCompatibleChatModelBuilder<false> {
41 pub fn new() -> Self {
43 Self {
44 model: None,
45 max_tokens: 8192,
46 client: Default::default(),
47 }
48 }
49}
50
51impl<const WITH_NAME: bool> AnthropicCompatibleChatModelBuilder<WITH_NAME> {
52 pub fn with_model(self, model: impl ToString) -> AnthropicCompatibleChatModelBuilder<true> {
54 AnthropicCompatibleChatModelBuilder {
55 model: Some(model.to_string()),
56 max_tokens: self.max_tokens,
57 client: self.client,
58 }
59 }
60
61 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
63 self.max_tokens = max_tokens;
64 self
65 }
66
67 pub fn with_claude_3_5_sonnet(self) -> AnthropicCompatibleChatModelBuilder<true> {
69 self.with_model("claude-3-5-sonnet-20241022")
70 }
71
72 pub fn with_claude_3_5_haiku(self) -> AnthropicCompatibleChatModelBuilder<true> {
74 self.with_model("claude-3-5-haiku-20241022")
75 }
76
77 pub fn with_claude_3_opus(self) -> AnthropicCompatibleChatModelBuilder<true> {
79 self.with_model("claude-3-opus-20240229")
80 .with_max_tokens(4096)
81 }
82
83 pub fn with_claude_3_sonnet(self) -> AnthropicCompatibleChatModelBuilder<true> {
85 self.with_model("claude-3-sonnet-20240229")
86 .with_max_tokens(4096)
87 }
88
89 pub fn with_claude_3_haiku(self) -> AnthropicCompatibleChatModelBuilder<true> {
91 self.with_model("claude-3-haiku-20240307")
92 .with_max_tokens(4096)
93 }
94
95 pub fn with_client(mut self, client: AnthropicCompatibleClient) -> Self {
97 self.client = client;
98 self
99 }
100}
101
102impl AnthropicCompatibleChatModelBuilder<true> {
103 pub fn build(self) -> AnthropicCompatibleChatModel {
105 AnthropicCompatibleChatModel {
106 inner: Arc::new(AnthropicCompatibleChatModelInner {
107 model: self.model.unwrap(),
108 max_tokens: self.max_tokens,
109 client: self.client,
110 }),
111 }
112 }
113}
114
115impl ModelBuilder for AnthropicCompatibleChatModelBuilder<true> {
116 type Model = AnthropicCompatibleChatModel;
117 type Error = std::convert::Infallible;
118
119 async fn start_with_loading_handler(
120 self,
121 _: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
122 ) -> Result<Self::Model, Self::Error> {
123 Ok(self.build())
124 }
125
126 fn requires_download(&self) -> bool {
127 false
128 }
129}
130
131#[derive(Error, Debug)]
133pub enum AnthropicCompatibleChatModelError {
134 #[error("Error resolving API key: {0}")]
136 APIKeyError(#[from] NoAnthropicAPIKeyError),
137 #[error("Error making request: {0}")]
139 ReqwestError(#[from] reqwest::Error),
140 #[error("Error receiving server side events: {0}")]
142 EventSourceError(#[from] reqwest_eventsource::Error),
143 #[error("Failed to deserialize Anthropic API response: {0}")]
145 DeserializeError(#[from] serde_json::Error),
146 #[error("Error streaming response from Anthropic API: {0}")]
148 StreamError(#[from] AnthropicCompatibleChatResponseError),
149}
150
151#[derive(Serialize, Deserialize, Clone)]
153pub struct AnthropicCompatibleChatSession {
154 messages: Vec<crate::ChatMessage>,
155}
156
157impl AnthropicCompatibleChatSession {
158 fn new() -> Self {
159 Self {
160 messages: Vec::new(),
161 }
162 }
163}
164
165impl ChatSession for AnthropicCompatibleChatSession {
166 type Error = serde_json::Error;
167
168 fn write_to(&self, into: &mut Vec<u8>) -> Result<(), Self::Error> {
169 let json = serde_json::to_vec(self)?;
170 into.extend_from_slice(&json);
171 Ok(())
172 }
173
174 fn from_bytes(bytes: &[u8]) -> Result<Self, Self::Error>
175 where
176 Self: std::marker::Sized,
177 {
178 let json = serde_json::from_slice(bytes)?;
179 Ok(json)
180 }
181
182 fn history(&self) -> Vec<crate::ChatMessage> {
183 self.messages.clone()
184 }
185
186 fn try_clone(&self) -> Result<Self, Self::Error>
187 where
188 Self: std::marker::Sized,
189 {
190 Ok(self.clone())
191 }
192}
193
194impl CreateChatSession for AnthropicCompatibleChatModel {
195 type ChatSession = AnthropicCompatibleChatSession;
196 type Error = AnthropicCompatibleChatModelError;
197
198 fn new_chat_session(&self) -> Result<Self::ChatSession, Self::Error> {
199 Ok(AnthropicCompatibleChatSession::new())
200 }
201}
202
203#[derive(Serialize, Deserialize)]
204#[serde(tag = "type")]
205enum AnthropicCompatibleChatResponse {
206 #[serde(rename = "content_block_delta")]
207 ContentBlockDelta(AnthropicCompatibleChatResponseContentBlockDelta),
208 #[serde(rename = "content_block_stop")]
209 ContentBlockStop,
210 #[serde(rename = "error")]
211 Error(AnthropicCompatibleChatResponseError),
212 #[serde(other)]
213 Unknown,
214}
215
216#[derive(Serialize, Deserialize, Error, Debug)]
218#[serde(tag = "type")]
219#[non_exhaustive]
220pub enum AnthropicCompatibleChatResponseError {
221 #[serde(rename = "invalid_request_error")]
223 #[error("Invalid request error: {message}")]
224 InvalidRequestError {
225 message: String,
227 },
228 #[serde(rename = "authentication_error")]
230 #[error("Authentication error: {message}")]
231 AuthenticationError {
232 message: String,
234 },
235 #[serde(rename = "permission_error")]
237 #[error("Permission error: {message}")]
238 PermissionError {
239 message: String,
241 },
242 #[serde(rename = "not_found_error")]
244 #[error("Not found error: {message}")]
245 NotFoundError {
246 message: String,
248 },
249 #[serde(rename = "request_too_large")]
251 #[error("Request too large: {message}")]
252 RequestTooLarge {
253 message: String,
255 },
256 #[serde(rename = "rate_limit_error")]
258 #[error("Rate limit error: {message}")]
259 RateLimitError {
260 message: String,
262 },
263 #[serde(rename = "api_error")]
265 #[error("API error: {message}")]
266 ApiError {
267 message: String,
269 },
270 #[serde(rename = "overloaded_error")]
272 #[error("Overloaded error: {message}")]
273 OverloadedError {
274 message: String,
276 },
277 #[serde(other)]
279 #[error("Unknown error")]
280 Unknown,
281}
282
283#[derive(Serialize, Deserialize)]
284struct AnthropicCompatibleChatResponseContentBlockDelta {
285 index: u32,
286 delta: AnthropicCompatibleChatResponseContentBlockDeltaMessage,
287}
288
289#[derive(Serialize, Deserialize)]
290#[serde(tag = "type")]
291enum AnthropicCompatibleChatResponseContentBlockDeltaMessage {
292 #[serde(rename = "text_delta")]
293 TextDelta { text: String },
294 #[serde(other)]
295 Unknown,
296}
297
298#[derive(Serialize, Deserialize)]
299enum FinishReason {
300 #[serde(rename = "content_filter")]
301 ContentFilter,
302 #[serde(rename = "function_call")]
303 FunctionCall,
304 #[serde(rename = "length")]
305 MaxTokens,
306 #[serde(rename = "stop")]
307 Stop,
308}
309
310#[derive(Serialize, Deserialize)]
311struct AnthropicCompatibleChatResponseChoiceMessage {
312 content: Option<String>,
313 refusal: Option<String>,
314}
315
316impl ChatModel<GenerationParameters> for AnthropicCompatibleChatModel {
317 fn add_messages_with_callback<'a>(
318 &'a self,
319 session: &'a mut Self::ChatSession,
320 messages: &[ChatMessage],
321 sampler: GenerationParameters,
322 mut on_token: impl FnMut(String) -> Result<(), Self::Error> + Send + Sync + 'static,
323 ) -> impl Future<Output = Result<(), Self::Error>> + Send + 'a {
324 let mut system_prompt = None;
325 let messages: Vec<_> = messages
326 .iter()
327 .filter(|message| {
328 if let crate::MessageType::SystemPrompt = message.role() {
329 system_prompt = Some(message.content().to_string());
330 false
331 } else {
332 true
333 }
334 })
335 .collect();
336 let myself = &*self.inner;
337 let mut json = serde_json::json!({
338 "model": myself.model,
339 "messages": messages,
340 "stream": true,
341 "top_p": sampler.top_p,
342 "top_k": sampler.top_k,
343 "temperature": sampler.temperature,
344 "max_tokens": sampler.max_length.min(myself.max_tokens),
345 });
346
347 async move {
348 let api_key = myself.client.resolve_api_key()?;
349 if let Some(stop_on) = sampler.stop_on.as_ref() {
350 json["stop"] = vec![stop_on.clone()].into();
351 }
352 if let Some(system) = system_prompt {
353 json["system"] = system.into();
354 }
355 let mut event_source = myself
356 .client
357 .reqwest_client
358 .post(format!("{}/messages", myself.client.base_url()))
359 .header("Content-Type", "application/json")
360 .header("x-api-key", api_key)
361 .header("anthropic-version", myself.client.version())
362 .json(&json)
363 .eventsource()
364 .unwrap();
365
366 let mut new_message_text = String::new();
367
368 while let Some(event) = event_source.next().await {
369 match event? {
370 Event::Open => {}
371 Event::Message(message) => {
372 let data =
373 serde_json::from_str::<AnthropicCompatibleChatResponse>(&message.data)?;
374 match data {
375 AnthropicCompatibleChatResponse::ContentBlockDelta(
376 anthropic_compatible_chat_response_content_block_delta,
377 ) => {
378 match anthropic_compatible_chat_response_content_block_delta.delta {
379 AnthropicCompatibleChatResponseContentBlockDeltaMessage::TextDelta { text } => {
380 new_message_text += &text;
381 on_token(text)?;
382 },
383 AnthropicCompatibleChatResponseContentBlockDeltaMessage::Unknown => tracing::trace!("Unknown delta from Anthropic API: {:?}", message.data),
384 }
385 }
386 AnthropicCompatibleChatResponse::ContentBlockStop => {
387 break;
388 }
389 AnthropicCompatibleChatResponse::Error(
390 anthropic_compatible_chat_response_error,
391 ) => {
392 return Err(AnthropicCompatibleChatModelError::StreamError(
393 anthropic_compatible_chat_response_error,
394 ))
395 }
396 AnthropicCompatibleChatResponse::Unknown => tracing::trace!(
397 "Unknown response from Anthropic API: {:?}",
398 message.data
399 ),
400 }
401 }
402 }
403 }
404
405 let new_message =
406 crate::ChatMessage::new(crate::MessageType::UserMessage, new_message_text);
407
408 session.messages.push(new_message);
409
410 Ok(())
411 }
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use std::sync::{Arc, RwLock};
418
419 use super::{
420 AnthropicCompatibleChatModelBuilder, ChatModel, CreateChatSession, GenerationParameters,
421 };
422
423 #[tokio::test]
424 async fn test_claude_3_5_haiku() {
425 let model = AnthropicCompatibleChatModelBuilder::new()
426 .with_claude_3_5_haiku()
427 .build();
428
429 let mut session = model.new_chat_session().unwrap();
430
431 let messages = vec![
432 crate::ChatMessage::new(
433 crate::MessageType::SystemPrompt,
434 "Respond like a pirate.".to_string(),
435 ),
436 crate::ChatMessage::new(crate::MessageType::UserMessage, "Hello, world!".to_string()),
437 ];
438 let all_text = Arc::new(RwLock::new(String::new()));
439 model
440 .add_messages_with_callback(&mut session, &messages, GenerationParameters::default(), {
441 let all_text = all_text.clone();
442 move |token| {
443 let mut all_text = all_text.write().unwrap();
444 all_text.push_str(&token);
445 print!("{token}");
446 std::io::Write::flush(&mut std::io::stdout()).unwrap();
447 Ok(())
448 }
449 })
450 .await
451 .unwrap();
452
453 let all_text = all_text.read().unwrap();
454 println!("{all_text}");
455
456 assert!(!all_text.is_empty());
457 }
458}