batch_mode_batch_schema/
message_role.rs1crate::ix!();
3
4#[derive(Default,Debug,Clone,PartialEq,Eq,Hash)]
5pub enum MessageRole {
6 #[default]
7 Assistant,
8 User,
9 System,
10 Tool,
11 Function,
12 Unknown(String),
13}
14
15impl<'de> Deserialize<'de> for MessageRole {
16 fn deserialize<D>(deserializer: D) -> Result<MessageRole, D::Error>
17 where
18 D: serde::Deserializer<'de>,
19 {
20 let s = String::deserialize(deserializer)?;
21 match s.as_str() {
22 "assistant" => Ok(MessageRole::Assistant),
23 "user" => Ok(MessageRole::User),
24 "system" => Ok(MessageRole::System),
25 "tool" => Ok(MessageRole::Tool),
26 "function" => Ok(MessageRole::Function),
27 other => Ok(MessageRole::Unknown(other.to_string())),
28 }
29 }
30}
31
32impl Serialize for MessageRole {
33 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
34 where
35 S: serde::Serializer,
36 {
37 let s = match self {
38 MessageRole::Assistant => "assistant",
39 MessageRole::User => "user",
40 MessageRole::System => "system",
41 MessageRole::Tool => "tool",
42 MessageRole::Function => "function",
43 MessageRole::Unknown(other) => other.as_str(),
44 };
45 serializer.serialize_str(s)
46 }
47}
48
49#[cfg(test)]
50mod message_role_tests {
51 use super::*;
52
53 #[test]
55 fn test_message_role_deserialization() {
56 let roles = vec!["assistant", "user", "system", "tool", "function"];
58 let expected_roles = vec![
59 MessageRole::Assistant,
60 MessageRole::User,
61 MessageRole::System,
62 MessageRole::Tool,
63 MessageRole::Function,
64 ];
65
66 for (role_str, expected_role) in roles.iter().zip(expected_roles.iter()) {
67 let json = format!("\"{}\"", role_str);
68 let role: MessageRole = serde_json::from_str(&json).unwrap();
69 pretty_assert_eq!(&role, expected_role);
70 }
71
72 let json = "\"unknown_role\"";
74 let role: MessageRole = serde_json::from_str(json).unwrap();
75 pretty_assert_eq!(role, MessageRole::Unknown("unknown_role".to_string()));
76
77 let json = "\"\"";
79 let role: MessageRole = serde_json::from_str(json).unwrap();
80 pretty_assert_eq!(role, MessageRole::Unknown("".to_string()));
81
82 let json = "123";
84 let result: Result<MessageRole, _> = serde_json::from_str(json);
85 assert!(result.is_err());
86
87 let json = "null";
89 let result: Result<MessageRole, _> = serde_json::from_str(json);
90 assert!(result.is_err());
91 }
92}