mofa_kernel/workflow/
state.rs1use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::agent::error::AgentResult;
11
12use super::StateUpdate;
13
14#[async_trait]
61pub trait GraphState: Clone + Send + Sync + 'static {
62 async fn apply_update(&mut self, key: &str, value: Value) -> AgentResult<()>;
67
68 async fn apply_updates(&mut self, updates: &[StateUpdate]) -> AgentResult<()> {
70 for update in updates {
71 self.apply_update(&update.key, update.value.clone()).await?;
72 }
73 Ok(())
74 }
75
76 fn get_value(&self, key: &str) -> Option<Value>;
80
81 fn keys(&self) -> Vec<&str>;
83
84 fn has_key(&self, key: &str) -> bool {
86 self.keys().contains(&key)
87 }
88
89 fn to_json(&self) -> AgentResult<Value>;
91
92 fn from_json(value: Value) -> AgentResult<Self>;
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct StateSchema {
102 pub name: String,
104 pub fields: Vec<StateField>,
106 pub version: String,
108}
109
110impl StateSchema {
111 pub fn new(name: impl Into<String>) -> Self {
113 Self {
114 name: name.into(),
115 fields: Vec::new(),
116 version: "1.0".to_string(),
117 }
118 }
119
120 pub fn add_field(mut self, field: StateField) -> Self {
122 self.fields.push(field);
123 self
124 }
125
126 pub fn field(mut self, name: impl Into<String>, type_name: impl Into<String>) -> Self {
128 self.fields.push(StateField {
129 name: name.into(),
130 type_name: type_name.into(),
131 description: None,
132 default: None,
133 required: false,
134 });
135 self
136 }
137
138 pub fn get_field(&self, name: &str) -> Option<&StateField> {
140 self.fields.iter().find(|f| f.name == name)
141 }
142
143 pub fn field_names(&self) -> Vec<&str> {
145 self.fields.iter().map(|f| f.name.as_str()).collect()
146 }
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct StateField {
152 pub name: String,
154 pub type_name: String,
156 pub description: Option<String>,
158 pub default: Option<Value>,
160 pub required: bool,
162}
163
164impl StateField {
165 pub fn new(name: impl Into<String>, type_name: impl Into<String>) -> Self {
167 Self {
168 name: name.into(),
169 type_name: type_name.into(),
170 description: None,
171 default: None,
172 required: false,
173 }
174 }
175
176 pub fn with_description(mut self, description: impl Into<String>) -> Self {
178 self.description = Some(description.into());
179 self
180 }
181
182 pub fn with_default(mut self, default: Value) -> Self {
184 self.default = Some(default);
185 self
186 }
187
188 pub fn with_required(mut self, required: bool) -> Self {
190 self.required = required;
191 self
192 }
193}
194
195#[derive(Debug, Clone, Default, Serialize, Deserialize)]
200pub struct JsonState {
201 data: serde_json::Map<String, Value>,
202}
203
204impl JsonState {
205 pub fn new() -> Self {
207 Self::default()
208 }
209
210 pub fn from_map(data: serde_json::Map<String, Value>) -> Self {
212 Self { data }
213 }
214
215 pub fn from_value(value: Value) -> AgentResult<Self> {
217 match value {
218 Value::Object(map) => Ok(Self { data: map }),
219 _ => Err(crate::agent::error::AgentError::InvalidInput(
220 "State must be a JSON object".to_string(),
221 )),
222 }
223 }
224
225 pub fn as_map(&self) -> &serde_json::Map<String, Value> {
227 &self.data
228 }
229
230 pub fn as_map_mut(&mut self) -> &mut serde_json::Map<String, Value> {
232 &mut self.data
233 }
234}
235
236#[async_trait]
237impl GraphState for JsonState {
238 async fn apply_update(&mut self, key: &str, value: Value) -> AgentResult<()> {
239 self.data.insert(key.to_string(), value);
240 Ok(())
241 }
242
243 fn get_value(&self, key: &str) -> Option<Value> {
244 self.data.get(key).cloned()
245 }
246
247 fn keys(&self) -> Vec<&str> {
248 self.data.keys().map(|s| s.as_str()).collect()
249 }
250
251 fn to_json(&self) -> AgentResult<Value> {
252 Ok(Value::Object(self.data.clone()))
253 }
254
255 fn from_json(value: Value) -> AgentResult<Self> {
256 Self::from_value(value)
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use serde_json::json;
264
265 #[tokio::test]
266 async fn test_json_state() {
267 let mut state = JsonState::new();
268
269 state.apply_update("name", json!("test")).await.unwrap();
270 state.apply_update("count", json!(42)).await.unwrap();
271
272 assert_eq!(state.get_value("name"), Some(json!("test")));
273 assert_eq!(state.get_value("count"), Some(json!(42)));
274 assert!(state.has_key("name"));
275 assert!(!state.has_key("unknown"));
276
277 let keys: Vec<&str> = state.keys();
278 assert_eq!(keys.len(), 2);
279 }
280
281 #[test]
282 fn test_state_schema() {
283 let schema = StateSchema::new("MyState")
284 .field("messages", "array")
285 .field("result", "string")
286 .add_field(
287 StateField::new("count", "number")
288 .with_description("Execution count")
289 .with_default(json!(0))
290 .with_required(true),
291 );
292
293 assert_eq!(schema.name, "MyState");
294 assert_eq!(schema.fields.len(), 3);
295 assert!(schema.get_field("messages").is_some());
296 assert!(schema.get_field("count").unwrap().required);
297 }
298
299 #[test]
300 fn test_json_state_from_value() {
301 let value = json!({
302 "key1": "value1",
303 "key2": 123
304 });
305
306 let state = JsonState::from_json(value).unwrap();
307 assert_eq!(state.get_value("key1"), Some(json!("value1")));
308 assert_eq!(state.get_value("key2"), Some(json!(123)));
309 }
310
311 #[test]
312 fn test_json_state_invalid_input() {
313 let result = JsonState::from_json(json!("not an object"));
314 assert!(result.is_err());
315 }
316}