use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::agent::error::AgentResult;
use super::StateUpdate;
#[async_trait]
pub trait GraphState: Clone + Send + Sync + 'static {
async fn apply_update(&mut self, key: &str, value: Value) -> AgentResult<()>;
async fn apply_updates(&mut self, updates: &[StateUpdate]) -> AgentResult<()> {
for update in updates {
self.apply_update(&update.key, update.value.clone()).await?;
}
Ok(())
}
fn get_value(&self, key: &str) -> Option<Value>;
fn keys(&self) -> Vec<&str>;
fn has_key(&self, key: &str) -> bool {
self.keys().contains(&key)
}
fn to_json(&self) -> AgentResult<Value>;
fn from_json(value: Value) -> AgentResult<Self>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateSchema {
pub name: String,
pub fields: Vec<StateField>,
pub version: String,
}
impl StateSchema {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
fields: Vec::new(),
version: "1.0".to_string(),
}
}
pub fn add_field(mut self, field: StateField) -> Self {
self.fields.push(field);
self
}
pub fn field(mut self, name: impl Into<String>, type_name: impl Into<String>) -> Self {
self.fields.push(StateField {
name: name.into(),
type_name: type_name.into(),
description: None,
default: None,
required: false,
});
self
}
pub fn get_field(&self, name: &str) -> Option<&StateField> {
self.fields.iter().find(|f| f.name == name)
}
pub fn field_names(&self) -> Vec<&str> {
self.fields.iter().map(|f| f.name.as_str()).collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateField {
pub name: String,
pub type_name: String,
pub description: Option<String>,
pub default: Option<Value>,
pub required: bool,
}
impl StateField {
pub fn new(name: impl Into<String>, type_name: impl Into<String>) -> Self {
Self {
name: name.into(),
type_name: type_name.into(),
description: None,
default: None,
required: false,
}
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn with_default(mut self, default: Value) -> Self {
self.default = Some(default);
self
}
pub fn with_required(mut self, required: bool) -> Self {
self.required = required;
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct JsonState {
data: serde_json::Map<String, Value>,
}
impl JsonState {
pub fn new() -> Self {
Self::default()
}
pub fn from_map(data: serde_json::Map<String, Value>) -> Self {
Self { data }
}
pub fn from_value(value: Value) -> AgentResult<Self> {
match value {
Value::Object(map) => Ok(Self { data: map }),
_ => Err(crate::agent::error::AgentError::InvalidInput(
"State must be a JSON object".to_string(),
)),
}
}
pub fn as_map(&self) -> &serde_json::Map<String, Value> {
&self.data
}
pub fn as_map_mut(&mut self) -> &mut serde_json::Map<String, Value> {
&mut self.data
}
}
#[async_trait]
impl GraphState for JsonState {
async fn apply_update(&mut self, key: &str, value: Value) -> AgentResult<()> {
self.data.insert(key.to_string(), value);
Ok(())
}
fn get_value(&self, key: &str) -> Option<Value> {
self.data.get(key).cloned()
}
fn keys(&self) -> Vec<&str> {
self.data.keys().map(|s| s.as_str()).collect()
}
fn to_json(&self) -> AgentResult<Value> {
Ok(Value::Object(self.data.clone()))
}
fn from_json(value: Value) -> AgentResult<Self> {
Self::from_value(value)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_json_state() {
let mut state = JsonState::new();
state.apply_update("name", json!("test")).await.unwrap();
state.apply_update("count", json!(42)).await.unwrap();
assert_eq!(state.get_value("name"), Some(json!("test")));
assert_eq!(state.get_value("count"), Some(json!(42)));
assert!(state.has_key("name"));
assert!(!state.has_key("unknown"));
let keys: Vec<&str> = state.keys();
assert_eq!(keys.len(), 2);
}
#[test]
fn test_state_schema() {
let schema = StateSchema::new("MyState")
.field("messages", "array")
.field("result", "string")
.add_field(
StateField::new("count", "number")
.with_description("Execution count")
.with_default(json!(0))
.with_required(true),
);
assert_eq!(schema.name, "MyState");
assert_eq!(schema.fields.len(), 3);
assert!(schema.get_field("messages").is_some());
assert!(schema.get_field("count").unwrap().required);
}
#[test]
fn test_json_state_from_value() {
let value = json!({
"key1": "value1",
"key2": 123
});
let state = JsonState::from_json(value).unwrap();
assert_eq!(state.get_value("key1"), Some(json!("value1")));
assert_eq!(state.get_value("key2"), Some(json!(123)));
}
#[test]
fn test_json_state_invalid_input() {
let result = JsonState::from_json(json!("not an object"));
assert!(result.is_err());
}
}