use std::sync::Arc;
use atomr_agents_callable::CallableHandle;
use atomr_agents_core::{Result, Value};
use serde::{Deserialize, Serialize};
use crate::dag::StepId;
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum JoinStrategy {
All,
Any,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct Concurrency(pub u32);
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct InputMapping {
#[serde(default)]
pub fields: Vec<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct HumanApproval {
pub prompt: String,
#[serde(default)]
pub context: Value,
}
pub trait BranchPredicate: Send + Sync + 'static {
fn evaluate(&self, output: &Value) -> bool;
}
pub enum Step {
Invoke {
callable: CallableHandle,
mapping: InputMapping,
},
Branch {
predicate: Arc<dyn BranchPredicate>,
if_true: StepId,
if_false: StepId,
},
Parallel { steps: Vec<StepId>, join: JoinStrategy },
Loop {
body: StepId,
predicate: Arc<dyn BranchPredicate>,
},
Map { body: StepId, concurrency: Concurrency },
Human { approval: HumanApproval },
}
impl Step {
pub fn invoke(callable: CallableHandle) -> Self {
Self::Invoke {
callable,
mapping: InputMapping::default(),
}
}
}
#[allow(dead_code)]
pub struct FnPredicate<F: Fn(&Value) -> bool + Send + Sync + 'static>(pub F);
impl<F: Fn(&Value) -> bool + Send + Sync + 'static> BranchPredicate for FnPredicate<F> {
fn evaluate(&self, output: &Value) -> bool {
(self.0)(output)
}
}
#[allow(dead_code)]
fn _result_unused() -> Result<()> {
Ok(())
}