1use std::{
2 future::Future,
3 sync::{Arc, RwLock},
4};
5
6use crate::{model::LlamaModelError, session::LlamaSessionLoadingError, Llama, LlamaSession};
7use kalosm_common::accelerated_device_if_available;
8use kalosm_language_model::{
9 ChatMessage, ChatModel, ChatSession, CreateChatSession, CreateTextCompletionSession,
10 MessageType, StructuredChatModel, StructuredTextCompletionModel, TextCompletionModel,
11};
12use kalosm_sample::{CreateParserState, Parser};
13use llm_samplers::types::Sampler;
14use minijinja::ErrorKind;
15
16#[cfg(test)]
17use pretty_assertions::assert_eq;
18
19fn get_new_tokens(
20 messages: &[ChatMessage],
21 session: &mut LlamaChatSession,
22 model: &Llama,
23) -> Result<String, LlamaModelError> {
24 let chat_template = model
25 .config
26 .chat_template
27 .as_ref()
28 .ok_or(LlamaModelError::NoChatTemplate)?;
29 let bos_token = &model.config.start_token_string;
30 let eos_token = &model.config.stop_token_string;
31 let current_text = if session.history.is_empty() {
32 String::new()
33 } else {
34 let old_formatted_text =
35 chat_template.format(bos_token, eos_token, &session.history, true)?;
36 let (before_last_eos, _) = old_formatted_text
38 .rsplit_once(eos_token)
39 .unwrap_or((&old_formatted_text, ""));
40 before_last_eos.to_string() + eos_token
41 };
42 session.history.extend_from_slice(messages);
43 let updated_text = chat_template.format(bos_token, eos_token, &session.history, true)?;
44 let new_text = updated_text.strip_prefix(¤t_text).ok_or_else(|| {
45 LlamaModelError::ChatTemplateError(minijinja::Error::new(
46 ErrorKind::InvalidOperation,
47 format!("Chat template should only add text to the end of the current text. Old text: {current_text}, new text: {updated_text}"),
48 ))
49 })?;
50
51 Ok(new_text.to_string())
52}
53
54impl CreateChatSession for Llama {
55 type Error = LlamaModelError;
56 type ChatSession = LlamaChatSession;
57
58 fn new_chat_session(&self) -> Result<Self::ChatSession, Self::Error> {
59 Ok(LlamaChatSession::new(self.new_session()?))
60 }
61}
62
63impl<S: Sampler + 'static> ChatModel<S> for Llama {
64 fn add_messages_with_callback<'a>(
65 &'a self,
66 session: &'a mut Self::ChatSession,
67 messages: &[ChatMessage],
68 sampler: S,
69 mut on_token: impl FnMut(String) -> Result<(), Self::Error> + Send + Sync + 'static,
70 ) -> impl Future<Output = Result<(), Self::Error>> + Send + 'a {
71 let new_text = get_new_tokens(messages, session, self);
72 async move {
73 let new_text = new_text?;
74 let model_response = Arc::new(RwLock::new(String::new()));
75 let on_token = {
76 let model_response = model_response.clone();
77 move |token: String| {
78 let mut model_response = model_response.write().unwrap();
79 *model_response += &token;
80 on_token(token)
81 }
82 };
83 self.stream_text_with_callback(&mut session.session, &new_text, sampler, on_token)
84 .await?;
85 session.history.push(ChatMessage::new(
86 MessageType::ModelAnswer,
87 model_response.read().unwrap().clone(),
88 ));
89 Ok(())
90 }
91 }
92}
93
94impl<S, Constraints> StructuredChatModel<Constraints, S> for Llama
95where
96 <Constraints as Parser>::Output: Send,
97 Constraints: CreateParserState + Send + 'static,
98 S: Sampler + 'static,
99{
100 fn add_message_with_callback_and_constraints<'a>(
101 &'a self,
102 session: &'a mut Self::ChatSession,
103 messages: &[ChatMessage],
104 sampler: S,
105 constraints: Constraints,
106 mut on_token: impl FnMut(String) -> Result<(), Self::Error> + Send + Sync + 'static,
107 ) -> impl Future<
108 Output = Result<
109 <Constraints as kalosm_language_model::ModelConstraints>::Output,
110 Self::Error,
111 >,
112 > + Send
113 + 'a {
114 let new_text = get_new_tokens(messages, session, self);
115 async move {
116 let new_text = new_text?;
117 let model_response = Arc::new(RwLock::new(String::new()));
118 let on_token = {
119 let model_response = model_response.clone();
120 move |token: String| {
121 let mut model_response = model_response.write().unwrap();
122 *model_response += &token;
123 on_token(token)
124 }
125 };
126 let result = self
127 .stream_text_with_callback_and_parser(
128 &mut session.session,
129 &new_text,
130 sampler,
131 constraints,
132 on_token,
133 )
134 .await?;
135 session.history.push(ChatMessage::new(
136 MessageType::ModelAnswer,
137 model_response.read().unwrap().clone(),
138 ));
139 Ok(result)
140 }
141 }
142}
143
144#[derive(Clone)]
146pub struct LlamaChatSession {
147 history: Vec<ChatMessage>,
148 session: LlamaSession,
149}
150
151impl ChatSession for LlamaChatSession {
152 type Error = LlamaSessionLoadingError;
153
154 fn write_to(&self, into: &mut Vec<u8>) -> Result<(), Self::Error> {
155 let device = accelerated_device_if_available()?;
156
157 let history_items = self.history.len() as u32;
158 let mut all_bytes = Vec::new();
159 all_bytes.extend_from_slice(&history_items.to_le_bytes());
160 for item in &self.history {
161 let ty = match item.role() {
162 MessageType::UserMessage => 0u8,
163 MessageType::ModelAnswer => 1,
164 MessageType::SystemPrompt => 2,
165 };
166 all_bytes.extend_from_slice(&ty.to_le_bytes());
167 let content_bytes = item.content().as_bytes();
168 let content_bytes_len = content_bytes.len() as u32;
169 all_bytes.extend_from_slice(&content_bytes_len.to_le_bytes());
170 all_bytes.extend_from_slice(content_bytes);
171 }
172
173 let tensors = self.session.get_tensor_map(&device);
174 let bytes = safetensors::serialize(&tensors, &None)?;
175 all_bytes.extend_from_slice(&bytes);
176
177 into.extend_from_slice(&all_bytes);
178
179 Ok(())
180 }
181
182 fn from_bytes(bytes: &[u8]) -> Result<Self, Self::Error>
183 where
184 Self: std::marker::Sized,
185 {
186 let mut history_items = Vec::new();
187 let mut cursor_pos = 0;
188 let history_item_count = u32::from_le_bytes(
189 bytes
190 .get(..4)
191 .ok_or(LlamaSessionLoadingError::InvalidChatMessages)?
192 .try_into()
193 .map_err(|_| LlamaSessionLoadingError::InvalidChatMessages)?,
194 );
195 cursor_pos += 4;
196 history_items.reserve(history_item_count as usize);
197 for _ in 0..history_item_count {
198 let ty = bytes[cursor_pos];
199 let ty = match ty {
200 0 => MessageType::UserMessage,
201 1 => MessageType::ModelAnswer,
202 2 => MessageType::SystemPrompt,
203 _ => return Err(LlamaSessionLoadingError::InvalidChatMessages),
204 };
205 cursor_pos += 1;
206 let content_bytes_len = u32::from_le_bytes(
207 bytes[cursor_pos..cursor_pos + 4]
208 .try_into()
209 .map_err(|_| LlamaSessionLoadingError::InvalidChatMessages)?,
210 );
211 cursor_pos += 4;
212 let content_bytes = &bytes[cursor_pos..cursor_pos + content_bytes_len as usize];
213 cursor_pos += content_bytes_len as usize;
214 let item = ChatMessage::new(
215 ty,
216 String::from_utf8(content_bytes.to_vec())
217 .map_err(|_| LlamaSessionLoadingError::InvalidChatMessages)?,
218 );
219 history_items.push(item);
220 }
221
222 let device = accelerated_device_if_available()?;
223 let tensors = candle_core::safetensors::load_buffer(&bytes[cursor_pos..], &device)?;
224
225 let session = LlamaSession::from_tensor_map(tensors)?;
226
227 Ok(Self {
228 history: history_items,
229 session,
230 })
231 }
232
233 fn history(&self) -> Vec<ChatMessage> {
234 self.history.clone()
235 }
236
237 fn try_clone(&self) -> Result<Self, Self::Error>
238 where
239 Self: std::marker::Sized,
240 {
241 Ok(self.clone())
242 }
243}
244
245#[test]
246fn test_serialize_deserialize_chat_session() {
247 use crate::raw::LlamaConfig;
248
249 let config = LlamaConfig::mock_test();
250 let session = LlamaChatSession {
251 history: vec![
252 ChatMessage::new(MessageType::UserMessage, "Hello, world!".to_string()),
253 ChatMessage::new(
254 MessageType::ModelAnswer,
255 "I'm doing great. How can I help you today?".to_string(),
256 ),
257 ChatMessage::new(
258 MessageType::SystemPrompt,
259 "The assistant will act like a pirate.".to_string(),
260 ),
261 ],
262 session: LlamaSession::new(&config),
263 };
264
265 let bytes = session.to_bytes().unwrap();
266 let session = LlamaChatSession::from_bytes(&bytes).unwrap();
267
268 assert_eq!(session.history, session.history);
269}
270
271impl LlamaChatSession {
272 #[allow(clippy::too_many_arguments)]
273 fn new(session: LlamaSession) -> Self {
275 Self {
276 history: Vec::new(),
277 session,
278 }
279 }
280}