celers-canvas 0.2.0

Workflow primitives for CeleRS (Chain, Chord, Group, Map)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
use crate::{Chain, Chord, Group};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;

/// Workflow optimization pass
#[derive(Debug, Clone, PartialEq)]
pub enum OptimizationPass {
    /// Common subexpression elimination
    CommonSubexpressionElimination,
    /// Dead code elimination
    DeadCodeElimination,
    /// Task fusion (combine sequential tasks)
    TaskFusion,
    /// Parallel task scheduling optimization
    ParallelScheduling,
    /// Resource allocation optimization
    ResourceOptimization,
}

impl std::fmt::Display for OptimizationPass {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::CommonSubexpressionElimination => write!(f, "CSE"),
            Self::DeadCodeElimination => write!(f, "DCE"),
            Self::TaskFusion => write!(f, "TaskFusion"),
            Self::ParallelScheduling => write!(f, "ParallelScheduling"),
            Self::ResourceOptimization => write!(f, "ResourceOptimization"),
        }
    }
}

/// Workflow compiler for optimization
#[derive(Debug, Clone)]
pub struct WorkflowCompiler {
    /// Optimization passes to apply
    pub passes: Vec<OptimizationPass>,
    /// Whether to enable aggressive optimizations
    pub aggressive: bool,
}

impl WorkflowCompiler {
    /// Create a new workflow compiler
    pub fn new() -> Self {
        Self {
            passes: vec![
                OptimizationPass::DeadCodeElimination,
                OptimizationPass::CommonSubexpressionElimination,
            ],
            aggressive: false,
        }
    }

    /// Enable aggressive optimizations
    pub fn aggressive(mut self) -> Self {
        self.aggressive = true;
        self.passes.push(OptimizationPass::TaskFusion);
        self.passes.push(OptimizationPass::ParallelScheduling);
        self.passes.push(OptimizationPass::ResourceOptimization);
        self
    }

    /// Add optimization pass
    pub fn add_pass(mut self, pass: OptimizationPass) -> Self {
        if !self.passes.contains(&pass) {
            self.passes.push(pass);
        }
        self
    }

    /// Optimize a chain by applying configured optimization passes
    pub fn optimize_chain(&self, chain: &Chain) -> Chain {
        let mut optimized = chain.clone();

        for pass in &self.passes {
            optimized = match pass {
                OptimizationPass::CommonSubexpressionElimination => {
                    self.apply_cse_chain(&optimized)
                }
                OptimizationPass::DeadCodeElimination => self.apply_dce_chain(&optimized),
                OptimizationPass::TaskFusion => self.apply_task_fusion(&optimized),
                OptimizationPass::ParallelScheduling => {
                    // Chains are sequential, no parallel scheduling
                    optimized
                }
                OptimizationPass::ResourceOptimization => {
                    self.apply_resource_optimization_chain(&optimized)
                }
            };
        }

        optimized
    }

    /// Optimize a group by applying configured optimization passes
    pub fn optimize_group(&self, group: &Group) -> Group {
        let mut optimized = group.clone();

        for pass in &self.passes {
            optimized = match pass {
                OptimizationPass::CommonSubexpressionElimination => {
                    self.apply_cse_group(&optimized)
                }
                OptimizationPass::DeadCodeElimination => self.apply_dce_group(&optimized),
                OptimizationPass::TaskFusion => {
                    // Groups are parallel, no task fusion
                    optimized
                }
                OptimizationPass::ParallelScheduling => self.apply_parallel_scheduling(&optimized),
                OptimizationPass::ResourceOptimization => {
                    self.apply_resource_optimization_group(&optimized)
                }
            };
        }

        optimized
    }

    /// Optimize a chord by applying configured optimization passes
    pub fn optimize_chord(&self, chord: &Chord) -> Chord {
        let optimized_group = self.optimize_group(&chord.header);
        Chord {
            header: optimized_group,
            body: chord.body.clone(),
        }
    }

    // Helper methods for optimization passes

