use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Role {
System,
User,
Assistant,
Tool,
}
impl Role {
pub fn as_str(&self) -> &'static str {
match self {
Self::System => "system",
Self::User => "user",
Self::Assistant => "assistant",
Self::Tool => "tool",
}
}
}
impl fmt::Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
assert_eq!(Role::System.to_string(), "system");
assert_eq!(Role::User.to_string(), "user");
assert_eq!(Role::Assistant.to_string(), "assistant");
assert_eq!(Role::Tool.to_string(), "tool");
}
#[test]
fn as_str_matches_display() {
assert_eq!(Role::System.as_str(), Role::System.to_string());
}
#[test]
fn serde_roundtrip() {
let json = serde_json::to_string(&Role::Assistant).unwrap();
assert_eq!(json, "\"Assistant\"");
let role: Role = serde_json::from_str(&json).unwrap();
assert_eq!(role, Role::Assistant);
}
}