Skip to main content

pmetal_distributed/
layer_assignment.rs

1//! Layer assignment solvers for pipeline-parallel inference.
2//!
3//! Determines how to partition a model's decoder layers across multiple nodes.
4//! Two strategies:
5//! - **Proportional**: layers proportional to available RAM (good default)
6//! - **Bandwidth-aware**: minimize bottleneck link cost
7//!
8//! With 2-4 nodes (typical Apple Silicon home cluster), exhaustive search
9//! over contiguous splits is feasible — no MILP solver needed.
10
11use std::ops::Range;
12
13/// Divide `num_layers` across nodes proportionally to `available_ram`.
14///
15/// Returns contiguous, non-overlapping layer ranges that cover `0..num_layers`.
16#[allow(clippy::single_range_in_vec_init)]
17pub fn assign_layers_proportional(num_layers: usize, available_ram: &[u64]) -> Vec<Range<usize>> {
18    let world_size = available_ram.len();
19    assert!(world_size > 0, "need at least one node");
20    assert!(num_layers > 0, "need at least one layer");
21    assert!(
22        num_layers >= world_size,
23        "more nodes ({world_size}) than layers ({num_layers})"
24    );
25
26    if world_size == 1 {
27        return vec![0..num_layers];
28    }
29
30    let total_ram: f64 = available_ram.iter().sum::<u64>() as f64;
31    let mut assignments = Vec::with_capacity(world_size);
32    let mut start = 0;
33
34    for (i, &ram) in available_ram.iter().enumerate() {
35        if i == world_size - 1 {
36            // Last node gets all remaining layers
37            assignments.push(start..num_layers);
38        } else {
39            let proportion = ram as f64 / total_ram;
40            let remaining_nodes = world_size - i - 1;
41            let remaining_layers = num_layers - start;
42            let max_for_this = remaining_layers - remaining_nodes; // leave >=1 per remaining node
43            let count = (proportion * num_layers as f64).round() as usize;
44            let count = count.clamp(1, max_for_this);
45            assignments.push(start..start + count);
46            start += count;
47        }
48    }
49
50    assignments
51}
52
53/// Divide layers to minimize bottleneck latency, accounting for per-node bandwidth.
54///
55/// `bandwidths[i]` is the link bandwidth (bytes/sec) from node i to node i+1.
56/// Nodes with higher bandwidth can handle the activation transfer cost of more layers.
57///
58/// For 2 nodes: exhaustive search over all split points.
59/// For 3+ nodes: heuristic weighted by bandwidth * ram.
60pub fn assign_layers_bandwidth_aware(
61    num_layers: usize,
62    available_ram: &[u64],
63    bandwidths: &[u64],
64) -> Vec<Range<usize>> {
65    let world_size = available_ram.len();
66    assert_eq!(world_size, bandwidths.len());
67
68    if world_size <= 1 {
69        return assign_layers_proportional(num_layers, available_ram);
70    }
71
72    if world_size == 2 {
73        return assign_two_nodes(num_layers, available_ram, bandwidths);
74    }
75
76    if world_size == 3 {
77        return assign_three_nodes(num_layers, available_ram, bandwidths);
78    }
79
80    // 4+ nodes: weighted proportional
81    let weights: Vec<u64> = available_ram
82        .iter()
83        .zip(bandwidths.iter())
84        .map(|(&r, &b)| {
85            let r_mb = (r / 1_000_000).max(1);
86            let b_mb = (b / 1_000_000).max(1);
87            r_mb * b_mb
88        })
89        .collect();
90    assign_layers_proportional(num_layers, &weights)
91}
92
93/// Exhaustive search for 2-node split.
94fn assign_two_nodes(
95    num_layers: usize,
96    _available_ram: &[u64],
97    bandwidths: &[u64],
98) -> Vec<Range<usize>> {
99    let mut best_split = 1;
100    let mut best_cost = f64::MAX;
101
102    for split in 1..num_layers {
103        let cost_0 = split as f64 / bandwidths[0].max(1) as f64;
104        let cost_1 = (num_layers - split) as f64 / bandwidths[1].max(1) as f64;
105        let max_cost = cost_0.max(cost_1);
106        if max_cost < best_cost {
107            best_cost = max_cost;
108            best_split = split;
109        }
110    }
111
112    vec![0..best_split, best_split..num_layers]
113}
114
115/// Exhaustive search for 3-node split.
116fn assign_three_nodes(
117    num_layers: usize,
118    _available_ram: &[u64],
119    bandwidths: &[u64],
120) -> Vec<Range<usize>> {
121    let mut best = (1usize, 2usize);
122    let mut best_cost = f64::MAX;
123
124    for s1 in 1..num_layers - 1 {
125        for s2 in s1 + 1..num_layers {
126            let cost_0 = s1 as f64 / bandwidths[0].max(1) as f64;
127            let cost_1 = (s2 - s1) as f64 / bandwidths[1].max(1) as f64;
128            let cost_2 = (num_layers - s2) as f64 / bandwidths[2].max(1) as f64;
129            let max_cost = cost_0.max(cost_1).max(cost_2);
130            if max_cost < best_cost {
131                best_cost = max_cost;
132                best = (s1, s2);
133            }
134        }
135    }
136
137    vec![0..best.0, best.0..best.1, best.1..num_layers]
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn proportional_equal() {
146        let r = assign_layers_proportional(32, &[16_000, 16_000]);
147        assert_eq!(r, vec![0..16, 16..32]);
148    }
149
150    #[test]
151    fn proportional_3x1() {
152        let r = assign_layers_proportional(32, &[48_000, 16_000]);
153        assert_eq!(r, vec![0..24, 24..32]);
154    }
155
156    #[test]
157    fn proportional_three_equal() {
158        let r = assign_layers_proportional(30, &[10_000, 10_000, 10_000]);
159        assert_eq!(r, vec![0..10, 10..20, 20..30]);
160    }
161
162    #[test]
163    fn bandwidth_two_nodes() {
164        // Node 0 has 2x bandwidth → should get ~2x layers
165        let r = assign_layers_bandwidth_aware(30, &[16_000, 16_000], &[200_000, 100_000]);
166        assert_eq!(r.len(), 2);
167        assert!(
168            r[0].len() > r[1].len(),
169            "faster node should get more layers"
170        );
171    }
172
173    #[test]
174    fn bandwidth_three_nodes() {
175        let r = assign_layers_bandwidth_aware(
176            30,
177            &[16_000, 16_000, 16_000],
178            &[100_000, 100_000, 100_000],
179        );
180        assert_eq!(r.len(), 3);
181        assert_eq!(r[0].start, 0);
182        assert_eq!(r[2].end, 30);
183        // All ranges should be contiguous
184        assert_eq!(r[0].end, r[1].start);
185        assert_eq!(r[1].end, r[2].start);
186    }
187
188    #[test]
189    fn minimum_one_layer_per_node() {
190        let r = assign_layers_proportional(4, &[100, 100, 100, 100]);
191        assert_eq!(r.len(), 4);
192        for range in &r {
193            assert!(!range.is_empty(), "each node must get at least one layer");
194        }
195    }
196}