llm_chain/prompt/
chat.rs

1use serde::{Deserialize, Serialize};
2use std::collections::VecDeque;
3use std::fmt;
4
5use crate::tokens::{Tokenizer, TokenizerError};
6
7use super::{StringTemplate, StringTemplateError};
8use crate::Parameters;
9
10/// The `ChatRole` enum represents the role of a chat message sender in a conversation.
11///
12/// It has four variants:
13/// - `User`: Represents a message sent by a user.
14/// - `Assistant`: Represents a message sent by an AI assistant.
15/// - `System`: Represents a message sent by a system or service.
16/// - `Other`: Represents a message sent by any other role, specified by a string.
17#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
18pub enum ChatRole {
19    User,
20    Assistant,
21    System,
22    Other(String),
23}
24
25impl fmt::Display for ChatRole {
26    /// Formats the `ChatRole` enum as a string.
27    ///
28    /// # Examples
29    ///
30    /// ```
31    /// use llm_chain::prompt::ChatRole;
32    ///
33    /// let user_role = ChatRole::User;
34    /// let assistant_role = ChatRole::Assistant;
35    ///
36    /// assert_eq!(format!("{}", user_role), "User");
37    /// assert_eq!(format!("{}", assistant_role), "Assistant");
38    /// ```
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        match self {
41            ChatRole::User => write!(f, "User"),
42            ChatRole::Assistant => write!(f, "Assistant"),
43            ChatRole::System => write!(f, "System"),
44            ChatRole::Other(s) => write!(f, "{}", s),
45        }
46    }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50/// The `ChatMessage` struct represents a chat message.
51/// It has two fields:
52/// - `role`: The role of the message sender.
53/// - `body`: The body of the message.
54pub struct ChatMessage<Body> {
55    role: ChatRole,
56    body: Body,
57}
58
59impl<Body> ChatMessage<Body> {
60    /// Creates a new chat message.
61    ///
62    /// # Arguments
63    /// * `role` - The role of the message sender.
64    /// * `body` - The body of the message.
65    pub fn new(role: ChatRole, body: Body) -> Self {
66        Self { role, body }
67    }
68
69    /// Creates a new chat message with the role of `Assistant`.
70    ///
71    /// # Arguments
72    /// * `body` - The body of the message.
73    ///
74    /// # Example
75    ///
76    /// ```
77    /// use llm_chain::prompt::{ChatMessage, ChatRole};
78    /// let msg = ChatMessage::assistant("Hello, how can I help you?");
79    ///
80    /// assert_eq!(msg.role(), &ChatRole::Assistant);
81    /// ```
82    pub fn assistant(body: Body) -> Self {
83        Self::new(ChatRole::Assistant, body)
84    }
85
86    /// Creates a new chat message with the role of `User`.
87    ///
88    /// # Arguments
89    /// * `body` - The body of the message.
90    ///
91    /// # Example
92    ///
93    /// ```
94    /// use llm_chain::prompt::{ChatMessage, ChatRole};
95    /// let msg = ChatMessage::user("What's the weather like today?");
96    ///
97    /// assert_eq!(msg.role(), &ChatRole::User);
98    /// ```
99    pub fn user(body: Body) -> Self {
100        Self::new(ChatRole::User, body)
101    }
102
103    /// Creates a new chat message with the role of `System`.
104    ///
105    /// # Arguments
106    /// * `body` - The body of the message.
107    ///
108    /// # Example
109    ///
110    /// ```
111    /// use llm_chain::prompt::{ChatMessage, ChatRole};
112    /// let msg = ChatMessage::system("Session started.");
113    ///
114    /// assert_eq!(msg.role(), &ChatRole::System);
115    /// ```
116    pub fn system(body: Body) -> Self {
117        Self::new(ChatRole::System, body)
118    }
119
120    /// Maps the body of the chat message using the provided function `f`.
121    ///
122    /// # Arguments
123    /// * `f` - The function to apply to the message body.
124    ///
125    /// # Example
126    ///
127    /// ```
128    /// use llm_chain::prompt::{ChatMessage, ChatRole};
129    /// let msg = ChatMessage::new(ChatRole::Assistant, "Hello!");
130    /// let mapped_msg = msg.map(|body| body.to_uppercase());
131    ///
132    /// assert_eq!(mapped_msg.body(), "HELLO!");
133    /// ```
134    pub fn map<U, F: FnOnce(&Body) -> U>(&self, f: F) -> ChatMessage<U> {
135        let role = self.role.clone();
136        ChatMessage {
137            role,
138            body: f(&self.body),
139        }
140    }
141
142    /// Applies a fallible function `f` to the body of the chat message and returns a new chat message
143    /// with the mapped body or an error if the function fails.
144    ///
145    /// # Arguments
146    /// * `f` - The fallible function to apply to the message body.
147    pub fn try_map<U, E, F: Fn(&Body) -> Result<U, E>>(&self, f: F) -> Result<ChatMessage<U>, E> {
148        let body = f(&self.body)?;
149        let role = self.role.clone();
150        Ok(ChatMessage { role, body })
151    }
152
153    /// Returns a reference to the role of the message sender.
154    pub fn role(&self) -> &ChatRole {
155        &self.role
156    }
157
158    /// Returns a reference to the body of the message.
159    pub fn body(&self) -> &Body {
160        &self.body
161    }
162}
163
164impl<T: fmt::Display> fmt::Display for ChatMessage<T> {
165    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166        write!(f, "{}: {}", self.role, self.body)
167    }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
171/// A collection of chat messages with various roles (e.g., user, assistant, system).
172pub struct ChatMessageCollection<Body> {
173    messages: VecDeque<ChatMessage<Body>>,
174}
175
176impl<Body> ChatMessageCollection<Body> {
177    /// Creates a new empty `ChatMessageCollection`.
178    pub fn new() -> Self {
179        ChatMessageCollection {
180            messages: VecDeque::new(),
181        }
182    }
183
184    /// Creates a `ChatMessageCollection` from a given vector of `ChatMessage`.
185    ///
186    /// # Arguments
187    ///
188    /// * `messages` - A vector of `ChatMessage` instances to be included in the collection.
189    pub fn for_vector(messages: Vec<ChatMessage<Body>>) -> Self {
190        ChatMessageCollection {
191            messages: messages.into(),
192        }
193    }
194
195    /// Adds a system message to the collection with the given body.
196    ///
197    /// # Arguments
198    ///
199    /// * `body` - The message body to be added as a system message.
200    pub fn with_system(mut self, body: Body) -> Self {
201        self.add_message(ChatMessage::system(body));
202        self
203    }
204
205    /// Adds a user message to the collection with the given body.
206    ///
207    /// # Arguments
208    ///
209    /// * `body` - The message body to be added as a user message.
210    pub fn with_user(mut self, body: Body) -> Self {
211        self.add_message(ChatMessage::user(body));
212        self
213    }
214
215    /// Adds an assistant message to the collection with the given body.
216    ///
217    /// # Arguments
218    ///
219    /// * `body` - The message body to be added as an assistant message.
220    pub fn with_assistant(mut self, body: Body) -> Self {
221        self.add_message(ChatMessage::assistant(body));
222        self
223    }
224
225    /// Appends another ChatMessageCollection to this one
226    ///
227    /// # Arguments
228    /// - `other` - The other ChatMessageCollection to append to this one
229    pub fn append(&mut self, other: ChatMessageCollection<Body>) {
230        self.messages.extend(other.messages);
231    }
232
233    /// Appends a `ChatMessage` to the collection.
234    ///
235    /// # Arguments
236    ///
237    /// * `message` - The `ChatMessage` instance to be added to the collection.
238    pub fn add_message(&mut self, message: ChatMessage<Body>) {
239        self.messages.push_back(message);
240    }
241
242    /// Removes the first message from the collection and returns it, or `None` if the collection is empty.
243    pub fn remove_first_message(&mut self) -> Option<ChatMessage<Body>> {
244        self.messages.pop_front()
245    }
246
247    /// Returns the number of messages in the collection.
248    pub fn len(&self) -> usize {
249        self.messages.len()
250    }
251
252    /// Gets the body of the last message in the collection
253    pub(crate) fn extract_last_body(&self) -> Option<&Body> {
254        self.messages.back().map(|x| &x.body)
255    }
256
257    /// Returns `true` if the collection contains no messages.
258    pub fn is_empty(&self) -> bool {
259        self.messages.is_empty()
260    }
261
262    /// Returns a reference to the message at the specified index, or `None` if the index is out of bounds.
263    ///
264    /// # Arguments
265    ///
266    /// * `index` - The index of the desired message in the collection.
267    pub fn get_message(&self, index: usize) -> Option<&ChatMessage<Body>> {
268        self.messages.get(index)
269    }
270
271    /// Returns an iterator over the messages in the collection.
272    pub fn iter(&self) -> std::collections::vec_deque::Iter<'_, ChatMessage<Body>> {
273        self.messages.iter()
274    }
275
276    /// Creates a new `ChatMessageCollection` with the results of applying a function to each `ChatMessage`.
277    ///
278    /// # Arguments
279    ///
280    /// * `f` - The function to apply to each `ChatMessage`.
281    pub fn map<U, F>(&self, f: F) -> ChatMessageCollection<U>
282    where
283        F: FnMut(&ChatMessage<Body>) -> ChatMessage<U>,
284    {
285        let mapped_messages: VecDeque<ChatMessage<U>> = self.messages.iter().map(f).collect();
286        ChatMessageCollection {
287            messages: mapped_messages,
288        }
289    }
290
291    /// Creates a new `ChatMessageCollection` by applying a fallible function to each message body
292    /// in the current collection. Returns an error if the function fails for any message.
293    ///
294    /// # Arguments
295    ///
296    /// * `f` - The fallible function to apply to each message body.
297    pub fn try_map<U, E, F: Fn(&Body) -> Result<U, E>>(
298        &self,
299        f: F,
300    ) -> Result<ChatMessageCollection<U>, E> {
301        let mut mapped_messages = VecDeque::new();
302
303        for msg in self.messages.iter() {
304            let mapped_msg = msg.try_map(|body| f(body))?;
305
306            mapped_messages.push_back(mapped_msg);
307        }
308
309        Ok(ChatMessageCollection {
310            messages: mapped_messages,
311        })
312    }
313
314    /// Trims the conversation to the specified number of messages by removing the oldest messages.
315    ///
316    /// # Arguments
317    ///
318    /// * `max_number_of_messages` - The desired number of messages to keep in the conversation.
319    pub fn trim_to_max_messages(&mut self, max_number_of_messages: usize) {
320        while self.len() > max_number_of_messages {
321            self.messages.pop_front();
322        }
323    }
324}
325
326impl<Body> Default for ChatMessageCollection<Body> {
327    fn default() -> Self {
328        ChatMessageCollection::new()
329    }
330}
331
332impl<T: fmt::Display> fmt::Display for ChatMessageCollection<T> {
333    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
334        for message in self.messages.iter() {
335            writeln!(f, "{}", message)?;
336        }
337        Ok(())
338    }
339}
340
341/// Implementation of `ChatMessageCollection` for `String`.
342impl ChatMessageCollection<String> {
343    /// Trims the conversation context by removing the oldest messages in the collection
344    /// until the total number of tokens in the remaining messages is less than or equal
345    /// to the specified `max_tokens` limit.
346    ///
347    /// # Arguments
348    ///
349    /// * `tokenizer` - An instance of a `Tokenizer` that is used to tokenize the chat message bodies.
350    /// * `max_tokens` - The maximum number of tokens allowed in the trimmed conversation context.
351    ///
352    /// # Returns
353    ///
354    /// A `Result<(), TokenizerError>` indicating success or failure.
355    pub fn trim_context<Tok>(
356        &mut self,
357        tokenizer: &Tok,
358        max_tokens: i32,
359    ) -> Result<(), TokenizerError>
360    where
361        Tok: Tokenizer,
362    {
363        let mut total_tokens: i32 = 0;
364
365        // Remove the oldest messages from the collection
366        // until the total tokens are within the limit.
367        while let Some(msg) = self.messages.back() {
368            let tokens = tokenizer.tokenize_str(&msg.body)?;
369            total_tokens += tokens.len() as i32;
370            if total_tokens > max_tokens {
371                self.messages.pop_back();
372            } else {
373                break;
374            }
375        }
376        Ok(())
377    }
378
379    /// Adds a user message to the conversation by templating the specified template string and parameters.
380    ///
381    /// # Arguments
382    ///
383    /// * `body` - A template string representing the message body.
384    /// * `parameters` - Parameters used to template the message body
385    ///
386    /// # Returns
387    ///
388    /// Result<Self, StringTemplateError> If Ok()
389    /// A Result containing a modified `ChatMessageCollection` with the new user message added on success, or an error if the body couldn't be templated
390
391    pub fn with_user_template(
392        self,
393        body: &str,
394        parameters: &Parameters,
395    ) -> Result<Self, StringTemplateError> {
396        match StringTemplate::tera(body).format(parameters) {
397            Err(e) => Err(e),
398            Ok(templated_body) => Ok(self.with_user(templated_body)),
399        }
400    }
401
402    /// Adds a system message to the conversation by templating the specified template string and parameters.
403    ///
404    /// # Arguments
405    ///
406    /// * `body` - A template string representing the message body.
407    /// * `parameters` - Parameters used to template the message body
408    ///
409    /// # Returns
410    ///
411    /// Result<Self, StringTemplateError> If Ok()
412    /// A Result containing a modified `ChatMessageCollection` with the new system message added on success, or an error if the body couldn't be templated
413
414    pub fn with_system_template(
415        self,
416        body: &str,
417        parameters: &Parameters,
418    ) -> Result<Self, StringTemplateError> {
419        match StringTemplate::tera(body).format(parameters) {
420            Err(e) => Err(e),
421            Ok(templated_body) => Ok(self.with_system(templated_body)),
422        }
423    }
424
425    /// Adds a assistant message to the conversation by templating the specified template string and parameters.
426    ///
427    /// # Arguments
428    ///
429    /// * `body` - A template string representing the message body.
430    /// * `parameters` - Parameters used to template the message body
431    ///
432    /// # Returns
433    ///
434    /// Result<Self, StringTemplateError> If Ok()
435    /// A Result containing a modified `ChatMessageCollection` with the new assistant message added on success, or an error if the body couldn't be templated
436
437    pub fn with_assistant_template(
438        self,
439        body: &str,
440        parameters: &Parameters,
441    ) -> Result<Self, StringTemplateError> {
442        match StringTemplate::tera(body).format(parameters) {
443            Err(e) => Err(e),
444            Ok(templated_body) => Ok(self.with_assistant(templated_body)),
445        }
446    }
447}
448
449/// Implementation of `ChatMessageCollection` for `StringTemplate`.
450impl ChatMessageCollection<StringTemplate> {
451    /// Adds a user message to the conversation using the specified template string.
452    ///
453    /// # Arguments
454    ///
455    /// * `body` - A template string representing the message body.
456    ///
457    /// # Returns
458    ///
459    /// A modified `ChatMessageCollection` with the new user message added.
460    pub fn with_user_template(self, body: &str) -> Self {
461        self.with_user(StringTemplate::tera(body))
462    }
463
464    /// Adds a system message to the conversation using the specified template string.
465    ///
466    /// # Arguments
467    ///
468    /// * `body` - A template string representing the message body.
469    ///
470    /// # Returns
471    ///
472    /// A modified `ChatMessageCollection` with the new system message added.
473    pub fn with_system_template(self, body: &str) -> Self {
474        self.with_system(StringTemplate::tera(body))
475    }
476
477    /// Adds an assistant message to the conversation using the specified template string.
478    ///
479    /// # Arguments
480    ///
481    /// * `body` - A template string representing the message body.
482    ///
483    /// # Returns
484    ///
485    /// A modified `ChatMessageCollection` with the new assistant message added.
486    pub fn with_assistant_template(self, body: &str) -> Self {
487        self.with_assistant(StringTemplate::tera(body))
488    }
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494
495    #[test]
496    fn test_map() {
497        let msg = ChatMessage::new(ChatRole::Assistant, "Hello!");
498        let mapped_msg = msg.map(|body| body.to_uppercase());
499
500        assert_eq!(mapped_msg.body, "HELLO!");
501        assert_eq!(mapped_msg.role, ChatRole::Assistant);
502    }
503
504    #[test]
505    fn test_chat_message_list() {
506        let mut chat_message_list = ChatMessageCollection::new();
507
508        assert_eq!(chat_message_list.len(), 0);
509
510        chat_message_list.add_message(ChatMessage::new(ChatRole::User, "Hello!"));
511        chat_message_list.add_message(ChatMessage::new(ChatRole::Assistant, "Hi there!"));
512
513        assert_eq!(chat_message_list.len(), 2);
514
515        assert_eq!(chat_message_list.get_message(0).unwrap().body, "Hello!");
516        assert_eq!(chat_message_list.get_message(1).unwrap().body, "Hi there!");
517
518        chat_message_list.remove_first_message();
519        assert_eq!(chat_message_list.len(), 1);
520    }
521
522    #[test]
523    fn test_chat_message_list_map() {
524        let mut chat_message_list = ChatMessageCollection::new();
525
526        chat_message_list.add_message(ChatMessage::new(ChatRole::User, "Hello!"));
527        chat_message_list.add_message(ChatMessage::new(ChatRole::Assistant, "Hi there!"));
528
529        let mapped_list = chat_message_list
530            .map(|msg| ChatMessage::new(msg.role.clone(), format!("{} (mapped)", msg.body)));
531
532        assert_eq!(mapped_list.get_message(0).unwrap().body, "Hello! (mapped)");
533        assert_eq!(
534            mapped_list.get_message(1).unwrap().body,
535            "Hi there! (mapped)"
536        );
537    }
538}