burn_lm_inference/
message.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Clone, Debug, Serialize, Deserialize, strum::Display, PartialEq)]
4#[serde(rename_all = "snake_case")]
5pub enum MessageRole {
6    System,
7    User,
8    Assistant,
9    Tool,
10    Unknown(String),
11}
12
13#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct Message {
15    pub role: MessageRole,
16    pub content: String,
17    pub refusal: Option<String>,
18}
19
20impl Message {
21    /// Update 'content' to be the text between the first occurrence of 'start'
22    /// and the last occurrence of 'end', excluding both markers.
23    /// If either marker is empty, not found, or in the wrong order,
24    /// the content remains unchanged.
25    pub fn cleanup(&mut self, start: &str, end: &str) {
26        if start.is_empty() || end.is_empty() {
27            return;
28        }
29
30        if let Some(start_index) = self.content.find(start) {
31            let content_start = start_index + start.len();
32            if let Some(last_end_index) = self.content.rfind(end) {
33                if last_end_index >= content_start {
34                    // Update content to be the text between the markers
35                    self.content = self.content[content_start..last_end_index].to_string();
36                }
37            }
38        }
39    }
40}
41
42#[cfg(test)]
43mod tests {
44    use super::*;
45    use rstest::*;
46
47    #[rstest(
48        initial_content,
49        start,
50        end,
51        expected_content,
52        case::markers_found(
53            "Hello, [start]This is a test[end] Goodbye",
54            "[start]",
55            "[end]",
56            "This is a test"
57        ),
58        case::start_marker_not_found(
59            "Hello, This is a test[end] Goodbye",
60            "[start]",
61            "[end]",
62            "Hello, This is a test[end] Goodbye"
63        ),
64        case::end_marker_not_found(
65            "Hello, [start]This is a test. Goodbye",
66            "[start]",
67            "[end]",
68            "Hello, [start]This is a test. Goodbye"
69        ),
70        case::both_markers_not_found(
71            "Hello, world! This is a test.",
72            "[start]",
73            "[end]",
74            "Hello, world! This is a test."
75        ),
76        case::empty_start_marker(
77            "Hello, [start]This is a test[end] Goodbye",
78            "",
79            "[end]",
80            "Hello, [start]This is a test[end] Goodbye"
81        ),
82        case::empty_end_marker(
83            "Hello, [start]This is a test[end] Goodbye",
84            "[start]",
85            "",
86            "Hello, [start]This is a test[end] Goodbye"
87        ),
88        case::multiple_occurrences(
89            "Ignore [start]Keep this[end] and [start]not this[end] end part",
90            "[start]",
91            "[end]",
92            "Keep this[end] and [start]not this"
93        ),
94        case::end_marker_before_start(
95            "Hello [end] there [start] world",
96            "[start]",
97            "[end]",
98            "Hello [end] there [start] world"
99        ),
100        case::same_marker("abcXdefXghi", "X", "X", "def")
101    )]
102    fn test_cleanup(initial_content: &str, start: &str, end: &str, expected_content: &str) {
103        let mut msg = Message {
104            role: MessageRole::User,
105            content: initial_content.to_string(),
106            refusal: None,
107        };
108        msg.cleanup(start, end);
109        assert_eq!(
110            msg.content, expected_content,
111            "Content should be '{expected_content}' after cleaning up with start '{start}' and end '{end}'"
112        );
113    }
114}