use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub key: String,
pub messages: Vec<Message>,
pub summary: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl Session {
pub fn new(key: &str) -> Self {
let now = Utc::now();
Self {
key: key.to_string(),
messages: Vec::new(),
summary: None,
created_at: now,
updated_at: now,
}
}
pub fn add_message(&mut self, message: Message) {
self.messages.push(message);
self.updated_at = Utc::now();
}
pub fn clear(&mut self) {
self.messages.clear();
self.summary = None;
self.updated_at = Utc::now();
}
pub fn set_summary(&mut self, summary: &str) {
self.summary = Some(summary.to_string());
self.updated_at = Utc::now();
}
pub fn message_count(&self) -> usize {
self.messages.len()
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
pub fn last_message(&self) -> Option<&Message> {
self.messages.last()
}
pub fn messages_by_role(&self, role: Role) -> Vec<&Message> {
self.messages.iter().filter(|m| m.role == role).collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
impl Message {
pub fn user(content: &str) -> Self {
Self {
role: Role::User,
content: content.to_string(),
tool_calls: None,
tool_call_id: None,
}
}
pub fn assistant(content: &str) -> Self {
Self {
role: Role::Assistant,
content: content.to_string(),
tool_calls: None,
tool_call_id: None,
}
}
pub fn system(content: &str) -> Self {
Self {
role: Role::System,
content: content.to_string(),
tool_calls: None,
tool_call_id: None,
}
}
pub fn tool_result(tool_call_id: &str, content: &str) -> Self {
Self {
role: Role::Tool,
content: content.to_string(),
tool_calls: None,
tool_call_id: Some(tool_call_id.to_string()),
}
}
pub fn assistant_with_tools(content: &str, tool_calls: Vec<ToolCall>) -> Self {
Self {
role: Role::Assistant,
content: content.to_string(),
tool_calls: Some(tool_calls),
tool_call_id: None,
}
}
pub fn has_tool_calls(&self) -> bool {
self.tool_calls
.as_ref()
.map(|tc| !tc.is_empty())
.unwrap_or(false)
}
pub fn is_tool_result(&self) -> bool {
self.role == Role::Tool && self.tool_call_id.is_some()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
impl std::fmt::Display for Role {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::System => write!(f, "system"),
Role::User => write!(f, "user"),
Role::Assistant => write!(f, "assistant"),
Role::Tool => write!(f, "tool"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: String,
}
impl ToolCall {
pub fn new(id: &str, name: &str, arguments: &str) -> Self {
Self {
id: id.to_string(),
name: name.to_string(),
arguments: arguments.to_string(),
}
}
pub fn parse_arguments<T: serde::de::DeserializeOwned>(&self) -> serde_json::Result<T> {
serde_json::from_str(&self.arguments)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_new() {
let session = Session::new("test-session");
assert_eq!(session.key, "test-session");
assert!(session.messages.is_empty());
assert!(session.summary.is_none());
assert!(session.created_at <= session.updated_at);
}
#[test]
fn test_session_add_message() {
let mut session = Session::new("test");
let initial_updated = session.updated_at;
std::thread::sleep(std::time::Duration::from_millis(10));
session.add_message(Message::user("Hello"));
assert_eq!(session.messages.len(), 1);
assert!(session.updated_at >= initial_updated);
}
#[test]
fn test_session_clear() {
let mut session = Session::new("test");
session.add_message(Message::user("Hello"));
session.set_summary("A greeting");
session.clear();
assert!(session.messages.is_empty());
assert!(session.summary.is_none());
}
#[test]
fn test_session_helpers() {
let mut session = Session::new("test");
assert!(session.is_empty());
assert_eq!(session.message_count(), 0);
assert!(session.last_message().is_none());
session.add_message(Message::user("Hello"));
session.add_message(Message::assistant("Hi!"));
assert!(!session.is_empty());
assert_eq!(session.message_count(), 2);
assert_eq!(session.last_message().unwrap().role, Role::Assistant);
assert_eq!(session.messages_by_role(Role::User).len(), 1);
}
#[test]
fn test_message_user() {
let msg = Message::user("Hello");
assert_eq!(msg.role, Role::User);
assert_eq!(msg.content, "Hello");
assert!(msg.tool_calls.is_none());
assert!(msg.tool_call_id.is_none());
}
#[test]
fn test_message_assistant() {
let msg = Message::assistant("Hi there");
assert_eq!(msg.role, Role::Assistant);
assert_eq!(msg.content, "Hi there");
}
#[test]
fn test_message_system() {
let msg = Message::system("You are helpful");
assert_eq!(msg.role, Role::System);
assert_eq!(msg.content, "You are helpful");
}
#[test]
fn test_message_tool_result() {
let msg = Message::tool_result("call_123", "Success");
assert_eq!(msg.role, Role::Tool);
assert_eq!(msg.content, "Success");
assert_eq!(msg.tool_call_id, Some("call_123".to_string()));
assert!(msg.is_tool_result());
}
#[test]
fn test_message_with_tool_calls() {
let tool_call = ToolCall::new("call_1", "search", r#"{"q": "test"}"#);
let msg = Message::assistant_with_tools("Searching...", vec![tool_call]);
assert!(msg.has_tool_calls());
let calls = msg.tool_calls.unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "search");
}
#[test]
fn test_role_display() {
assert_eq!(Role::System.to_string(), "system");
assert_eq!(Role::User.to_string(), "user");
assert_eq!(Role::Assistant.to_string(), "assistant");
assert_eq!(Role::Tool.to_string(), "tool");
}
#[test]
fn test_role_serialize() {
let user = Role::User;
let json = serde_json::to_string(&user).unwrap();
assert_eq!(json, r#""user""#);
let parsed: Role = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, Role::User);
}
#[test]
fn test_tool_call_new() {
let call = ToolCall::new("call_123", "web_search", r#"{"query": "rust"}"#);
assert_eq!(call.id, "call_123");
assert_eq!(call.name, "web_search");
assert_eq!(call.arguments, r#"{"query": "rust"}"#);
}
#[test]
fn test_tool_call_parse_arguments() {
#[derive(Debug, Deserialize, PartialEq)]
struct SearchArgs {
query: String,
}
let call = ToolCall::new("call_1", "search", r#"{"query": "rust"}"#);
let args: SearchArgs = call.parse_arguments().unwrap();
assert_eq!(args.query, "rust");
}
#[test]
fn test_session_serialization() {
let mut session = Session::new("test-session");
session.add_message(Message::user("Hello"));
session.add_message(Message::assistant("Hi!"));
let json = serde_json::to_string(&session).unwrap();
let parsed: Session = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.key, "test-session");
assert_eq!(parsed.messages.len(), 2);
assert_eq!(parsed.messages[0].role, Role::User);
assert_eq!(parsed.messages[1].role, Role::Assistant);
}
#[test]
fn test_message_serialization_skips_none() {
let msg = Message::user("Hello");
let json = serde_json::to_string(&msg).unwrap();
assert!(!json.contains("tool_calls"));
assert!(!json.contains("tool_call_id"));
}
}