use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
impl Role {
pub fn as_str(&self) -> &'static str {
match self {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
}
}
}
impl std::fmt::Display for Role {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl From<Role> for String {
fn from(role: Role) -> Self {
role.as_str().to_string()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
Stop,
Length,
ToolCalls,
ContentFilter,
#[serde(other)]
Other,
}
impl StopReason {
pub fn is_complete(&self) -> bool {
matches!(self, StopReason::Stop)
}
pub fn needs_more_tokens(&self) -> bool {
matches!(self, StopReason::Length)
}
pub fn has_tool_calls(&self) -> bool {
matches!(self, StopReason::ToolCalls)
}
}
impl std::fmt::Display for StopReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StopReason::Stop => write!(f, "stop"),
StopReason::Length => write!(f, "length"),
StopReason::ToolCalls => write!(f, "tool_calls"),
StopReason::ContentFilter => write!(f, "content_filter"),
StopReason::Other => write!(f, "other"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_role_conversion() {
assert_eq!(Role::System.as_str(), "system");
assert_eq!(Role::User.as_str(), "user");
assert_eq!(Role::Assistant.as_str(), "assistant");
let role_string: String = Role::User.into();
assert_eq!(role_string, "user");
}
#[test]
fn test_stop_reason_checks() {
assert!(StopReason::Stop.is_complete());
assert!(!StopReason::Length.is_complete());
assert!(StopReason::Length.needs_more_tokens());
assert!(!StopReason::Stop.needs_more_tokens());
assert!(StopReason::ToolCalls.has_tool_calls());
assert!(!StopReason::Stop.has_tool_calls());
}
#[test]
fn test_role_serialization() {
let json = serde_json::to_string(&Role::System).unwrap();
assert_eq!(json, r#""system""#);
let role: Role = serde_json::from_str(r#""user""#).unwrap();
assert_eq!(role, Role::User);
}
}