kalosm_language_model/claude/
chat.rs

1use super::{AnthropicCompatibleClient, NoAnthropicAPIKeyError};
2use crate::{
3    ChatMessage, ChatModel, ChatSession, CreateChatSession, GenerationParameters, ModelBuilder,
4};
5use futures_util::StreamExt;
6use kalosm_model_types::ModelLoadingProgress;
7use reqwest_eventsource::{Event, RequestBuilderExt};
8use serde::{Deserialize, Serialize};
9use std::{future::Future, sync::Arc};
10use thiserror::Error;
11
12#[derive(Debug)]
13struct AnthropicCompatibleChatModelInner {
14    model: String,
15    max_tokens: u32,
16    client: AnthropicCompatibleClient,
17}
18
19/// An chat model that uses Anthropic's API for the a remote chat model.
20#[derive(Debug, Clone)]
21pub struct AnthropicCompatibleChatModel {
22    inner: Arc<AnthropicCompatibleChatModelInner>,
23}
24
25impl AnthropicCompatibleChatModel {
26    /// Create a new builder for the Anthropic compatible chat model.
27    pub fn builder() -> AnthropicCompatibleChatModelBuilder<false> {
28        AnthropicCompatibleChatModelBuilder::new()
29    }
30}
31
32/// A builder for an Anthropic compatible chat model.
33#[derive(Debug, Default)]
34pub struct AnthropicCompatibleChatModelBuilder<const WITH_NAME: bool> {
35    model: Option<String>,
36    max_tokens: u32,
37    client: AnthropicCompatibleClient,
38}
39
40impl AnthropicCompatibleChatModelBuilder<false> {
41    /// Creates a new builder
42    pub fn new() -> Self {
43        Self {
44            model: None,
45            max_tokens: 8192,
46            client: Default::default(),
47        }
48    }
49}
50
51impl<const WITH_NAME: bool> AnthropicCompatibleChatModelBuilder<WITH_NAME> {
52    /// Set the name of the model to use.
53    pub fn with_model(self, model: impl ToString) -> AnthropicCompatibleChatModelBuilder<true> {
54        AnthropicCompatibleChatModelBuilder {
55            model: Some(model.to_string()),
56            max_tokens: self.max_tokens,
57            client: self.client,
58        }
59    }
60
61    /// Set the default max tokens to use when generating text.
62    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
63        self.max_tokens = max_tokens;
64        self
65    }
66
67    /// Set the model to `claude-3-5-sonnet-20241022`
68    pub fn with_claude_3_5_sonnet(self) -> AnthropicCompatibleChatModelBuilder<true> {
69        self.with_model("claude-3-5-sonnet-20241022")
70    }
71
72    /// Set the model to `claude-3-5-haiku-20241022`
73    pub fn with_claude_3_5_haiku(self) -> AnthropicCompatibleChatModelBuilder<true> {
74        self.with_model("claude-3-5-haiku-20241022")
75    }
76
77    /// Set the model to `claude-3-opus-20240229`
78    pub fn with_claude_3_opus(self) -> AnthropicCompatibleChatModelBuilder<true> {
79        self.with_model("claude-3-opus-20240229")
80            .with_max_tokens(4096)
81    }
82
83    /// Set the model to `claude-3-sonnet-20240229`
84    pub fn with_claude_3_sonnet(self) -> AnthropicCompatibleChatModelBuilder<true> {
85        self.with_model("claude-3-sonnet-20240229")
86            .with_max_tokens(4096)
87    }
88
89    /// Set the model to `claude-3-haiku-20240307`
90    pub fn with_claude_3_haiku(self) -> AnthropicCompatibleChatModelBuilder<true> {
91        self.with_model("claude-3-haiku-20240307")
92            .with_max_tokens(4096)
93    }
94
95    /// Set the client used to make requests to the Anthropic API.
96    pub fn with_client(mut self, client: AnthropicCompatibleClient) -> Self {
97        self.client = client;
98        self
99    }
100}
101
102impl AnthropicCompatibleChatModelBuilder<true> {
103    /// Build the model.
104    pub fn build(self) -> AnthropicCompatibleChatModel {
105        AnthropicCompatibleChatModel {
106            inner: Arc::new(AnthropicCompatibleChatModelInner {
107                model: self.model.unwrap(),
108                max_tokens: self.max_tokens,
109                client: self.client,
110            }),
111        }
112    }
113}
114
115impl ModelBuilder for AnthropicCompatibleChatModelBuilder<true> {
116    type Model = AnthropicCompatibleChatModel;
117    type Error = std::convert::Infallible;
118
119    async fn start_with_loading_handler(
120        self,
121        _: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
122    ) -> Result<Self::Model, Self::Error> {
123        Ok(self.build())
124    }
125
126    fn requires_download(&self) -> bool {
127        false
128    }
129}
130
131/// An error that can occur when running a [`AnthropicCompatibleChatModel`].
132#[derive(Error, Debug)]
133pub enum AnthropicCompatibleChatModelError {
134    /// An error occurred while resolving the API key.
135    #[error("Error resolving API key: {0}")]
136    APIKeyError(#[from] NoAnthropicAPIKeyError),
137    /// An error occurred while making a request to the Anthropic API.
138    #[error("Error making request: {0}")]
139    ReqwestError(#[from] reqwest::Error),
140    /// An error occurred while receiving server side events from the Anthropic API.
141    #[error("Error receiving server side events: {0}")]
142    EventSourceError(#[from] reqwest_eventsource::Error),
143    /// Failed to deserialize Anthropic API response.
144    #[error("Failed to deserialize Anthropic API response: {0}")]
145    DeserializeError(#[from] serde_json::Error),
146    /// An error occurred while streaming the response from the Anthropic API.
147    #[error("Error streaming response from Anthropic API: {0}")]
148    StreamError(#[from] AnthropicCompatibleChatResponseError),
149}
150
151/// A chat session for the Anthropic compatible chat model.
152#[derive(Serialize, Deserialize, Clone)]
153pub struct AnthropicCompatibleChatSession {
154    messages: Vec<crate::ChatMessage>,
155}
156
157impl AnthropicCompatibleChatSession {
158    fn new() -> Self {
159        Self {
160            messages: Vec::new(),
161        }
162    }
163}
164
165impl ChatSession for AnthropicCompatibleChatSession {
166    type Error = serde_json::Error;
167
168    fn write_to(&self, into: &mut Vec<u8>) -> Result<(), Self::Error> {
169        let json = serde_json::to_vec(self)?;
170        into.extend_from_slice(&json);
171        Ok(())
172    }
173
174    fn from_bytes(bytes: &[u8]) -> Result<Self, Self::Error>
175    where
176        Self: std::marker::Sized,
177    {
178        let json = serde_json::from_slice(bytes)?;
179        Ok(json)
180    }
181
182    fn history(&self) -> Vec<crate::ChatMessage> {
183        self.messages.clone()
184    }
185
186    fn try_clone(&self) -> Result<Self, Self::Error>
187    where
188        Self: std::marker::Sized,
189    {
190        Ok(self.clone())
191    }
192}
193
194impl CreateChatSession for AnthropicCompatibleChatModel {
195    type ChatSession = AnthropicCompatibleChatSession;
196    type Error = AnthropicCompatibleChatModelError;
197
198    fn new_chat_session(&self) -> Result<Self::ChatSession, Self::Error> {
199        Ok(AnthropicCompatibleChatSession::new())
200    }
201}
202
203#[derive(Serialize, Deserialize)]
204#[serde(tag = "type")]
205enum AnthropicCompatibleChatResponse {
206    #[serde(rename = "content_block_delta")]
207    ContentBlockDelta(AnthropicCompatibleChatResponseContentBlockDelta),
208    #[serde(rename = "content_block_stop")]
209    ContentBlockStop,
210    #[serde(rename = "error")]
211    Error(AnthropicCompatibleChatResponseError),
212    #[serde(other)]
213    Unknown,
214}
215
216/// An error that can occur when receiving a stream from the Anthropic API.
217#[derive(Serialize, Deserialize, Error, Debug)]
218#[serde(tag = "type")]
219#[non_exhaustive]
220pub enum AnthropicCompatibleChatResponseError {
221    /// The request was invalid.
222    #[serde(rename = "invalid_request_error")]
223    #[error("Invalid request error: {message}")]
224    InvalidRequestError {
225        /// The error message.
226        message: String,
227    },
228    /// Authentication failed.
229    #[serde(rename = "authentication_error")]
230    #[error("Authentication error: {message}")]
231    AuthenticationError {
232        /// The error message.
233        message: String,
234    },
235    /// An permission error occurred.
236    #[serde(rename = "permission_error")]
237    #[error("Permission error: {message}")]
238    PermissionError {
239        /// The error message.
240        message: String,
241    },
242    /// The resource was not found.
243    #[serde(rename = "not_found_error")]
244    #[error("Not found error: {message}")]
245    NotFoundError {
246        /// The error message.
247        message: String,
248    },
249    /// The request was too large.
250    #[serde(rename = "request_too_large")]
251    #[error("Request too large: {message}")]
252    RequestTooLarge {
253        /// The error message.
254        message: String,
255    },
256    /// The rate limit was exceeded.
257    #[serde(rename = "rate_limit_error")]
258    #[error("Rate limit error: {message}")]
259    RateLimitError {
260        /// The error message.
261        message: String,
262    },
263    /// An API error occurred.
264    #[serde(rename = "api_error")]
265    #[error("API error: {message}")]
266    ApiError {
267        /// The error message.
268        message: String,
269    },
270    /// The server is overloaded.
271    #[serde(rename = "overloaded_error")]
272    #[error("Overloaded error: {message}")]
273    OverloadedError {
274        /// The error message.
275        message: String,
276    },
277    /// An unknown error occurred.
278    #[serde(other)]
279    #[error("Unknown error")]
280    Unknown,
281}
282
283#[derive(Serialize, Deserialize)]
284struct AnthropicCompatibleChatResponseContentBlockDelta {
285    index: u32,
286    delta: AnthropicCompatibleChatResponseContentBlockDeltaMessage,
287}
288
289#[derive(Serialize, Deserialize)]
290#[serde(tag = "type")]
291enum AnthropicCompatibleChatResponseContentBlockDeltaMessage {
292    #[serde(rename = "text_delta")]
293    TextDelta { text: String },
294    #[serde(other)]
295    Unknown,
296}
297
298#[derive(Serialize, Deserialize)]
299enum FinishReason {
300    #[serde(rename = "content_filter")]
301    ContentFilter,
302    #[serde(rename = "function_call")]
303    FunctionCall,
304    #[serde(rename = "length")]
305    MaxTokens,
306    #[serde(rename = "stop")]
307    Stop,
308}
309
310#[derive(Serialize, Deserialize)]
311struct AnthropicCompatibleChatResponseChoiceMessage {
312    content: Option<String>,
313    refusal: Option<String>,
314}
315
316impl ChatModel<GenerationParameters> for AnthropicCompatibleChatModel {
317    fn add_messages_with_callback<'a>(
318        &'a self,
319        session: &'a mut Self::ChatSession,
320        messages: &[ChatMessage],
321        sampler: GenerationParameters,
322        mut on_token: impl FnMut(String) -> Result<(), Self::Error> + Send + Sync + 'static,
323    ) -> impl Future<Output = Result<(), Self::Error>> + Send + 'a {
324        let mut system_prompt = None;
325        let messages: Vec<_> = messages
326            .iter()
327            .filter(|message| {
328                if let crate::MessageType::SystemPrompt = message.role() {
329                    system_prompt = Some(message.content().to_string());
330                    false
331                } else {
332                    true
333                }
334            })
335            .collect();
336        let myself = &*self.inner;
337        let mut json = serde_json::json!({
338            "model": myself.model,
339            "messages": messages,
340            "stream": true,
341            "top_p": sampler.top_p,
342            "top_k": sampler.top_k,
343            "temperature": sampler.temperature,
344            "max_tokens": sampler.max_length.min(myself.max_tokens),
345        });
346
347        async move {
348            let api_key = myself.client.resolve_api_key()?;
349            if let Some(stop_on) = sampler.stop_on.as_ref() {
350                json["stop"] = vec![stop_on.clone()].into();
351            }
352            if let Some(system) = system_prompt {
353                json["system"] = system.into();
354            }
355            let mut event_source = myself
356                .client
357                .reqwest_client
358                .post(format!("{}/messages", myself.client.base_url()))
359                .header("Content-Type", "application/json")
360                .header("x-api-key", api_key)
361                .header("anthropic-version", myself.client.version())
362                .json(&json)
363                .eventsource()
364                .unwrap();
365
366            let mut new_message_text = String::new();
367
368            while let Some(event) = event_source.next().await {
369                match event? {
370                    Event::Open => {}
371                    Event::Message(message) => {
372                        let data =
373                            serde_json::from_str::<AnthropicCompatibleChatResponse>(&message.data)?;
374                        match data {
375                            AnthropicCompatibleChatResponse::ContentBlockDelta(
376                                anthropic_compatible_chat_response_content_block_delta,
377                            ) => {
378                                match anthropic_compatible_chat_response_content_block_delta.delta {
379                                AnthropicCompatibleChatResponseContentBlockDeltaMessage::TextDelta { text } => {
380                                        new_message_text += &text;
381                                        on_token(text)?;
382                                },
383                                AnthropicCompatibleChatResponseContentBlockDeltaMessage::Unknown => tracing::trace!("Unknown delta from Anthropic API: {:?}", message.data),
384                            }
385                            }
386                            AnthropicCompatibleChatResponse::ContentBlockStop => {
387                                break;
388                            }
389                            AnthropicCompatibleChatResponse::Error(
390                                anthropic_compatible_chat_response_error,
391                            ) => {
392                                return Err(AnthropicCompatibleChatModelError::StreamError(
393                                    anthropic_compatible_chat_response_error,
394                                ))
395                            }
396                            AnthropicCompatibleChatResponse::Unknown => tracing::trace!(
397                                "Unknown response from Anthropic API: {:?}",
398                                message.data
399                            ),
400                        }
401                    }
402                }
403            }
404
405            let new_message =
406                crate::ChatMessage::new(crate::MessageType::UserMessage, new_message_text);
407
408            session.messages.push(new_message);
409
410            Ok(())
411        }
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use std::sync::{Arc, RwLock};
418
419    use super::{
420        AnthropicCompatibleChatModelBuilder, ChatModel, CreateChatSession, GenerationParameters,
421    };
422
423    #[tokio::test]
424    async fn test_claude_3_5_haiku() {
425        let model = AnthropicCompatibleChatModelBuilder::new()
426            .with_claude_3_5_haiku()
427            .build();
428
429        let mut session = model.new_chat_session().unwrap();
430
431        let messages = vec![
432            crate::ChatMessage::new(
433                crate::MessageType::SystemPrompt,
434                "Respond like a pirate.".to_string(),
435            ),
436            crate::ChatMessage::new(crate::MessageType::UserMessage, "Hello, world!".to_string()),
437        ];
438        let all_text = Arc::new(RwLock::new(String::new()));
439        model
440            .add_messages_with_callback(&mut session, &messages, GenerationParameters::default(), {
441                let all_text = all_text.clone();
442                move |token| {
443                    let mut all_text = all_text.write().unwrap();
444                    all_text.push_str(&token);
445                    print!("{token}");
446                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
447                    Ok(())
448                }
449            })
450            .await
451            .unwrap();
452
453        let all_text = all_text.read().unwrap();
454        println!("{all_text}");
455
456        assert!(!all_text.is_empty());
457    }
458}