1use crate::{HandlerRegistry, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5#[derive(Debug, Clone)]
6pub struct PipelineHandler {
7 pub steps: Vec<PipelineStep>,
8}
9
10#[derive(Debug, Clone)]
11pub struct PipelineStep {
12 pub tool: String,
13 pub input: Option<serde_json::Value>,
14 pub output_var: Option<String>,
15 pub condition: Option<String>,
16 pub error_policy: ErrorPolicy,
17}
18
19#[derive(Debug, Clone, PartialEq)]
20pub enum ErrorPolicy {
21 FailFast,
22 Continue,
23}
24
25#[derive(Debug, Deserialize)]
26pub struct PipelineInput {
27 #[serde(default)]
28 pub variables: HashMap<String, serde_json::Value>,
29}
30
31#[derive(Debug, Serialize)]
32pub struct PipelineOutput {
33 pub results: Vec<StepResult>,
34 pub variables: HashMap<String, serde_json::Value>,
35}
36
37#[derive(Debug, Serialize)]
38pub struct StepResult {
39 pub tool: String,
40 pub success: bool,
41 pub output: Option<serde_json::Value>,
42 pub error: Option<String>,
43}
44
45impl PipelineHandler {
46 pub fn new(steps: Vec<PipelineStep>) -> Self {
47 Self { steps }
48 }
49
50 pub async fn execute(
51 &self,
52 input: PipelineInput,
53 registry: &HandlerRegistry,
54 ) -> Result<PipelineOutput> {
55 let mut variables = input.variables;
56 let mut results = Vec::new();
57
58 for step in &self.steps {
59 if let Some(condition) = &step.condition {
61 if !self.evaluate_condition(condition, &variables) {
62 continue;
63 }
64 }
65
66 let step_input = if let Some(input_template) = &step.input {
68 self.interpolate_variables(input_template, &variables)
69 } else {
70 serde_json::json!({})
71 };
72
73 let step_result = match registry
75 .dispatch(&step.tool, &serde_json::to_vec(&step_input)?)
76 .await
77 {
78 Ok(output) => {
79 let output_value: serde_json::Value = serde_json::from_slice(&output)?;
80
81 if let Some(var_name) = &step.output_var {
83 variables.insert(var_name.clone(), output_value.clone());
84 }
85
86 StepResult {
87 tool: step.tool.clone(),
88 success: true,
89 output: Some(output_value),
90 error: None,
91 }
92 }
93 Err(e) => {
94 let result = StepResult {
95 tool: step.tool.clone(),
96 success: false,
97 output: None,
98 error: Some(e.to_string()),
99 };
100
101 if step.error_policy == ErrorPolicy::FailFast {
103 results.push(result);
104 return Err(e);
105 }
106
107 result
108 }
109 };
110
111 results.push(step_result);
112 }
113
114 Ok(PipelineOutput { results, variables })
115 }
116
117 fn evaluate_condition(
118 &self,
119 condition: &str,
120 variables: &HashMap<String, serde_json::Value>,
121 ) -> bool {
122 if let Some(var_name) = condition.strip_prefix('!') {
125 !variables.contains_key(var_name)
126 } else {
127 variables.contains_key(condition)
128 }
129 }
130
131 #[allow(clippy::only_used_in_recursion)]
132 fn interpolate_variables(
133 &self,
134 template: &serde_json::Value,
135 variables: &HashMap<String, serde_json::Value>,
136 ) -> serde_json::Value {
137 match template {
138 serde_json::Value::String(s) => {
139 let mut result = s.clone();
141 for (key, value) in variables {
142 let pattern = format!("{{{{{}}}}}", key);
143 if let Some(value_str) = value.as_str() {
144 result = result.replace(&pattern, value_str);
145 }
146 }
147 serde_json::Value::String(result)
148 }
149 serde_json::Value::Object(obj) => {
150 let mut new_obj = serde_json::Map::new();
151 for (k, v) in obj {
152 new_obj.insert(k.clone(), self.interpolate_variables(v, variables));
153 }
154 serde_json::Value::Object(new_obj)
155 }
156 serde_json::Value::Array(arr) => {
157 let new_arr: Vec<_> = arr
158 .iter()
159 .map(|v| self.interpolate_variables(v, variables))
160 .collect();
161 serde_json::Value::Array(new_arr)
162 }
163 other => other.clone(),
164 }
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_pipeline_handler_new() {
174 let steps = vec![PipelineStep {
175 tool: "test_tool".to_string(),
176 input: None,
177 output_var: None,
178 condition: None,
179 error_policy: ErrorPolicy::FailFast,
180 }];
181
182 let handler = PipelineHandler::new(steps);
183 assert_eq!(handler.steps.len(), 1);
184 assert_eq!(handler.steps[0].tool, "test_tool");
185 }
186
187 #[test]
188 fn test_error_policy_equality() {
189 assert_eq!(ErrorPolicy::FailFast, ErrorPolicy::FailFast);
190 assert_eq!(ErrorPolicy::Continue, ErrorPolicy::Continue);
191 assert_ne!(ErrorPolicy::FailFast, ErrorPolicy::Continue);
192 }
193
194 #[test]
195 fn test_evaluate_condition_exists() {
196 let handler = PipelineHandler::new(vec![]);
197 let mut vars = HashMap::new();
198 vars.insert("key".to_string(), serde_json::json!("value"));
199
200 assert!(handler.evaluate_condition("key", &vars));
201 assert!(!handler.evaluate_condition("missing", &vars));
202 }
203
204 #[test]
205 fn test_evaluate_condition_not_exists() {
206 let handler = PipelineHandler::new(vec![]);
207 let mut vars = HashMap::new();
208 vars.insert("key".to_string(), serde_json::json!("value"));
209
210 assert!(!handler.evaluate_condition("!key", &vars));
211 assert!(handler.evaluate_condition("!missing", &vars));
212 }
213
214 #[test]
215 fn test_interpolate_variables_string() {
216 let handler = PipelineHandler::new(vec![]);
217 let mut vars = HashMap::new();
218 vars.insert("name".to_string(), serde_json::json!("Alice"));
219
220 let template = serde_json::json!("Hello {{name}}!");
221 let result = handler.interpolate_variables(&template, &vars);
222
223 assert_eq!(result, serde_json::json!("Hello Alice!"));
224 }
225
226 #[test]
227 fn test_interpolate_variables_object() {
228 let handler = PipelineHandler::new(vec![]);
229 let mut vars = HashMap::new();
230 vars.insert("user".to_string(), serde_json::json!("Bob"));
231
232 let template = serde_json::json!({"greeting": "Hi {{user}}"});
233 let result = handler.interpolate_variables(&template, &vars);
234
235 assert_eq!(result["greeting"], "Hi Bob");
236 }
237
238 #[test]
239 fn test_interpolate_variables_array() {
240 let handler = PipelineHandler::new(vec![]);
241 let mut vars = HashMap::new();
242 vars.insert("item".to_string(), serde_json::json!("test"));
243
244 let template = serde_json::json!(["{{item}}", "other"]);
245 let result = handler.interpolate_variables(&template, &vars);
246
247 assert_eq!(result[0], "test");
248 assert_eq!(result[1], "other");
249 }
250
251 #[test]
252 fn test_interpolate_variables_no_match() {
253 let handler = PipelineHandler::new(vec![]);
254 let vars = HashMap::new();
255
256 let template = serde_json::json!("Hello {{missing}}!");
257 let result = handler.interpolate_variables(&template, &vars);
258
259 assert_eq!(result, serde_json::json!("Hello {{missing}}!"));
260 }
261
262 #[test]
263 fn test_pipeline_input_deserialization() {
264 let json = r#"{"variables": {"key": "value"}}"#;
265 let input: PipelineInput = serde_json::from_str(json).unwrap();
266
267 assert_eq!(input.variables.len(), 1);
268 assert_eq!(input.variables["key"], "value");
269 }
270
271 #[test]
272 fn test_pipeline_output_serialization() {
273 let output = PipelineOutput {
274 results: vec![StepResult {
275 tool: "test".to_string(),
276 success: true,
277 output: Some(serde_json::json!({"result": "ok"})),
278 error: None,
279 }],
280 variables: HashMap::new(),
281 };
282
283 let json = serde_json::to_string(&output).unwrap();
284 assert!(json.contains("\"tool\":\"test\""));
285 assert!(json.contains("\"success\":true"));
286 }
287
288 #[tokio::test]
289 async fn test_pipeline_execute_simple() {
290 use crate::{Handler, HandlerRegistry};
291 use schemars::JsonSchema;
292
293 #[derive(Debug, serde::Deserialize, JsonSchema)]
295 struct TestInput {
296 value: String,
297 }
298
299 #[derive(Debug, serde::Serialize, JsonSchema)]
300 struct TestOutput {
301 result: String,
302 }
303
304 struct TestHandler;
305
306 #[async_trait::async_trait]
307 impl Handler for TestHandler {
308 type Input = TestInput;
309 type Output = TestOutput;
310 type Error = crate::Error;
311
312 async fn handle(&self, input: Self::Input) -> crate::Result<Self::Output> {
313 Ok(TestOutput {
314 result: format!("processed: {}", input.value),
315 })
316 }
317 }
318
319 let mut registry = HandlerRegistry::new();
321 registry.register("test_tool", TestHandler);
322
323 let handler = PipelineHandler::new(vec![PipelineStep {
325 tool: "test_tool".to_string(),
326 input: Some(serde_json::json!({"value": "hello"})),
327 output_var: Some("result".to_string()),
328 condition: None,
329 error_policy: ErrorPolicy::FailFast,
330 }]);
331
332 let input = PipelineInput {
333 variables: HashMap::new(),
334 };
335
336 let output = handler.execute(input, ®istry).await.unwrap();
337
338 assert_eq!(output.results.len(), 1);
339 assert!(output.results[0].success);
340 assert!(output.variables.contains_key("result"));
341 }
342
343 #[tokio::test]
344 async fn test_pipeline_execute_with_condition_skip() {
345 use crate::HandlerRegistry;
346
347 let registry = HandlerRegistry::new();
348
349 let handler = PipelineHandler::new(vec![PipelineStep {
350 tool: "nonexistent".to_string(),
351 input: None,
352 output_var: None,
353 condition: Some("missing_var".to_string()),
354 error_policy: ErrorPolicy::FailFast,
355 }]);
356
357 let input = PipelineInput {
358 variables: HashMap::new(),
359 };
360
361 let output = handler.execute(input, ®istry).await.unwrap();
362
363 assert_eq!(output.results.len(), 0);
365 }
366
367 #[tokio::test]
368 async fn test_pipeline_execute_continue_on_error() {
369 use crate::HandlerRegistry;
370
371 let registry = HandlerRegistry::new();
372
373 let handler = PipelineHandler::new(vec![
374 PipelineStep {
375 tool: "nonexistent1".to_string(),
376 input: None,
377 output_var: None,
378 condition: None,
379 error_policy: ErrorPolicy::Continue,
380 },
381 PipelineStep {
382 tool: "nonexistent2".to_string(),
383 input: None,
384 output_var: None,
385 condition: None,
386 error_policy: ErrorPolicy::Continue,
387 },
388 ]);
389
390 let input = PipelineInput {
391 variables: HashMap::new(),
392 };
393
394 let output = handler.execute(input, ®istry).await.unwrap();
395
396 assert_eq!(output.results.len(), 2);
398 assert!(!output.results[0].success);
399 assert!(!output.results[1].success);
400 }
401
402 #[tokio::test]
403 async fn test_pipeline_execute_fail_fast() {
404 use crate::HandlerRegistry;
405
406 let registry = HandlerRegistry::new();
407
408 let handler = PipelineHandler::new(vec![
409 PipelineStep {
410 tool: "nonexistent1".to_string(),
411 input: None,
412 output_var: None,
413 condition: None,
414 error_policy: ErrorPolicy::FailFast,
415 },
416 PipelineStep {
417 tool: "nonexistent2".to_string(),
418 input: None,
419 output_var: None,
420 condition: None,
421 error_policy: ErrorPolicy::FailFast,
422 },
423 ]);
424
425 let input = PipelineInput {
426 variables: HashMap::new(),
427 };
428
429 let result = handler.execute(input, ®istry).await;
430
431 assert!(result.is_err());
433 }
434}