batch_mode_batch_scribe/
language_model_message_role.rs

1// ---------------- [ File: batch-mode-batch-scribe/src/language_model_message_role.rs ]
2crate::ix!();
3
4/// Enumeration of roles in a message.
5#[derive(Clone,Debug)]
6pub enum LanguageModelMessageRole {
7    System,
8    User,
9}
10
11impl Serialize for LanguageModelMessageRole {
12    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
13    where
14        S: Serializer,
15    {
16        trace!("Serializing LanguageModelMessageRole: {:?}", self);
17        match self {
18            LanguageModelMessageRole::System => {
19                trace!("Serializing as 'system'");
20                serializer.serialize_str("system")
21            }
22            LanguageModelMessageRole::User => {
23                trace!("Serializing as 'user'");
24                serializer.serialize_str("user")
25            }
26        }
27    }
28}
29
30impl<'de> Deserialize<'de> for LanguageModelMessageRole {
31    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
32    where
33        D: Deserializer<'de>,
34    {
35        let s: String = String::deserialize(deserializer)?;
36        trace!("Deserializing LanguageModelMessageRole from string: {:?}", s);
37        match s.as_str() {
38            "system" => Ok(LanguageModelMessageRole::System),
39            "user" => Ok(LanguageModelMessageRole::User),
40            other => {
41                error!("Unknown role: {}", other);
42                Err(DeError::custom("unknown message role"))
43            }
44        }
45    }
46}
47
48/// Field-level serializers/deserializers used for the `role` field in `LanguageModelMessage`.
49/// We keep these so that references like `#[serde(with = "message_role")]` remain valid.
50pub(crate) mod message_role {
51    use super::*;
52    use crate::imports::*;
53
54    pub fn serialize<S>(
55        value: &LanguageModelMessageRole,
56        serializer: S,
57    ) -> Result<S::Ok, S::Error>
58    where
59        S: Serializer,
60    {
61        trace!("(message_role) Serializing LanguageModelMessageRole: {:?}", value);
62        match value {
63            LanguageModelMessageRole::System => {
64                trace!("(message_role) Serializing as 'system'");
65                serializer.serialize_str("system")
66            }
67            LanguageModelMessageRole::User => {
68                trace!("(message_role) Serializing as 'user'");
69                serializer.serialize_str("user")
70            }
71        }
72    }
73
74    pub fn deserialize<'de, D>(
75        deserializer: D,
76    ) -> Result<LanguageModelMessageRole, D::Error>
77    where
78        D: Deserializer<'de>,
79    {
80        let s: String = String::deserialize(deserializer)?;
81        trace!("(message_role) Deserializing LanguageModelMessageRole from string: {:?}", s);
82        match s.as_str() {
83            "system" => Ok(LanguageModelMessageRole::System),
84            "user" => Ok(LanguageModelMessageRole::User),
85            other => {
86                error!("(message_role) Unknown role: {}", other);
87                Err(D::Error::custom("unknown message role"))
88            }
89        }
90    }
91}
92
93#[cfg(test)]
94mod language_model_message_role_exhaustive_tests {
95    use super::*;
96    use crate::imports::*;
97
98    #[traced_test]
99    fn serialize_system_role_to_json() {
100        trace!("===== BEGIN TEST: serialize_system_role_to_json =====");
101        let role = LanguageModelMessageRole::System;
102        let serialized = serde_json::to_string(&role)
103            .expect("Failed to serialize LanguageModelMessageRole");
104        debug!("Serialized system role: {}", serialized);
105        pretty_assert_eq!(serialized, r#""system""#, "System role should serialize to \"system\"");
106        trace!("===== END TEST: serialize_system_role_to_json =====");
107    }
108
109    #[traced_test]
110    fn serialize_user_role_to_json() {
111        trace!("===== BEGIN TEST: serialize_user_role_to_json =====");
112        let role = LanguageModelMessageRole::User;
113        let serialized = serde_json::to_string(&role)
114            .expect("Failed to serialize LanguageModelMessageRole");
115        debug!("Serialized user role: {}", serialized);
116        pretty_assert_eq!(serialized, r#""user""#, "User role should serialize to \"user\"");
117        trace!("===== END TEST: serialize_user_role_to_json =====");
118    }
119
120    #[traced_test]
121    fn deserialize_system_role_from_json() {
122        trace!("===== BEGIN TEST: deserialize_system_role_from_json =====");
123        let json_str = r#""system""#;
124        let role: LanguageModelMessageRole = serde_json::from_str(json_str)
125            .expect("Failed to deserialize system role");
126        debug!("Deserialized role: {:?}", role);
127        match role {
128            LanguageModelMessageRole::System => trace!("Correctly deserialized as System"),
129            _ => panic!("Deserialization mismatch for system role"),
130        }
131        trace!("===== END TEST: deserialize_system_role_from_json =====");
132    }
133
134    #[traced_test]
135    fn deserialize_user_role_from_json() {
136        trace!("===== BEGIN TEST: deserialize_user_role_from_json =====");
137        let json_str = r#""user""#;
138        let role: LanguageModelMessageRole = serde_json::from_str(json_str)
139            .expect("Failed to deserialize user role");
140        debug!("Deserialized role: {:?}", role);
141        match role {
142            LanguageModelMessageRole::User => trace!("Correctly deserialized as User"),
143            _ => panic!("Deserialization mismatch for user role"),
144        }
145        trace!("===== END TEST: deserialize_user_role_from_json =====");
146    }
147
148    #[traced_test]
149    fn deserialize_unknown_role_returns_error() {
150        trace!("===== BEGIN TEST: deserialize_unknown_role_returns_error =====");
151        let json_str = r#""admin""#;
152        let result = serde_json::from_str::<LanguageModelMessageRole>(json_str);
153        debug!("Deserialization result: {:?}", result);
154        assert!(result.is_err(), "Unknown role should result in an error");
155        trace!("===== END TEST: deserialize_unknown_role_returns_error =====");
156    }
157}