Skip to main content

agent_runtime/
step_impls.rs

1use crate::{
2    agent::{Agent, AgentConfig},
3    runtime::Runtime,
4    step::{Step, StepError, StepInput, StepOutput, StepOutputMetadata, StepResult, StepType},
5    workflow::Workflow,
6};
7use async_trait::async_trait;
8
9#[cfg(test)]
10#[path = "step_impls_test.rs"]
11mod step_impls_test;
12
13/// A step that executes an agent
14pub struct AgentStep {
15    agent: Agent,
16    name: String,
17}
18
19impl AgentStep {
20    /// Create a new agent step from an agent configuration
21    pub fn new(config: AgentConfig) -> Self {
22        let name = config.name.clone();
23        Self {
24            agent: Agent::new(config),
25            name,
26        }
27    }
28
29    /// Create from an existing Agent
30    pub fn from_agent(agent: Agent, name: String) -> Self {
31        Self { agent, name }
32    }
33}
34
35#[async_trait]
36impl Step for AgentStep {
37    async fn execute_with_context(
38        &self,
39        input: StepInput,
40        ctx: crate::step::ExecutionContext<'_>,
41    ) -> StepResult {
42        let start = std::time::Instant::now();
43
44        // Convert StepInput to AgentInput
45        let agent_input = crate::types::AgentInput {
46            data: input.data,
47            metadata: crate::types::AgentInputMetadata {
48                step_index: input.metadata.step_index,
49                previous_agent: input.metadata.previous_step.clone(),
50            },
51            chat_history: None,
52        };
53
54        // Execute agent with event stream
55        let result = self
56            .agent
57            .execute_with_events(agent_input, ctx.event_stream)
58            .await
59            .map_err(|e| StepError::AgentError(e.to_string()))?;
60
61        Ok(StepOutput {
62            data: result.data,
63            metadata: StepOutputMetadata {
64                step_name: self.name.clone(),
65                step_type: StepType::Agent,
66                execution_time_ms: start.elapsed().as_millis() as u64,
67            },
68        })
69    }
70
71    fn name(&self) -> &str {
72        &self.name
73    }
74
75    fn step_type(&self) -> StepType {
76        StepType::Agent
77    }
78
79    fn description(&self) -> Option<&str> {
80        Some(self.agent.config().system_prompt.as_str())
81    }
82}
83
84/// A step that transforms data using a pure function
85pub struct TransformStep {
86    name: String,
87    transform_fn: Box<dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync>,
88}
89
90impl TransformStep {
91    pub fn new<F>(name: String, transform_fn: F) -> Self
92    where
93        F: Fn(serde_json::Value) -> serde_json::Value + Send + Sync + 'static,
94    {
95        Self {
96            name,
97            transform_fn: Box::new(transform_fn),
98        }
99    }
100}
101
102#[async_trait]
103impl Step for TransformStep {
104    async fn execute_with_context(
105        &self,
106        input: StepInput,
107        _ctx: crate::step::ExecutionContext<'_>,
108    ) -> StepResult {
109        // Use the same logic as execute() - transforms don't need events yet
110        self.execute(input).await
111    }
112
113    async fn execute(&self, input: StepInput) -> StepResult {
114        let start = std::time::Instant::now();
115
116        let output_data = (self.transform_fn)(input.data);
117
118        Ok(StepOutput {
119            data: output_data,
120            metadata: StepOutputMetadata {
121                step_name: self.name.clone(),
122                step_type: StepType::Transform,
123                execution_time_ms: start.elapsed().as_millis() as u64,
124            },
125        })
126    }
127
128    fn name(&self) -> &str {
129        &self.name
130    }
131
132    fn step_type(&self) -> StepType {
133        StepType::Transform
134    }
135}
136
137/// A step that conditionally executes one of two branches
138pub struct ConditionalStep {
139    name: String,
140    condition_fn: Box<dyn Fn(&serde_json::Value) -> bool + Send + Sync>,
141    true_step: Box<dyn Step>,
142    false_step: Box<dyn Step>,
143}
144
145impl ConditionalStep {
146    pub fn new<F>(
147        name: String,
148        condition_fn: F,
149        true_step: Box<dyn Step>,
150        false_step: Box<dyn Step>,
151    ) -> Self
152    where
153        F: Fn(&serde_json::Value) -> bool + Send + Sync + 'static,
154    {
155        Self {
156            name,
157            condition_fn: Box::new(condition_fn),
158            true_step,
159            false_step,
160        }
161    }
162}
163
164#[async_trait]
165impl Step for ConditionalStep {
166    async fn execute_with_context(
167        &self,
168        input: StepInput,
169        ctx: crate::step::ExecutionContext<'_>,
170    ) -> StepResult {
171        let start = std::time::Instant::now();
172
173        let condition_result = (self.condition_fn)(&input.data);
174
175        let chosen_step = if condition_result {
176            &self.true_step
177        } else {
178            &self.false_step
179        };
180
181        // Execute the chosen branch with context
182        let mut result = chosen_step.execute_with_context(input, ctx).await?;
183
184        // Update metadata to reflect this conditional step
185        result.metadata.step_name = self.name.clone();
186        result.metadata.step_type = StepType::Conditional;
187        result.metadata.execution_time_ms = start.elapsed().as_millis() as u64;
188
189        Ok(result)
190    }
191
192    async fn execute(&self, input: StepInput) -> StepResult {
193        let start = std::time::Instant::now();
194
195        let condition_result = (self.condition_fn)(&input.data);
196
197        let chosen_step = if condition_result {
198            &self.true_step
199        } else {
200            &self.false_step
201        };
202
203        // Execute the chosen branch
204        let mut result = chosen_step.execute(input).await?;
205
206        // Update metadata to reflect this conditional step
207        result.metadata.step_name = self.name.clone();
208        result.metadata.step_type = StepType::Conditional;
209        result.metadata.execution_time_ms = start.elapsed().as_millis() as u64;
210
211        Ok(result)
212    }
213
214    fn name(&self) -> &str {
215        &self.name
216    }
217
218    fn step_type(&self) -> StepType {
219        StepType::Conditional
220    }
221
222    fn get_branches(&self) -> Option<(&dyn Step, &dyn Step)> {
223        Some((self.true_step.as_ref(), self.false_step.as_ref()))
224    }
225}
226
227/// A step that executes an entire workflow as a sub-workflow
228pub struct SubWorkflowStep {
229    name: String,
230    workflow_builder: Box<dyn Fn() -> Workflow + Send + Sync>,
231}
232
233impl SubWorkflowStep {
234    pub fn new<F>(name: String, workflow_builder: F) -> Self
235    where
236        F: Fn() -> Workflow + Send + Sync + 'static,
237    {
238        Self {
239            name,
240            workflow_builder: Box::new(workflow_builder),
241        }
242    }
243
244    /// Execute the sub-workflow using the provided runtime
245    /// This ensures events are emitted to the parent's event stream
246    pub(crate) fn execute_with_runtime<'a>(
247        &'a self,
248        input: StepInput,
249        runtime: &'a crate::runtime::Runtime,
250    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = StepResult> + Send + 'a>> {
251        Box::pin(async move {
252            let start = std::time::Instant::now();
253
254            // Build the sub-workflow
255            let mut sub_workflow = (self.workflow_builder)();
256
257            // Override initial input with step input
258            sub_workflow.initial_input = input.data.clone();
259
260            // Execute the sub-workflow with parent context
261            let parent_workflow_id = Some(input.metadata.workflow_id.clone());
262            let run = runtime
263                .execute_with_parent(sub_workflow, parent_workflow_id)
264                .await;
265
266            if run.state != crate::workflow::WorkflowState::Completed {
267                return Err(StepError::ExecutionFailed(format!(
268                    "Sub-workflow failed: {:?}",
269                    run.state
270                )));
271            }
272
273            let output_data = run.final_output.unwrap_or(serde_json::json!({}));
274
275            Ok(StepOutput {
276                data: output_data,
277                metadata: StepOutputMetadata {
278                    step_name: self.name.clone(),
279                    step_type: StepType::SubWorkflow,
280                    execution_time_ms: start.elapsed().as_millis() as u64,
281                },
282            })
283        })
284    }
285}
286
287#[async_trait]
288impl Step for SubWorkflowStep {
289    async fn execute_with_context(
290        &self,
291        input: StepInput,
292        _ctx: crate::step::ExecutionContext<'_>,
293    ) -> StepResult {
294        // This creates a new runtime - won't share events with parent
295        // Use execute_with_runtime() from the parent runtime instead
296        let runtime = Runtime::new();
297        self.execute_with_runtime(input, &runtime).await
298    }
299
300    async fn execute(&self, input: StepInput) -> StepResult {
301        // This creates a new runtime - won't share events with parent
302        // Use execute_with_runtime() from the parent runtime instead
303        let runtime = Runtime::new();
304        self.execute_with_runtime(input, &runtime).await
305    }
306
307    fn name(&self) -> &str {
308        &self.name
309    }
310
311    fn step_type(&self) -> StepType {
312        StepType::SubWorkflow
313    }
314
315    fn description(&self) -> Option<&str> {
316        Some("Executes a nested workflow")
317    }
318
319    fn get_sub_workflow(&self) -> Option<crate::workflow::Workflow> {
320        Some((self.workflow_builder)())
321    }
322}