llm_chain/prompt/chat.rs
1use serde::{Deserialize, Serialize};
2use std::collections::VecDeque;
3use std::fmt;
4
5use crate::tokens::{Tokenizer, TokenizerError};
6
7use super::{StringTemplate, StringTemplateError};
8use crate::Parameters;
9
10/// The `ChatRole` enum represents the role of a chat message sender in a conversation.
11///
12/// It has four variants:
13/// - `User`: Represents a message sent by a user.
14/// - `Assistant`: Represents a message sent by an AI assistant.
15/// - `System`: Represents a message sent by a system or service.
16/// - `Other`: Represents a message sent by any other role, specified by a string.
17#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
18pub enum ChatRole {
19 User,
20 Assistant,
21 System,
22 Other(String),
23}
24
25impl fmt::Display for ChatRole {
26 /// Formats the `ChatRole` enum as a string.
27 ///
28 /// # Examples
29 ///
30 /// ```
31 /// use llm_chain::prompt::ChatRole;
32 ///
33 /// let user_role = ChatRole::User;
34 /// let assistant_role = ChatRole::Assistant;
35 ///
36 /// assert_eq!(format!("{}", user_role), "User");
37 /// assert_eq!(format!("{}", assistant_role), "Assistant");
38 /// ```
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 ChatRole::User => write!(f, "User"),
42 ChatRole::Assistant => write!(f, "Assistant"),
43 ChatRole::System => write!(f, "System"),
44 ChatRole::Other(s) => write!(f, "{}", s),
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50/// The `ChatMessage` struct represents a chat message.
51/// It has two fields:
52/// - `role`: The role of the message sender.
53/// - `body`: The body of the message.
54pub struct ChatMessage<Body> {
55 role: ChatRole,
56 body: Body,
57}
58
59impl<Body> ChatMessage<Body> {
60 /// Creates a new chat message.
61 ///
62 /// # Arguments
63 /// * `role` - The role of the message sender.
64 /// * `body` - The body of the message.
65 pub fn new(role: ChatRole, body: Body) -> Self {
66 Self { role, body }
67 }
68
69 /// Creates a new chat message with the role of `Assistant`.
70 ///
71 /// # Arguments
72 /// * `body` - The body of the message.
73 ///
74 /// # Example
75 ///
76 /// ```
77 /// use llm_chain::prompt::{ChatMessage, ChatRole};
78 /// let msg = ChatMessage::assistant("Hello, how can I help you?");
79 ///
80 /// assert_eq!(msg.role(), &ChatRole::Assistant);
81 /// ```
82 pub fn assistant(body: Body) -> Self {
83 Self::new(ChatRole::Assistant, body)
84 }
85
86 /// Creates a new chat message with the role of `User`.
87 ///
88 /// # Arguments
89 /// * `body` - The body of the message.
90 ///
91 /// # Example
92 ///
93 /// ```
94 /// use llm_chain::prompt::{ChatMessage, ChatRole};
95 /// let msg = ChatMessage::user("What's the weather like today?");
96 ///
97 /// assert_eq!(msg.role(), &ChatRole::User);
98 /// ```
99 pub fn user(body: Body) -> Self {
100 Self::new(ChatRole::User, body)
101 }
102
103 /// Creates a new chat message with the role of `System`.
104 ///
105 /// # Arguments
106 /// * `body` - The body of the message.
107 ///
108 /// # Example
109 ///
110 /// ```
111 /// use llm_chain::prompt::{ChatMessage, ChatRole};
112 /// let msg = ChatMessage::system("Session started.");
113 ///
114 /// assert_eq!(msg.role(), &ChatRole::System);
115 /// ```
116 pub fn system(body: Body) -> Self {
117 Self::new(ChatRole::System, body)
118 }
119
120 /// Maps the body of the chat message using the provided function `f`.
121 ///
122 /// # Arguments
123 /// * `f` - The function to apply to the message body.
124 ///
125 /// # Example
126 ///
127 /// ```
128 /// use llm_chain::prompt::{ChatMessage, ChatRole};
129 /// let msg = ChatMessage::new(ChatRole::Assistant, "Hello!");
130 /// let mapped_msg = msg.map(|body| body.to_uppercase());
131 ///
132 /// assert_eq!(mapped_msg.body(), "HELLO!");
133 /// ```
134 pub fn map<U, F: FnOnce(&Body) -> U>(&self, f: F) -> ChatMessage<U> {
135 let role = self.role.clone();
136 ChatMessage {
137 role,
138 body: f(&self.body),
139 }
140 }
141
142 /// Applies a fallible function `f` to the body of the chat message and returns a new chat message
143 /// with the mapped body or an error if the function fails.
144 ///
145 /// # Arguments
146 /// * `f` - The fallible function to apply to the message body.
147 pub fn try_map<U, E, F: Fn(&Body) -> Result<U, E>>(&self, f: F) -> Result<ChatMessage<U>, E> {
148 let body = f(&self.body)?;
149 let role = self.role.clone();
150 Ok(ChatMessage { role, body })
151 }
152
153 /// Returns a reference to the role of the message sender.
154 pub fn role(&self) -> &ChatRole {
155 &self.role
156 }
157
158 /// Returns a reference to the body of the message.
159 pub fn body(&self) -> &Body {
160 &self.body
161 }
162}
163
164impl<T: fmt::Display> fmt::Display for ChatMessage<T> {
165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166 write!(f, "{}: {}", self.role, self.body)
167 }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
171/// A collection of chat messages with various roles (e.g., user, assistant, system).
172pub struct ChatMessageCollection<Body> {
173 messages: VecDeque<ChatMessage<Body>>,
174}
175
176impl<Body> ChatMessageCollection<Body> {
177 /// Creates a new empty `ChatMessageCollection`.
178 pub fn new() -> Self {
179 ChatMessageCollection {
180 messages: VecDeque::new(),
181 }
182 }
183
184 /// Creates a `ChatMessageCollection` from a given vector of `ChatMessage`.
185 ///
186 /// # Arguments
187 ///
188 /// * `messages` - A vector of `ChatMessage` instances to be included in the collection.
189 pub fn for_vector(messages: Vec<ChatMessage<Body>>) -> Self {
190 ChatMessageCollection {
191 messages: messages.into(),
192 }
193 }
194
195 /// Adds a system message to the collection with the given body.
196 ///
197 /// # Arguments
198 ///
199 /// * `body` - The message body to be added as a system message.
200 pub fn with_system(mut self, body: Body) -> Self {
201 self.add_message(ChatMessage::system(body));
202 self
203 }
204
205 /// Adds a user message to the collection with the given body.
206 ///
207 /// # Arguments
208 ///
209 /// * `body` - The message body to be added as a user message.
210 pub fn with_user(mut self, body: Body) -> Self {
211 self.add_message(ChatMessage::user(body));
212 self
213 }
214
215 /// Adds an assistant message to the collection with the given body.
216 ///
217 /// # Arguments
218 ///
219 /// * `body` - The message body to be added as an assistant message.
220 pub fn with_assistant(mut self, body: Body) -> Self {
221 self.add_message(ChatMessage::assistant(body));
222 self
223 }
224
225 /// Appends another ChatMessageCollection to this one
226 ///
227 /// # Arguments
228 /// - `other` - The other ChatMessageCollection to append to this one
229 pub fn append(&mut self, other: ChatMessageCollection<Body>) {
230 self.messages.extend(other.messages);
231 }
232
233 /// Appends a `ChatMessage` to the collection.
234 ///
235 /// # Arguments
236 ///
237 /// * `message` - The `ChatMessage` instance to be added to the collection.
238 pub fn add_message(&mut self, message: ChatMessage<Body>) {
239 self.messages.push_back(message);
240 }
241
242 /// Removes the first message from the collection and returns it, or `None` if the collection is empty.
243 pub fn remove_first_message(&mut self) -> Option<ChatMessage<Body>> {
244 self.messages.pop_front()
245 }
246
247 /// Returns the number of messages in the collection.
248 pub fn len(&self) -> usize {
249 self.messages.len()
250 }
251
252 /// Gets the body of the last message in the collection
253 pub(crate) fn extract_last_body(&self) -> Option<&Body> {
254 self.messages.back().map(|x| &x.body)
255 }
256
257 /// Returns `true` if the collection contains no messages.
258 pub fn is_empty(&self) -> bool {
259 self.messages.is_empty()
260 }
261
262 /// Returns a reference to the message at the specified index, or `None` if the index is out of bounds.
263 ///
264 /// # Arguments
265 ///
266 /// * `index` - The index of the desired message in the collection.
267 pub fn get_message(&self, index: usize) -> Option<&ChatMessage<Body>> {
268 self.messages.get(index)
269 }
270
271 /// Returns an iterator over the messages in the collection.
272 pub fn iter(&self) -> std::collections::vec_deque::Iter<'_, ChatMessage<Body>> {
273 self.messages.iter()
274 }
275
276 /// Creates a new `ChatMessageCollection` with the results of applying a function to each `ChatMessage`.
277 ///
278 /// # Arguments
279 ///
280 /// * `f` - The function to apply to each `ChatMessage`.
281 pub fn map<U, F>(&self, f: F) -> ChatMessageCollection<U>
282 where
283 F: FnMut(&ChatMessage<Body>) -> ChatMessage<U>,
284 {
285 let mapped_messages: VecDeque<ChatMessage<U>> = self.messages.iter().map(f).collect();
286 ChatMessageCollection {
287 messages: mapped_messages,
288 }
289 }
290
291 /// Creates a new `ChatMessageCollection` by applying a fallible function to each message body
292 /// in the current collection. Returns an error if the function fails for any message.
293 ///
294 /// # Arguments
295 ///
296 /// * `f` - The fallible function to apply to each message body.
297 pub fn try_map<U, E, F: Fn(&Body) -> Result<U, E>>(
298 &self,
299 f: F,
300 ) -> Result<ChatMessageCollection<U>, E> {
301 let mut mapped_messages = VecDeque::new();
302
303 for msg in self.messages.iter() {
304 let mapped_msg = msg.try_map(|body| f(body))?;
305
306 mapped_messages.push_back(mapped_msg);
307 }
308
309 Ok(ChatMessageCollection {
310 messages: mapped_messages,
311 })
312 }
313
314 /// Trims the conversation to the specified number of messages by removing the oldest messages.
315 ///
316 /// # Arguments
317 ///
318 /// * `max_number_of_messages` - The desired number of messages to keep in the conversation.
319 pub fn trim_to_max_messages(&mut self, max_number_of_messages: usize) {
320 while self.len() > max_number_of_messages {
321 self.messages.pop_front();
322 }
323 }
324}
325
326impl<Body> Default for ChatMessageCollection<Body> {
327 fn default() -> Self {
328 ChatMessageCollection::new()
329 }
330}
331
332impl<T: fmt::Display> fmt::Display for ChatMessageCollection<T> {
333 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
334 for message in self.messages.iter() {
335 writeln!(f, "{}", message)?;
336 }
337 Ok(())
338 }
339}
340
341/// Implementation of `ChatMessageCollection` for `String`.
342impl ChatMessageCollection<String> {
343 /// Trims the conversation context by removing the oldest messages in the collection
344 /// until the total number of tokens in the remaining messages is less than or equal
345 /// to the specified `max_tokens` limit.
346 ///
347 /// # Arguments
348 ///
349 /// * `tokenizer` - An instance of a `Tokenizer` that is used to tokenize the chat message bodies.
350 /// * `max_tokens` - The maximum number of tokens allowed in the trimmed conversation context.
351 ///
352 /// # Returns
353 ///
354 /// A `Result<(), TokenizerError>` indicating success or failure.
355 pub fn trim_context<Tok>(
356 &mut self,
357 tokenizer: &Tok,
358 max_tokens: i32,
359 ) -> Result<(), TokenizerError>
360 where
361 Tok: Tokenizer,
362 {
363 let mut total_tokens: i32 = 0;
364
365 // Remove the oldest messages from the collection
366 // until the total tokens are within the limit.
367 while let Some(msg) = self.messages.back() {
368 let tokens = tokenizer.tokenize_str(&msg.body)?;
369 total_tokens += tokens.len() as i32;
370 if total_tokens > max_tokens {
371 self.messages.pop_back();
372 } else {
373 break;
374 }
375 }
376 Ok(())
377 }
378
379 /// Adds a user message to the conversation by templating the specified template string and parameters.
380 ///
381 /// # Arguments
382 ///
383 /// * `body` - A template string representing the message body.
384 /// * `parameters` - Parameters used to template the message body
385 ///
386 /// # Returns
387 ///
388 /// Result<Self, StringTemplateError> If Ok()
389 /// A Result containing a modified `ChatMessageCollection` with the new user message added on success, or an error if the body couldn't be templated
390
391 pub fn with_user_template(
392 self,
393 body: &str,
394 parameters: &Parameters,
395 ) -> Result<Self, StringTemplateError> {
396 match StringTemplate::tera(body).format(parameters) {
397 Err(e) => Err(e),
398 Ok(templated_body) => Ok(self.with_user(templated_body)),
399 }
400 }
401
402 /// Adds a system message to the conversation by templating the specified template string and parameters.
403 ///
404 /// # Arguments
405 ///
406 /// * `body` - A template string representing the message body.
407 /// * `parameters` - Parameters used to template the message body
408 ///
409 /// # Returns
410 ///
411 /// Result<Self, StringTemplateError> If Ok()
412 /// A Result containing a modified `ChatMessageCollection` with the new system message added on success, or an error if the body couldn't be templated
413
414 pub fn with_system_template(
415 self,
416 body: &str,
417 parameters: &Parameters,
418 ) -> Result<Self, StringTemplateError> {
419 match StringTemplate::tera(body).format(parameters) {
420 Err(e) => Err(e),
421 Ok(templated_body) => Ok(self.with_system(templated_body)),
422 }
423 }
424
425 /// Adds a assistant message to the conversation by templating the specified template string and parameters.
426 ///
427 /// # Arguments
428 ///
429 /// * `body` - A template string representing the message body.
430 /// * `parameters` - Parameters used to template the message body
431 ///
432 /// # Returns
433 ///
434 /// Result<Self, StringTemplateError> If Ok()
435 /// A Result containing a modified `ChatMessageCollection` with the new assistant message added on success, or an error if the body couldn't be templated
436
437 pub fn with_assistant_template(
438 self,
439 body: &str,
440 parameters: &Parameters,
441 ) -> Result<Self, StringTemplateError> {
442 match StringTemplate::tera(body).format(parameters) {
443 Err(e) => Err(e),
444 Ok(templated_body) => Ok(self.with_assistant(templated_body)),
445 }
446 }
447}
448
449/// Implementation of `ChatMessageCollection` for `StringTemplate`.
450impl ChatMessageCollection<StringTemplate> {
451 /// Adds a user message to the conversation using the specified template string.
452 ///
453 /// # Arguments
454 ///
455 /// * `body` - A template string representing the message body.
456 ///
457 /// # Returns
458 ///
459 /// A modified `ChatMessageCollection` with the new user message added.
460 pub fn with_user_template(self, body: &str) -> Self {
461 self.with_user(StringTemplate::tera(body))
462 }
463
464 /// Adds a system message to the conversation using the specified template string.
465 ///
466 /// # Arguments
467 ///
468 /// * `body` - A template string representing the message body.
469 ///
470 /// # Returns
471 ///
472 /// A modified `ChatMessageCollection` with the new system message added.
473 pub fn with_system_template(self, body: &str) -> Self {
474 self.with_system(StringTemplate::tera(body))
475 }
476
477 /// Adds an assistant message to the conversation using the specified template string.
478 ///
479 /// # Arguments
480 ///
481 /// * `body` - A template string representing the message body.
482 ///
483 /// # Returns
484 ///
485 /// A modified `ChatMessageCollection` with the new assistant message added.
486 pub fn with_assistant_template(self, body: &str) -> Self {
487 self.with_assistant(StringTemplate::tera(body))
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn test_map() {
497 let msg = ChatMessage::new(ChatRole::Assistant, "Hello!");
498 let mapped_msg = msg.map(|body| body.to_uppercase());
499
500 assert_eq!(mapped_msg.body, "HELLO!");
501 assert_eq!(mapped_msg.role, ChatRole::Assistant);
502 }
503
504 #[test]
505 fn test_chat_message_list() {
506 let mut chat_message_list = ChatMessageCollection::new();
507
508 assert_eq!(chat_message_list.len(), 0);
509
510 chat_message_list.add_message(ChatMessage::new(ChatRole::User, "Hello!"));
511 chat_message_list.add_message(ChatMessage::new(ChatRole::Assistant, "Hi there!"));
512
513 assert_eq!(chat_message_list.len(), 2);
514
515 assert_eq!(chat_message_list.get_message(0).unwrap().body, "Hello!");
516 assert_eq!(chat_message_list.get_message(1).unwrap().body, "Hi there!");
517
518 chat_message_list.remove_first_message();
519 assert_eq!(chat_message_list.len(), 1);
520 }
521
522 #[test]
523 fn test_chat_message_list_map() {
524 let mut chat_message_list = ChatMessageCollection::new();
525
526 chat_message_list.add_message(ChatMessage::new(ChatRole::User, "Hello!"));
527 chat_message_list.add_message(ChatMessage::new(ChatRole::Assistant, "Hi there!"));
528
529 let mapped_list = chat_message_list
530 .map(|msg| ChatMessage::new(msg.role.clone(), format!("{} (mapped)", msg.body)));
531
532 assert_eq!(mapped_list.get_message(0).unwrap().body, "Hello! (mapped)");
533 assert_eq!(
534 mapped_list.get_message(1).unwrap().body,
535 "Hi there! (mapped)"
536 );
537 }
538}