use crate::{Chain, Chord, Group};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq)]
pub enum OptimizationPass {
CommonSubexpressionElimination,
DeadCodeElimination,
TaskFusion,
ParallelScheduling,
ResourceOptimization,
}
impl std::fmt::Display for OptimizationPass {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::CommonSubexpressionElimination => write!(f, "CSE"),
Self::DeadCodeElimination => write!(f, "DCE"),
Self::TaskFusion => write!(f, "TaskFusion"),
Self::ParallelScheduling => write!(f, "ParallelScheduling"),
Self::ResourceOptimization => write!(f, "ResourceOptimization"),
}
}
}
#[derive(Debug, Clone)]
pub struct WorkflowCompiler {
pub passes: Vec<OptimizationPass>,
pub aggressive: bool,
}
impl WorkflowCompiler {
pub fn new() -> Self {
Self {
passes: vec![
OptimizationPass::DeadCodeElimination,
OptimizationPass::CommonSubexpressionElimination,
],
aggressive: false,
}
}
pub fn aggressive(mut self) -> Self {
self.aggressive = true;
self.passes.push(OptimizationPass::TaskFusion);
self.passes.push(OptimizationPass::ParallelScheduling);
self.passes.push(OptimizationPass::ResourceOptimization);
self
}
pub fn add_pass(mut self, pass: OptimizationPass) -> Self {
if !self.passes.contains(&pass) {
self.passes.push(pass);
}
self
}
pub fn optimize_chain(&self, chain: &Chain) -> Chain {
let mut optimized = chain.clone();
for pass in &self.passes {
optimized = match pass {
OptimizationPass::CommonSubexpressionElimination => {
self.apply_cse_chain(&optimized)
}
OptimizationPass::DeadCodeElimination => self.apply_dce_chain(&optimized),
OptimizationPass::TaskFusion => self.apply_task_fusion(&optimized),
OptimizationPass::ParallelScheduling => {
optimized
}
OptimizationPass::ResourceOptimization => {
self.apply_resource_optimization_chain(&optimized)
}
};
}
optimized
}
pub fn optimize_group(&self, group: &Group) -> Group {
let mut optimized = group.clone();
for pass in &self.passes {
optimized = match pass {
OptimizationPass::CommonSubexpressionElimination => {
self.apply_cse_group(&optimized)
}
OptimizationPass::DeadCodeElimination => self.apply_dce_group(&optimized),
OptimizationPass::TaskFusion => {
optimized
}
OptimizationPass::ParallelScheduling => self.apply_parallel_scheduling(&optimized),
OptimizationPass::ResourceOptimization => {
self.apply_resource_optimization_group(&optimized)
}
};
}
optimized
}
pub fn optimize_chord(&self, chord: &Chord) -> Chord {
let optimized_group = self.optimize_group(&chord.header);
Chord {
header: optimized_group,
body: chord.body.clone(),
}
}
fn apply_cse_chain(&self, chain: &Chain) -> Chain {
let mut seen = HashMap::new();
let mut optimized_tasks = Vec::new();
for (idx, task) in chain.tasks.iter().enumerate() {
let key = format!(
"{}:{}:{}",
task.task,
serde_json::to_string(&task.args).unwrap_or_default(),
serde_json::to_string(&task.kwargs).unwrap_or_default()
);
if let Some(&prev_idx) = seen.get(&key) {
if self.aggressive && prev_idx < idx {
continue;
}
} else {
seen.insert(key, idx);
}
optimized_tasks.push(task.clone());
}
Chain {
tasks: optimized_tasks,
}
}
fn apply_cse_group(&self, group: &Group) -> Group {
let mut seen = HashMap::new();
let mut optimized_tasks = Vec::new();
for task in &group.tasks {
let key = format!(
"{}:{}:{}",
task.task,
serde_json::to_string(&task.args).unwrap_or_default(),
serde_json::to_string(&task.kwargs).unwrap_or_default()
);
if let std::collections::hash_map::Entry::Vacant(e) = seen.entry(key) {
e.insert(true);
optimized_tasks.push(task.clone());
} else {
if !self.aggressive {
optimized_tasks.push(task.clone());
}
}
}
Group {
tasks: optimized_tasks,
group_id: group.group_id,
}
}
fn apply_dce_chain(&self, chain: &Chain) -> Chain {
let optimized_tasks: Vec<_> = chain
.tasks
.iter()
.filter(|task| !task.task.is_empty())
.cloned()
.collect();
Chain {
tasks: optimized_tasks,
}
}
fn apply_dce_group(&self, group: &Group) -> Group {
let optimized_tasks: Vec<_> = group
.tasks
.iter()
.filter(|task| !task.task.is_empty())
.cloned()
.collect();
Group {
tasks: optimized_tasks,
group_id: group.group_id,
}
}
fn apply_task_fusion(&self, chain: &Chain) -> Chain {
if !self.aggressive || chain.tasks.len() < 2 {
return chain.clone();
}
let mut optimized_tasks = Vec::new();
let mut i = 0;
while i < chain.tasks.len() {
let current = &chain.tasks[i];
if i + 1 < chain.tasks.len() {
let next = &chain.tasks[i + 1];
if current.task == next.task
&& current.immutable
&& next.immutable
&& current.options.priority == next.options.priority
{
let mut fused = current.clone();
fused.args.extend(next.args.clone());
optimized_tasks.push(fused);
i += 2; continue;
}
}
optimized_tasks.push(current.clone());
i += 1;
}
Chain {
tasks: optimized_tasks,
}
}
fn apply_parallel_scheduling(&self, group: &Group) -> Group {
let mut optimized_tasks = group.tasks.clone();
optimized_tasks.sort_by(|a, b| {
let a_priority = a.options.priority.unwrap_or(0);
let b_priority = b.options.priority.unwrap_or(0);
b_priority.cmp(&a_priority)
});
Group {
tasks: optimized_tasks,
group_id: group.group_id,
}
}
fn apply_resource_optimization_chain(&self, chain: &Chain) -> Chain {
let mut optimized_tasks = chain.tasks.clone();
if self.aggressive {
optimized_tasks.sort_by(|a, b| {
let a_queue = a.options.queue.as_deref().unwrap_or("");
let b_queue = b.options.queue.as_deref().unwrap_or("");
a_queue.cmp(b_queue)
});
}
Chain {
tasks: optimized_tasks,
}
}
fn apply_resource_optimization_group(&self, group: &Group) -> Group {
let mut optimized_tasks = group.tasks.clone();
optimized_tasks.sort_by(|a, b| {
let a_queue = a.options.queue.as_deref().unwrap_or("");
let b_queue = b.options.queue.as_deref().unwrap_or("");
a_queue.cmp(b_queue)
});
Group {
tasks: optimized_tasks,
group_id: group.group_id,
}
}
}
impl Default for WorkflowCompiler {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for WorkflowCompiler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "WorkflowCompiler[")?;
for (i, pass) in self.passes.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", pass)?;
}
if self.aggressive {
write!(f, " aggressive")?;
}
write!(f, "]")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TypedResult<T> {
pub value: T,
pub type_name: String,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
}
impl<T: Serialize> TypedResult<T> {
pub fn new(value: T) -> Self {
Self {
value,
type_name: std::any::type_name::<T>().to_string(),
metadata: HashMap::new(),
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn type_name(&self) -> &str {
&self.type_name
}
}
impl<T: std::fmt::Display> std::fmt::Display for TypedResult<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"TypedResult[type={}, value={}]",
self.type_name, self.value
)
}
}
#[derive(Debug, Clone)]
pub struct TypeValidator {
pub expected_type: String,
pub allow_compatible: bool,
}
impl TypeValidator {
pub fn new(expected_type: impl Into<String>) -> Self {
Self {
expected_type: expected_type.into(),
allow_compatible: false,
}
}
pub fn allow_compatible(mut self) -> Self {
self.allow_compatible = true;
self
}
pub fn validate(&self, actual_type: &str) -> bool {
if actual_type == self.expected_type {
return true;
}
if self.allow_compatible {
self.is_compatible(actual_type)
} else {
false
}
}
fn is_compatible(&self, actual_type: &str) -> bool {
if self.expected_type.contains("Option") && actual_type != "None" {
return true;
}
if self.expected_type == "serde_json::Value" {
return true;
}
false
}
}
impl std::fmt::Display for TypeValidator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "TypeValidator[expected={}]", self.expected_type)?;
if self.allow_compatible {
write!(f, " (allow_compatible)")?;
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct TaskDependency {
pub task_id: Uuid,
pub output_key: Option<String>,
#[serde(default)]
pub optional: bool,
}
impl TaskDependency {
pub fn new(task_id: Uuid) -> Self {
Self {
task_id,
output_key: None,
optional: false,
}
}
pub fn with_output_key(mut self, key: impl Into<String>) -> Self {
self.output_key = Some(key.into());
self
}
pub fn optional(mut self) -> Self {
self.optional = true;
self
}
}
impl std::fmt::Display for TaskDependency {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "TaskDependency[{}]", self.task_id)?;
if let Some(ref key) = self.output_key {
write!(f, " output={}", key)?;
}
if self.optional {
write!(f, " (optional)")?;
}
Ok(())
}
}