use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ChatRole {
System,
User,
Assistant,
Tool,
}
impl Default for ChatRole {
fn default() -> Self {
Self::User
}
}
impl std::fmt::Display for ChatRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ChatRole::System => write!(f, "system"),
ChatRole::User => write!(f, "user"),
ChatRole::Assistant => write!(f, "assistant"),
ChatRole::Tool => write!(f, "tool"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_role_serialize() {
assert_eq!(
serde_json::to_string(&ChatRole::System).unwrap(),
"\"system\""
);
assert_eq!(serde_json::to_string(&ChatRole::User).unwrap(), "\"user\"");
assert_eq!(
serde_json::to_string(&ChatRole::Assistant).unwrap(),
"\"assistant\""
);
assert_eq!(serde_json::to_string(&ChatRole::Tool).unwrap(), "\"tool\"");
}
#[test]
fn test_chat_role_deserialize() {
assert_eq!(
serde_json::from_str::<ChatRole>("\"system\"").unwrap(),
ChatRole::System
);
assert_eq!(
serde_json::from_str::<ChatRole>("\"user\"").unwrap(),
ChatRole::User
);
assert_eq!(
serde_json::from_str::<ChatRole>("\"assistant\"").unwrap(),
ChatRole::Assistant
);
assert_eq!(
serde_json::from_str::<ChatRole>("\"tool\"").unwrap(),
ChatRole::Tool
);
}
#[test]
fn test_chat_role_default() {
assert_eq!(ChatRole::default(), ChatRole::User);
}
#[test]
fn test_chat_role_display() {
assert_eq!(ChatRole::System.to_string(), "system");
assert_eq!(ChatRole::User.to_string(), "user");
assert_eq!(ChatRole::Assistant.to_string(), "assistant");
assert_eq!(ChatRole::Tool.to_string(), "tool");
}
#[test]
fn test_chat_role_equality() {
assert_eq!(ChatRole::User, ChatRole::User);
assert_ne!(ChatRole::User, ChatRole::Assistant);
}
#[test]
fn test_chat_role_clone() {
let role = ChatRole::Assistant;
let cloned = role.clone();
assert_eq!(role, cloned);
}
#[test]
fn test_chat_role_copy() {
let role = ChatRole::System;
let copied: ChatRole = role; assert_eq!(role, copied);
}
}