1use crate::backend::{BackendError, BackendResult};
9use crate::tensor::{compute_strides, Tensor};
10pub use crate::tensor::DType;
11
12pub trait TensorParallel: Send + Sync {
14 fn world_size(&self) -> usize;
16
17 fn rank(&self) -> usize;
19
20 fn all_reduce_sum(&self, tensor: &mut Tensor) -> BackendResult<()>;
22
23 fn all_gather(&self, local: &Tensor, output: &mut Tensor) -> BackendResult<()>;
26
27 fn scatter(&self, input: &Tensor, output: &mut Tensor) -> BackendResult<()>;
29
30 fn barrier(&self) -> BackendResult<()>;
32}
33
34#[derive(Debug, Clone)]
36pub struct TPConfig {
37 pub num_devices: usize,
39 pub device_ids: Vec<usize>,
41}
42
43impl Default for TPConfig {
44 fn default() -> Self {
45 Self {
46 num_devices: 1,
47 device_ids: vec![0],
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
54pub struct ShardingPlan {
55 pub heads_per_device: usize,
57 pub kv_heads_per_device: usize,
59 pub ffn_dim_per_device: usize,
61 pub total_heads: usize,
63 pub total_kv_heads: usize,
65 pub total_ffn_dim: usize,
67}
68
69impl ShardingPlan {
70 pub fn from_config(
72 num_heads: usize,
73 num_kv_heads: usize,
74 ffn_dim: usize,
75 world_size: usize,
76 ) -> Result<Self, String> {
77 if num_heads % world_size != 0 {
79 return Err(format!(
80 "num_heads ({}) must be divisible by world_size ({})",
81 num_heads, world_size
82 ));
83 }
84 if num_kv_heads % world_size != 0 {
85 return Err(format!(
86 "num_kv_heads ({}) must be divisible by world_size ({})",
87 num_kv_heads, world_size
88 ));
89 }
90 if ffn_dim % world_size != 0 {
91 return Err(format!(
92 "ffn_dim ({}) must be divisible by world_size ({})",
93 ffn_dim, world_size
94 ));
95 }
96
97 Ok(Self {
98 heads_per_device: num_heads / world_size,
99 kv_heads_per_device: num_kv_heads / world_size,
100 ffn_dim_per_device: ffn_dim / world_size,
101 total_heads: num_heads,
102 total_kv_heads: num_kv_heads,
103 total_ffn_dim: ffn_dim,
104 })
105 }
106}
107
108pub fn shard_weight(
116 weight: &Tensor,
117 dim: usize,
118 rank: usize,
119 world_size: usize,
120) -> Result<Tensor, BackendError> {
121 let shape = weight.shape();
122 if dim >= shape.len() {
123 return Err(BackendError::InvalidArgument(format!(
124 "dim {} out of range for shape {:?}",
125 dim, shape
126 )));
127 }
128 if rank >= world_size {
129 return Err(BackendError::InvalidArgument(format!(
130 "rank {} must be < world_size {}",
131 rank, world_size
132 )));
133 }
134 let dim_size = shape[dim];
135 if dim_size % world_size != 0 {
136 return Err(BackendError::InvalidArgument(format!(
137 "shape[{}] ({}) must be divisible by world_size ({})",
138 dim, dim_size, world_size
139 )));
140 }
141 if !weight.is_contiguous() {
142 return Err(BackendError::InvalidArgument(
143 "weight must be contiguous for sharding".into(),
144 ));
145 }
146 if weight.dtype().is_quantized() {
147 return Err(BackendError::Unsupported(
148 "shard_weight does not support quantized tensors".into(),
149 ));
150 }
151
152 let chunk_size = dim_size / world_size;
153 let start_idx = rank * chunk_size;
154
155 let mut out_shape = shape.to_vec();
157 out_shape[dim] = chunk_size;
158
159 let out_numel: usize = out_shape.iter().product();
160 let elem_size = weight.dtype().size_for_elements(1);
161 let out_bytes = weight.dtype().size_for_elements(out_numel);
162 let mut out_data = vec![0u8; out_bytes];
163
164 let in_strides = weight.strides();
165 let in_data = weight.data();
166
167 for out_linear in 0..out_numel {
169 let mut out_idx = vec![0; out_shape.len()];
171 let mut rem = out_linear;
172 for d in (0..out_shape.len()).rev() {
173 out_idx[d] = rem % out_shape[d];
174 rem /= out_shape[d];
175 }
176 let mut in_idx = out_idx.clone();
178 in_idx[dim] += start_idx;
179 let in_linear: usize = in_idx.iter().zip(in_strides.iter()).map(|(i, s)| i * s).sum();
181 let src_off = in_linear * elem_size;
182 let dst_off = out_linear * elem_size;
183 out_data[dst_off..dst_off + elem_size]
184 .copy_from_slice(&in_data[src_off..src_off + elem_size]);
185 }
186
187 Tensor::new(out_data, out_shape, weight.dtype())
188 .map_err(|e| BackendError::OperationFailed(format!("{}", e)))
189}
190
191pub fn merge_shards(shards: &[Tensor], dim: usize) -> Result<Tensor, BackendError> {
196 if shards.is_empty() {
197 return Err(BackendError::InvalidArgument(
198 "merge_shards requires at least one shard".into(),
199 ));
200 }
201 let dtype = shards[0].dtype();
202 if dtype.is_quantized() {
203 return Err(BackendError::Unsupported(
204 "merge_shards does not support quantized tensors".into(),
205 ));
206 }
207 for s in shards {
208 if s.dtype() != dtype {
209 return Err(BackendError::InvalidArgument(
210 "all shards must have the same dtype".into(),
211 ));
212 }
213 if !s.is_contiguous() {
214 return Err(BackendError::InvalidArgument(
215 "all shards must be contiguous".into(),
216 ));
217 }
218 }
219
220 let first_shape = shards[0].shape();
221 if dim >= first_shape.len() {
222 return Err(BackendError::InvalidArgument(format!(
223 "dim {} out of range for shape {:?}",
224 dim, first_shape
225 )));
226 }
227
228 let mut merged_shape = first_shape.to_vec();
230 merged_shape[dim] = 0;
231 for s in shards {
232 let s_shape = s.shape();
233 if s_shape.len() != merged_shape.len() {
234 return Err(BackendError::InvalidArgument(
235 "all shards must have same number of dimensions".into(),
236 ));
237 }
238 for (i, (m, &ss)) in merged_shape.iter_mut().zip(s_shape.iter()).enumerate() {
239 if i == dim {
240 *m += ss;
241 } else if *m != ss {
242 return Err(BackendError::InvalidArgument(format!(
243 "shard shape mismatch at dim {}: expected {}, got {}",
244 i, m, ss
245 )));
246 }
247 }
248 }
249
250 let merged_numel: usize = merged_shape.iter().product();
251 let elem_size = dtype.size_for_elements(1);
252 let merged_bytes = dtype.size_for_elements(merged_numel);
253 let mut merged_data = vec![0u8; merged_bytes];
254 let merged_strides = compute_strides(&merged_shape);
255
256 let mut offset_along_dim = 0;
257 for shard in shards {
258 let shard_shape = shard.shape();
259 let shard_size = shard_shape[dim];
260 let shard_numel: usize = shard_shape.iter().product();
261 let shard_data = shard.data();
262
263 for shard_linear in 0..shard_numel {
264 let mut shard_idx = vec![0; shard_shape.len()];
265 let mut rem = shard_linear;
266 for d in (0..shard_shape.len()).rev() {
267 shard_idx[d] = rem % shard_shape[d];
268 rem /= shard_shape[d];
269 }
270 let mut merged_idx = shard_idx.clone();
271 merged_idx[dim] += offset_along_dim;
272 let merged_linear: usize = merged_idx
273 .iter()
274 .zip(merged_strides.iter())
275 .map(|(i, s)| i * s)
276 .sum();
277 let src_off = shard_linear * elem_size;
278 let dst_off = merged_linear * elem_size;
279 merged_data[dst_off..dst_off + elem_size]
280 .copy_from_slice(&shard_data[src_off..src_off + elem_size]);
281 }
282 offset_along_dim += shard_size;
283 }
284
285 Tensor::new(merged_data, merged_shape, dtype)
286 .map_err(|e| BackendError::OperationFailed(format!("{}", e)))
287}
288
289pub struct SingleDeviceTP;
291
292impl TensorParallel for SingleDeviceTP {
293 fn world_size(&self) -> usize {
294 1
295 }
296 fn rank(&self) -> usize {
297 0
298 }
299 fn all_reduce_sum(&self, _tensor: &mut Tensor) -> BackendResult<()> {
300 Ok(())
301 }
302 fn all_gather(&self, local: &Tensor, output: &mut Tensor) -> BackendResult<()> {
303 let local_data = local.data();
304 let out_data = output
305 .data_mut()
306 .ok_or_else(|| BackendError::InvalidArgument("output must be mutable".into()))?;
307 let copy_len = local_data.len().min(out_data.len());
308 out_data[..copy_len].copy_from_slice(&local_data[..copy_len]);
309 Ok(())
310 }
311 fn scatter(&self, input: &Tensor, output: &mut Tensor) -> BackendResult<()> {
312 let input_data = input.data();
313 let out_data = output
314 .data_mut()
315 .ok_or_else(|| BackendError::InvalidArgument("output must be mutable".into()))?;
316 let copy_len = input_data.len().min(out_data.len());
317 out_data[..copy_len].copy_from_slice(&input_data[..copy_len]);
318 Ok(())
319 }
320 fn barrier(&self) -> BackendResult<()> {
321 Ok(())
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_sharding_plan_valid() {
331 let plan = ShardingPlan::from_config(32, 8, 11008, 2).unwrap();
332 assert_eq!(plan.heads_per_device, 16);
333 assert_eq!(plan.kv_heads_per_device, 4);
334 assert_eq!(plan.ffn_dim_per_device, 5504);
335 assert_eq!(plan.total_heads, 32);
336 assert_eq!(plan.total_kv_heads, 8);
337 assert_eq!(plan.total_ffn_dim, 11008);
338 }
339
340 #[test]
341 fn test_sharding_plan_invalid() {
342 assert!(ShardingPlan::from_config(31, 8, 11008, 2).is_err());
344 assert!(ShardingPlan::from_config(32, 7, 11008, 2).is_err());
346 assert!(ShardingPlan::from_config(32, 8, 11007, 2).is_err());
348 }
349
350 #[test]
351 fn test_shard_weight() {
352 let data: Vec<f32> = (0..32).map(|i| i as f32).collect();
354 let weight = Tensor::from_f32(&data, vec![8, 4]).unwrap();
355
356 let shard0 = shard_weight(&weight, 0, 0, 2).unwrap();
357 assert_eq!(shard0.shape(), &[4, 4]);
358 let s0 = shard0.as_f32().unwrap();
359 assert_eq!(s0, &[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]);
360
361 let shard1 = shard_weight(&weight, 0, 1, 2).unwrap();
362 assert_eq!(shard1.shape(), &[4, 4]);
363 let s1 = shard1.as_f32().unwrap();
364 assert_eq!(s1, &[16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0]);
365 }
366
367 #[test]
368 fn test_merge_shards() {
369 let shard0 = Tensor::from_f32(
370 &[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
371 vec![4, 4],
372 )
373 .unwrap();
374 let shard1 = Tensor::from_f32(
375 &[16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
376 vec![4, 4],
377 )
378 .unwrap();
379
380 let merged = merge_shards(&[shard0, shard1], 0).unwrap();
381 assert_eq!(merged.shape(), &[8, 4]);
382 let m = merged.as_f32().unwrap();
383 let expected: Vec<f32> = (0..32).map(|i| i as f32).collect();
384 assert_eq!(m, expected.as_slice());
385 }
386
387 #[test]
388 fn test_single_device_tp() {
389 let tp = SingleDeviceTP;
390 assert_eq!(tp.world_size(), 1);
391 assert_eq!(tp.rank(), 0);
392
393 let mut tensor = Tensor::from_f32(&[1.0, 2.0, 3.0], vec![3]).unwrap();
394 tp.all_reduce_sum(&mut tensor).unwrap();
395 assert_eq!(tensor.as_f32().unwrap(), &[1.0, 2.0, 3.0]);
396
397 let local = Tensor::from_f32(&[1.0, 2.0], vec![2]).unwrap();
398 let mut output = Tensor::zeros(vec![2], DType::F32);
399 tp.all_gather(&local, &mut output).unwrap();
400 assert_eq!(output.as_f32().unwrap(), &[1.0, 2.0]);
401
402 let input = Tensor::from_f32(&[1.0, 2.0], vec![2]).unwrap();
403 let mut out = Tensor::zeros(vec![2], DType::F32);
404 tp.scatter(&input, &mut out).unwrap();
405 assert_eq!(out.as_f32().unwrap(), &[1.0, 2.0]);
406
407 tp.barrier().unwrap();
408 }
409}