sayr_engine/
workflow.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use serde_json::{Map, Value};
7
8use crate::agent::Agent;
9use crate::{LanguageModel, Result};
10
11/// Shared state threaded through a workflow execution.
12#[derive(Debug, Clone, Default)]
13pub struct WorkflowContext {
14    pub state: Map<String, Value>,
15    pub logs: Vec<String>,
16}
17
18impl WorkflowContext {
19    pub fn insert(&mut self, key: impl Into<String>, value: Value) {
20        self.state.insert(key.into(), value);
21    }
22
23    pub fn get(&self, key: &str) -> Option<&Value> {
24        self.state.get(key)
25    }
26}
27
28#[async_trait]
29pub trait WorkflowTask: Send + Sync {
30    async fn run(&self, ctx: &mut WorkflowContext) -> Result<Value>;
31}
32
33type TaskFuture<'a> = Pin<Box<dyn Future<Output = Result<Value>> + Send + 'a>>;
34
35/// Wrap a plain async function as a workflow task.
36pub struct FunctionTask<F>
37where
38    F: for<'a> Fn(&'a mut WorkflowContext) -> TaskFuture<'a> + Send + Sync,
39{
40    func: F,
41}
42
43impl<F> FunctionTask<F>
44where
45    F: for<'a> Fn(&'a mut WorkflowContext) -> TaskFuture<'a> + Send + Sync,
46{
47    pub fn new(func: F) -> Self {
48        Self { func }
49    }
50}
51
52#[async_trait]
53impl<F> WorkflowTask for FunctionTask<F>
54where
55    F: for<'a> Fn(&'a mut WorkflowContext) -> TaskFuture<'a> + Send + Sync,
56{
57    async fn run(&self, ctx: &mut WorkflowContext) -> Result<Value> {
58        (self.func)(ctx).await
59    }
60}
61
62/// Task that dispatches to an individual agent and stores the reply under a key.
63pub struct AgentTask<M: LanguageModel> {
64    agent: Arc<tokio::sync::Mutex<Agent<M>>>,
65    prompt_key: Option<String>,
66    store_under: Option<String>,
67    fallback_prompt: String,
68}
69
70impl<M: LanguageModel> AgentTask<M> {
71    pub fn new(
72        agent: Arc<tokio::sync::Mutex<Agent<M>>>,
73        prompt_key: Option<String>,
74        store_under: Option<String>,
75        fallback_prompt: impl Into<String>,
76    ) -> Self {
77        Self {
78            agent,
79            prompt_key,
80            store_under,
81            fallback_prompt: fallback_prompt.into(),
82        }
83    }
84}
85
86#[async_trait]
87impl<M: LanguageModel> WorkflowTask for AgentTask<M> {
88    async fn run(&self, ctx: &mut WorkflowContext) -> Result<Value> {
89        let prompt = self
90            .prompt_key
91            .as_ref()
92            .and_then(|k| ctx.get(k))
93            .and_then(|v| v.as_str())
94            .unwrap_or(&self.fallback_prompt)
95            .to_string();
96        let mut agent = self.agent.lock().await;
97        let reply = agent.respond(prompt).await?;
98        let value = Value::String(reply.clone());
99        if let Some(key) = &self.store_under {
100            ctx.insert(key.clone(), value.clone());
101        }
102        Ok(value)
103    }
104}
105
106pub type Condition = Arc<dyn Fn(&WorkflowContext) -> bool + Send + Sync>;
107
108#[derive(Clone)]
109pub enum WorkflowNode {
110    Task(Arc<dyn WorkflowTask>),
111    Sequence(Vec<WorkflowNode>),
112    Parallel(Vec<WorkflowNode>),
113    Conditional {
114        condition: Condition,
115        then_branch: Box<WorkflowNode>,
116        else_branch: Option<Box<WorkflowNode>>,
117    },
118    Loop {
119        condition: Condition,
120        body: Box<WorkflowNode>,
121        max_iterations: usize,
122    },
123}
124
125impl WorkflowNode {
126    fn execute<'a>(
127        &'a self,
128        ctx: &'a mut WorkflowContext,
129    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Value>> + Send + 'a>> {
130        Box::pin(async move {
131            match self {
132                WorkflowNode::Task(task) => task.run(ctx).await,
133                WorkflowNode::Sequence(steps) => {
134                    let mut last = Value::Null;
135                    for step in steps {
136                        last = step.execute(ctx).await?;
137                    }
138                    Ok(last)
139                }
140                WorkflowNode::Parallel(steps) => {
141                    let mut combined = Vec::new();
142                    for step in steps {
143                        combined.push(step.execute(ctx).await?);
144                    }
145                    Ok(Value::Array(combined))
146                }
147                WorkflowNode::Conditional {
148                    condition,
149                    then_branch,
150                    else_branch,
151                } => {
152                    if condition(ctx) {
153                        then_branch.execute(ctx).await
154                    } else if let Some(other) = else_branch {
155                        other.execute(ctx).await
156                    } else {
157                        Ok(Value::Null)
158                    }
159                }
160                WorkflowNode::Loop {
161                    condition,
162                    body,
163                    max_iterations,
164                } => {
165                    let mut last = Value::Null;
166                    for _ in 0..*max_iterations {
167                        if !(condition)(ctx) {
168                            break;
169                        }
170                        last = body.execute(ctx).await?;
171                    }
172                    Ok(last)
173                }
174            }
175        })
176    }
177}
178
179#[derive(Clone)]
180pub struct Workflow {
181    pub name: String,
182    pub root: WorkflowNode,
183}
184
185impl Workflow {
186    pub fn new(name: impl Into<String>, root: WorkflowNode) -> Self {
187        Self {
188            name: name.into(),
189            root,
190        }
191    }
192
193    pub async fn run(&self, ctx: &mut WorkflowContext) -> Result<Value> {
194        self.root.execute(ctx).await
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use serde_json::json;
202
203    #[tokio::test]
204    async fn executes_sequential_and_parallel_nodes() {
205        let task_a = FunctionTask::new(|ctx: &mut WorkflowContext| {
206            Box::pin(async move {
207                ctx.insert("a", json!(1));
208                Ok(json!("done"))
209            })
210        });
211        let task_b = FunctionTask::new(|ctx: &mut WorkflowContext| {
212            Box::pin(async move {
213                let current = ctx.get("a").and_then(|v| v.as_i64()).unwrap_or(0);
214                ctx.insert("b", json!(current + 1));
215                Ok(json!("b"))
216            })
217        });
218
219        let flow = Workflow::new(
220            "demo",
221            WorkflowNode::Sequence(vec![
222                WorkflowNode::Task(Arc::new(task_a)),
223                WorkflowNode::Parallel(vec![
224                    WorkflowNode::Task(Arc::new(task_b)),
225                    WorkflowNode::Task(Arc::new(FunctionTask::new(|ctx: &mut WorkflowContext| {
226                        Box::pin(async move {
227                            ctx.insert("c", json!(true));
228                            Ok(json!("c"))
229                        })
230                    }))),
231                ]),
232            ]),
233        );
234
235        let mut ctx = WorkflowContext::default();
236        let result = flow.run(&mut ctx).await.unwrap();
237        assert!(result.is_array());
238        assert_eq!(ctx.get("a").unwrap(), &json!(1));
239        assert_eq!(ctx.get("b").unwrap(), &json!(2));
240        assert_eq!(ctx.get("c").unwrap(), &json!(true));
241    }
242
243    #[tokio::test]
244    async fn executes_conditional_loop() {
245        let body = WorkflowNode::Task(Arc::new(FunctionTask::new(|ctx: &mut WorkflowContext| {
246            Box::pin(async move {
247                let next = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0) + 1;
248                ctx.insert("count", json!(next));
249                Ok(json!(next))
250            })
251        })));
252
253        let condition: Condition = Arc::new(|ctx: &WorkflowContext| {
254            ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0) < 3
255        });
256
257        let flow = Workflow::new(
258            "looping",
259            WorkflowNode::Loop {
260                condition,
261                body: Box::new(body),
262                max_iterations: 10,
263            },
264        );
265
266        let mut ctx = WorkflowContext::default();
267        ctx.insert("count", json!(0));
268        flow.run(&mut ctx).await.unwrap();
269        assert_eq!(ctx.get("count").unwrap(), &json!(3));
270    }
271}