use crate::callable::Callable;
use crate::kernel::ExecutionId;
use std::sync::Arc;
#[derive(Debug)]
pub struct ParallelResult {
pub name: String,
pub execution_id: ExecutionId,
pub output: Result<String, String>,
}
#[derive(Debug, Clone, Default)]
pub enum FanOut {
#[default]
Broadcast,
Split { delimiter: String },
Custom,
}
#[derive(Debug, Clone)]
pub enum FanIn {
Concat { separator: String },
FirstSuccess,
JsonArray,
Custom,
}
impl Default for FanIn {
fn default() -> Self {
FanIn::Concat {
separator: "\n".to_string(),
}
}
}
pub struct ParallelFlow<C: Callable> {
branches: Vec<Arc<C>>,
name: String,
fan_out: FanOut,
fan_in: FanIn,
}
impl<C: Callable + 'static> ParallelFlow<C> {
pub fn new(name: impl Into<String>) -> Self {
Self {
branches: Vec::new(),
name: name.into(),
fan_out: FanOut::Broadcast,
fan_in: FanIn::Concat {
separator: "\n".to_string(),
},
}
}
pub fn add_branch(mut self, callable: Arc<C>) -> Self {
self.branches.push(callable);
self
}
pub fn with_fan_out(mut self, strategy: FanOut) -> Self {
self.fan_out = strategy;
self
}
pub fn with_fan_in(mut self, strategy: FanIn) -> Self {
self.fan_in = strategy;
self
}
pub async fn execute(&self, input: &str) -> Vec<ParallelResult> {
let input = input.to_string();
let handles: Vec<_> = self
.branches
.iter()
.enumerate()
.map(|(idx, branch)| {
let branch = Arc::clone(branch);
let branch_input = self.prepare_input(&input, idx);
let execution_id = ExecutionId::new();
let branch_name = branch.name().to_string();
tokio::spawn(async move {
let result = branch.run(&branch_input).await;
ParallelResult {
name: branch_name,
execution_id,
output: result.map_err(|e| e.to_string()),
}
})
})
.collect();
let mut results = Vec::new();
for handle in handles {
match handle.await {
Ok(result) => results.push(result),
Err(e) => {
results.push(ParallelResult {
name: "unknown".to_string(),
execution_id: ExecutionId::new(),
output: Err(format!("Task panicked: {}", e)),
});
}
}
}
results
}
pub async fn execute_aggregated(&self, input: &str) -> anyhow::Result<String> {
let results = self.execute(input).await;
self.aggregate_results(results)
}
fn prepare_input(&self, input: &str, index: usize) -> String {
match &self.fan_out {
FanOut::Broadcast => input.to_string(),
FanOut::Split { delimiter } => {
let parts: Vec<&str> = input.split(delimiter).collect();
parts.get(index).copied().unwrap_or("").to_string()
}
FanOut::Custom => input.to_string(),
}
}
fn aggregate_results(&self, results: Vec<ParallelResult>) -> anyhow::Result<String> {
match &self.fan_in {
FanIn::Concat { separator } => {
let outputs: Vec<String> =
results.into_iter().filter_map(|r| r.output.ok()).collect();
Ok(outputs.join(separator))
}
FanIn::FirstSuccess => results
.into_iter()
.find_map(|r| r.output.ok())
.ok_or_else(|| anyhow::anyhow!("All branches failed")),
FanIn::JsonArray => {
let outputs: Vec<String> =
results.into_iter().filter_map(|r| r.output.ok()).collect();
Ok(serde_json::to_string(&outputs)?)
}
FanIn::Custom => {
let outputs: Vec<String> =
results.into_iter().filter_map(|r| r.output.ok()).collect();
Ok(outputs.join("\n"))
}
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn branch_count(&self) -> usize {
self.branches.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use std::time::Duration;
struct MockCallable {
name: String,
response: String,
delay_ms: Option<u64>,
}
impl MockCallable {
fn new(name: &str, response: &str) -> Self {
Self {
name: name.to_string(),
response: response.to_string(),
delay_ms: None,
}
}
fn with_delay(name: &str, response: &str, delay_ms: u64) -> Self {
Self {
name: name.to_string(),
response: response.to_string(),
delay_ms: Some(delay_ms),
}
}
}
#[async_trait]
impl Callable for MockCallable {
fn name(&self) -> &str {
&self.name
}
async fn run(&self, input: &str) -> anyhow::Result<String> {
if let Some(delay) = self.delay_ms {
tokio::time::sleep(Duration::from_millis(delay)).await;
}
Ok(format!("{}:{}", self.response, input))
}
}
#[tokio::test]
async fn test_parallel_single_branch() {
let flow =
ParallelFlow::new("single").add_branch(Arc::new(MockCallable::new("b1", "result1")));
let results = flow.execute("input").await;
assert_eq!(results.len(), 1);
assert_eq!(results[0].name, "b1");
assert!(results[0].output.as_ref().unwrap().contains("result1"));
}
#[tokio::test]
async fn test_parallel_multiple_branches() {
let flow = ParallelFlow::new("multi")
.add_branch(Arc::new(MockCallable::new("b1", "r1")))
.add_branch(Arc::new(MockCallable::new("b2", "r2")))
.add_branch(Arc::new(MockCallable::new("b3", "r3")));
assert_eq!(flow.branch_count(), 3);
assert_eq!(flow.name(), "multi");
let results = flow.execute("test").await;
assert_eq!(results.len(), 3);
for result in &results {
assert!(result.output.is_ok());
}
}
#[tokio::test]
async fn test_parallel_executes_concurrently() {
use std::time::Instant;
let flow = ParallelFlow::new("concurrent")
.add_branch(Arc::new(MockCallable::with_delay("b1", "r1", 50)))
.add_branch(Arc::new(MockCallable::with_delay("b2", "r2", 50)))
.add_branch(Arc::new(MockCallable::with_delay("b3", "r3", 50)));
let start = Instant::now();
let results = flow.execute("test").await;
let elapsed = start.elapsed();
assert!(
elapsed.as_millis() < 120,
"Expected <120ms but took {}ms",
elapsed.as_millis()
);
assert_eq!(results.len(), 3);
}
#[tokio::test]
async fn test_parallel_aggregated_concat() {
let flow = ParallelFlow::new("concat")
.add_branch(Arc::new(MockCallable::new("a", "A")))
.add_branch(Arc::new(MockCallable::new("b", "B")))
.with_fan_in(FanIn::Concat {
separator: "|".to_string(),
});
let result = flow.execute_aggregated("x").await.unwrap();
assert!(result.contains("A:x"));
assert!(result.contains("B:x"));
assert!(result.contains("|"));
}
#[tokio::test]
async fn test_parallel_aggregated_json_array() {
let flow = ParallelFlow::new("json")
.add_branch(Arc::new(MockCallable::new("a", "result_a")))
.add_branch(Arc::new(MockCallable::new("b", "result_b")))
.with_fan_in(FanIn::JsonArray);
let result = flow.execute_aggregated("input").await.unwrap();
let parsed: Vec<String> = serde_json::from_str(&result).unwrap();
assert_eq!(parsed.len(), 2);
}
#[tokio::test]
async fn test_fan_out_split_distributes_by_index() {
let flow = ParallelFlow::new("split")
.with_fan_out(FanOut::Split {
delimiter: ",".to_string(),
})
.add_branch(Arc::new(MockCallable::new("a", "first")))
.add_branch(Arc::new(MockCallable::new("b", "second")))
.add_branch(Arc::new(MockCallable::new("c", "third")));
let results = flow.execute("one,two,three").await;
let outputs: Vec<String> = results.into_iter().map(|r| r.output.unwrap()).collect();
assert_eq!(outputs[0], "first:one");
assert_eq!(outputs[1], "second:two");
assert_eq!(outputs[2], "third:three");
}
#[tokio::test]
async fn test_parallel_first_success() {
struct MaybeFailCallable {
name: &'static str,
should_fail: bool,
}
#[async_trait]
impl Callable for MaybeFailCallable {
fn name(&self) -> &str {
self.name
}
async fn run(&self, _input: &str) -> anyhow::Result<String> {
if self.should_fail {
anyhow::bail!("Intentional failure")
}
Ok("success_result".to_string())
}
}
let flow = ParallelFlow::new("first_success")
.add_branch(Arc::new(MaybeFailCallable {
name: "fail",
should_fail: true,
}))
.add_branch(Arc::new(MaybeFailCallable {
name: "success",
should_fail: false,
}))
.with_fan_in(FanIn::FirstSuccess);
let result = flow.execute_aggregated("test").await.unwrap();
assert_eq!(result, "success_result");
}
#[tokio::test]
async fn test_parallel_all_fail_first_success() {
struct FailCallable(&'static str);
#[async_trait]
impl Callable for FailCallable {
fn name(&self) -> &str {
self.0
}
async fn run(&self, _input: &str) -> anyhow::Result<String> {
anyhow::bail!("Failed: {}", self.0)
}
}
let flow = ParallelFlow::new("all_fail")
.add_branch(Arc::new(FailCallable("f1")))
.add_branch(Arc::new(FailCallable("f2")))
.with_fan_in(FanIn::FirstSuccess);
let result = flow.execute_aggregated("test").await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("All branches failed"));
}
#[tokio::test]
async fn test_fan_out_broadcast() {
let flow = ParallelFlow::new("broadcast")
.add_branch(Arc::new(MockCallable::new("a", "A")))
.add_branch(Arc::new(MockCallable::new("b", "B")))
.with_fan_out(FanOut::Broadcast);
let results = flow.execute("same_input").await;
for result in results {
assert!(result.output.as_ref().unwrap().contains("same_input"));
}
}
#[tokio::test]
async fn test_parallel_result_contains_execution_id() {
let flow =
ParallelFlow::new("with_ids").add_branch(Arc::new(MockCallable::new("b1", "r1")));
let results = flow.execute("test").await;
assert_eq!(results.len(), 1);
assert!(!results[0].execution_id.as_str().is_empty());
}
#[test]
fn test_fan_out_default() {
let fan_out = FanOut::default();
matches!(fan_out, FanOut::Broadcast);
}
#[test]
fn test_fan_in_default() {
let fan_in = FanIn::default();
matches!(fan_in, FanIn::Concat { .. });
}
}