use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
use blazen_events::AnyEvent;
use crate::context::Context;
use crate::error::WorkflowError;
#[derive(Debug)]
pub enum StepOutput {
Single(Box<dyn AnyEvent>),
Multiple(Vec<Box<dyn AnyEvent>>),
None,
}
pub type StepFn = Arc<
dyn Fn(
Box<dyn AnyEvent>,
Context,
) -> Pin<Box<dyn Future<Output = Result<StepOutput, WorkflowError>> + Send>>
+ Send
+ Sync,
>;
#[derive(Clone)]
pub struct StepRegistration {
pub name: String,
pub accepts: Vec<&'static str>,
pub emits: Vec<&'static str>,
pub handler: StepFn,
pub max_concurrency: usize,
pub semaphore: Option<Arc<Semaphore>>,
pub timeout: Option<Duration>,
pub retry_config: Option<Arc<blazen_llm::retry::RetryConfig>>,
}
impl std::fmt::Debug for StepRegistration {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StepRegistration")
.field("name", &self.name)
.field("accepts", &self.accepts)
.field("emits", &self.emits)
.field("max_concurrency", &self.max_concurrency)
.field("timeout", &self.timeout)
.field("retry_config", &self.retry_config.is_some())
.finish_non_exhaustive()
}
}
impl StepRegistration {
#[must_use]
pub fn new(
name: String,
accepts: Vec<&'static str>,
emits: Vec<&'static str>,
handler: StepFn,
max_concurrency: usize,
) -> Self {
let semaphore = if max_concurrency > 0 {
Some(Arc::new(Semaphore::new(max_concurrency)))
} else {
None
};
Self {
name,
accepts,
emits,
handler,
max_concurrency,
semaphore,
timeout: None,
retry_config: None,
}
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
#[must_use]
pub fn no_timeout(mut self) -> Self {
self.timeout = None;
self
}
#[must_use]
pub fn with_retry_config(mut self, config: blazen_llm::retry::RetryConfig) -> Self {
self.retry_config = Some(Arc::new(config));
self
}
#[must_use]
pub fn no_retry(mut self) -> Self {
self.retry_config = Some(Arc::new(blazen_llm::retry::RetryConfig {
max_retries: 0,
..blazen_llm::retry::RetryConfig::default()
}));
self
}
#[must_use]
pub fn clear_retry_config(mut self) -> Self {
self.retry_config = None;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JoinStrategy {
WaitAll,
FirstCompletes,
}
pub type SubWorkflowInputMapper = Arc<dyn Fn(&dyn AnyEvent) -> serde_json::Value + Send + Sync>;
pub type SubWorkflowOutputMapper =
Arc<dyn Fn(serde_json::Value) -> Box<dyn AnyEvent> + Send + Sync>;
#[derive(Clone)]
pub struct SubWorkflowStep {
pub name: String,
pub accepts: Vec<&'static str>,
pub emits: Vec<&'static str>,
pub workflow: Arc<crate::workflow::Workflow>,
pub input_mapper: SubWorkflowInputMapper,
pub output_mapper: SubWorkflowOutputMapper,
pub timeout: Option<Duration>,
pub retry_config: Option<Arc<blazen_llm::retry::RetryConfig>>,
}
impl std::fmt::Debug for SubWorkflowStep {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SubWorkflowStep")
.field("name", &self.name)
.field("accepts", &self.accepts)
.field("emits", &self.emits)
.field("workflow", &self.workflow)
.field("timeout", &self.timeout)
.field("retry_config", &self.retry_config.is_some())
.finish_non_exhaustive()
}
}
impl SubWorkflowStep {
#[must_use]
pub fn with_json_mappers(
name: impl Into<String>,
accepts: Vec<&'static str>,
emits: Vec<&'static str>,
workflow: std::sync::Arc<crate::workflow::Workflow>,
) -> Self {
let name_str = name.into();
let output_event_type: &'static str =
blazen_events::intern_event_type(&format!("{name_str}::output"));
let output_event_type_owned = output_event_type;
Self {
name: name_str,
accepts,
emits,
workflow,
input_mapper: std::sync::Arc::new(|event| event.to_json()),
output_mapper: std::sync::Arc::new(move |value| {
Box::new(blazen_events::DynamicEvent {
event_type: output_event_type_owned.to_string(),
data: value,
})
}),
timeout: None,
retry_config: None,
}
}
#[must_use]
pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
self.timeout = Some(timeout);
self
}
#[must_use]
pub fn with_retry_config(mut self, cfg: blazen_llm::retry::RetryConfig) -> Self {
self.retry_config = Some(std::sync::Arc::new(cfg));
self
}
}
#[derive(Clone)]
pub struct ParallelSubWorkflowsStep {
pub name: String,
pub accepts: Vec<&'static str>,
pub emits: Vec<&'static str>,
pub branches: Vec<SubWorkflowStep>,
pub join_strategy: JoinStrategy,
}
impl std::fmt::Debug for ParallelSubWorkflowsStep {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ParallelSubWorkflowsStep")
.field("name", &self.name)
.field("accepts", &self.accepts)
.field("emits", &self.emits)
.field("branch_count", &self.branches.len())
.field("join_strategy", &self.join_strategy)
.finish_non_exhaustive()
}
}
#[derive(Clone, Debug)]
pub enum StepKind {
Regular(StepRegistration),
SubWorkflow(SubWorkflowStep),
ParallelSubWorkflows(ParallelSubWorkflowsStep),
}
impl StepKind {
#[must_use]
pub fn name(&self) -> &str {
match self {
StepKind::Regular(r) => &r.name,
StepKind::SubWorkflow(s) => &s.name,
StepKind::ParallelSubWorkflows(p) => &p.name,
}
}
#[must_use]
pub fn accepts(&self) -> &[&'static str] {
match self {
StepKind::Regular(r) => &r.accepts,
StepKind::SubWorkflow(s) => &s.accepts,
StepKind::ParallelSubWorkflows(p) => &p.accepts,
}
}
#[must_use]
pub fn emits(&self) -> &[&'static str] {
match self {
StepKind::Regular(r) => &r.emits,
StepKind::SubWorkflow(s) => &s.emits,
StepKind::ParallelSubWorkflows(p) => &p.emits,
}
}
}
impl From<StepRegistration> for StepKind {
fn from(reg: StepRegistration) -> Self {
StepKind::Regular(reg)
}
}
impl From<SubWorkflowStep> for StepKind {
fn from(step: SubWorkflowStep) -> Self {
StepKind::SubWorkflow(step)
}
}
impl From<ParallelSubWorkflowsStep> for StepKind {
fn from(step: ParallelSubWorkflowsStep) -> Self {
StepKind::ParallelSubWorkflows(step)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_handler() -> StepFn {
Arc::new(|_event, _ctx| Box::pin(async { Ok(StepOutput::None) }))
}
fn make_registration() -> StepRegistration {
StepRegistration::new(
"test_step".to_owned(),
vec!["ev::a"],
vec!["ev::b"],
make_handler(),
0,
)
}
#[test]
fn step_registration_default_has_no_timeout() {
let reg = make_registration();
assert!(reg.timeout.is_none());
}
#[test]
fn step_registration_with_timeout_sets_field() {
let reg = make_registration().with_timeout(Duration::from_millis(250));
assert_eq!(reg.timeout, Some(Duration::from_millis(250)));
}
#[test]
fn step_registration_no_timeout_clears_field() {
let reg = make_registration()
.with_timeout(Duration::from_secs(1))
.no_timeout();
assert!(reg.timeout.is_none());
}
#[test]
fn step_registration_default_has_no_retry_config() {
let reg = StepRegistration::new(
"n".into(),
vec![],
vec![],
std::sync::Arc::new(|_, _| Box::pin(async { Ok(StepOutput::None) })),
0,
);
assert!(reg.retry_config.is_none());
}
#[test]
fn step_registration_with_retry_config_sets_field() {
let reg = StepRegistration::new(
"n".into(),
vec![],
vec![],
std::sync::Arc::new(|_, _| Box::pin(async { Ok(StepOutput::None) })),
0,
)
.with_retry_config(blazen_llm::retry::RetryConfig {
max_retries: 11,
..blazen_llm::retry::RetryConfig::default()
});
assert_eq!(reg.retry_config.unwrap().max_retries, 11);
}
}