#[derive(Debug, Clone)]
pub struct PipelineStage {
pub node_id: String,
pub start_layer: usize,
pub end_layer: usize,
pub has_embedding: bool,
pub has_lm_head: bool,
}
pub struct PipelineSchedule {
stages: Vec<PipelineStage>,
}
impl PipelineSchedule {
pub fn new(
assignments: &std::collections::HashMap<String, (usize, usize)>,
total_layers: usize,
) -> Self {
let mut stages: Vec<PipelineStage> = assignments
.iter()
.map(|(node_id, &(start, end))| PipelineStage {
node_id: node_id.clone(),
start_layer: start,
end_layer: end,
has_embedding: start == 0,
has_lm_head: end == total_layers,
})
.collect();
stages.sort_by_key(|s| s.start_layer);
Self { stages }
}
pub fn stages(&self) -> &[PipelineStage] {
&self.stages
}
pub fn num_stages(&self) -> usize {
self.stages.len()
}
pub fn stage_for_layer(&self, layer_idx: usize) -> Option<&PipelineStage> {
self.stages
.iter()
.find(|s| layer_idx >= s.start_layer && layer_idx < s.end_layer)
}
pub fn output_node(&self) -> Option<&str> {
self.stages.last().map(|s| s.node_id.as_str())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_pipeline_schedule() {
let mut assignments = HashMap::new();
assignments.insert("node-0".to_string(), (0, 16));
assignments.insert("node-1".to_string(), (16, 32));
let schedule = PipelineSchedule::new(&assignments, 32);
assert_eq!(schedule.num_stages(), 2);
assert!(schedule.stages()[0].has_embedding);
assert!(!schedule.stages()[0].has_lm_head);
assert!(!schedule.stages()[1].has_embedding);
assert!(schedule.stages()[1].has_lm_head);
assert_eq!(schedule.output_node(), Some("node-1"));
}
}