use crate::errors::{Error, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Debug;
pub trait State: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + std::fmt::Debug + 'static {
fn merge(&mut self, other: Self) -> Result<()>;
fn to_value(&self) -> Result<serde_json::Value> {
serde_json::to_value(self).map_err(Error::from)
}
fn from_value(value: serde_json::Value) -> Result<Self> {
serde_json::from_value(value).map_err(Error::from)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Message {
pub role: String,
pub content: String,
pub name: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
pub tool_call_id: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl Message {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: content.into(),
name: None,
tool_calls: None,
tool_call_id: None,
metadata: HashMap::new(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".to_string(),
content: content.into(),
name: None,
tool_calls: None,
tool_call_id: None,
metadata: HashMap::new(),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".to_string(),
content: content.into(),
name: None,
tool_calls: None,
tool_call_id: None,
metadata: HashMap::new(),
}
}
pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
Self {
role: "tool".to_string(),
content: content.into(),
name: None,
tool_calls: None,
tool_call_id: Some(tool_call_id.into()),
metadata: HashMap::new(),
}
}
pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
self.tool_calls = Some(tool_calls);
self
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
impl ToolCall {
pub fn new(
id: impl Into<String>,
name: impl Into<String>,
arguments: serde_json::Value,
) -> Self {
Self {
id: id.into(),
name: name.into(),
arguments,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessagesState {
pub messages: Vec<Message>,
}
impl State for MessagesState {
fn merge(&mut self, other: Self) -> Result<()> {
add_messages(&mut self.messages, other.messages);
Ok(())
}
}
pub fn add_messages(existing: &mut Vec<Message>, new: Vec<Message>) {
existing.extend(new);
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DictState {
pub data: HashMap<String, serde_json::Value>,
}
impl State for DictState {
fn merge(&mut self, other: Self) -> Result<()> {
self.data.extend(other.data);
Ok(())
}
}
impl DictState {
pub fn new() -> Self {
Self {
data: HashMap::new(),
}
}
pub fn with_data(data: HashMap<String, serde_json::Value>) -> Self {
Self { data }
}
pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
self.data.get(key)
}
pub fn set(&mut self, key: impl Into<String>, value: serde_json::Value) {
self.data.insert(key.into(), value);
}
}
impl Default for DictState {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
struct TestState {
count: i32,
}
impl State for TestState {
fn merge(&mut self, other: Self) -> Result<()> {
self.count += other.count;
Ok(())
}
}
#[test]
fn test_state_merge() {
let mut state = TestState { count: 5 };
let other = TestState { count: 3 };
state.merge(other).unwrap();
assert_eq!(state.count, 8);
}
#[test]
fn test_message_creation() {
let msg = Message::user("Hello");
assert_eq!(msg.role, "user");
assert_eq!(msg.content, "Hello");
let msg = Message::assistant("Hi").with_name("bot");
assert_eq!(msg.name.as_deref(), Some("bot"));
}
#[test]
fn test_messages_state() {
let mut state = MessagesState {
messages: vec![Message::user("Hello")],
};
let update = MessagesState {
messages: vec![Message::assistant("Hi there!")],
};
state.merge(update).unwrap();
assert_eq!(state.messages.len(), 2);
assert_eq!(state.messages[0].role, "user");
assert_eq!(state.messages[1].role, "assistant");
}
#[test]
fn test_dict_state() {
let mut state = DictState::new();
state.set("key1", serde_json::json!("value1"));
let mut other = DictState::new();
other.set("key2", serde_json::json!(42));
state.merge(other).unwrap();
assert_eq!(state.data.len(), 2);
assert_eq!(state.get("key1").unwrap(), &serde_json::json!("value1"));
assert_eq!(state.get("key2").unwrap(), &serde_json::json!(42));
}
#[test]
fn test_tool_call() {
let tool_call = ToolCall::new("call-1", "search", serde_json::json!({"query": "rust"}));
assert_eq!(tool_call.id, "call-1");
assert_eq!(tool_call.name, "search");
}
}