use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
pub trait State: Clone + Send + Sync + Serialize + DeserializeOwned + 'static {
fn schema() -> serde_json::Value {
serde_json::json!({})
}
}
impl State for () {
fn schema() -> serde_json::Value {
serde_json::json!({"type": "null"})
}
}
impl State for String {
fn schema() -> serde_json::Value {
serde_json::json!({"type": "string"})
}
}
impl State for serde_json::Value {
fn schema() -> serde_json::Value {
serde_json::json!({})
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct Message {
pub role: MessageRole,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
impl Message {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: MessageRole::System,
content: content.into(),
name: None,
tool_call_id: None,
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: MessageRole::User,
content: content.into(),
name: None,
tool_call_id: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: MessageRole::Assistant,
content: content.into(),
name: None,
tool_call_id: None,
}
}
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: MessageRole::Tool,
content: content.into(),
name: None,
tool_call_id: Some(tool_call_id.into()),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
impl ToolCall {
pub fn new(
id: impl Into<String>,
name: impl Into<String>,
arguments: serde_json::Value,
) -> Self {
Self {
id: id.into(),
name: name.into(),
arguments,
}
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct AgentState {
pub messages: Vec<Message>,
pub tool_calls: Vec<ToolCall>,
#[serde(default)]
pub context: HashMap<String, serde_json::Value>,
#[serde(default)]
pub iteration: usize,
#[serde(default)]
pub is_complete: bool,
}
impl State for AgentState {
fn schema() -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"messages": { "type": "array", "channel": "append" },
"tool_calls": { "type": "array" },
"context": { "type": "object" },
"iteration": { "type": "integer" },
"is_complete": { "type": "boolean" }
}
})
}
}
impl AgentState {
pub fn new() -> Self {
Self::default()
}
pub fn with_user_message(message: impl Into<String>) -> Self {
Self {
messages: vec![Message::user(message)],
..Default::default()
}
}
pub fn with_system_and_user(system: impl Into<String>, user: impl Into<String>) -> Self {
Self {
messages: vec![Message::system(system), Message::user(user)],
..Default::default()
}
}
pub fn add_assistant_message(&mut self, content: impl Into<String>) {
self.messages.push(Message::assistant(content));
}
pub fn add_user_message(&mut self, content: impl Into<String>) {
self.messages.push(Message::user(content));
}
pub fn add_tool_result(&mut self, tool_call_id: impl Into<String>, content: impl Into<String>) {
self.messages
.push(Message::tool_result(tool_call_id, content));
}
pub fn get_context<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
self.context
.get(key)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
pub fn set_context<T: Serialize>(&mut self, key: impl Into<String>, value: T) {
if let Ok(v) = serde_json::to_value(value) {
self.context.insert(key.into(), v);
}
}
pub fn remove_context(&mut self, key: &str) -> Option<serde_json::Value> {
self.context.remove(key)
}
pub fn has_context(&self, key: &str) -> bool {
self.context.contains_key(key)
}
pub fn last_message(&self) -> Option<&Message> {
self.messages.last()
}
pub fn last_assistant_message(&self) -> Option<&Message> {
self.messages
.iter()
.rev()
.find(|m| m.role == MessageRole::Assistant)
}
pub fn last_user_message(&self) -> Option<&Message> {
self.messages
.iter()
.rev()
.find(|m| m.role == MessageRole::User)
}
pub fn clear_tool_calls(&mut self) {
self.tool_calls.clear();
}
pub fn has_pending_tool_calls(&self) -> bool {
!self.tool_calls.is_empty()
}
pub fn mark_complete(&mut self) {
self.is_complete = true;
}
pub fn increment_iteration(&mut self) {
self.iteration += 1;
}
}
pub type SharedState = Arc<RwLock<AgentState>>;
pub trait SharedStateExt {
fn new_shared(state: AgentState) -> SharedState;
fn new_shared_empty() -> SharedState;
}
impl SharedStateExt for SharedState {
fn new_shared(state: AgentState) -> SharedState {
Arc::new(RwLock::new(state))
}
fn new_shared_empty() -> SharedState {
Self::new_shared(AgentState::new())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
struct TestState {
value: i32,
}
impl State for TestState {
fn schema() -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"value": { "type": "integer" }
}
})
}
}
#[test]
fn test_state_schema() {
let schema = TestState::schema();
assert_eq!(schema["type"], "object");
}
#[test]
fn test_unit_state() {
let schema = <()>::schema();
assert_eq!(schema["type"], "null");
}
#[test]
fn test_message_constructors() {
let msg = Message::system("You are helpful");
assert_eq!(msg.role, MessageRole::System);
assert_eq!(msg.content, "You are helpful");
let msg = Message::user("Hello");
assert_eq!(msg.role, MessageRole::User);
let msg = Message::assistant("Hi there");
assert_eq!(msg.role, MessageRole::Assistant);
let msg = Message::tool_result("call_1", "result");
assert_eq!(msg.role, MessageRole::Tool);
assert_eq!(msg.tool_call_id, Some("call_1".to_string()));
}
#[test]
fn test_agent_state_messages() {
let mut state = AgentState::new();
state.add_user_message("Hello");
state.add_assistant_message("Hi!");
assert_eq!(state.messages.len(), 2);
assert_eq!(state.last_message().unwrap().role, MessageRole::Assistant);
assert_eq!(
state.last_user_message().unwrap().content,
"Hello".to_string()
);
}
#[test]
fn test_agent_state_context() {
let mut state = AgentState::new();
state.set_context("count", 42i32);
state.set_context("name", "test".to_string());
assert_eq!(state.get_context::<i32>("count"), Some(42));
assert_eq!(
state.get_context::<String>("name"),
Some("test".to_string())
);
assert!(state.has_context("count"));
assert!(!state.has_context("missing"));
state.remove_context("count");
assert!(!state.has_context("count"));
}
#[test]
fn test_tool_call() {
let call = ToolCall::new("id1", "get_weather", serde_json::json!({"city": "NYC"}));
assert_eq!(call.id, "id1");
assert_eq!(call.name, "get_weather");
}
}