#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum Role {
#[serde(rename = "system")]
System,
#[serde(rename = "user")]
User,
#[serde(rename = "assistant")]
Assistant,
#[serde(rename = "tool")]
Tool,
}
impl Role {
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::System => "system",
Self::User => "user",
Self::Assistant => "assistant",
Self::Tool => "tool",
}
}
}
impl std::fmt::Display for Role {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
impl std::str::FromStr for Role {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
"system" => Self::System,
"user" => Self::User,
"assistant" => Self::Assistant,
"tool" => Self::Tool,
other => return Err(format!("unknown chat role: {other}")),
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ChatMessage {
pub role: Role,
pub content: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<ToolCall>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
impl ChatMessage {
#[must_use]
pub fn new(role: Role, content: impl Into<String>) -> Self {
Self {
role,
content: content.into(),
tool_call_id: None,
tool_calls: Vec::new(),
name: None,
}
}
#[must_use]
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: Role::Tool,
content: content.into(),
tool_call_id: Some(tool_call_id.into()),
tool_calls: Vec::new(),
name: None,
}
}
#[must_use]
pub fn with_tool_call(mut self, call: ToolCall) -> Self {
self.tool_calls.push(call);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::chat::tool_call::ToolCall;
use serde_json::json;
use std::str::FromStr;
#[test]
fn new_constructs() {
let m = ChatMessage::new(Role::User, "hi");
assert_eq!(m.role, Role::User);
assert_eq!(m.content, "hi");
assert!(m.tool_call_id.is_none());
assert!(m.tool_calls.is_empty());
assert!(m.name.is_none());
}
#[test]
fn tool_result_constructs() {
let m = ChatMessage::tool_result("call_1", "ok");
assert_eq!(m.role, Role::Tool);
assert_eq!(m.tool_call_id.as_deref(), Some("call_1"));
assert_eq!(m.content, "ok");
}
#[test]
fn with_tool_call_appends() {
let c = ToolCall::new("id_1", "f", json!({}));
let m = ChatMessage::new(Role::Assistant, "x").with_tool_call(c.clone());
assert_eq!(m.tool_calls.len(), 1);
assert_eq!(m.tool_calls[0].id, "id_1");
}
#[test]
fn role_from_str() {
assert_eq!(Role::from_str("system").unwrap(), Role::System);
assert_eq!(Role::from_str("user").unwrap(), Role::User);
assert_eq!(Role::from_str("assistant").unwrap(), Role::Assistant);
assert_eq!(Role::from_str("tool").unwrap(), Role::Tool);
assert!(Role::from_str("nope").is_err());
}
#[test]
fn role_serialize_round_trip() {
let json = serde_json::to_string(&Role::User).unwrap();
assert_eq!(json, "\"user\"");
let r: Role = serde_json::from_str("\"assistant\"").unwrap();
assert_eq!(r, Role::Assistant);
}
}
use crate::chat::tool_call::ToolCall;