xai_rust/models/
message.rs1use serde::{Deserialize, Serialize};
4
5use super::content::ContentPart;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9#[serde(rename_all = "lowercase")]
10pub enum Role {
11 System,
13 User,
15 Assistant,
17 Developer,
19 Tool,
21}
22
23impl Role {
24 pub fn is_system(&self) -> bool {
26 matches!(self, Role::System | Role::Developer)
27 }
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32#[serde(untagged)]
33pub enum MessageContent {
34 Text(String),
36 Parts(Vec<ContentPart>),
38}
39
40impl MessageContent {
41 pub fn text(text: impl Into<String>) -> Self {
43 Self::Text(text.into())
44 }
45
46 pub fn parts(parts: Vec<ContentPart>) -> Self {
48 Self::Parts(parts)
49 }
50
51 pub fn as_text(&self) -> Option<&str> {
53 match self {
54 MessageContent::Text(text) => Some(text),
55 MessageContent::Parts(_) => None,
56 }
57 }
58
59 pub fn to_text(&self) -> String {
61 match self {
62 MessageContent::Text(text) => text.clone(),
63 MessageContent::Parts(parts) => parts
64 .iter()
65 .filter_map(|p| p.as_text())
66 .collect::<Vec<_>>()
67 .join(""),
68 }
69 }
70}
71
72impl From<String> for MessageContent {
73 fn from(text: String) -> Self {
74 Self::Text(text)
75 }
76}
77
78impl From<&str> for MessageContent {
79 fn from(text: &str) -> Self {
80 Self::Text(text.to_string())
81 }
82}
83
84impl From<Vec<ContentPart>> for MessageContent {
85 fn from(parts: Vec<ContentPart>) -> Self {
86 Self::Parts(parts)
87 }
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct Message {
93 pub role: Role,
95 pub content: MessageContent,
97 #[serde(skip_serializing_if = "Option::is_none")]
99 pub name: Option<String>,
100 #[serde(skip_serializing_if = "Option::is_none")]
102 pub tool_call_id: Option<String>,
103}
104
105impl Message {
106 pub fn new(role: Role, content: impl Into<MessageContent>) -> Self {
108 Self {
109 role,
110 content: content.into(),
111 name: None,
112 tool_call_id: None,
113 }
114 }
115
116 pub fn system(content: impl Into<String>) -> Self {
118 Self::new(Role::System, content.into())
119 }
120
121 pub fn user(content: impl Into<MessageContent>) -> Self {
123 Self::new(Role::User, content)
124 }
125
126 pub fn assistant(content: impl Into<String>) -> Self {
128 Self::new(Role::Assistant, content.into())
129 }
130
131 pub fn tool(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
133 Self {
134 role: Role::Tool,
135 content: MessageContent::Text(content.into()),
136 name: None,
137 tool_call_id: Some(tool_call_id.into()),
138 }
139 }
140
141 pub fn with_name(mut self, name: impl Into<String>) -> Self {
143 self.name = Some(name.into());
144 self
145 }
146
147 pub fn text(&self) -> String {
149 self.content.to_text()
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
160 fn role_serializes_lowercase() {
161 assert_eq!(serde_json::to_string(&Role::System).unwrap(), r#""system""#);
162 assert_eq!(serde_json::to_string(&Role::User).unwrap(), r#""user""#);
163 assert_eq!(
164 serde_json::to_string(&Role::Assistant).unwrap(),
165 r#""assistant""#
166 );
167 assert_eq!(
168 serde_json::to_string(&Role::Developer).unwrap(),
169 r#""developer""#
170 );
171 assert_eq!(serde_json::to_string(&Role::Tool).unwrap(), r#""tool""#);
172 }
173
174 #[test]
175 fn role_roundtrip_all_variants() {
176 for role in [
177 Role::System,
178 Role::User,
179 Role::Assistant,
180 Role::Developer,
181 Role::Tool,
182 ] {
183 let json = serde_json::to_string(&role).unwrap();
184 let back: Role = serde_json::from_str(&json).unwrap();
185 assert_eq!(back, role);
186 }
187 }
188
189 #[test]
190 fn role_is_system() {
191 assert!(Role::System.is_system());
192 assert!(Role::Developer.is_system());
193 assert!(!Role::User.is_system());
194 assert!(!Role::Assistant.is_system());
195 assert!(!Role::Tool.is_system());
196 }
197
198 #[test]
199 fn role_rejects_unknown() {
200 let result = serde_json::from_str::<Role>(r#""moderator""#);
201 assert!(result.is_err());
202 }
203
204 #[test]
207 fn message_content_text_roundtrip() {
208 let content = MessageContent::Text("hello world".to_string());
209 let json = serde_json::to_value(&content).unwrap();
210 assert_eq!(json, serde_json::json!("hello world"));
212
213 let back: MessageContent = serde_json::from_value(json).unwrap();
214 assert_eq!(back.as_text().unwrap(), "hello world");
215 }
216
217 #[test]
218 fn message_content_parts_roundtrip() {
219 let content =
220 MessageContent::Parts(vec![ContentPart::text("part1"), ContentPart::text("part2")]);
221 let json = serde_json::to_value(&content).unwrap();
222 assert!(json.is_array());
223
224 let back: MessageContent = serde_json::from_value(json).unwrap();
225 assert_eq!(back.to_text(), "part1part2");
226 }
227
228 #[test]
229 fn message_content_as_text_returns_none_for_parts() {
230 let content = MessageContent::Parts(vec![ContentPart::text("x")]);
231 assert!(content.as_text().is_none());
232 }
233
234 #[test]
235 fn message_content_from_string() {
236 let content: MessageContent = "hello".into();
237 assert_eq!(content.as_text().unwrap(), "hello");
238 }
239
240 #[test]
241 fn message_content_from_owned_string() {
242 let content: MessageContent = String::from("hello").into();
243 assert_eq!(content.as_text().unwrap(), "hello");
244 }
245
246 #[test]
247 fn message_content_from_vec_parts() {
248 let parts = vec![ContentPart::text("a"), ContentPart::text("b")];
249 let content: MessageContent = parts.into();
250 assert_eq!(content.to_text(), "ab");
251 }
252
253 #[test]
256 fn message_system_roundtrip() {
257 let msg = Message::system("You are helpful");
258 let json = serde_json::to_value(&msg).unwrap();
259 assert_eq!(json["role"], "system");
260 assert_eq!(json["content"], "You are helpful");
261
262 let back: Message = serde_json::from_value(json).unwrap();
263 assert_eq!(back.role, Role::System);
264 assert_eq!(back.text(), "You are helpful");
265 }
266
267 #[test]
268 fn message_user_roundtrip() {
269 let msg = Message::user("What is 1+1?");
270 let json = serde_json::to_value(&msg).unwrap();
271 assert_eq!(json["role"], "user");
272
273 let back: Message = serde_json::from_value(json).unwrap();
274 assert_eq!(back.role, Role::User);
275 assert_eq!(back.text(), "What is 1+1?");
276 }
277
278 #[test]
279 fn message_assistant_roundtrip() {
280 let msg = Message::assistant("The answer is 2");
281 let json = serde_json::to_value(&msg).unwrap();
282 assert_eq!(json["role"], "assistant");
283
284 let back: Message = serde_json::from_value(json).unwrap();
285 assert_eq!(back.role, Role::Assistant);
286 }
287
288 #[test]
289 fn message_tool_roundtrip() {
290 let msg = Message::tool("call_123", r#"{"result": 42}"#);
291 let json = serde_json::to_value(&msg).unwrap();
292 assert_eq!(json["role"], "tool");
293 assert_eq!(json["tool_call_id"], "call_123");
294
295 let back: Message = serde_json::from_value(json).unwrap();
296 assert_eq!(back.role, Role::Tool);
297 assert_eq!(back.tool_call_id.as_deref(), Some("call_123"));
298 }
299
300 #[test]
301 fn message_with_name_roundtrip() {
302 let msg = Message::user("hi").with_name("alice");
303 let json = serde_json::to_value(&msg).unwrap();
304 assert_eq!(json["name"], "alice");
305
306 let back: Message = serde_json::from_value(json).unwrap();
307 assert_eq!(back.name.as_deref(), Some("alice"));
308 }
309
310 #[test]
311 fn message_skips_none_fields() {
312 let msg = Message::user("hi");
313 let json = serde_json::to_value(&msg).unwrap();
314 assert!(json.get("name").is_none());
315 assert!(json.get("tool_call_id").is_none());
316 }
317}