1use crate::{
2 agent::{Agent, AgentConfig},
3 runtime::Runtime,
4 step::{Step, StepError, StepInput, StepOutput, StepOutputMetadata, StepResult, StepType},
5 workflow::Workflow,
6};
7use async_trait::async_trait;
8
9#[cfg(test)]
10#[path = "step_impls_test.rs"]
11mod step_impls_test;
12
13pub struct AgentStep {
15 agent: Agent,
16 name: String,
17}
18
19impl AgentStep {
20 pub fn new(config: AgentConfig) -> Self {
22 let name = config.name.clone();
23 Self {
24 agent: Agent::new(config),
25 name,
26 }
27 }
28
29 pub fn from_agent(agent: Agent, name: String) -> Self {
31 Self { agent, name }
32 }
33}
34
35#[async_trait]
36impl Step for AgentStep {
37 async fn execute_with_context(
38 &self,
39 input: StepInput,
40 ctx: crate::step::ExecutionContext<'_>,
41 ) -> StepResult {
42 let start = std::time::Instant::now();
43
44 let agent_input = crate::types::AgentInput {
46 data: input.data,
47 metadata: crate::types::AgentInputMetadata {
48 step_index: input.metadata.step_index,
49 previous_agent: input.metadata.previous_step.clone(),
50 },
51 };
52
53 let result = self
55 .agent
56 .execute_with_events(agent_input, ctx.event_stream)
57 .await
58 .map_err(|e| StepError::AgentError(e.to_string()))?;
59
60 Ok(StepOutput {
61 data: result.data,
62 metadata: StepOutputMetadata {
63 step_name: self.name.clone(),
64 step_type: StepType::Agent,
65 execution_time_ms: start.elapsed().as_millis() as u64,
66 },
67 })
68 }
69
70 fn name(&self) -> &str {
71 &self.name
72 }
73
74 fn step_type(&self) -> StepType {
75 StepType::Agent
76 }
77
78 fn description(&self) -> Option<&str> {
79 Some(self.agent.config().system_prompt.as_str())
80 }
81}
82
83pub struct TransformStep {
85 name: String,
86 transform_fn: Box<dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync>,
87}
88
89impl TransformStep {
90 pub fn new<F>(name: String, transform_fn: F) -> Self
91 where
92 F: Fn(serde_json::Value) -> serde_json::Value + Send + Sync + 'static,
93 {
94 Self {
95 name,
96 transform_fn: Box::new(transform_fn),
97 }
98 }
99}
100
101#[async_trait]
102impl Step for TransformStep {
103 async fn execute_with_context(
104 &self,
105 input: StepInput,
106 _ctx: crate::step::ExecutionContext<'_>,
107 ) -> StepResult {
108 self.execute(input).await
110 }
111
112 async fn execute(&self, input: StepInput) -> StepResult {
113 let start = std::time::Instant::now();
114
115 let output_data = (self.transform_fn)(input.data);
116
117 Ok(StepOutput {
118 data: output_data,
119 metadata: StepOutputMetadata {
120 step_name: self.name.clone(),
121 step_type: StepType::Transform,
122 execution_time_ms: start.elapsed().as_millis() as u64,
123 },
124 })
125 }
126
127 fn name(&self) -> &str {
128 &self.name
129 }
130
131 fn step_type(&self) -> StepType {
132 StepType::Transform
133 }
134}
135
136pub struct ConditionalStep {
138 name: String,
139 condition_fn: Box<dyn Fn(&serde_json::Value) -> bool + Send + Sync>,
140 true_step: Box<dyn Step>,
141 false_step: Box<dyn Step>,
142}
143
144impl ConditionalStep {
145 pub fn new<F>(
146 name: String,
147 condition_fn: F,
148 true_step: Box<dyn Step>,
149 false_step: Box<dyn Step>,
150 ) -> Self
151 where
152 F: Fn(&serde_json::Value) -> bool + Send + Sync + 'static,
153 {
154 Self {
155 name,
156 condition_fn: Box::new(condition_fn),
157 true_step,
158 false_step,
159 }
160 }
161}
162
163#[async_trait]
164impl Step for ConditionalStep {
165 async fn execute_with_context(
166 &self,
167 input: StepInput,
168 ctx: crate::step::ExecutionContext<'_>,
169 ) -> StepResult {
170 let start = std::time::Instant::now();
171
172 let condition_result = (self.condition_fn)(&input.data);
173
174 let chosen_step = if condition_result {
175 &self.true_step
176 } else {
177 &self.false_step
178 };
179
180 let mut result = chosen_step.execute_with_context(input, ctx).await?;
182
183 result.metadata.step_name = self.name.clone();
185 result.metadata.step_type = StepType::Conditional;
186 result.metadata.execution_time_ms = start.elapsed().as_millis() as u64;
187
188 Ok(result)
189 }
190
191 async fn execute(&self, input: StepInput) -> StepResult {
192 let start = std::time::Instant::now();
193
194 let condition_result = (self.condition_fn)(&input.data);
195
196 let chosen_step = if condition_result {
197 &self.true_step
198 } else {
199 &self.false_step
200 };
201
202 let mut result = chosen_step.execute(input).await?;
204
205 result.metadata.step_name = self.name.clone();
207 result.metadata.step_type = StepType::Conditional;
208 result.metadata.execution_time_ms = start.elapsed().as_millis() as u64;
209
210 Ok(result)
211 }
212
213 fn name(&self) -> &str {
214 &self.name
215 }
216
217 fn step_type(&self) -> StepType {
218 StepType::Conditional
219 }
220
221 fn get_branches(&self) -> Option<(&dyn Step, &dyn Step)> {
222 Some((self.true_step.as_ref(), self.false_step.as_ref()))
223 }
224}
225
226pub struct SubWorkflowStep {
228 name: String,
229 workflow_builder: Box<dyn Fn() -> Workflow + Send + Sync>,
230}
231
232impl SubWorkflowStep {
233 pub fn new<F>(name: String, workflow_builder: F) -> Self
234 where
235 F: Fn() -> Workflow + Send + Sync + 'static,
236 {
237 Self {
238 name,
239 workflow_builder: Box::new(workflow_builder),
240 }
241 }
242
243 pub(crate) fn execute_with_runtime<'a>(
246 &'a self,
247 input: StepInput,
248 runtime: &'a crate::runtime::Runtime,
249 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = StepResult> + Send + 'a>> {
250 Box::pin(async move {
251 let start = std::time::Instant::now();
252
253 let mut sub_workflow = (self.workflow_builder)();
255
256 sub_workflow.initial_input = input.data.clone();
258
259 let parent_workflow_id = Some(input.metadata.workflow_id.clone());
261 let run = runtime
262 .execute_with_parent(sub_workflow, parent_workflow_id)
263 .await;
264
265 if run.state != crate::workflow::WorkflowState::Completed {
266 return Err(StepError::ExecutionFailed(format!(
267 "Sub-workflow failed: {:?}",
268 run.state
269 )));
270 }
271
272 let output_data = run.final_output.unwrap_or(serde_json::json!({}));
273
274 Ok(StepOutput {
275 data: output_data,
276 metadata: StepOutputMetadata {
277 step_name: self.name.clone(),
278 step_type: StepType::SubWorkflow,
279 execution_time_ms: start.elapsed().as_millis() as u64,
280 },
281 })
282 })
283 }
284}
285
286#[async_trait]
287impl Step for SubWorkflowStep {
288 async fn execute_with_context(
289 &self,
290 input: StepInput,
291 _ctx: crate::step::ExecutionContext<'_>,
292 ) -> StepResult {
293 let runtime = Runtime::new();
296 self.execute_with_runtime(input, &runtime).await
297 }
298
299 async fn execute(&self, input: StepInput) -> StepResult {
300 let runtime = Runtime::new();
303 self.execute_with_runtime(input, &runtime).await
304 }
305
306 fn name(&self) -> &str {
307 &self.name
308 }
309
310 fn step_type(&self) -> StepType {
311 StepType::SubWorkflow
312 }
313
314 fn description(&self) -> Option<&str> {
315 Some("Executes a nested workflow")
316 }
317
318 fn get_sub_workflow(&self) -> Option<crate::workflow::Workflow> {
319 Some((self.workflow_builder)())
320 }
321}