kalosm_llama/
chat.rs

1use std::{
2    future::Future,
3    sync::{Arc, RwLock},
4};
5
6use crate::{model::LlamaModelError, session::LlamaSessionLoadingError, Llama, LlamaSession};
7use kalosm_common::accelerated_device_if_available;
8use kalosm_language_model::{
9    ChatMessage, ChatModel, ChatSession, CreateChatSession, CreateTextCompletionSession,
10    MessageType, StructuredChatModel, StructuredTextCompletionModel, TextCompletionModel,
11};
12use kalosm_sample::{CreateParserState, Parser};
13use llm_samplers::types::Sampler;
14use minijinja::ErrorKind;
15
16#[cfg(test)]
17use pretty_assertions::assert_eq;
18
19fn get_new_tokens(
20    messages: &[ChatMessage],
21    session: &mut LlamaChatSession,
22    model: &Llama,
23) -> Result<String, LlamaModelError> {
24    let chat_template = model
25        .config
26        .chat_template
27        .as_ref()
28        .ok_or(LlamaModelError::NoChatTemplate)?;
29    let bos_token = &model.config.start_token_string;
30    let eos_token = &model.config.stop_token_string;
31    let current_text = if session.history.is_empty() {
32        String::new()
33    } else {
34        let old_formatted_text =
35            chat_template.format(bos_token, eos_token, &session.history, true)?;
36        // Some chat templates (like llama v3) always include the generation prompt even when we tell them not to. If they do, try to strip it off
37        let (before_last_eos, _) = old_formatted_text
38            .rsplit_once(eos_token)
39            .unwrap_or((&old_formatted_text, ""));
40        before_last_eos.to_string() + eos_token
41    };
42    session.history.extend_from_slice(messages);
43    let updated_text = chat_template.format(bos_token, eos_token, &session.history, true)?;
44    let new_text = updated_text.strip_prefix(&current_text).ok_or_else(|| {
45        LlamaModelError::ChatTemplateError(minijinja::Error::new(
46            ErrorKind::InvalidOperation,
47            format!("Chat template should only add text to the end of the current text. Old text: {current_text}, new text: {updated_text}"),
48        ))
49    })?;
50
51    Ok(new_text.to_string())
52}
53
54impl CreateChatSession for Llama {
55    type Error = LlamaModelError;
56    type ChatSession = LlamaChatSession;
57
58    fn new_chat_session(&self) -> Result<Self::ChatSession, Self::Error> {
59        Ok(LlamaChatSession::new(self.new_session()?))
60    }
61}
62
63impl<S: Sampler + 'static> ChatModel<S> for Llama {
64    fn add_messages_with_callback<'a>(
65        &'a self,
66        session: &'a mut Self::ChatSession,
67        messages: &[ChatMessage],
68        sampler: S,
69        mut on_token: impl FnMut(String) -> Result<(), Self::Error> + Send + Sync + 'static,
70    ) -> impl Future<Output = Result<(), Self::Error>> + Send + 'a {
71        let new_text = get_new_tokens(messages, session, self);
72        async move {
73            let new_text = new_text?;
74            let model_response = Arc::new(RwLock::new(String::new()));
75            let on_token = {
76                let model_response = model_response.clone();
77                move |token: String| {
78                    let mut model_response = model_response.write().unwrap();
79                    *model_response += &token;
80                    on_token(token)
81                }
82            };
83            self.stream_text_with_callback(&mut session.session, &new_text, sampler, on_token)
84                .await?;
85            session.history.push(ChatMessage::new(
86                MessageType::ModelAnswer,
87                model_response.read().unwrap().clone(),
88            ));
89            Ok(())
90        }
91    }
92}
93
94impl<S, Constraints> StructuredChatModel<Constraints, S> for Llama
95where
96    <Constraints as Parser>::Output: Send,
97    Constraints: CreateParserState + Send + 'static,
98    S: Sampler + 'static,
99{
100    fn add_message_with_callback_and_constraints<'a>(
101        &'a self,
102        session: &'a mut Self::ChatSession,
103        messages: &[ChatMessage],
104        sampler: S,
105        constraints: Constraints,
106        mut on_token: impl FnMut(String) -> Result<(), Self::Error> + Send + Sync + 'static,
107    ) -> impl Future<
108        Output = Result<
109            <Constraints as kalosm_language_model::ModelConstraints>::Output,
110            Self::Error,
111        >,
112    > + Send
113           + 'a {
114        let new_text = get_new_tokens(messages, session, self);
115        async move {
116            let new_text = new_text?;
117            let model_response = Arc::new(RwLock::new(String::new()));
118            let on_token = {
119                let model_response = model_response.clone();
120                move |token: String| {
121                    let mut model_response = model_response.write().unwrap();
122                    *model_response += &token;
123                    on_token(token)
124                }
125            };
126            let result = self
127                .stream_text_with_callback_and_parser(
128                    &mut session.session,
129                    &new_text,
130                    sampler,
131                    constraints,
132                    on_token,
133                )
134                .await?;
135            session.history.push(ChatMessage::new(
136                MessageType::ModelAnswer,
137                model_response.read().unwrap().clone(),
138            ));
139            Ok(result)
140        }
141    }
142}
143
144/// A Llama chat session.
145#[derive(Clone)]
146pub struct LlamaChatSession {
147    history: Vec<ChatMessage>,
148    session: LlamaSession,
149}
150
151impl ChatSession for LlamaChatSession {
152    type Error = LlamaSessionLoadingError;
153
154    fn write_to(&self, into: &mut Vec<u8>) -> Result<(), Self::Error> {
155        let device = accelerated_device_if_available()?;
156
157        let history_items = self.history.len() as u32;
158        let mut all_bytes = Vec::new();
159        all_bytes.extend_from_slice(&history_items.to_le_bytes());
160        for item in &self.history {
161            let ty = match item.role() {
162                MessageType::UserMessage => 0u8,
163                MessageType::ModelAnswer => 1,
164                MessageType::SystemPrompt => 2,
165            };
166            all_bytes.extend_from_slice(&ty.to_le_bytes());
167            let content_bytes = item.content().as_bytes();
168            let content_bytes_len = content_bytes.len() as u32;
169            all_bytes.extend_from_slice(&content_bytes_len.to_le_bytes());
170            all_bytes.extend_from_slice(content_bytes);
171        }
172
173        let tensors = self.session.get_tensor_map(&device);
174        let bytes = safetensors::serialize(&tensors, &None)?;
175        all_bytes.extend_from_slice(&bytes);
176
177        into.extend_from_slice(&all_bytes);
178
179        Ok(())
180    }
181
182    fn from_bytes(bytes: &[u8]) -> Result<Self, Self::Error>
183    where
184        Self: std::marker::Sized,
185    {
186        let mut history_items = Vec::new();
187        let mut cursor_pos = 0;
188        let history_item_count = u32::from_le_bytes(
189            bytes
190                .get(..4)
191                .ok_or(LlamaSessionLoadingError::InvalidChatMessages)?
192                .try_into()
193                .map_err(|_| LlamaSessionLoadingError::InvalidChatMessages)?,
194        );
195        cursor_pos += 4;
196        history_items.reserve(history_item_count as usize);
197        for _ in 0..history_item_count {
198            let ty = bytes[cursor_pos];
199            let ty = match ty {
200                0 => MessageType::UserMessage,
201                1 => MessageType::ModelAnswer,
202                2 => MessageType::SystemPrompt,
203                _ => return Err(LlamaSessionLoadingError::InvalidChatMessages),
204            };
205            cursor_pos += 1;
206            let content_bytes_len = u32::from_le_bytes(
207                bytes[cursor_pos..cursor_pos + 4]
208                    .try_into()
209                    .map_err(|_| LlamaSessionLoadingError::InvalidChatMessages)?,
210            );
211            cursor_pos += 4;
212            let content_bytes = &bytes[cursor_pos..cursor_pos + content_bytes_len as usize];
213            cursor_pos += content_bytes_len as usize;
214            let item = ChatMessage::new(
215                ty,
216                String::from_utf8(content_bytes.to_vec())
217                    .map_err(|_| LlamaSessionLoadingError::InvalidChatMessages)?,
218            );
219            history_items.push(item);
220        }
221
222        let device = accelerated_device_if_available()?;
223        let tensors = candle_core::safetensors::load_buffer(&bytes[cursor_pos..], &device)?;
224
225        let session = LlamaSession::from_tensor_map(tensors)?;
226
227        Ok(Self {
228            history: history_items,
229            session,
230        })
231    }
232
233    fn history(&self) -> Vec<ChatMessage> {
234        self.history.clone()
235    }
236
237    fn try_clone(&self) -> Result<Self, Self::Error>
238    where
239        Self: std::marker::Sized,
240    {
241        Ok(self.clone())
242    }
243}
244
245#[test]
246fn test_serialize_deserialize_chat_session() {
247    use crate::raw::LlamaConfig;
248
249    let config = LlamaConfig::mock_test();
250    let session = LlamaChatSession {
251        history: vec![
252            ChatMessage::new(MessageType::UserMessage, "Hello, world!".to_string()),
253            ChatMessage::new(
254                MessageType::ModelAnswer,
255                "I'm doing great. How can I help you today?".to_string(),
256            ),
257            ChatMessage::new(
258                MessageType::SystemPrompt,
259                "The assistant will act like a pirate.".to_string(),
260            ),
261        ],
262        session: LlamaSession::new(&config),
263    };
264
265    let bytes = session.to_bytes().unwrap();
266    let session = LlamaChatSession::from_bytes(&bytes).unwrap();
267
268    assert_eq!(session.history, session.history);
269}
270
271impl LlamaChatSession {
272    #[allow(clippy::too_many_arguments)]
273    /// Creates a new chat history.
274    fn new(session: LlamaSession) -> Self {
275        Self {
276            history: Vec::new(),
277            session,
278        }
279    }
280}