Skip to main content

mofa_kernel/workflow/
state.rs

1//! Graph State Trait and Types
2//!
3//! Defines the state management interface for workflow graphs.
4//! The GraphState trait allows custom state types to work with the workflow system.
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::agent::error::AgentResult;
11
12use super::StateUpdate;
13
14/// Graph state trait
15///
16/// Implement this trait to define custom state types for workflows.
17/// The trait provides methods for applying updates and serialization.
18///
19/// # Example
20///
21/// ```rust,ignore
22/// use serde::{Serialize, Deserialize};
23/// use mofa_kernel::workflow::GraphState;
24///
25/// #[derive(Clone, Serialize, Deserialize)]
26/// struct MyState {
27///     messages: Vec<String>,
28///     result: Option<String>,
29/// }
30///
31/// impl GraphState for MyState {
32///     async fn apply_update(&mut self, key: &str, value: Value) -> AgentResult<()> {
33///         match key {
34///             "messages" => {
35///                 if let Some(msg) = value.as_str() {
36///                     self.messages.push(msg.to_string());
37///                 }
38///             }
39///             "result" => {
40///                 self.result = value.as_str().map(|s| s.to_string());
41///             }
42///             _ => {}
43///         }
44///         Ok(())
45///     }
46///
47///     fn get_value(&self, key: &str) -> Option<Value> {
48///         match key {
49///             "messages" => Some(serde_json::to_value(&self.messages).unwrap()),
50///             "result" => Some(serde_json::to_value(&self.result).unwrap()),
51///             _ => None,
52///         }
53///     }
54///
55///     fn keys(&self) -> Vec<&str> {
56///         vec!["messages", "result"]
57///     }
58/// }
59/// ```
60#[async_trait]
61pub trait GraphState: Clone + Send + Sync + 'static {
62    /// Apply a state update
63    ///
64    /// This method is called when a node returns state updates.
65    /// The implementation should merge the update into the state.
66    async fn apply_update(&mut self, key: &str, value: Value) -> AgentResult<()>;
67
68    /// Apply multiple updates
69    async fn apply_updates(&mut self, updates: &[StateUpdate]) -> AgentResult<()> {
70        for update in updates {
71            self.apply_update(&update.key, update.value.clone()).await?;
72        }
73        Ok(())
74    }
75
76    /// Get a value by key
77    ///
78    /// Returns the current value for a given key, or None if the key doesn't exist.
79    fn get_value(&self, key: &str) -> Option<Value>;
80
81    /// Get all keys in this state
82    fn keys(&self) -> Vec<&str>;
83
84    /// Check if a key exists
85    fn has_key(&self, key: &str) -> bool {
86        self.keys().contains(&key)
87    }
88
89    /// Convert entire state to a JSON Value
90    fn to_json(&self) -> AgentResult<Value>;
91
92    /// Create state from a JSON Value
93    fn from_json(value: Value) -> AgentResult<Self>;
94}
95
96/// State schema for validation and documentation
97///
98/// Describes the structure of a graph's state, including
99/// key names, types, and reducer configurations.
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct StateSchema {
102    /// Schema name
103    pub name: String,
104    /// Field definitions
105    pub fields: Vec<StateField>,
106    /// Schema version
107    pub version: String,
108}
109
110impl StateSchema {
111    /// Create a new state schema
112    pub fn new(name: impl Into<String>) -> Self {
113        Self {
114            name: name.into(),
115            fields: Vec::new(),
116            version: "1.0".to_string(),
117        }
118    }
119
120    /// Add a field to the schema
121    pub fn add_field(mut self, field: StateField) -> Self {
122        self.fields.push(field);
123        self
124    }
125
126    /// Add a simple field
127    pub fn field(mut self, name: impl Into<String>, type_name: impl Into<String>) -> Self {
128        self.fields.push(StateField {
129            name: name.into(),
130            type_name: type_name.into(),
131            description: None,
132            default: None,
133            required: false,
134        });
135        self
136    }
137
138    /// Get a field by name
139    pub fn get_field(&self, name: &str) -> Option<&StateField> {
140        self.fields.iter().find(|f| f.name == name)
141    }
142
143    /// Get all field names
144    pub fn field_names(&self) -> Vec<&str> {
145        self.fields.iter().map(|f| f.name.as_str()).collect()
146    }
147}
148
149/// A single field in the state schema
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct StateField {
152    /// Field name
153    pub name: String,
154    /// Type name (e.g., "string", "number", "array", "object")
155    pub type_name: String,
156    /// Field description
157    pub description: Option<String>,
158    /// Default value
159    pub default: Option<Value>,
160    /// Whether this field is required
161    pub required: bool,
162}
163
164impl StateField {
165    /// Create a new state field
166    pub fn new(name: impl Into<String>, type_name: impl Into<String>) -> Self {
167        Self {
168            name: name.into(),
169            type_name: type_name.into(),
170            description: None,
171            default: None,
172            required: false,
173        }
174    }
175
176    /// Set description
177    pub fn with_description(mut self, description: impl Into<String>) -> Self {
178        self.description = Some(description.into());
179        self
180    }
181
182    /// Set default value
183    pub fn with_default(mut self, default: Value) -> Self {
184        self.default = Some(default);
185        self
186    }
187
188    /// Set required flag
189    pub fn with_required(mut self, required: bool) -> Self {
190        self.required = required;
191        self
192    }
193}
194
195/// A simple JSON-based state implementation
196///
197/// This is a basic implementation of GraphState that uses a JSON object
198/// as the backing store. Useful for simple workflows or testing.
199#[derive(Debug, Clone, Default, Serialize, Deserialize)]
200pub struct JsonState {
201    data: serde_json::Map<String, Value>,
202}
203
204impl JsonState {
205    /// Create a new empty JSON state
206    pub fn new() -> Self {
207        Self::default()
208    }
209
210    /// Create from a JSON object
211    pub fn from_map(data: serde_json::Map<String, Value>) -> Self {
212        Self { data }
213    }
214
215    /// Create from a JSON value (must be an object)
216    pub fn from_value(value: Value) -> AgentResult<Self> {
217        match value {
218            Value::Object(map) => Ok(Self { data: map }),
219            _ => Err(crate::agent::error::AgentError::InvalidInput(
220                "State must be a JSON object".to_string(),
221            )),
222        }
223    }
224
225    /// Get a reference to the underlying map
226    pub fn as_map(&self) -> &serde_json::Map<String, Value> {
227        &self.data
228    }
229
230    /// Get a mutable reference to the underlying map
231    pub fn as_map_mut(&mut self) -> &mut serde_json::Map<String, Value> {
232        &mut self.data
233    }
234}
235
236#[async_trait]
237impl GraphState for JsonState {
238    async fn apply_update(&mut self, key: &str, value: Value) -> AgentResult<()> {
239        self.data.insert(key.to_string(), value);
240        Ok(())
241    }
242
243    fn get_value(&self, key: &str) -> Option<Value> {
244        self.data.get(key).cloned()
245    }
246
247    fn keys(&self) -> Vec<&str> {
248        self.data.keys().map(|s| s.as_str()).collect()
249    }
250
251    fn to_json(&self) -> AgentResult<Value> {
252        Ok(Value::Object(self.data.clone()))
253    }
254
255    fn from_json(value: Value) -> AgentResult<Self> {
256        Self::from_value(value)
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use serde_json::json;
264
265    #[tokio::test]
266    async fn test_json_state() {
267        let mut state = JsonState::new();
268
269        state.apply_update("name", json!("test")).await.unwrap();
270        state.apply_update("count", json!(42)).await.unwrap();
271
272        assert_eq!(state.get_value("name"), Some(json!("test")));
273        assert_eq!(state.get_value("count"), Some(json!(42)));
274        assert!(state.has_key("name"));
275        assert!(!state.has_key("unknown"));
276
277        let keys: Vec<&str> = state.keys();
278        assert_eq!(keys.len(), 2);
279    }
280
281    #[test]
282    fn test_state_schema() {
283        let schema = StateSchema::new("MyState")
284            .field("messages", "array")
285            .field("result", "string")
286            .add_field(
287                StateField::new("count", "number")
288                    .with_description("Execution count")
289                    .with_default(json!(0))
290                    .with_required(true),
291            );
292
293        assert_eq!(schema.name, "MyState");
294        assert_eq!(schema.fields.len(), 3);
295        assert!(schema.get_field("messages").is_some());
296        assert!(schema.get_field("count").unwrap().required);
297    }
298
299    #[test]
300    fn test_json_state_from_value() {
301        let value = json!({
302            "key1": "value1",
303            "key2": 123
304        });
305
306        let state = JsonState::from_json(value).unwrap();
307        assert_eq!(state.get_value("key1"), Some(json!("value1")));
308        assert_eq!(state.get_value("key2"), Some(json!(123)));
309    }
310
311    #[test]
312    fn test_json_state_invalid_input() {
313        let result = JsonState::from_json(json!("not an object"));
314        assert!(result.is_err());
315    }
316}