use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Clone, Debug, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[serde(into = "String", try_from = "String")]
pub enum Role {
#[default]
User,
Assistant,
System,
Tool,
Custom(String),
}
impl Role {
#[must_use]
pub fn as_str(&self) -> &str {
match self {
Role::User => "user",
Role::Assistant => "assistant",
Role::System => "system",
Role::Tool => "tool",
Role::Custom(s) => s.as_str(),
}
}
#[must_use]
pub fn matches(&self, role_str: &str) -> bool {
self.as_str() == role_str
}
}
impl fmt::Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl From<&str> for Role {
fn from(s: &str) -> Self {
match s {
"user" => Role::User,
"assistant" => Role::Assistant,
"system" => Role::System,
"tool" => Role::Tool,
other => Role::Custom(other.to_string()),
}
}
}
impl From<String> for Role {
fn from(s: String) -> Self {
Role::from(s.as_str())
}
}
impl From<Role> for String {
fn from(role: Role) -> Self {
role.as_str().to_string()
}
}
#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct Message {
#[serde(with = "role_serde")]
pub role: Role,
pub content: String,
}
mod role_serde {
use super::Role;
use serde::{Deserialize, Deserializer, Serializer};
pub fn serialize<S>(role: &Role, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(role.as_str())
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Role, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Ok(Role::from(s))
}
}
impl Message {
#[must_use]
#[deprecated(
since = "0.3.0",
note = "Use Message::with_role(Role::..., ...) or Message::user()/assistant()/system()/tool()"
)]
pub fn new(role: &str, content: &str) -> Self {
Self {
role: Role::from(role),
content: content.to_string(),
}
}
#[must_use]
pub fn with_role(role: Role, content: &str) -> Self {
Self {
role,
content: content.to_string(),
}
}
#[must_use]
pub fn user(content: &str) -> Self {
Self::with_role(Role::User, content)
}
#[must_use]
pub fn assistant(content: &str) -> Self {
Self::with_role(Role::Assistant, content)
}
#[must_use]
pub fn system(content: &str) -> Self {
Self::with_role(Role::System, content)
}
#[must_use]
pub fn tool(content: &str) -> Self {
Self::with_role(Role::Tool, content)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_role_from_str() {
assert_eq!(Role::from("user"), Role::User);
assert_eq!(Role::from("assistant"), Role::Assistant);
assert_eq!(Role::from("system"), Role::System);
assert_eq!(Role::from("tool"), Role::Tool);
assert_eq!(Role::from("custom"), Role::Custom("custom".to_string()));
}
#[test]
fn test_role_as_str() {
assert_eq!(Role::User.as_str(), "user");
assert_eq!(Role::Assistant.as_str(), "assistant");
assert_eq!(Role::System.as_str(), "system");
assert_eq!(Role::Tool.as_str(), "tool");
assert_eq!(Role::Custom("foo".into()).as_str(), "foo");
}
#[test]
fn test_message_role_typed_field() {
let msg = Message::user("hello");
assert_eq!(msg.role, Role::User);
let msg = Message::assistant("hi");
assert_eq!(msg.role, Role::Assistant);
let msg = Message::with_role(Role::Custom("custom".into()), "data");
assert_eq!(msg.role, Role::Custom("custom".into()));
}
#[test]
#[allow(deprecated)]
fn test_message_new_deprecated_compat() {
let msg = Message::new("custom", "data");
assert_eq!(msg.role, Role::Custom("custom".into()));
}
#[test]
fn test_message_with_role() {
let msg = Message::with_role(Role::Tool, "result");
assert_eq!(msg.role, Role::Tool);
assert_eq!(msg.content, "result");
}
#[test]
fn test_role_serialization() {
let role = Role::User;
let json = serde_json::to_string(&role).unwrap();
assert_eq!(json, "\"user\"");
let parsed: Role = serde_json::from_str("\"assistant\"").unwrap();
assert_eq!(parsed, Role::Assistant);
let custom: Role = serde_json::from_str("\"function\"").unwrap();
assert_eq!(custom, Role::Custom("function".into()));
}
#[test]
fn test_message_backward_compatibility() {
let json = r#"{"role": "user", "content": "hello"}"#;
let msg: Message = serde_json::from_str(json).unwrap();
assert_eq!(msg.role, Role::User);
}
}