use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Debug;
pub trait StateSchema: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + Debug {
fn from_input(input: Self) -> Self {
input
}
fn to_json(&self) -> serde_json::Value {
serde_json::to_value(self).unwrap_or(serde_json::Value::Null)
}
}
#[derive(Debug, Clone, Serialize)]
pub struct StateUpdate<S: StateSchema> {
pub update: Option<S>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl<S: StateSchema> StateUpdate<S> {
pub fn full(state: S) -> Self {
Self {
update: Some(state),
metadata: HashMap::new(),
}
}
pub fn with_metadata(state: S, metadata: HashMap<String, serde_json::Value>) -> Self {
Self {
update: Some(state),
metadata,
}
}
pub fn unchanged() -> Self {
Self {
update: None,
metadata: HashMap::new(),
}
}
pub fn add_metadata(&mut self, key: String, value: serde_json::Value) {
self.metadata.insert(key, value);
}
}
pub trait Reducer<S: StateSchema>: Send + Sync {
fn reduce(&self, current: &S, update: &S) -> S;
}
pub struct ReplaceReducer;
impl<S: StateSchema> Reducer<S> for ReplaceReducer {
fn reduce(&self, _current: &S, update: &S) -> S {
update.clone()
}
}
pub struct AppendReducer<S: StateSchema, T: Clone + Send + Sync> {
pub field_accessor: fn(&S) -> &[T],
pub field_mutator: fn(&mut S, Vec<T>),
}
impl<S: StateSchema, T: Clone + Send + Sync> Reducer<S> for AppendReducer<S, T> {
fn reduce(&self, current: &S, update: &S) -> S {
let current_items = (self.field_accessor)(current);
let update_items = (self.field_accessor)(update);
let mut merged: Vec<T> = current_items.to_vec();
merged.extend(update_items.iter().cloned());
let mut result = current.clone();
(self.field_mutator)(&mut result, merged);
result
}
}
pub struct AppendMessagesReducer;
impl Reducer<AgentState> for AppendMessagesReducer {
fn reduce(&self, current: &AgentState, update: &AgentState) -> AgentState {
let mut result = update.clone();
result.messages = current.messages.clone();
result.messages.extend(update.messages.iter().cloned());
result
}
}
pub struct AppendStepsReducer;
impl Reducer<AgentState> for AppendStepsReducer {
fn reduce(&self, current: &AgentState, update: &AgentState) -> AgentState {
let mut result = update.clone();
result.steps = current.steps.clone();
result.steps.extend(update.steps.iter().cloned());
result
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentState {
pub input: String,
pub messages: Vec<MessageEntry>,
pub steps: Vec<StepEntry>,
pub output: Option<String>,
}
impl StateSchema for AgentState {}
impl AgentState {
pub fn new(input: String) -> Self {
let msg = MessageEntry::human(input.clone());
Self {
input,
messages: vec![msg],
steps: vec![],
output: None,
}
}
pub fn add_message(&mut self, message: MessageEntry) {
self.messages.push(message);
}
pub fn add_step(&mut self, step: StepEntry) {
self.steps.push(step);
}
pub fn set_output(&mut self, output: String) {
self.output = Some(output);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageEntry {
pub role: MessageRole,
pub content: String,
}
impl MessageEntry {
pub fn human(content: String) -> Self {
Self {
role: MessageRole::Human,
content,
}
}
pub fn ai(content: String) -> Self {
Self {
role: MessageRole::AI,
content,
}
}
pub fn system(content: String) -> Self {
Self {
role: MessageRole::System,
content,
}
}
pub fn tool(content: String) -> Self {
Self {
role: MessageRole::Tool,
content,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum MessageRole {
System,
Human,
AI,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepEntry {
pub action: String,
pub observation: String,
}
impl StepEntry {
pub fn new(action: String, observation: String) -> Self {
Self {
action,
observation,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_append_messages_reducer() {
let mut current = AgentState::new("Hello".to_string());
current.add_message(MessageEntry::ai("Response 1".to_string()));
let mut update = AgentState::new("Hello".to_string());
update.add_message(MessageEntry::ai("Response 2".to_string()));
update.set_output("Done".to_string());
let reducer = AppendMessagesReducer;
let result = reducer.reduce(¤t, &update);
assert_eq!(result.messages.len(), 3);
assert_eq!(result.output, Some("Done".to_string()));
}
#[test]
fn test_append_steps_reducer() {
let mut current = AgentState::new("Test".to_string());
current.add_step(StepEntry::new(
"Action 1".to_string(),
"Result 1".to_string(),
));
let mut update = AgentState::new("Test".to_string());
update.add_step(StepEntry::new(
"Action 2".to_string(),
"Result 2".to_string(),
));
let reducer = AppendStepsReducer;
let result = reducer.reduce(¤t, &update);
assert_eq!(result.steps.len(), 2);
}
}