Skip to main content

llama_rs/backend/
tensor_parallel.rs

1//! Tensor parallelism for multi-GPU inference
2//!
3//! Splits model layers across multiple GPUs on a single node.
4//! Attention heads are split across GPUs (each GPU handles a subset of heads).
5//! FFN matrices are split column-wise (gate/up) or row-wise (down).
6//! All-reduce is used to synchronize after each layer.
7
8use crate::backend::{BackendError, BackendResult};
9use crate::tensor::{compute_strides, Tensor};
10pub use crate::tensor::DType;
11
12/// Trait for tensor parallel communication primitives
13pub trait TensorParallel: Send + Sync {
14    /// Number of devices (world size)
15    fn world_size(&self) -> usize;
16
17    /// This device's rank (0-indexed)
18    fn rank(&self) -> usize;
19
20    /// All-reduce sum: sum tensor across all devices in-place
21    fn all_reduce_sum(&self, tensor: &mut Tensor) -> BackendResult<()>;
22
23    /// All-gather: gather local tensors from all devices into output
24    /// output is world_size * local_size
25    fn all_gather(&self, local: &Tensor, output: &mut Tensor) -> BackendResult<()>;
26
27    /// Scatter: split input across devices, each gets 1/world_size
28    fn scatter(&self, input: &Tensor, output: &mut Tensor) -> BackendResult<()>;
29
30    /// Barrier: synchronize all devices
31    fn barrier(&self) -> BackendResult<()>;
32}
33
34/// Tensor parallelism configuration
35#[derive(Debug, Clone)]
36pub struct TPConfig {
37    /// Number of GPUs to use
38    pub num_devices: usize,
39    /// Device IDs (e.g., [0, 1] for 2 GPUs)
40    pub device_ids: Vec<usize>,
41}
42
43impl Default for TPConfig {
44    fn default() -> Self {
45        Self {
46            num_devices: 1,
47            device_ids: vec![0],
48        }
49    }
50}
51
52/// How a model is sharded across devices
53#[derive(Debug, Clone)]
54pub struct ShardingPlan {
55    /// Number of attention heads per device
56    pub heads_per_device: usize,
57    /// Number of KV heads per device
58    pub kv_heads_per_device: usize,
59    /// FFN intermediate size per device
60    pub ffn_dim_per_device: usize,
61    /// Total number of heads
62    pub total_heads: usize,
63    /// Total number of KV heads
64    pub total_kv_heads: usize,
65    /// Total FFN intermediate size
66    pub total_ffn_dim: usize,
67}
68
69impl ShardingPlan {
70    /// Create a sharding plan from model config
71    pub fn from_config(
72        num_heads: usize,
73        num_kv_heads: usize,
74        ffn_dim: usize,
75        world_size: usize,
76    ) -> Result<Self, String> {
77        // Validate divisibility
78        if num_heads % world_size != 0 {
79            return Err(format!(
80                "num_heads ({}) must be divisible by world_size ({})",
81                num_heads, world_size
82            ));
83        }
84        if num_kv_heads % world_size != 0 {
85            return Err(format!(
86                "num_kv_heads ({}) must be divisible by world_size ({})",
87                num_kv_heads, world_size
88            ));
89        }
90        if ffn_dim % world_size != 0 {
91            return Err(format!(
92                "ffn_dim ({}) must be divisible by world_size ({})",
93                ffn_dim, world_size
94            ));
95        }
96
97        Ok(Self {
98            heads_per_device: num_heads / world_size,
99            kv_heads_per_device: num_kv_heads / world_size,
100            ffn_dim_per_device: ffn_dim / world_size,
101            total_heads: num_heads,
102            total_kv_heads: num_kv_heads,
103            total_ffn_dim: ffn_dim,
104        })
105    }
106}
107
108/// Split a weight tensor along a dimension for tensor parallelism
109///
110/// For a [out, in] matrix split along dim=0 with world_size=2:
111/// - rank 0 gets [0..out/2, :]
112/// - rank 1 gets [out/2..out, :]
113///
114/// Supports F32, F16, BF16, F64. Quantized types are not supported.
115pub fn shard_weight(
116    weight: &Tensor,
117    dim: usize,
118    rank: usize,
119    world_size: usize,
120) -> Result<Tensor, BackendError> {
121    let shape = weight.shape();
122    if dim >= shape.len() {
123        return Err(BackendError::InvalidArgument(format!(
124            "dim {} out of range for shape {:?}",
125            dim, shape
126        )));
127    }
128    if rank >= world_size {
129        return Err(BackendError::InvalidArgument(format!(
130            "rank {} must be < world_size {}",
131            rank, world_size
132        )));
133    }
134    let dim_size = shape[dim];
135    if dim_size % world_size != 0 {
136        return Err(BackendError::InvalidArgument(format!(
137            "shape[{}] ({}) must be divisible by world_size ({})",
138            dim, dim_size, world_size
139        )));
140    }
141    if !weight.is_contiguous() {
142        return Err(BackendError::InvalidArgument(
143            "weight must be contiguous for sharding".into(),
144        ));
145    }
146    if weight.dtype().is_quantized() {
147        return Err(BackendError::Unsupported(
148            "shard_weight does not support quantized tensors".into(),
149        ));
150    }
151
152    let chunk_size = dim_size / world_size;
153    let start_idx = rank * chunk_size;
154
155    // Build output shape: replace shape[dim] with chunk_size
156    let mut out_shape = shape.to_vec();
157    out_shape[dim] = chunk_size;
158
159    let out_numel: usize = out_shape.iter().product();
160    let elem_size = weight.dtype().size_for_elements(1);
161    let out_bytes = weight.dtype().size_for_elements(out_numel);
162    let mut out_data = vec![0u8; out_bytes];
163
164    let in_strides = weight.strides();
165    let in_data = weight.data();
166
167    // Iterate over all output indices and copy from input
168    for out_linear in 0..out_numel {
169        // Decode output linear index to multi-index
170        let mut out_idx = vec![0; out_shape.len()];
171        let mut rem = out_linear;
172        for d in (0..out_shape.len()).rev() {
173            out_idx[d] = rem % out_shape[d];
174            rem /= out_shape[d];
175        }
176        // Map to input index (offset in split dimension)
177        let mut in_idx = out_idx.clone();
178        in_idx[dim] += start_idx;
179        // Compute input linear index
180        let in_linear: usize = in_idx.iter().zip(in_strides.iter()).map(|(i, s)| i * s).sum();
181        let src_off = in_linear * elem_size;
182        let dst_off = out_linear * elem_size;
183        out_data[dst_off..dst_off + elem_size]
184            .copy_from_slice(&in_data[src_off..src_off + elem_size]);
185    }
186
187    Tensor::new(out_data, out_shape, weight.dtype())
188        .map_err(|e| BackendError::OperationFailed(format!("{}", e)))
189}
190
191/// Merge sharded weight tensors back (inverse of shard_weight)
192///
193/// Concatenates shards along the given dimension.
194/// Supports F32, F16, BF16, F64. Quantized types are not supported.
195pub fn merge_shards(shards: &[Tensor], dim: usize) -> Result<Tensor, BackendError> {
196    if shards.is_empty() {
197        return Err(BackendError::InvalidArgument(
198            "merge_shards requires at least one shard".into(),
199        ));
200    }
201    let dtype = shards[0].dtype();
202    if dtype.is_quantized() {
203        return Err(BackendError::Unsupported(
204            "merge_shards does not support quantized tensors".into(),
205        ));
206    }
207    for s in shards {
208        if s.dtype() != dtype {
209            return Err(BackendError::InvalidArgument(
210                "all shards must have the same dtype".into(),
211            ));
212        }
213        if !s.is_contiguous() {
214            return Err(BackendError::InvalidArgument(
215                "all shards must be contiguous".into(),
216            ));
217        }
218    }
219
220    let first_shape = shards[0].shape();
221    if dim >= first_shape.len() {
222        return Err(BackendError::InvalidArgument(format!(
223            "dim {} out of range for shape {:?}",
224            dim, first_shape
225        )));
226    }
227
228    // Build merged shape: sum shards along dim
229    let mut merged_shape = first_shape.to_vec();
230    merged_shape[dim] = 0;
231    for s in shards {
232        let s_shape = s.shape();
233        if s_shape.len() != merged_shape.len() {
234            return Err(BackendError::InvalidArgument(
235                "all shards must have same number of dimensions".into(),
236            ));
237        }
238        for (i, (m, &ss)) in merged_shape.iter_mut().zip(s_shape.iter()).enumerate() {
239            if i == dim {
240                *m += ss;
241            } else if *m != ss {
242                return Err(BackendError::InvalidArgument(format!(
243                    "shard shape mismatch at dim {}: expected {}, got {}",
244                    i, m, ss
245                )));
246            }
247        }
248    }
249
250    let merged_numel: usize = merged_shape.iter().product();
251    let elem_size = dtype.size_for_elements(1);
252    let merged_bytes = dtype.size_for_elements(merged_numel);
253    let mut merged_data = vec![0u8; merged_bytes];
254    let merged_strides = compute_strides(&merged_shape);
255
256    let mut offset_along_dim = 0;
257    for shard in shards {
258        let shard_shape = shard.shape();
259        let shard_size = shard_shape[dim];
260        let shard_numel: usize = shard_shape.iter().product();
261        let shard_data = shard.data();
262
263        for shard_linear in 0..shard_numel {
264            let mut shard_idx = vec![0; shard_shape.len()];
265            let mut rem = shard_linear;
266            for d in (0..shard_shape.len()).rev() {
267                shard_idx[d] = rem % shard_shape[d];
268                rem /= shard_shape[d];
269            }
270            let mut merged_idx = shard_idx.clone();
271            merged_idx[dim] += offset_along_dim;
272            let merged_linear: usize = merged_idx
273                .iter()
274                .zip(merged_strides.iter())
275                .map(|(i, s)| i * s)
276                .sum();
277            let src_off = shard_linear * elem_size;
278            let dst_off = merged_linear * elem_size;
279            merged_data[dst_off..dst_off + elem_size]
280                .copy_from_slice(&shard_data[src_off..src_off + elem_size]);
281        }
282        offset_along_dim += shard_size;
283    }
284
285    Tensor::new(merged_data, merged_shape, dtype)
286        .map_err(|e| BackendError::OperationFailed(format!("{}", e)))
287}
288
289/// No-op tensor parallel for single device (world_size=1)
290pub struct SingleDeviceTP;
291
292impl TensorParallel for SingleDeviceTP {
293    fn world_size(&self) -> usize {
294        1
295    }
296    fn rank(&self) -> usize {
297        0
298    }
299    fn all_reduce_sum(&self, _tensor: &mut Tensor) -> BackendResult<()> {
300        Ok(())
301    }
302    fn all_gather(&self, local: &Tensor, output: &mut Tensor) -> BackendResult<()> {
303        let local_data = local.data();
304        let out_data = output
305            .data_mut()
306            .ok_or_else(|| BackendError::InvalidArgument("output must be mutable".into()))?;
307        let copy_len = local_data.len().min(out_data.len());
308        out_data[..copy_len].copy_from_slice(&local_data[..copy_len]);
309        Ok(())
310    }
311    fn scatter(&self, input: &Tensor, output: &mut Tensor) -> BackendResult<()> {
312        let input_data = input.data();
313        let out_data = output
314            .data_mut()
315            .ok_or_else(|| BackendError::InvalidArgument("output must be mutable".into()))?;
316        let copy_len = input_data.len().min(out_data.len());
317        out_data[..copy_len].copy_from_slice(&input_data[..copy_len]);
318        Ok(())
319    }
320    fn barrier(&self) -> BackendResult<()> {
321        Ok(())
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn test_sharding_plan_valid() {
331        let plan = ShardingPlan::from_config(32, 8, 11008, 2).unwrap();
332        assert_eq!(plan.heads_per_device, 16);
333        assert_eq!(plan.kv_heads_per_device, 4);
334        assert_eq!(plan.ffn_dim_per_device, 5504);
335        assert_eq!(plan.total_heads, 32);
336        assert_eq!(plan.total_kv_heads, 8);
337        assert_eq!(plan.total_ffn_dim, 11008);
338    }
339
340    #[test]
341    fn test_sharding_plan_invalid() {
342        // num_heads not divisible
343        assert!(ShardingPlan::from_config(31, 8, 11008, 2).is_err());
344        // num_kv_heads not divisible
345        assert!(ShardingPlan::from_config(32, 7, 11008, 2).is_err());
346        // ffn_dim not divisible
347        assert!(ShardingPlan::from_config(32, 8, 11007, 2).is_err());
348    }
349
350    #[test]
351    fn test_shard_weight() {
352        // [8, 4] tensor, split along dim=0 into 2 shards of [4, 4]
353        let data: Vec<f32> = (0..32).map(|i| i as f32).collect();
354        let weight = Tensor::from_f32(&data, vec![8, 4]).unwrap();
355
356        let shard0 = shard_weight(&weight, 0, 0, 2).unwrap();
357        assert_eq!(shard0.shape(), &[4, 4]);
358        let s0 = shard0.as_f32().unwrap();
359        assert_eq!(s0, &[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]);
360
361        let shard1 = shard_weight(&weight, 0, 1, 2).unwrap();
362        assert_eq!(shard1.shape(), &[4, 4]);
363        let s1 = shard1.as_f32().unwrap();
364        assert_eq!(s1, &[16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0]);
365    }
366
367    #[test]
368    fn test_merge_shards() {
369        let shard0 = Tensor::from_f32(
370            &[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
371            vec![4, 4],
372        )
373        .unwrap();
374        let shard1 = Tensor::from_f32(
375            &[16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
376            vec![4, 4],
377        )
378        .unwrap();
379
380        let merged = merge_shards(&[shard0, shard1], 0).unwrap();
381        assert_eq!(merged.shape(), &[8, 4]);
382        let m = merged.as_f32().unwrap();
383        let expected: Vec<f32> = (0..32).map(|i| i as f32).collect();
384        assert_eq!(m, expected.as_slice());
385    }
386
387    #[test]
388    fn test_single_device_tp() {
389        let tp = SingleDeviceTP;
390        assert_eq!(tp.world_size(), 1);
391        assert_eq!(tp.rank(), 0);
392
393        let mut tensor = Tensor::from_f32(&[1.0, 2.0, 3.0], vec![3]).unwrap();
394        tp.all_reduce_sum(&mut tensor).unwrap();
395        assert_eq!(tensor.as_f32().unwrap(), &[1.0, 2.0, 3.0]);
396
397        let local = Tensor::from_f32(&[1.0, 2.0], vec![2]).unwrap();
398        let mut output = Tensor::zeros(vec![2], DType::F32);
399        tp.all_gather(&local, &mut output).unwrap();
400        assert_eq!(output.as_f32().unwrap(), &[1.0, 2.0]);
401
402        let input = Tensor::from_f32(&[1.0, 2.0], vec![2]).unwrap();
403        let mut out = Tensor::zeros(vec![2], DType::F32);
404        tp.scatter(&input, &mut out).unwrap();
405        assert_eq!(out.as_f32().unwrap(), &[1.0, 2.0]);
406
407        tp.barrier().unwrap();
408    }
409}