1use crate::{
2 agent::{Agent, AgentConfig},
3 runtime::Runtime,
4 step::{Step, StepError, StepInput, StepOutput, StepOutputMetadata, StepResult, StepType},
5 workflow::Workflow,
6};
7use async_trait::async_trait;
8
9#[cfg(test)]
10#[path = "step_impls_test.rs"]
11mod step_impls_test;
12
13pub struct AgentStep {
15 agent: Agent,
16 name: String,
17}
18
19impl AgentStep {
20 pub fn new(config: AgentConfig) -> Self {
22 let name = config.name.clone();
23 Self {
24 agent: Agent::new(config),
25 name,
26 }
27 }
28
29 pub fn from_agent(agent: Agent, name: String) -> Self {
31 Self { agent, name }
32 }
33}
34
35#[async_trait]
36impl Step for AgentStep {
37 async fn execute_with_context(
38 &self,
39 input: StepInput,
40 ctx: crate::step::ExecutionContext<'_>,
41 ) -> StepResult {
42 let start = std::time::Instant::now();
43
44 let agent_input = crate::types::AgentInput {
46 data: input.data,
47 metadata: crate::types::AgentInputMetadata {
48 step_index: input.metadata.step_index,
49 previous_agent: input.metadata.previous_step.clone(),
50 },
51 chat_history: None,
52 };
53
54 let result = self
56 .agent
57 .execute_with_events(agent_input, ctx.event_stream)
58 .await
59 .map_err(|e| StepError::AgentError(e.to_string()))?;
60
61 Ok(StepOutput {
62 data: result.data,
63 metadata: StepOutputMetadata {
64 step_name: self.name.clone(),
65 step_type: StepType::Agent,
66 execution_time_ms: start.elapsed().as_millis() as u64,
67 },
68 })
69 }
70
71 fn name(&self) -> &str {
72 &self.name
73 }
74
75 fn step_type(&self) -> StepType {
76 StepType::Agent
77 }
78
79 fn description(&self) -> Option<&str> {
80 Some(self.agent.config().system_prompt.as_str())
81 }
82}
83
84pub struct TransformStep {
86 name: String,
87 transform_fn: Box<dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync>,
88}
89
90impl TransformStep {
91 pub fn new<F>(name: String, transform_fn: F) -> Self
92 where
93 F: Fn(serde_json::Value) -> serde_json::Value + Send + Sync + 'static,
94 {
95 Self {
96 name,
97 transform_fn: Box::new(transform_fn),
98 }
99 }
100}
101
102#[async_trait]
103impl Step for TransformStep {
104 async fn execute_with_context(
105 &self,
106 input: StepInput,
107 _ctx: crate::step::ExecutionContext<'_>,
108 ) -> StepResult {
109 self.execute(input).await
111 }
112
113 async fn execute(&self, input: StepInput) -> StepResult {
114 let start = std::time::Instant::now();
115
116 let output_data = (self.transform_fn)(input.data);
117
118 Ok(StepOutput {
119 data: output_data,
120 metadata: StepOutputMetadata {
121 step_name: self.name.clone(),
122 step_type: StepType::Transform,
123 execution_time_ms: start.elapsed().as_millis() as u64,
124 },
125 })
126 }
127
128 fn name(&self) -> &str {
129 &self.name
130 }
131
132 fn step_type(&self) -> StepType {
133 StepType::Transform
134 }
135}
136
137pub struct ConditionalStep {
139 name: String,
140 condition_fn: Box<dyn Fn(&serde_json::Value) -> bool + Send + Sync>,
141 true_step: Box<dyn Step>,
142 false_step: Box<dyn Step>,
143}
144
145impl ConditionalStep {
146 pub fn new<F>(
147 name: String,
148 condition_fn: F,
149 true_step: Box<dyn Step>,
150 false_step: Box<dyn Step>,
151 ) -> Self
152 where
153 F: Fn(&serde_json::Value) -> bool + Send + Sync + 'static,
154 {
155 Self {
156 name,
157 condition_fn: Box::new(condition_fn),
158 true_step,
159 false_step,
160 }
161 }
162}
163
164#[async_trait]
165impl Step for ConditionalStep {
166 async fn execute_with_context(
167 &self,
168 input: StepInput,
169 ctx: crate::step::ExecutionContext<'_>,
170 ) -> StepResult {
171 let start = std::time::Instant::now();
172
173 let condition_result = (self.condition_fn)(&input.data);
174
175 let chosen_step = if condition_result {
176 &self.true_step
177 } else {
178 &self.false_step
179 };
180
181 let mut result = chosen_step.execute_with_context(input, ctx).await?;
183
184 result.metadata.step_name = self.name.clone();
186 result.metadata.step_type = StepType::Conditional;
187 result.metadata.execution_time_ms = start.elapsed().as_millis() as u64;
188
189 Ok(result)
190 }
191
192 async fn execute(&self, input: StepInput) -> StepResult {
193 let start = std::time::Instant::now();
194
195 let condition_result = (self.condition_fn)(&input.data);
196
197 let chosen_step = if condition_result {
198 &self.true_step
199 } else {
200 &self.false_step
201 };
202
203 let mut result = chosen_step.execute(input).await?;
205
206 result.metadata.step_name = self.name.clone();
208 result.metadata.step_type = StepType::Conditional;
209 result.metadata.execution_time_ms = start.elapsed().as_millis() as u64;
210
211 Ok(result)
212 }
213
214 fn name(&self) -> &str {
215 &self.name
216 }
217
218 fn step_type(&self) -> StepType {
219 StepType::Conditional
220 }
221
222 fn get_branches(&self) -> Option<(&dyn Step, &dyn Step)> {
223 Some((self.true_step.as_ref(), self.false_step.as_ref()))
224 }
225}
226
227pub struct SubWorkflowStep {
229 name: String,
230 workflow_builder: Box<dyn Fn() -> Workflow + Send + Sync>,
231}
232
233impl SubWorkflowStep {
234 pub fn new<F>(name: String, workflow_builder: F) -> Self
235 where
236 F: Fn() -> Workflow + Send + Sync + 'static,
237 {
238 Self {
239 name,
240 workflow_builder: Box::new(workflow_builder),
241 }
242 }
243
244 pub(crate) fn execute_with_runtime<'a>(
247 &'a self,
248 input: StepInput,
249 runtime: &'a crate::runtime::Runtime,
250 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = StepResult> + Send + 'a>> {
251 Box::pin(async move {
252 let start = std::time::Instant::now();
253
254 let mut sub_workflow = (self.workflow_builder)();
256
257 sub_workflow.initial_input = input.data.clone();
259
260 let parent_workflow_id = Some(input.metadata.workflow_id.clone());
262 let run = runtime
263 .execute_with_parent(sub_workflow, parent_workflow_id)
264 .await;
265
266 if run.state != crate::workflow::WorkflowState::Completed {
267 return Err(StepError::ExecutionFailed(format!(
268 "Sub-workflow failed: {:?}",
269 run.state
270 )));
271 }
272
273 let output_data = run.final_output.unwrap_or(serde_json::json!({}));
274
275 Ok(StepOutput {
276 data: output_data,
277 metadata: StepOutputMetadata {
278 step_name: self.name.clone(),
279 step_type: StepType::SubWorkflow,
280 execution_time_ms: start.elapsed().as_millis() as u64,
281 },
282 })
283 })
284 }
285}
286
287#[async_trait]
288impl Step for SubWorkflowStep {
289 async fn execute_with_context(
290 &self,
291 input: StepInput,
292 _ctx: crate::step::ExecutionContext<'_>,
293 ) -> StepResult {
294 let runtime = Runtime::new();
297 self.execute_with_runtime(input, &runtime).await
298 }
299
300 async fn execute(&self, input: StepInput) -> StepResult {
301 let runtime = Runtime::new();
304 self.execute_with_runtime(input, &runtime).await
305 }
306
307 fn name(&self) -> &str {
308 &self.name
309 }
310
311 fn step_type(&self) -> StepType {
312 StepType::SubWorkflow
313 }
314
315 fn description(&self) -> Option<&str> {
316 Some("Executes a nested workflow")
317 }
318
319 fn get_sub_workflow(&self) -> Option<crate::workflow::Workflow> {
320 Some((self.workflow_builder)())
321 }
322}