mofa_kernel/workflow/
state.rs1use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::agent::error::AgentResult;
11
12use super::StateUpdate;
13
14#[async_trait]
61pub trait GraphState: Clone + Send + Sync + 'static {
62 async fn apply_update(&mut self, key: &str, value: Value) -> AgentResult<()>;
67
68 async fn apply_updates(&mut self, updates: &[StateUpdate]) -> AgentResult<()> {
70 for update in updates {
71 self.apply_update(&update.key, update.value.clone()).await?;
72 }
73 Ok(())
74 }
75
76 fn get_value(&self, key: &str) -> Option<Value>;
80
81 fn keys(&self) -> Vec<&str>;
83
84 fn has_key(&self, key: &str) -> bool {
86 self.keys().contains(&key)
87 }
88
89 fn to_json(&self) -> AgentResult<Value>;
91
92 fn from_json(value: Value) -> AgentResult<Self>;
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct StateSchema {
102 pub name: String,
104 pub fields: Vec<StateField>,
106 pub version: String,
108}
109
110impl StateSchema {
111 pub fn new(name: impl Into<String>) -> Self {
113 Self {
114 name: name.into(),
115 fields: Vec::new(),
116 version: "1.0".to_string(),
117 }
118 }
119
120 pub fn add_field(mut self, field: StateField) -> Self {
122 self.fields.push(field);
123 self
124 }
125
126 pub fn field(
128 mut self,
129 name: impl Into<String>,
130 type_name: impl Into<String>,
131 ) -> Self {
132 self.fields.push(StateField {
133 name: name.into(),
134 type_name: type_name.into(),
135 description: None,
136 default: None,
137 required: false,
138 });
139 self
140 }
141
142 pub fn get_field(&self, name: &str) -> Option<&StateField> {
144 self.fields.iter().find(|f| f.name == name)
145 }
146
147 pub fn field_names(&self) -> Vec<&str> {
149 self.fields.iter().map(|f| f.name.as_str()).collect()
150 }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct StateField {
156 pub name: String,
158 pub type_name: String,
160 pub description: Option<String>,
162 pub default: Option<Value>,
164 pub required: bool,
166}
167
168impl StateField {
169 pub fn new(name: impl Into<String>, type_name: impl Into<String>) -> Self {
171 Self {
172 name: name.into(),
173 type_name: type_name.into(),
174 description: None,
175 default: None,
176 required: false,
177 }
178 }
179
180 pub fn with_description(mut self, description: impl Into<String>) -> Self {
182 self.description = Some(description.into());
183 self
184 }
185
186 pub fn with_default(mut self, default: Value) -> Self {
188 self.default = Some(default);
189 self
190 }
191
192 pub fn with_required(mut self, required: bool) -> Self {
194 self.required = required;
195 self
196 }
197}
198
199#[derive(Debug, Clone, Default, Serialize, Deserialize)]
204pub struct JsonState {
205 data: serde_json::Map<String, Value>,
206}
207
208impl JsonState {
209 pub fn new() -> Self {
211 Self::default()
212 }
213
214 pub fn from_map(data: serde_json::Map<String, Value>) -> Self {
216 Self { data }
217 }
218
219 pub fn from_value(value: Value) -> AgentResult<Self> {
221 match value {
222 Value::Object(map) => Ok(Self { data: map }),
223 _ => Err(crate::agent::error::AgentError::InvalidInput(
224 "State must be a JSON object".to_string(),
225 )),
226 }
227 }
228
229 pub fn as_map(&self) -> &serde_json::Map<String, Value> {
231 &self.data
232 }
233
234 pub fn as_map_mut(&mut self) -> &mut serde_json::Map<String, Value> {
236 &mut self.data
237 }
238}
239
240#[async_trait]
241impl GraphState for JsonState {
242 async fn apply_update(&mut self, key: &str, value: Value) -> AgentResult<()> {
243 self.data.insert(key.to_string(), value);
244 Ok(())
245 }
246
247 fn get_value(&self, key: &str) -> Option<Value> {
248 self.data.get(key).cloned()
249 }
250
251 fn keys(&self) -> Vec<&str> {
252 self.data.keys().map(|s| s.as_str()).collect()
253 }
254
255 fn to_json(&self) -> AgentResult<Value> {
256 Ok(Value::Object(self.data.clone()))
257 }
258
259 fn from_json(value: Value) -> AgentResult<Self> {
260 Self::from_value(value)
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use serde_json::json;
268
269 #[tokio::test]
270 async fn test_json_state() {
271 let mut state = JsonState::new();
272
273 state.apply_update("name", json!("test")).await.unwrap();
274 state.apply_update("count", json!(42)).await.unwrap();
275
276 assert_eq!(state.get_value("name"), Some(json!("test")));
277 assert_eq!(state.get_value("count"), Some(json!(42)));
278 assert!(state.has_key("name"));
279 assert!(!state.has_key("unknown"));
280
281 let keys: Vec<&str> = state.keys();
282 assert_eq!(keys.len(), 2);
283 }
284
285 #[test]
286 fn test_state_schema() {
287 let schema = StateSchema::new("MyState")
288 .field("messages", "array")
289 .field("result", "string")
290 .add_field(
291 StateField::new("count", "number")
292 .with_description("Execution count")
293 .with_default(json!(0))
294 .with_required(true),
295 );
296
297 assert_eq!(schema.name, "MyState");
298 assert_eq!(schema.fields.len(), 3);
299 assert!(schema.get_field("messages").is_some());
300 assert!(schema.get_field("count").unwrap().required);
301 }
302
303 #[test]
304 fn test_json_state_from_value() {
305 let value = json!({
306 "key1": "value1",
307 "key2": 123
308 });
309
310 let state = JsonState::from_json(value).unwrap();
311 assert_eq!(state.get_value("key1"), Some(json!("value1")));
312 assert_eq!(state.get_value("key2"), Some(json!(123)));
313 }
314
315 #[test]
316 fn test_json_state_invalid_input() {
317 let result = JsonState::from_json(json!("not an object"));
318 assert!(result.is_err());
319 }
320}