Skip to main content

mofa_kernel/workflow/
state.rs

1//! Graph State Trait and Types
2//!
3//! Defines the state management interface for workflow graphs.
4//! The GraphState trait allows custom state types to work with the workflow system.
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::agent::error::AgentResult;
11
12use super::StateUpdate;
13
14/// Graph state trait
15///
16/// Implement this trait to define custom state types for workflows.
17/// The trait provides methods for applying updates and serialization.
18///
19/// # Example
20///
21/// ```rust,ignore
22/// use serde::{Serialize, Deserialize};
23/// use mofa_kernel::workflow::GraphState;
24///
25/// #[derive(Clone, Serialize, Deserialize)]
26/// struct MyState {
27///     messages: Vec<String>,
28///     result: Option<String>,
29/// }
30///
31/// impl GraphState for MyState {
32///     async fn apply_update(&mut self, key: &str, value: Value) -> AgentResult<()> {
33///         match key {
34///             "messages" => {
35///                 if let Some(msg) = value.as_str() {
36///                     self.messages.push(msg.to_string());
37///                 }
38///             }
39///             "result" => {
40///                 self.result = value.as_str().map(|s| s.to_string());
41///             }
42///             _ => {}
43///         }
44///         Ok(())
45///     }
46///
47///     fn get_value(&self, key: &str) -> Option<Value> {
48///         match key {
49///             "messages" => Some(serde_json::to_value(&self.messages).unwrap()),
50///             "result" => Some(serde_json::to_value(&self.result).unwrap()),
51///             _ => None,
52///         }
53///     }
54///
55///     fn keys(&self) -> Vec<&str> {
56///         vec!["messages", "result"]
57///     }
58/// }
59/// ```
60#[async_trait]
61pub trait GraphState: Clone + Send + Sync + 'static {
62    /// Apply a state update
63    ///
64    /// This method is called when a node returns state updates.
65    /// The implementation should merge the update into the state.
66    async fn apply_update(&mut self, key: &str, value: Value) -> AgentResult<()>;
67
68    /// Apply multiple updates
69    async fn apply_updates(&mut self, updates: &[StateUpdate]) -> AgentResult<()> {
70        for update in updates {
71            self.apply_update(&update.key, update.value.clone()).await?;
72        }
73        Ok(())
74    }
75
76    /// Get a value by key
77    ///
78    /// Returns the current value for a given key, or None if the key doesn't exist.
79    fn get_value(&self, key: &str) -> Option<Value>;
80
81    /// Get all keys in this state
82    fn keys(&self) -> Vec<&str>;
83
84    /// Check if a key exists
85    fn has_key(&self, key: &str) -> bool {
86        self.keys().contains(&key)
87    }
88
89    /// Convert entire state to a JSON Value
90    fn to_json(&self) -> AgentResult<Value>;
91
92    /// Create state from a JSON Value
93    fn from_json(value: Value) -> AgentResult<Self>;
94}
95
96/// State schema for validation and documentation
97///
98/// Describes the structure of a graph's state, including
99/// key names, types, and reducer configurations.
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct StateSchema {
102    /// Schema name
103    pub name: String,
104    /// Field definitions
105    pub fields: Vec<StateField>,
106    /// Schema version
107    pub version: String,
108}
109
110impl StateSchema {
111    /// Create a new state schema
112    pub fn new(name: impl Into<String>) -> Self {
113        Self {
114            name: name.into(),
115            fields: Vec::new(),
116            version: "1.0".to_string(),
117        }
118    }
119
120    /// Add a field to the schema
121    pub fn add_field(mut self, field: StateField) -> Self {
122        self.fields.push(field);
123        self
124    }
125
126    /// Add a simple field
127    pub fn field(
128        mut self,
129        name: impl Into<String>,
130        type_name: impl Into<String>,
131    ) -> Self {
132        self.fields.push(StateField {
133            name: name.into(),
134            type_name: type_name.into(),
135            description: None,
136            default: None,
137            required: false,
138        });
139        self
140    }
141
142    /// Get a field by name
143    pub fn get_field(&self, name: &str) -> Option<&StateField> {
144        self.fields.iter().find(|f| f.name == name)
145    }
146
147    /// Get all field names
148    pub fn field_names(&self) -> Vec<&str> {
149        self.fields.iter().map(|f| f.name.as_str()).collect()
150    }
151}
152
153/// A single field in the state schema
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct StateField {
156    /// Field name
157    pub name: String,
158    /// Type name (e.g., "string", "number", "array", "object")
159    pub type_name: String,
160    /// Field description
161    pub description: Option<String>,
162    /// Default value
163    pub default: Option<Value>,
164    /// Whether this field is required
165    pub required: bool,
166}
167
168impl StateField {
169    /// Create a new state field
170    pub fn new(name: impl Into<String>, type_name: impl Into<String>) -> Self {
171        Self {
172            name: name.into(),
173            type_name: type_name.into(),
174            description: None,
175            default: None,
176            required: false,
177        }
178    }
179
180    /// Set description
181    pub fn with_description(mut self, description: impl Into<String>) -> Self {
182        self.description = Some(description.into());
183        self
184    }
185
186    /// Set default value
187    pub fn with_default(mut self, default: Value) -> Self {
188        self.default = Some(default);
189        self
190    }
191
192    /// Set required flag
193    pub fn with_required(mut self, required: bool) -> Self {
194        self.required = required;
195        self
196    }
197}
198
199/// A simple JSON-based state implementation
200///
201/// This is a basic implementation of GraphState that uses a JSON object
202/// as the backing store. Useful for simple workflows or testing.
203#[derive(Debug, Clone, Default, Serialize, Deserialize)]
204pub struct JsonState {
205    data: serde_json::Map<String, Value>,
206}
207
208impl JsonState {
209    /// Create a new empty JSON state
210    pub fn new() -> Self {
211        Self::default()
212    }
213
214    /// Create from a JSON object
215    pub fn from_map(data: serde_json::Map<String, Value>) -> Self {
216        Self { data }
217    }
218
219    /// Create from a JSON value (must be an object)
220    pub fn from_value(value: Value) -> AgentResult<Self> {
221        match value {
222            Value::Object(map) => Ok(Self { data: map }),
223            _ => Err(crate::agent::error::AgentError::InvalidInput(
224                "State must be a JSON object".to_string(),
225            )),
226        }
227    }
228
229    /// Get a reference to the underlying map
230    pub fn as_map(&self) -> &serde_json::Map<String, Value> {
231        &self.data
232    }
233
234    /// Get a mutable reference to the underlying map
235    pub fn as_map_mut(&mut self) -> &mut serde_json::Map<String, Value> {
236        &mut self.data
237    }
238}
239
240#[async_trait]
241impl GraphState for JsonState {
242    async fn apply_update(&mut self, key: &str, value: Value) -> AgentResult<()> {
243        self.data.insert(key.to_string(), value);
244        Ok(())
245    }
246
247    fn get_value(&self, key: &str) -> Option<Value> {
248        self.data.get(key).cloned()
249    }
250
251    fn keys(&self) -> Vec<&str> {
252        self.data.keys().map(|s| s.as_str()).collect()
253    }
254
255    fn to_json(&self) -> AgentResult<Value> {
256        Ok(Value::Object(self.data.clone()))
257    }
258
259    fn from_json(value: Value) -> AgentResult<Self> {
260        Self::from_value(value)
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use serde_json::json;
268
269    #[tokio::test]
270    async fn test_json_state() {
271        let mut state = JsonState::new();
272
273        state.apply_update("name", json!("test")).await.unwrap();
274        state.apply_update("count", json!(42)).await.unwrap();
275
276        assert_eq!(state.get_value("name"), Some(json!("test")));
277        assert_eq!(state.get_value("count"), Some(json!(42)));
278        assert!(state.has_key("name"));
279        assert!(!state.has_key("unknown"));
280
281        let keys: Vec<&str> = state.keys();
282        assert_eq!(keys.len(), 2);
283    }
284
285    #[test]
286    fn test_state_schema() {
287        let schema = StateSchema::new("MyState")
288            .field("messages", "array")
289            .field("result", "string")
290            .add_field(
291                StateField::new("count", "number")
292                    .with_description("Execution count")
293                    .with_default(json!(0))
294                    .with_required(true),
295            );
296
297        assert_eq!(schema.name, "MyState");
298        assert_eq!(schema.fields.len(), 3);
299        assert!(schema.get_field("messages").is_some());
300        assert!(schema.get_field("count").unwrap().required);
301    }
302
303    #[test]
304    fn test_json_state_from_value() {
305        let value = json!({
306            "key1": "value1",
307            "key2": 123
308        });
309
310        let state = JsonState::from_json(value).unwrap();
311        assert_eq!(state.get_value("key1"), Some(json!("value1")));
312        assert_eq!(state.get_value("key2"), Some(json!(123)));
313    }
314
315    #[test]
316    fn test_json_state_invalid_input() {
317        let result = JsonState::from_json(json!("not an object"));
318        assert!(result.is_err());
319    }
320}