1use serde::{Deserialize, Serialize};
2use std::fmt;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6pub enum Role {
7 System,
9 User,
11 Assistant,
13 Tool,
15}
16
17impl Role {
18 pub fn as_str(&self) -> &'static str {
20 match self {
21 Self::System => "system",
22 Self::User => "user",
23 Self::Assistant => "assistant",
24 Self::Tool => "tool",
25 }
26 }
27}
28
29impl fmt::Display for Role {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 f.write_str(self.as_str())
32 }
33}
34
35#[cfg(test)]
36mod tests {
37 use super::*;
38
39 #[test]
40 fn display() {
41 assert_eq!(Role::System.to_string(), "system");
42 assert_eq!(Role::User.to_string(), "user");
43 assert_eq!(Role::Assistant.to_string(), "assistant");
44 assert_eq!(Role::Tool.to_string(), "tool");
45 }
46
47 #[test]
48 fn as_str_matches_display() {
49 assert_eq!(Role::System.as_str(), Role::System.to_string());
50 }
51
52 #[test]
53 fn serde_roundtrip() {
54 let json = serde_json::to_string(&Role::Assistant).unwrap();
55 assert_eq!(json, "\"Assistant\"");
56 let role: Role = serde_json::from_str(&json).unwrap();
57 assert_eq!(role, Role::Assistant);
58 }
59}