    /// Common Subexpression Elimination for chains
    fn apply_cse_chain(&self, chain: &Chain) -> Chain {
        let mut seen = HashMap::new();
        let mut optimized_tasks = Vec::new();

        for (idx, task) in chain.tasks.iter().enumerate() {
            // Create a key for deduplication (task name + args)
            let key = format!(
                "{}:{}:{}",
                task.task,
                serde_json::to_string(&task.args).unwrap_or_default(),
                serde_json::to_string(&task.kwargs).unwrap_or_default()
            );

            if let Some(&prev_idx) = seen.get(&key) {
                // Skip duplicate if aggressive mode
                if self.aggressive && prev_idx < idx {
                    continue;
                }
            } else {
                seen.insert(key, idx);
            }

            optimized_tasks.push(task.clone());
        }

        Chain {
            tasks: optimized_tasks,
        }
    }

    /// Common Subexpression Elimination for groups
    fn apply_cse_group(&self, group: &Group) -> Group {
        let mut seen = HashMap::new();
        let mut optimized_tasks = Vec::new();

        for task in &group.tasks {
            // Create a key for deduplication (task name + args)
            let key = format!(
                "{}:{}:{}",
                task.task,
                serde_json::to_string(&task.args).unwrap_or_default(),
                serde_json::to_string(&task.kwargs).unwrap_or_default()
            );

            if let std::collections::hash_map::Entry::Vacant(e) = seen.entry(key) {
                e.insert(true);
                optimized_tasks.push(task.clone());
            } else {
                // Skip duplicate in aggressive mode
                if !self.aggressive {
                    optimized_tasks.push(task.clone());
                }
            }
        }

        Group {
            tasks: optimized_tasks,
            group_id: group.group_id,
        }
    }

    /// Dead Code Elimination for chains
    fn apply_dce_chain(&self, chain: &Chain) -> Chain {
        // Remove tasks that have no effect (e.g., empty task names)
        let optimized_tasks: Vec<_> = chain
            .tasks
            .iter()
            .filter(|task| !task.task.is_empty())
            .cloned()
            .collect();

        Chain {
            tasks: optimized_tasks,
        }
    }

    /// Dead Code Elimination for groups
    fn apply_dce_group(&self, group: &Group) -> Group {
        // Remove tasks that have no effect (e.g., empty task names)
        let optimized_tasks: Vec<_> = group
            .tasks
            .iter()
            .filter(|task| !task.task.is_empty())
            .cloned()
            .collect();

        Group {
            tasks: optimized_tasks,
            group_id: group.group_id,
        }
    }

    /// Task Fusion for chains (combine similar sequential tasks)
    fn apply_task_fusion(&self, chain: &Chain) -> Chain {
        if !self.aggressive || chain.tasks.len() < 2 {
            return chain.clone();
        }

        let mut optimized_tasks = Vec::new();
        let mut i = 0;

        while i < chain.tasks.len() {
            let current = &chain.tasks[i];

            // Check if next task can be fused (same task name, immutable)
            if i + 1 < chain.tasks.len() {
                let next = &chain.tasks[i + 1];

                if current.task == next.task
                    && current.immutable
                    && next.immutable
                    && current.options.priority == next.options.priority
                {
                    // Fuse tasks: combine args
                    let mut fused = current.clone();
                    fused.args.extend(next.args.clone());
                    optimized_tasks.push(fused);
                    i += 2; // Skip both tasks
                    continue;
                }
            }

            optimized_tasks.push(current.clone());
            i += 1;
        }

        Chain {
            tasks: optimized_tasks,
        }
    }

    /// Parallel Scheduling optimization for groups
    fn apply_parallel_scheduling(&self, group: &Group) -> Group {
        let mut optimized_tasks = group.tasks.clone();

        // Sort tasks by priority (higher priority first)
        optimized_tasks.sort_by(|a, b| {
            let a_priority = a.options.priority.unwrap_or(0);
            let b_priority = b.options.priority.unwrap_or(0);
            b_priority.cmp(&a_priority)
        });

        Group {
            tasks: optimized_tasks,
            group_id: group.group_id,
        }
    }

    /// Resource Optimization for chains
    fn apply_resource_optimization_chain(&self, chain: &Chain) -> Chain {
        // Group tasks by queue to improve resource utilization
        let mut optimized_tasks = chain.tasks.clone();

        if self.aggressive {
            // Reorder tasks to group by queue while maintaining dependencies
            optimized_tasks.sort_by(|a, b| {
                let a_queue = a.options.queue.as_deref().unwrap_or("");
                let b_queue = b.options.queue.as_deref().unwrap_or("");
                a_queue.cmp(b_queue)
            });
        }

        Chain {
            tasks: optimized_tasks,
        }
    }

