Skip to main content

somatize_compiler/
scheduler.rs

1//! Scheduler: distributes ExecutionPlan nodes across available workers.
2//!
3//! Rules:
4//! 1. Sequential phases → single worker (avoid data transfer)
5//! 2. Parallel branches → distribute across workers by capability
6//! 3. Differentiable connected nodes → same worker (gradient flow)
7//! 4. Study trials → round-robin across all workers
8//! 5. Auto-assign: users don't pick workers, the scheduler does
9
10use crate::ExecutionPlan;
11use serde::{Deserialize, Serialize};
12
13/// A worker's capabilities and current load.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct WorkerInfo {
16    pub id: String,
17    pub name: String,
18    pub tags: Vec<String>,
19    pub gpu: bool,
20    pub cpu_cores: usize,
21    pub active_jobs: usize,
22    pub max_concurrent: usize,
23}
24
25impl WorkerInfo {
26    pub fn available_slots(&self) -> usize {
27        self.max_concurrent.saturating_sub(self.active_jobs)
28    }
29
30    pub fn has_capacity(&self) -> bool {
31        self.available_slots() > 0
32    }
33
34    pub fn matches_tag(&self, tag: &str) -> bool {
35        self.tags.iter().any(|t| t == tag)
36    }
37}
38
39/// Assignment of a node/phase to a specific worker.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Assignment {
42    pub node_id: String,
43    pub worker_id: String,
44    pub worker_name: String,
45    pub phase: Phase,
46    pub reason: String,
47}
48
49/// Execution phase type.
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
51#[serde(rename_all = "snake_case")]
52pub enum Phase {
53    Sequential,
54    Parallel,
55    Trial { trial_index: usize, total: usize },
56}
57
58/// The complete distribution plan produced by the scheduler.
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct DistributionPlan {
61    pub assignments: Vec<Assignment>,
62    pub phases: Vec<PlanPhase>,
63    pub data_transfers: Vec<DataTransfer>,
64    pub warnings: Vec<String>,
65}
66
67/// A phase in the execution plan.
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct PlanPhase {
70    pub phase_index: usize,
71    pub phase_type: Phase,
72    pub node_ids: Vec<String>,
73    pub worker_ids: Vec<String>,
74}
75
76/// A data transfer between workers.
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct DataTransfer {
79    pub from_node: String,
80    pub to_node: String,
81    pub from_worker: String,
82    pub to_worker: String,
83    pub transfer_type: String, // "s3", "direct", "cached"
84}
85
86/// Mutable state accumulated during scheduling.
87struct ScheduleState<'a> {
88    workers: Vec<&'a WorkerInfo>,
89    diff_nodes: &'a [String],
90    assignments: Vec<Assignment>,
91    phases: Vec<PlanPhase>,
92    transfers: Vec<DataTransfer>,
93    warnings: Vec<String>,
94    phase_index: usize,
95}
96
97/// Schedule an execution plan across available workers.
98pub fn schedule(
99    plan: &ExecutionPlan,
100    workers: &[WorkerInfo],
101    differentiable_nodes: &[String],
102) -> DistributionPlan {
103    let mut state = ScheduleState {
104        workers: Vec::new(),
105        diff_nodes: differentiable_nodes,
106        assignments: Vec::new(),
107        phases: Vec::new(),
108        transfers: Vec::new(),
109        warnings: Vec::new(),
110        phase_index: 0,
111    };
112
113    if workers.is_empty() {
114        state
115            .warnings
116            .push("No workers available — will execute locally".into());
117        return DistributionPlan {
118            assignments: state.assignments,
119            phases: state.phases,
120            data_transfers: state.transfers,
121            warnings: state.warnings,
122        };
123    }
124
125    state.workers = workers.iter().filter(|w| w.has_capacity()).collect();
126    if state.workers.is_empty() {
127        state.warnings.push("All workers are at capacity".into());
128        return DistributionPlan {
129            assignments: state.assignments,
130            phases: state.phases,
131            data_transfers: state.transfers,
132            warnings: state.warnings,
133        };
134    }
135
136    schedule_plan(plan, &mut state, None);
137
138    DistributionPlan {
139        assignments: state.assignments,
140        phases: state.phases,
141        data_transfers: state.transfers,
142        warnings: state.warnings,
143    }
144}
145
146fn schedule_plan(plan: &ExecutionPlan, state: &mut ScheduleState<'_>, forced_worker: Option<&str>) {
147    match plan {
148        ExecutionPlan::Execute { node_id } => {
149            let worker = if let Some(fw) = forced_worker {
150                state
151                    .workers
152                    .iter()
153                    .find(|w| w.id == fw)
154                    .unwrap_or(&state.workers[0])
155            } else {
156                least_loaded(&state.workers)
157            };
158
159            state.assignments.push(Assignment {
160                node_id: node_id.clone(),
161                worker_id: worker.id.clone(),
162                worker_name: worker.name.clone(),
163                phase: Phase::Sequential,
164                reason: if forced_worker.is_some() {
165                    "grouped with differentiable neighbors".into()
166                } else {
167                    "least loaded worker".into()
168                },
169            });
170        }
171
172        ExecutionPlan::Sequence(steps) => {
173            let worker = forced_worker
174                .and_then(|fw| state.workers.iter().find(|w| w.id == fw).copied())
175                .unwrap_or_else(|| least_loaded(&state.workers));
176
177            let node_ids = collect_node_ids(plan);
178            let has_diff = node_ids.iter().any(|n| state.diff_nodes.contains(n));
179            let force = if has_diff {
180                Some(worker.id.as_str())
181            } else {
182                forced_worker
183            };
184
185            state.phases.push(PlanPhase {
186                phase_index: state.phase_index,
187                phase_type: Phase::Sequential,
188                node_ids: node_ids.clone(),
189                worker_ids: vec![worker.id.clone()],
190            });
191            state.phase_index += 1;
192
193            for step in steps {
194                schedule_plan(step, state, force);
195            }
196        }
197
198        ExecutionPlan::Parallel(branches) => {
199            let branch_ids: Vec<Vec<String>> = branches.iter().map(collect_node_ids).collect();
200            let mut assigned_workers = Vec::new();
201
202            for (i, branch) in branches.iter().enumerate() {
203                let worker_idx = i % state.workers.len();
204                let worker = state.workers[worker_idx];
205                assigned_workers.push(worker.id.clone());
206
207                let worker_id = worker.id.clone();
208                schedule_plan(branch, state, Some(&worker_id));
209
210                // Check if data transfer is needed from previous phase
211                if let Some(prev) = state
212                    .assignments
213                    .iter()
214                    .rev()
215                    .find(|a| !branch_ids[i].contains(&a.node_id))
216                    .filter(|prev| prev.worker_id != state.workers[worker_idx].id)
217                {
218                    state.transfers.push(DataTransfer {
219                        from_node: prev.node_id.clone(),
220                        to_node: branch_ids[i].first().cloned().unwrap_or_default(),
221                        from_worker: prev.worker_id.clone(),
222                        to_worker: state.workers[worker_idx].id.clone(),
223                        transfer_type: "s3".into(),
224                    });
225                }
226            }
227
228            state.phases.push(PlanPhase {
229                phase_index: state.phase_index,
230                phase_type: Phase::Parallel,
231                node_ids: branch_ids.into_iter().flatten().collect(),
232                worker_ids: assigned_workers,
233            });
234            state.phase_index += 1;
235        }
236
237        ExecutionPlan::Cached { node_id, .. } => {
238            let worker = forced_worker
239                .and_then(|fw| state.workers.iter().find(|w| w.id == fw).copied())
240                .unwrap_or_else(|| least_loaded(&state.workers));
241            state.assignments.push(Assignment {
242                node_id: node_id.clone(),
243                worker_id: worker.id.clone(),
244                worker_name: worker.name.clone(),
245                phase: Phase::Sequential,
246                reason: "cached — will skip execution".into(),
247            });
248        }
249
250        ExecutionPlan::Remote { plan, .. } => {
251            schedule_plan(plan, state, None);
252        }
253
254        ExecutionPlan::Loop { body, node_id, .. } => {
255            let worker = forced_worker
256                .and_then(|fw| state.workers.iter().find(|w| w.id == fw).copied())
257                .unwrap_or_else(|| least_loaded(&state.workers));
258            state.assignments.push(Assignment {
259                node_id: node_id.clone(),
260                worker_id: worker.id.clone(),
261                worker_name: worker.name.clone(),
262                phase: Phase::Sequential,
263                reason: "loop controller".into(),
264            });
265            let worker_id = worker.id.clone();
266            schedule_plan(body, state, Some(&worker_id));
267        }
268
269        ExecutionPlan::Branch { node_id, arms, .. } => {
270            let worker = forced_worker
271                .and_then(|fw| state.workers.iter().find(|w| w.id == fw).copied())
272                .unwrap_or_else(|| least_loaded(&state.workers));
273            state.assignments.push(Assignment {
274                node_id: node_id.clone(),
275                worker_id: worker.id.clone(),
276                worker_name: worker.name.clone(),
277                phase: Phase::Sequential,
278                reason: "branch condition".into(),
279            });
280            let worker_id = worker.id.clone();
281            for (_, arm_plan) in arms {
282                schedule_plan(arm_plan, state, Some(&worker_id));
283            }
284        }
285
286        ExecutionPlan::Composite { node_ids } => {
287            let worker = forced_worker
288                .and_then(|fw| state.workers.iter().find(|w| w.id == fw).copied())
289                .unwrap_or_else(|| least_loaded(&state.workers));
290
291            state.phases.push(PlanPhase {
292                phase_index: state.phase_index,
293                phase_type: Phase::Sequential,
294                node_ids: node_ids.clone(),
295                worker_ids: vec![worker.id.clone()],
296            });
297            state.phase_index += 1;
298
299            let worker_id = worker.id.clone();
300            for nid in node_ids {
301                state.assignments.push(Assignment {
302                    node_id: nid.clone(),
303                    worker_id: worker.id.clone(),
304                    worker_name: worker.name.clone(),
305                    phase: Phase::Sequential,
306                    reason: "composite block — same worker for gradient flow".into(),
307                });
308            }
309            drop(worker_id);
310        }
311
312        ExecutionPlan::Stream { node_ids, .. } => {
313            // Stream: all filters on the same worker for stateful chunk processing.
314            let worker = forced_worker
315                .and_then(|fw| state.workers.iter().find(|w| w.id == fw).copied())
316                .unwrap_or_else(|| least_loaded(&state.workers));
317
318            state.phases.push(PlanPhase {
319                phase_index: state.phase_index,
320                phase_type: Phase::Sequential,
321                node_ids: node_ids.clone(),
322                worker_ids: vec![worker.id.clone()],
323            });
324            state.phase_index += 1;
325
326            for nid in node_ids {
327                state.assignments.push(Assignment {
328                    node_id: nid.clone(),
329                    worker_id: worker.id.clone(),
330                    worker_name: worker.name.clone(),
331                    phase: Phase::Sequential,
332                    reason: "stream block — same worker for stateful chunk processing".into(),
333                });
334            }
335        }
336
337        ExecutionPlan::Empty => {}
338    }
339}
340
341fn least_loaded<'a>(workers: &[&'a WorkerInfo]) -> &'a WorkerInfo {
342    workers.iter().max_by_key(|w| w.available_slots()).unwrap()
343}
344
345fn collect_node_ids(plan: &ExecutionPlan) -> Vec<String> {
346    plan.node_ids().into_iter().map(|s| s.to_string()).collect()
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    fn test_workers() -> Vec<WorkerInfo> {
354        vec![
355            WorkerInfo {
356                id: "w1".into(),
357                name: "GPU-A100".into(),
358                tags: vec!["gpu".into()],
359                gpu: true,
360                cpu_cores: 16,
361                active_jobs: 0,
362                max_concurrent: 4,
363            },
364            WorkerInfo {
365                id: "w2".into(),
366                name: "CPU-Server".into(),
367                tags: vec!["cpu".into()],
368                gpu: false,
369                cpu_cores: 64,
370                active_jobs: 1,
371                max_concurrent: 8,
372            },
373        ]
374    }
375
376    #[test]
377    fn sequential_same_worker() {
378        let plan = ExecutionPlan::Sequence(vec![
379            ExecutionPlan::Execute {
380                node_id: "normalize".into(),
381            },
382            ExecutionPlan::Execute {
383                node_id: "select".into(),
384            },
385            ExecutionPlan::Execute {
386                node_id: "classify".into(),
387            },
388        ]);
389
390        let result = schedule(&plan, &test_workers(), &[]);
391        // All should be on the same worker
392        let worker_ids: Vec<&str> = result
393            .assignments
394            .iter()
395            .map(|a| a.worker_id.as_str())
396            .collect();
397        assert!(worker_ids.windows(2).all(|w| w[0] == w[1]));
398    }
399
400    #[test]
401    fn parallel_distributes() {
402        let plan = ExecutionPlan::Parallel(vec![
403            ExecutionPlan::Execute {
404                node_id: "train_svm".into(),
405            },
406            ExecutionPlan::Execute {
407                node_id: "train_knn".into(),
408            },
409        ]);
410
411        let result = schedule(&plan, &test_workers(), &[]);
412        assert_eq!(result.assignments.len(), 2);
413        // Should be on different workers
414        assert_ne!(
415            result.assignments[0].worker_id,
416            result.assignments[1].worker_id
417        );
418    }
419
420    #[test]
421    fn no_workers_warns() {
422        let plan = ExecutionPlan::Execute {
423            node_id: "test".into(),
424        };
425        let result = schedule(&plan, &[], &[]);
426        assert!(!result.warnings.is_empty());
427    }
428
429    #[test]
430    fn sequence_then_parallel() {
431        let plan = ExecutionPlan::Sequence(vec![
432            ExecutionPlan::Execute {
433                node_id: "load".into(),
434            },
435            ExecutionPlan::Execute {
436                node_id: "normalize".into(),
437            },
438            ExecutionPlan::Parallel(vec![
439                ExecutionPlan::Execute {
440                    node_id: "train_a".into(),
441                },
442                ExecutionPlan::Execute {
443                    node_id: "train_b".into(),
444                },
445            ]),
446        ]);
447
448        let result = schedule(&plan, &test_workers(), &[]);
449        // load + normalize on same worker, train_a and train_b distributed
450        assert!(result.assignments.len() >= 4);
451        assert_eq!(
452            result.assignments[0].worker_id,
453            result.assignments[1].worker_id
454        );
455    }
456
457    #[test]
458    fn data_transfer_on_split() {
459        let plan = ExecutionPlan::Sequence(vec![
460            ExecutionPlan::Execute {
461                node_id: "preprocess".into(),
462            },
463            ExecutionPlan::Parallel(vec![
464                ExecutionPlan::Execute {
465                    node_id: "branch_a".into(),
466                },
467                ExecutionPlan::Execute {
468                    node_id: "branch_b".into(),
469                },
470            ]),
471        ]);
472
473        let result = schedule(&plan, &test_workers(), &[]);
474        // Should have at least one data transfer (preprocess → branch on different worker)
475        assert!(
476            !result.data_transfers.is_empty()
477                || result
478                    .assignments
479                    .iter()
480                    .all(|a| a.worker_id == result.assignments[0].worker_id)
481        );
482    }
483}