batch_mode_batch_schema/
message_role.rs

1// ---------------- [ File: batch-mode-batch-schema/src/message_role.rs ]
2crate::ix!();
3
4#[derive(Default,Debug,Clone,PartialEq,Eq,Hash)]
5pub enum MessageRole {
6    #[default]
7    Assistant,
8    User,
9    System,
10    Tool,
11    Function,
12    Unknown(String),
13}
14
15impl<'de> Deserialize<'de> for MessageRole {
16    fn deserialize<D>(deserializer: D) -> Result<MessageRole, D::Error>
17    where
18        D: serde::Deserializer<'de>,
19    {
20        let s = String::deserialize(deserializer)?;
21        match s.as_str() {
22            "assistant" => Ok(MessageRole::Assistant),
23            "user" => Ok(MessageRole::User),
24            "system" => Ok(MessageRole::System),
25            "tool" => Ok(MessageRole::Tool),
26            "function" => Ok(MessageRole::Function),
27            other => Ok(MessageRole::Unknown(other.to_string())),
28        }
29    }
30}
31
32impl Serialize for MessageRole {
33    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
34    where
35        S: serde::Serializer,
36    {
37        let s = match self {
38            MessageRole::Assistant => "assistant",
39            MessageRole::User => "user",
40            MessageRole::System => "system",
41            MessageRole::Tool => "tool",
42            MessageRole::Function => "function",
43            MessageRole::Unknown(other) => other.as_str(),
44        };
45        serializer.serialize_str(s)
46    }
47}
48
49#[cfg(test)]
50mod message_role_tests {
51    use super::*;
52
53    // Test suite for MessageRole
54    #[test]
55    fn test_message_role_deserialization() {
56        // Known roles
57        let roles = vec!["assistant", "user", "system", "tool", "function"];
58        let expected_roles = vec![
59            MessageRole::Assistant,
60            MessageRole::User,
61            MessageRole::System,
62            MessageRole::Tool,
63            MessageRole::Function,
64        ];
65
66        for (role_str, expected_role) in roles.iter().zip(expected_roles.iter()) {
67            let json = format!("\"{}\"", role_str);
68            let role: MessageRole = serde_json::from_str(&json).unwrap();
69            pretty_assert_eq!(&role, expected_role);
70        }
71
72        // Unknown role
73        let json = "\"unknown_role\"";
74        let role: MessageRole = serde_json::from_str(json).unwrap();
75        pretty_assert_eq!(role, MessageRole::Unknown("unknown_role".to_string()));
76
77        // Empty string as role
78        let json = "\"\"";
79        let role: MessageRole = serde_json::from_str(json).unwrap();
80        pretty_assert_eq!(role, MessageRole::Unknown("".to_string()));
81
82        // Invalid role (non-string)
83        let json = "123";
84        let result: Result<MessageRole, _> = serde_json::from_str(json);
85        assert!(result.is_err());
86
87        // Null role
88        let json = "null";
89        let result: Result<MessageRole, _> = serde_json::from_str(json);
90        assert!(result.is_err());
91    }
92}