    /// Resource Optimization for groups
    fn apply_resource_optimization_group(&self, group: &Group) -> Group {
        // Balance tasks across queues
        let mut optimized_tasks = group.tasks.clone();

        // Sort by queue to improve locality
        optimized_tasks.sort_by(|a, b| {
            let a_queue = a.options.queue.as_deref().unwrap_or("");
            let b_queue = b.options.queue.as_deref().unwrap_or("");
            a_queue.cmp(b_queue)
        });

        Group {
            tasks: optimized_tasks,
            group_id: group.group_id,
        }
    }
}

impl Default for WorkflowCompiler {
    fn default() -> Self {
        Self::new()
    }
}

impl std::fmt::Display for WorkflowCompiler {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "WorkflowCompiler[")?;
        for (i, pass) in self.passes.iter().enumerate() {
            if i > 0 {
                write!(f, ", ")?;
            }
            write!(f, "{}", pass)?;
        }
        if self.aggressive {
            write!(f, " aggressive")?;
        }
        write!(f, "]")
    }
}

// ============================================================================
// Type-Safe Result Passing
// ============================================================================

/// Type-safe result wrapper for workflow results
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TypedResult<T> {
    /// Result value
    pub value: T,
    /// Result type name for validation
    pub type_name: String,
    /// Result metadata
    #[serde(default)]
    pub metadata: HashMap<String, serde_json::Value>,
}

impl<T: Serialize> TypedResult<T> {
    /// Create a new typed result
    pub fn new(value: T) -> Self {
        Self {
            value,
            type_name: std::any::type_name::<T>().to_string(),
            metadata: HashMap::new(),
        }
    }

    /// Add metadata to the result
    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
        self.metadata.insert(key.into(), value);
        self
    }

    /// Get the type name
    pub fn type_name(&self) -> &str {
        &self.type_name
    }
}

impl<T: std::fmt::Display> std::fmt::Display for TypedResult<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "TypedResult[type={}, value={}]",
            self.type_name, self.value
        )
    }
}

/// Type validator for result passing
#[derive(Debug, Clone)]
pub struct TypeValidator {
    /// Expected type name
    pub expected_type: String,
    /// Whether to allow compatible types
    pub allow_compatible: bool,
}

impl TypeValidator {
    /// Create a new type validator
    pub fn new(expected_type: impl Into<String>) -> Self {
        Self {
            expected_type: expected_type.into(),
            allow_compatible: false,
        }
    }

    /// Allow compatible types
    pub fn allow_compatible(mut self) -> Self {
        self.allow_compatible = true;
        self
    }

    /// Validate a type name
    pub fn validate(&self, actual_type: &str) -> bool {
        if actual_type == self.expected_type {
            return true;
        }
        if self.allow_compatible {
            self.is_compatible(actual_type)
        } else {
            false
        }
    }

    /// Check if types are compatible
    fn is_compatible(&self, actual_type: &str) -> bool {
        // Simple compatibility check (can be extended)
        if self.expected_type.contains("Option") && actual_type != "None" {
            return true;
        }
        if self.expected_type == "serde_json::Value" {
            return true;
        }
        false
    }
}

impl std::fmt::Display for TypeValidator {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "TypeValidator[expected={}]", self.expected_type)?;
        if self.allow_compatible {
            write!(f, " (allow_compatible)")?;
        }
        Ok(())
    }
}

// ============================================================================
// Data Dependencies
// ============================================================================

/// Task dependency specification
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct TaskDependency {
    /// Task ID that this task depends on
    pub task_id: Uuid,
    /// Output key to use from the dependency (optional)
    pub output_key: Option<String>,
    /// Whether this dependency is optional
    #[serde(default)]
    pub optional: bool,
}

impl TaskDependency {
    /// Create a new task dependency
    pub fn new(task_id: Uuid) -> Self {
        Self {
            task_id,
            output_key: None,
            optional: false,
        }
    }

    /// Set output key
    pub fn with_output_key(mut self, key: impl Into<String>) -> Self {
        self.output_key = Some(key.into());
        self
    }

    /// Mark as optional
    pub fn optional(mut self) -> Self {
        self.optional = true;
        self
    }
}

impl std::fmt::Display for TaskDependency {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "TaskDependency[{}]", self.task_id)?;
        if let Some(ref key) = self.output_key {
            write!(f, " output={}", key)?;
        }
        if self.optional {
            write!(f, " (optional)")?;
        }
        Ok(())
    }
}