Skip to main content

agent_runtime/
step_impls.rs

1use crate::{
2    agent::{Agent, AgentConfig},
3    runtime::Runtime,
4    step::{Step, StepError, StepInput, StepOutput, StepOutputMetadata, StepResult, StepType},
5    workflow::Workflow,
6};
7use async_trait::async_trait;
8
9#[cfg(test)]
10#[path = "step_impls_test.rs"]
11mod step_impls_test;
12
13/// A step that executes an agent
14pub struct AgentStep {
15    agent: Agent,
16    name: String,
17}
18
19impl AgentStep {
20    /// Create a new agent step from an agent configuration
21    pub fn new(config: AgentConfig) -> Self {
22        let name = config.name.clone();
23        Self {
24            agent: Agent::new(config),
25            name,
26        }
27    }
28
29    /// Create from an existing Agent
30    pub fn from_agent(agent: Agent, name: String) -> Self {
31        Self { agent, name }
32    }
33}
34
35#[async_trait]
36impl Step for AgentStep {
37    async fn execute_with_context(
38        &self,
39        input: StepInput,
40        ctx: crate::step::ExecutionContext<'_>,
41    ) -> StepResult {
42        let start = std::time::Instant::now();
43
44        // Convert StepInput to AgentInput
45        let agent_input = crate::types::AgentInput {
46            data: input.data,
47            metadata: crate::types::AgentInputMetadata {
48                step_index: input.metadata.step_index,
49                previous_agent: input.metadata.previous_step.clone(),
50            },
51        };
52
53        // Execute agent with event stream
54        let result = self
55            .agent
56            .execute_with_events(agent_input, ctx.event_stream)
57            .await
58            .map_err(|e| StepError::AgentError(e.to_string()))?;
59
60        Ok(StepOutput {
61            data: result.data,
62            metadata: StepOutputMetadata {
63                step_name: self.name.clone(),
64                step_type: StepType::Agent,
65                execution_time_ms: start.elapsed().as_millis() as u64,
66            },
67        })
68    }
69
70    fn name(&self) -> &str {
71        &self.name
72    }
73
74    fn step_type(&self) -> StepType {
75        StepType::Agent
76    }
77
78    fn description(&self) -> Option<&str> {
79        Some(self.agent.config().system_prompt.as_str())
80    }
81}
82
83/// A step that transforms data using a pure function
84pub struct TransformStep {
85    name: String,
86    transform_fn: Box<dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync>,
87}
88
89impl TransformStep {
90    pub fn new<F>(name: String, transform_fn: F) -> Self
91    where
92        F: Fn(serde_json::Value) -> serde_json::Value + Send + Sync + 'static,
93    {
94        Self {
95            name,
96            transform_fn: Box::new(transform_fn),
97        }
98    }
99}
100
101#[async_trait]
102impl Step for TransformStep {
103    async fn execute_with_context(
104        &self,
105        input: StepInput,
106        _ctx: crate::step::ExecutionContext<'_>,
107    ) -> StepResult {
108        // Use the same logic as execute() - transforms don't need events yet
109        self.execute(input).await
110    }
111
112    async fn execute(&self, input: StepInput) -> StepResult {
113        let start = std::time::Instant::now();
114
115        let output_data = (self.transform_fn)(input.data);
116
117        Ok(StepOutput {
118            data: output_data,
119            metadata: StepOutputMetadata {
120                step_name: self.name.clone(),
121                step_type: StepType::Transform,
122                execution_time_ms: start.elapsed().as_millis() as u64,
123            },
124        })
125    }
126
127    fn name(&self) -> &str {
128        &self.name
129    }
130
131    fn step_type(&self) -> StepType {
132        StepType::Transform
133    }
134}
135
136/// A step that conditionally executes one of two branches
137pub struct ConditionalStep {
138    name: String,
139    condition_fn: Box<dyn Fn(&serde_json::Value) -> bool + Send + Sync>,
140    true_step: Box<dyn Step>,
141    false_step: Box<dyn Step>,
142}
143
144impl ConditionalStep {
145    pub fn new<F>(
146        name: String,
147        condition_fn: F,
148        true_step: Box<dyn Step>,
149        false_step: Box<dyn Step>,
150    ) -> Self
151    where
152        F: Fn(&serde_json::Value) -> bool + Send + Sync + 'static,
153    {
154        Self {
155            name,
156            condition_fn: Box::new(condition_fn),
157            true_step,
158            false_step,
159        }
160    }
161}
162
163#[async_trait]
164impl Step for ConditionalStep {
165    async fn execute_with_context(
166        &self,
167        input: StepInput,
168        ctx: crate::step::ExecutionContext<'_>,
169    ) -> StepResult {
170        let start = std::time::Instant::now();
171
172        let condition_result = (self.condition_fn)(&input.data);
173
174        let chosen_step = if condition_result {
175            &self.true_step
176        } else {
177            &self.false_step
178        };
179
180        // Execute the chosen branch with context
181        let mut result = chosen_step.execute_with_context(input, ctx).await?;
182
183        // Update metadata to reflect this conditional step
184        result.metadata.step_name = self.name.clone();
185        result.metadata.step_type = StepType::Conditional;
186        result.metadata.execution_time_ms = start.elapsed().as_millis() as u64;
187
188        Ok(result)
189    }
190
191    async fn execute(&self, input: StepInput) -> StepResult {
192        let start = std::time::Instant::now();
193
194        let condition_result = (self.condition_fn)(&input.data);
195
196        let chosen_step = if condition_result {
197            &self.true_step
198        } else {
199            &self.false_step
200        };
201
202        // Execute the chosen branch
203        let mut result = chosen_step.execute(input).await?;
204
205        // Update metadata to reflect this conditional step
206        result.metadata.step_name = self.name.clone();
207        result.metadata.step_type = StepType::Conditional;
208        result.metadata.execution_time_ms = start.elapsed().as_millis() as u64;
209
210        Ok(result)
211    }
212
213    fn name(&self) -> &str {
214        &self.name
215    }
216
217    fn step_type(&self) -> StepType {
218        StepType::Conditional
219    }
220
221    fn get_branches(&self) -> Option<(&dyn Step, &dyn Step)> {
222        Some((self.true_step.as_ref(), self.false_step.as_ref()))
223    }
224}
225
226/// A step that executes an entire workflow as a sub-workflow
227pub struct SubWorkflowStep {
228    name: String,
229    workflow_builder: Box<dyn Fn() -> Workflow + Send + Sync>,
230}
231
232impl SubWorkflowStep {
233    pub fn new<F>(name: String, workflow_builder: F) -> Self
234    where
235        F: Fn() -> Workflow + Send + Sync + 'static,
236    {
237        Self {
238            name,
239            workflow_builder: Box::new(workflow_builder),
240        }
241    }
242
243    /// Execute the sub-workflow using the provided runtime
244    /// This ensures events are emitted to the parent's event stream
245    pub(crate) fn execute_with_runtime<'a>(
246        &'a self,
247        input: StepInput,
248        runtime: &'a crate::runtime::Runtime,
249    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = StepResult> + Send + 'a>> {
250        Box::pin(async move {
251            let start = std::time::Instant::now();
252
253            // Build the sub-workflow
254            let mut sub_workflow = (self.workflow_builder)();
255
256            // Override initial input with step input
257            sub_workflow.initial_input = input.data.clone();
258
259            // Execute the sub-workflow with parent context
260            let parent_workflow_id = Some(input.metadata.workflow_id.clone());
261            let run = runtime
262                .execute_with_parent(sub_workflow, parent_workflow_id)
263                .await;
264
265            if run.state != crate::workflow::WorkflowState::Completed {
266                return Err(StepError::ExecutionFailed(format!(
267                    "Sub-workflow failed: {:?}",
268                    run.state
269                )));
270            }
271
272            let output_data = run.final_output.unwrap_or(serde_json::json!({}));
273
274            Ok(StepOutput {
275                data: output_data,
276                metadata: StepOutputMetadata {
277                    step_name: self.name.clone(),
278                    step_type: StepType::SubWorkflow,
279                    execution_time_ms: start.elapsed().as_millis() as u64,
280                },
281            })
282        })
283    }
284}
285
286#[async_trait]
287impl Step for SubWorkflowStep {
288    async fn execute_with_context(
289        &self,
290        input: StepInput,
291        _ctx: crate::step::ExecutionContext<'_>,
292    ) -> StepResult {
293        // This creates a new runtime - won't share events with parent
294        // Use execute_with_runtime() from the parent runtime instead
295        let runtime = Runtime::new();
296        self.execute_with_runtime(input, &runtime).await
297    }
298
299    async fn execute(&self, input: StepInput) -> StepResult {
300        // This creates a new runtime - won't share events with parent
301        // Use execute_with_runtime() from the parent runtime instead
302        let runtime = Runtime::new();
303        self.execute_with_runtime(input, &runtime).await
304    }
305
306    fn name(&self) -> &str {
307        &self.name
308    }
309
310    fn step_type(&self) -> StepType {
311        StepType::SubWorkflow
312    }
313
314    fn description(&self) -> Option<&str> {
315        Some("Executes a nested workflow")
316    }
317
318    fn get_sub_workflow(&self) -> Option<crate::workflow::Workflow> {
319        Some((self.workflow_builder)())
320    }
321}