burn_lm_inference/
message.rs1use serde::{Deserialize, Serialize};
2
3#[derive(Clone, Debug, Serialize, Deserialize, strum::Display, PartialEq)]
4#[serde(rename_all = "snake_case")]
5pub enum MessageRole {
6 System,
7 User,
8 Assistant,
9 Tool,
10 Unknown(String),
11}
12
13#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct Message {
15 pub role: MessageRole,
16 pub content: String,
17 pub refusal: Option<String>,
18}
19
20impl Message {
21 pub fn cleanup(&mut self, start: &str, end: &str) {
26 if start.is_empty() || end.is_empty() {
27 return;
28 }
29
30 if let Some(start_index) = self.content.find(start) {
31 let content_start = start_index + start.len();
32 if let Some(last_end_index) = self.content.rfind(end) {
33 if last_end_index >= content_start {
34 self.content = self.content[content_start..last_end_index].to_string();
36 }
37 }
38 }
39 }
40}
41
42#[cfg(test)]
43mod tests {
44 use super::*;
45 use rstest::*;
46
47 #[rstest(
48 initial_content,
49 start,
50 end,
51 expected_content,
52 case::markers_found(
53 "Hello, [start]This is a test[end] Goodbye",
54 "[start]",
55 "[end]",
56 "This is a test"
57 ),
58 case::start_marker_not_found(
59 "Hello, This is a test[end] Goodbye",
60 "[start]",
61 "[end]",
62 "Hello, This is a test[end] Goodbye"
63 ),
64 case::end_marker_not_found(
65 "Hello, [start]This is a test. Goodbye",
66 "[start]",
67 "[end]",
68 "Hello, [start]This is a test. Goodbye"
69 ),
70 case::both_markers_not_found(
71 "Hello, world! This is a test.",
72 "[start]",
73 "[end]",
74 "Hello, world! This is a test."
75 ),
76 case::empty_start_marker(
77 "Hello, [start]This is a test[end] Goodbye",
78 "",
79 "[end]",
80 "Hello, [start]This is a test[end] Goodbye"
81 ),
82 case::empty_end_marker(
83 "Hello, [start]This is a test[end] Goodbye",
84 "[start]",
85 "",
86 "Hello, [start]This is a test[end] Goodbye"
87 ),
88 case::multiple_occurrences(
89 "Ignore [start]Keep this[end] and [start]not this[end] end part",
90 "[start]",
91 "[end]",
92 "Keep this[end] and [start]not this"
93 ),
94 case::end_marker_before_start(
95 "Hello [end] there [start] world",
96 "[start]",
97 "[end]",
98 "Hello [end] there [start] world"
99 ),
100 case::same_marker("abcXdefXghi", "X", "X", "def")
101 )]
102 fn test_cleanup(initial_content: &str, start: &str, end: &str, expected_content: &str) {
103 let mut msg = Message {
104 role: MessageRole::User,
105 content: initial_content.to_string(),
106 refusal: None,
107 };
108 msg.cleanup(start, end);
109 assert_eq!(
110 msg.content, expected_content,
111 "Content should be '{expected_content}' after cleaning up with start '{start}' and end '{end}'"
112 );
113 }
114}