batch_mode_batch_schema/
message_role.rs1crate::ix!();
3
4#[derive(Debug, Clone, PartialEq, Eq, Hash)]
5pub enum MessageRole {
6 Assistant,
7 User,
8 System,
9 Unknown(String),
10}
11
12impl<'de> Deserialize<'de> for MessageRole {
13 fn deserialize<D>(deserializer: D) -> Result<MessageRole, D::Error>
14 where
15 D: serde::Deserializer<'de>,
16 {
17 let s = String::deserialize(deserializer)?;
18 match s.as_str() {
19 "assistant" => Ok(MessageRole::Assistant),
20 "user" => Ok(MessageRole::User),
21 "system" => Ok(MessageRole::System),
22 other => Ok(MessageRole::Unknown(other.to_string())),
23 }
24 }
25}
26
27impl Serialize for MessageRole {
28 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
29 where
30 S: serde::Serializer,
31 {
32 let s = match self {
33 MessageRole::Assistant => "assistant",
34 MessageRole::User => "user",
35 MessageRole::System => "system",
36 MessageRole::Unknown(other) => other.as_str(),
37 };
38 serializer.serialize_str(s)
39 }
40}
41
42#[cfg(test)]
43mod tests {
44 use super::*;
45 use serde_json::json;
46
47 #[test]
49 fn test_message_role_deserialization() {
50 let roles = vec!["assistant", "user", "system"];
52 let expected_roles = vec![
53 MessageRole::Assistant,
54 MessageRole::User,
55 MessageRole::System,
56 ];
57
58 for (role_str, expected_role) in roles.iter().zip(expected_roles.iter()) {
59 let json = format!("\"{}\"", role_str);
60 let role: MessageRole = serde_json::from_str(&json).unwrap();
61 assert_eq!(&role, expected_role);
62 }
63
64 let json = "\"unknown_role\"";
66 let role: MessageRole = serde_json::from_str(json).unwrap();
67 assert_eq!(role, MessageRole::Unknown("unknown_role".to_string()));
68
69 let json = "\"\"";
71 let role: MessageRole = serde_json::from_str(json).unwrap();
72 assert_eq!(role, MessageRole::Unknown("".to_string()));
73
74 let json = "123";
76 let result: Result<MessageRole, _> = serde_json::from_str(json);
77 assert!(result.is_err());
78
79 let json = "null";
81 let result: Result<MessageRole, _> = serde_json::from_str(json);
82 assert!(result.is_err());
83 }
84}