1use std::collections::HashMap;
38use std::path::{Path, PathBuf};
39use std::sync::{Arc, Mutex};
40use std::thread;
41
42use safetensors::serialize_to_file;
43use safetensors::tensor::{Dtype, SafeTensors, TensorView};
44use serde::{Deserialize, Serialize};
45
46use ferrotorch_core::storage::TensorStorage;
47use ferrotorch_core::{FerrotorchError, Float, Tensor};
48
49#[derive(Debug, thiserror::Error)]
55#[non_exhaustive]
56pub enum DistCheckpointError {
57 #[error("I/O error: {message}")]
58 Io { message: String },
59
60 #[error("serialization error: {message}")]
61 Serialization { message: String },
62
63 #[error("metadata error: {message}")]
64 Metadata { message: String },
65
66 #[error("shard file missing: {path}")]
67 MissingShard { path: String },
68
69 #[error("tensor error: {message}")]
70 Tensor { message: String },
71
72 #[error("invalid argument: {message}")]
73 InvalidArgument { message: String },
74
75 #[error("async checkpoint failed: {message}")]
76 AsyncFailed { message: String },
77}
78
79impl From<DistCheckpointError> for FerrotorchError {
80 fn from(e: DistCheckpointError) -> Self {
81 FerrotorchError::InvalidArgument {
82 message: e.to_string(),
83 }
84 }
85}
86
87impl From<std::io::Error> for DistCheckpointError {
88 fn from(e: std::io::Error) -> Self {
89 DistCheckpointError::Io {
90 message: e.to_string(),
91 }
92 }
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
106#[non_exhaustive]
107pub struct TensorShardSpec {
108 pub full_shape: Vec<usize>,
110 pub shard_dim: usize,
112 pub shard_sizes: Vec<usize>,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
123#[non_exhaustive]
124pub struct ShardMetadata {
125 pub num_ranks: usize,
127 pub tensor_specs: HashMap<String, TensorShardSpec>,
129}
130
131#[non_exhaustive]
140pub struct DistributedCheckpoint {
141 pub checkpoint_dir: PathBuf,
143 pub shard_metadata: ShardMetadata,
145}
146
147fn st_dtype<T: Float>() -> Result<Dtype, DistCheckpointError> {
153 match std::mem::size_of::<T>() {
154 4 => Ok(Dtype::F32),
155 8 => Ok(Dtype::F64),
156 other => Err(DistCheckpointError::InvalidArgument {
157 message: format!("unsupported element size {other} for safetensors serialization"),
158 }),
159 }
160}
161
162fn as_le_bytes<T: Float>(data: &[T]) -> &[u8] {
164 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data)) }
186}
187
188fn shard_path(dir: &Path, rank: usize) -> PathBuf {
190 dir.join(format!("rank_{rank}.safetensors"))
191}
192
193fn metadata_path(dir: &Path) -> PathBuf {
195 dir.join("metadata.json")
196}
197
198fn save_tensors_to_file<T: Float>(
200 tensors: &HashMap<String, Tensor<T>>,
201 path: &Path,
202) -> Result<(), DistCheckpointError> {
203 let dtype = st_dtype::<T>()?;
204
205 let mut keys: Vec<&str> = tensors.keys().map(String::as_str).collect();
206 keys.sort_unstable();
207
208 struct Entry<'a> {
209 name: String,
210 shape: Vec<usize>,
211 data: &'a [u8],
212 }
213
214 let mut entries: Vec<Entry<'_>> = Vec::with_capacity(keys.len());
215 for key in &keys {
216 let tensor = &tensors[*key];
217 let data_slice = tensor.data().map_err(|e| DistCheckpointError::Tensor {
218 message: format!("failed to read tensor \"{key}\": {e}"),
219 })?;
220 entries.push(Entry {
221 name: (*key).to_owned(),
222 shape: tensor.shape().to_vec(),
223 data: as_le_bytes(data_slice),
224 });
225 }
226
227 let views: Vec<(String, TensorView<'_>)> = entries
228 .iter()
229 .map(|entry| {
230 TensorView::new(dtype, entry.shape.clone(), entry.data)
231 .map(|v| (entry.name.clone(), v))
232 .map_err(|e| DistCheckpointError::Serialization {
233 message: format!("TensorView for \"{}\": {e}", entry.name),
234 })
235 })
236 .collect::<Result<Vec<_>, _>>()?;
237
238 serialize_to_file(views, &None, path).map_err(|e| DistCheckpointError::Serialization {
239 message: format!("safetensors write to {}: {e}", path.display()),
240 })?;
241
242 Ok(())
243}
244
245fn load_tensors_from_file<T: Float>(
247 path: &Path,
248) -> Result<HashMap<String, Tensor<T>>, DistCheckpointError> {
249 let elem_size = std::mem::size_of::<T>();
250 let expected = st_dtype::<T>()?;
251
252 let file_data = std::fs::read(path).map_err(|e| DistCheckpointError::Io {
253 message: format!("reading {}: {e}", path.display()),
254 })?;
255
256 let st =
257 SafeTensors::deserialize(&file_data).map_err(|e| DistCheckpointError::Serialization {
258 message: format!("parsing {}: {e}", path.display()),
259 })?;
260
261 let tensor_list = st.tensors();
262 let mut result: HashMap<String, Tensor<T>> = HashMap::with_capacity(tensor_list.len());
263
264 for (name, view) in &tensor_list {
265 if view.dtype() != expected {
266 return Err(DistCheckpointError::Tensor {
267 message: format!(
268 "tensor \"{name}\" has dtype {:?}, expected {:?}",
269 view.dtype(),
270 expected
271 ),
272 });
273 }
274
275 let shape = view.shape().to_vec();
276 let byte_data = view.data();
277 let numel: usize = if shape.is_empty() {
278 1
279 } else {
280 shape.iter().product()
281 };
282 let expected_bytes = numel * elem_size;
283
284 if byte_data.len() != expected_bytes {
285 return Err(DistCheckpointError::Tensor {
286 message: format!(
287 "tensor \"{name}\" has {} bytes but shape {shape:?} requires {expected_bytes}",
288 byte_data.len()
289 ),
290 });
291 }
292
293 let data: Vec<T> = byte_data
294 .chunks_exact(elem_size)
295 .map(|chunk| {
296 let mut bytes = [0u8; 8];
297 bytes[..elem_size].copy_from_slice(chunk);
298 unsafe { std::ptr::read_unaligned(bytes.as_ptr() as *const T) }
325 })
326 .collect();
327
328 let storage = TensorStorage::cpu(data);
329 let tensor = Tensor::from_storage(storage, shape, false).map_err(|e| {
330 DistCheckpointError::Tensor {
331 message: format!("creating tensor \"{name}\": {e}"),
332 }
333 })?;
334 result.insert(name.clone(), tensor);
335 }
336
337 Ok(result)
338}
339
340pub fn save_distributed<T: Float>(
360 state_dict: &HashMap<String, Tensor<T>>,
361 dir: &Path,
362 rank: usize,
363 world_size: usize,
364 shard_spec: &ShardMetadata,
365) -> Result<(), DistCheckpointError> {
366 if world_size == 0 {
368 return Err(DistCheckpointError::InvalidArgument {
369 message: "world_size must be >= 1".into(),
370 });
371 }
372 if rank >= world_size {
373 return Err(DistCheckpointError::InvalidArgument {
374 message: format!("rank {rank} >= world_size {world_size}"),
375 });
376 }
377
378 std::fs::create_dir_all(dir)?;
380
381 let path = shard_path(dir, rank);
383 save_tensors_to_file(state_dict, &path)?;
384
385 if rank == 0 {
387 let json = serde_json::to_string_pretty(shard_spec).map_err(|e| {
388 DistCheckpointError::Serialization {
389 message: format!("serializing metadata: {e}"),
390 }
391 })?;
392 std::fs::write(metadata_path(dir), json)?;
393 }
394
395 Ok(())
396}
397
398pub fn load_distributed<T: Float>(
410 dir: &Path,
411 rank: usize,
412 world_size: usize,
413) -> Result<HashMap<String, Tensor<T>>, DistCheckpointError> {
414 if world_size == 0 {
415 return Err(DistCheckpointError::InvalidArgument {
416 message: "world_size must be >= 1".into(),
417 });
418 }
419 if rank >= world_size {
420 return Err(DistCheckpointError::InvalidArgument {
421 message: format!("rank {rank} >= world_size {world_size}"),
422 });
423 }
424
425 let meta_path = metadata_path(dir);
427 let meta_json = std::fs::read_to_string(&meta_path).map_err(|e| DistCheckpointError::Io {
428 message: format!("reading {}: {e}", meta_path.display()),
429 })?;
430 let metadata: ShardMetadata =
431 serde_json::from_str(&meta_json).map_err(|e| DistCheckpointError::Serialization {
432 message: format!("parsing metadata: {e}"),
433 })?;
434
435 let old_world_size = metadata.num_ranks;
436
437 if old_world_size == world_size {
438 let path = shard_path(dir, rank);
440 if !path.exists() {
441 if metadata.tensor_specs.is_empty() {
444 return Ok(HashMap::new());
445 }
446 return Err(DistCheckpointError::MissingShard {
447 path: path.display().to_string(),
448 });
449 }
450 load_tensors_from_file(&path)
451 } else {
452 reshard(dir, old_world_size, world_size, rank)
454 }
455}
456
457pub fn reshard<T: Float>(
479 dir: &Path,
480 old_world_size: usize,
481 new_world_size: usize,
482 new_rank: usize,
483) -> Result<HashMap<String, Tensor<T>>, DistCheckpointError> {
484 if new_world_size == 0 {
485 return Err(DistCheckpointError::InvalidArgument {
486 message: "new_world_size must be >= 1".into(),
487 });
488 }
489 if new_rank >= new_world_size {
490 return Err(DistCheckpointError::InvalidArgument {
491 message: format!("new_rank {new_rank} >= new_world_size {new_world_size}"),
492 });
493 }
494 if old_world_size == 0 {
495 return Err(DistCheckpointError::InvalidArgument {
496 message: "old_world_size must be >= 1".into(),
497 });
498 }
499
500 let meta_path = metadata_path(dir);
502 let meta_json = std::fs::read_to_string(&meta_path).map_err(|e| DistCheckpointError::Io {
503 message: format!("reading {}: {e}", meta_path.display()),
504 })?;
505 let metadata: ShardMetadata =
506 serde_json::from_str(&meta_json).map_err(|e| DistCheckpointError::Serialization {
507 message: format!("parsing metadata: {e}"),
508 })?;
509
510 let mut old_shards: Vec<HashMap<String, Tensor<T>>> = Vec::with_capacity(old_world_size);
512 for old_rank in 0..old_world_size {
513 let path = shard_path(dir, old_rank);
514 if !path.exists() {
515 return Err(DistCheckpointError::MissingShard {
516 path: path.display().to_string(),
517 });
518 }
519 old_shards.push(load_tensors_from_file(&path)?);
520 }
521
522 let mut result: HashMap<String, Tensor<T>> = HashMap::new();
524
525 for (name, spec) in &metadata.tensor_specs {
526 let shard_dim = spec.shard_dim;
527 let full_shape = &spec.full_shape;
528
529 let mut shard_datas: Vec<Vec<T>> = Vec::with_capacity(old_world_size);
531 let mut shard_shapes: Vec<Vec<usize>> = Vec::with_capacity(old_world_size);
532
533 for (old_rank, shard) in old_shards.iter().enumerate().take(old_world_size) {
534 let tensor = shard.get(name).ok_or_else(|| DistCheckpointError::Tensor {
535 message: format!("tensor \"{name}\" missing from rank {old_rank} shard"),
536 })?;
537 shard_datas.push(tensor.data_vec().map_err(|e| DistCheckpointError::Tensor {
538 message: format!("reading tensor \"{name}\" from rank {old_rank}: {e}"),
539 })?);
540 shard_shapes.push(tensor.shape().to_vec());
541 }
542
543 let full_data = concat_along_dim(&shard_datas, &shard_shapes, shard_dim, full_shape)?;
545
546 let full_dim_size = full_shape[shard_dim];
548 let new_shard_sizes = compute_shard_sizes(full_dim_size, new_world_size);
549 let new_offset: usize = new_shard_sizes[..new_rank].iter().sum();
550 let new_size = new_shard_sizes[new_rank];
551
552 let mut new_shape = full_shape.clone();
554 new_shape[shard_dim] = new_size;
555
556 let new_data = slice_along_dim(&full_data, full_shape, shard_dim, new_offset, new_size);
557
558 let tensor =
559 Tensor::from_storage(TensorStorage::cpu(new_data), new_shape, false).map_err(|e| {
560 DistCheckpointError::Tensor {
561 message: format!("creating resharded tensor \"{name}\": {e}"),
562 }
563 })?;
564
565 result.insert(name.clone(), tensor);
566 }
567
568 Ok(result)
569}
570
571fn compute_shard_sizes(total: usize, num_parts: usize) -> Vec<usize> {
574 let base = total / num_parts;
575 let remainder = total % num_parts;
576 (0..num_parts)
577 .map(|i| if i < remainder { base + 1 } else { base })
578 .collect()
579}
580
581fn concat_along_dim<T: Float>(
587 shard_datas: &[Vec<T>],
588 shard_shapes: &[Vec<usize>],
589 dim: usize,
590 full_shape: &[usize],
591) -> Result<Vec<T>, DistCheckpointError> {
592 let ndim = full_shape.len();
593 if dim >= ndim {
594 return Err(DistCheckpointError::InvalidArgument {
595 message: format!("shard_dim {dim} >= ndim {ndim}"),
596 });
597 }
598
599 let full_numel: usize = full_shape.iter().product();
600 let mut full_data = vec![<T as num_traits::Zero>::zero(); full_numel];
601
602 let outer: usize = full_shape[..dim].iter().product();
608 let inner: usize = full_shape[dim + 1..].iter().product();
609 let full_middle = full_shape[dim];
610
611 let mut dim_offset = 0;
613 for (shard_idx, shard_data) in shard_datas.iter().enumerate() {
614 let shard_middle = shard_shapes[shard_idx][dim];
615
616 for d in 0..ndim {
618 if d != dim && shard_shapes[shard_idx][d] != full_shape[d] {
619 return Err(DistCheckpointError::Tensor {
620 message: format!(
621 "shard {shard_idx} has shape {:?} but expected dim {d} to be {} (full shape {full_shape:?})",
622 shard_shapes[shard_idx], full_shape[d]
623 ),
624 });
625 }
626 }
627
628 for o in 0..outer {
629 let src_start = o * shard_middle * inner;
630 let dst_start = o * full_middle * inner + dim_offset * inner;
631 let count = shard_middle * inner;
632
633 full_data[dst_start..dst_start + count]
634 .copy_from_slice(&shard_data[src_start..src_start + count]);
635 }
636
637 dim_offset += shard_middle;
638 }
639
640 if dim_offset != full_middle {
641 return Err(DistCheckpointError::Tensor {
642 message: format!(
643 "shard sizes along dim {dim} sum to {dim_offset}, expected {full_middle}"
644 ),
645 });
646 }
647
648 Ok(full_data)
649}
650
651fn slice_along_dim<T: Float>(
656 data: &[T],
657 shape: &[usize],
658 dim: usize,
659 offset: usize,
660 size: usize,
661) -> Vec<T> {
662 let outer: usize = shape[..dim].iter().product();
663 let full_middle = shape[dim];
664 let inner: usize = shape[dim + 1..].iter().product();
665
666 let out_numel = outer * size * inner;
667 let mut result = Vec::with_capacity(out_numel);
668
669 for o in 0..outer {
670 let src_start = o * full_middle * inner + offset * inner;
671 let count = size * inner;
672 result.extend_from_slice(&data[src_start..src_start + count]);
673 }
674
675 result
676}
677
678pub struct CheckpointFuture {
687 handle: Option<thread::JoinHandle<Result<(), DistCheckpointError>>>,
688 result: Option<Result<(), DistCheckpointError>>,
690}
691
692impl CheckpointFuture {
693 pub fn wait(&mut self) -> Result<(), DistCheckpointError> {
701 if let Some(handle) = self.handle.take() {
702 let res = handle
703 .join()
704 .map_err(|_| DistCheckpointError::AsyncFailed {
705 message: "background checkpoint thread panicked".into(),
706 })?;
707 self.result = Some(res);
708 }
709
710 match &self.result {
711 Some(Ok(())) => Ok(()),
712 Some(Err(e)) => Err(DistCheckpointError::AsyncFailed {
713 message: format!("{e}"),
714 }),
715 None => Err(DistCheckpointError::AsyncFailed {
716 message: "no checkpoint was started".into(),
717 }),
718 }
719 }
720
721 pub fn is_done(&self) -> bool {
723 if self.result.is_some() {
724 return true;
725 }
726 match &self.handle {
727 Some(h) => h.is_finished(),
728 None => true,
729 }
730 }
731}
732
733pub struct AsyncCheckpointer {
758 dir: PathBuf,
759 rank: usize,
760 world_size: usize,
761 shard_spec: ShardMetadata,
762 in_flight: Arc<Mutex<bool>>,
764}
765
766impl AsyncCheckpointer {
767 pub fn new(dir: PathBuf, rank: usize, world_size: usize, shard_spec: ShardMetadata) -> Self {
774 Self {
775 dir,
776 rank,
777 world_size,
778 shard_spec,
779 in_flight: Arc::new(Mutex::new(false)),
780 }
781 }
782
783 pub fn dir(&self) -> &Path {
785 &self.dir
786 }
787
788 pub fn rank(&self) -> usize {
790 self.rank
791 }
792
793 pub fn world_size(&self) -> usize {
795 self.world_size
796 }
797
798 pub fn save_async(
811 &self,
812 state_dict: &HashMap<String, Tensor<f32>>,
813 ) -> Result<CheckpointFuture, DistCheckpointError> {
814 {
816 let mut guard =
817 self.in_flight
818 .lock()
819 .map_err(|e| DistCheckpointError::AsyncFailed {
820 message: format!("lock poisoned: {e}"),
821 })?;
822 if *guard {
823 return Err(DistCheckpointError::AsyncFailed {
824 message: "another async checkpoint is already in flight".into(),
825 });
826 }
827 *guard = true;
828 }
829
830 let mut staged: HashMap<String, (Vec<f32>, Vec<usize>)> =
833 HashMap::with_capacity(state_dict.len());
834
835 for (name, tensor) in state_dict {
836 let data = tensor.data_vec().map_err(|e| {
837 if let Ok(mut g) = self.in_flight.lock() {
839 *g = false;
840 }
841 DistCheckpointError::Tensor {
842 message: format!("staging tensor \"{name}\": {e}"),
843 }
844 })?;
845 let shape = tensor.shape().to_vec();
846 staged.insert(name.clone(), (data, shape));
847 }
848
849 let dir = self.dir.clone();
851 let rank = self.rank;
852 let shard_spec = self.shard_spec.clone();
853 let in_flight = Arc::clone(&self.in_flight);
854
855 let handle = thread::spawn(move || {
856 let result = (|| -> Result<(), DistCheckpointError> {
857 let mut tensors: HashMap<String, Tensor<f32>> =
859 HashMap::with_capacity(staged.len());
860 for (name, (data, shape)) in staged {
861 let tensor = Tensor::from_storage(TensorStorage::cpu(data), shape, false)
862 .map_err(|e| DistCheckpointError::Tensor {
863 message: format!("rebuilding tensor \"{name}\": {e}"),
864 })?;
865 tensors.insert(name, tensor);
866 }
867
868 std::fs::create_dir_all(&dir)?;
870 let path = shard_path(&dir, rank);
871 save_tensors_to_file(&tensors, &path)?;
872
873 if rank == 0 {
875 let json = serde_json::to_string_pretty(&shard_spec).map_err(|e| {
876 DistCheckpointError::Serialization {
877 message: format!("serializing metadata: {e}"),
878 }
879 })?;
880 std::fs::write(metadata_path(&dir), json)?;
881 }
882
883 Ok(())
884 })();
885
886 if let Ok(mut g) = in_flight.lock() {
888 *g = false;
889 }
890
891 result
892 });
893
894 Ok(CheckpointFuture {
895 handle: Some(handle),
896 result: None,
897 })
898 }
899}
900
901pub fn flat_shard_metadata(
911 state_dict: &HashMap<String, Tensor<f32>>,
912 world_size: usize,
913) -> ShardMetadata {
914 let mut tensor_specs = HashMap::new();
915 for (name, tensor) in state_dict {
916 let shape = tensor.shape();
917 let shard_numel = shape.iter().product::<usize>();
920 let full_numel = shard_numel * world_size;
921 let shard_sizes = vec![shard_numel; world_size];
922 tensor_specs.insert(
923 name.clone(),
924 TensorShardSpec {
925 full_shape: vec![full_numel],
926 shard_dim: 0,
927 shard_sizes,
928 },
929 );
930 }
931 ShardMetadata {
932 num_ranks: world_size,
933 tensor_specs,
934 }
935}
936
937#[cfg(test)]
942mod tests {
943 use super::*;
944 use ferrotorch_core::Tensor;
945 use ferrotorch_core::storage::TensorStorage;
946 use std::collections::HashMap;
947
948 fn make_tensor(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
949 Tensor::from_storage(TensorStorage::cpu(data), shape, false).unwrap()
950 }
951
952 fn temp_dir(name: &str) -> PathBuf {
953 std::env::temp_dir()
954 .join("ferrotorch_test_dist_ckpt")
955 .join(name)
956 }
957
958 fn cleanup(dir: &Path) {
959 let _ = std::fs::remove_dir_all(dir);
960 }
961
962 #[test]
965 fn test_save_load_single_rank() {
966 let dir = temp_dir("single_rank");
967 cleanup(&dir);
968
969 let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
970 state.insert(
971 "weight".into(),
972 make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]),
973 );
974 state.insert("bias".into(), make_tensor(vec![0.1, 0.2], vec![2]));
975
976 let spec = ShardMetadata {
977 num_ranks: 1,
978 tensor_specs: {
979 let mut m = HashMap::new();
980 m.insert(
981 "weight".into(),
982 TensorShardSpec {
983 full_shape: vec![4],
984 shard_dim: 0,
985 shard_sizes: vec![4],
986 },
987 );
988 m.insert(
989 "bias".into(),
990 TensorShardSpec {
991 full_shape: vec![2],
992 shard_dim: 0,
993 shard_sizes: vec![2],
994 },
995 );
996 m
997 },
998 };
999
1000 save_distributed(&state, &dir, 0, 1, &spec).unwrap();
1001 let loaded = load_distributed::<f32>(&dir, 0, 1).unwrap();
1002
1003 assert_eq!(loaded.len(), 2);
1004 assert_eq!(loaded["weight"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
1005 assert_eq!(loaded["bias"].data().unwrap(), &[0.1, 0.2]);
1006
1007 cleanup(&dir);
1008 }
1009
1010 #[test]
1011 fn test_save_load_two_ranks() {
1012 let dir = temp_dir("two_ranks");
1013 cleanup(&dir);
1014
1015 let mut state0: HashMap<String, Tensor<f32>> = HashMap::new();
1017 state0.insert("weight".into(), make_tensor(vec![1.0, 2.0], vec![2]));
1018
1019 let mut state1: HashMap<String, Tensor<f32>> = HashMap::new();
1021 state1.insert("weight".into(), make_tensor(vec![3.0, 4.0], vec![2]));
1022
1023 let spec = ShardMetadata {
1024 num_ranks: 2,
1025 tensor_specs: {
1026 let mut m = HashMap::new();
1027 m.insert(
1028 "weight".into(),
1029 TensorShardSpec {
1030 full_shape: vec![4],
1031 shard_dim: 0,
1032 shard_sizes: vec![2, 2],
1033 },
1034 );
1035 m
1036 },
1037 };
1038
1039 save_distributed(&state0, &dir, 0, 2, &spec).unwrap();
1040 save_distributed(&state1, &dir, 1, 2, &spec).unwrap();
1041
1042 let loaded0 = load_distributed::<f32>(&dir, 0, 2).unwrap();
1044 let loaded1 = load_distributed::<f32>(&dir, 1, 2).unwrap();
1045
1046 assert_eq!(loaded0["weight"].data().unwrap(), &[1.0, 2.0]);
1047 assert_eq!(loaded1["weight"].data().unwrap(), &[3.0, 4.0]);
1048
1049 cleanup(&dir);
1050 }
1051
1052 #[test]
1055 fn test_reshard_2_to_4() {
1056 let dir = temp_dir("reshard_2_to_4");
1058 cleanup(&dir);
1059
1060 let mut state0: HashMap<String, Tensor<f32>> = HashMap::new();
1063 state0.insert("w".into(), make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]));
1064
1065 let mut state1: HashMap<String, Tensor<f32>> = HashMap::new();
1066 state1.insert("w".into(), make_tensor(vec![5.0, 6.0, 7.0, 8.0], vec![4]));
1067
1068 let spec = ShardMetadata {
1069 num_ranks: 2,
1070 tensor_specs: {
1071 let mut m = HashMap::new();
1072 m.insert(
1073 "w".into(),
1074 TensorShardSpec {
1075 full_shape: vec![8],
1076 shard_dim: 0,
1077 shard_sizes: vec![4, 4],
1078 },
1079 );
1080 m
1081 },
1082 };
1083
1084 save_distributed(&state0, &dir, 0, 2, &spec).unwrap();
1085 save_distributed(&state1, &dir, 1, 2, &spec).unwrap();
1086
1087 let r0 = reshard::<f32>(&dir, 2, 4, 0).unwrap();
1089 let r1 = reshard::<f32>(&dir, 2, 4, 1).unwrap();
1090 let r2 = reshard::<f32>(&dir, 2, 4, 2).unwrap();
1091 let r3 = reshard::<f32>(&dir, 2, 4, 3).unwrap();
1092
1093 assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0]);
1094 assert_eq!(r1["w"].data().unwrap(), &[3.0, 4.0]);
1095 assert_eq!(r2["w"].data().unwrap(), &[5.0, 6.0]);
1096 assert_eq!(r3["w"].data().unwrap(), &[7.0, 8.0]);
1097
1098 cleanup(&dir);
1099 }
1100
1101 #[test]
1102 fn test_reshard_4_to_2() {
1103 let dir = temp_dir("reshard_4_to_2");
1105 cleanup(&dir);
1106
1107 let spec = ShardMetadata {
1108 num_ranks: 4,
1109 tensor_specs: {
1110 let mut m = HashMap::new();
1111 m.insert(
1112 "w".into(),
1113 TensorShardSpec {
1114 full_shape: vec![8],
1115 shard_dim: 0,
1116 shard_sizes: vec![2, 2, 2, 2],
1117 },
1118 );
1119 m
1120 },
1121 };
1122
1123 for rank in 0..4 {
1124 let start = rank as f32 * 2.0 + 1.0;
1125 let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1126 state.insert("w".into(), make_tensor(vec![start, start + 1.0], vec![2]));
1127 save_distributed(&state, &dir, rank, 4, &spec).unwrap();
1128 }
1129
1130 let r0 = reshard::<f32>(&dir, 4, 2, 0).unwrap();
1132 let r1 = reshard::<f32>(&dir, 4, 2, 1).unwrap();
1133
1134 assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
1135 assert_eq!(r1["w"].data().unwrap(), &[5.0, 6.0, 7.0, 8.0]);
1136
1137 cleanup(&dir);
1138 }
1139
1140 #[test]
1141 fn test_reshard_2d_tensor() {
1142 let dir = temp_dir("reshard_2d");
1145 cleanup(&dir);
1146
1147 let mut state0: HashMap<String, Tensor<f32>> = HashMap::new();
1148 state0.insert(
1149 "w".into(),
1150 make_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]),
1151 );
1152
1153 let mut state1: HashMap<String, Tensor<f32>> = HashMap::new();
1154 state1.insert(
1155 "w".into(),
1156 make_tensor(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], vec![2, 3]),
1157 );
1158
1159 let spec = ShardMetadata {
1160 num_ranks: 2,
1161 tensor_specs: {
1162 let mut m = HashMap::new();
1163 m.insert(
1164 "w".into(),
1165 TensorShardSpec {
1166 full_shape: vec![4, 3],
1167 shard_dim: 0,
1168 shard_sizes: vec![2, 2],
1169 },
1170 );
1171 m
1172 },
1173 };
1174
1175 save_distributed(&state0, &dir, 0, 2, &spec).unwrap();
1176 save_distributed(&state1, &dir, 1, 2, &spec).unwrap();
1177
1178 let r0 = reshard::<f32>(&dir, 2, 4, 0).unwrap();
1180 let r1 = reshard::<f32>(&dir, 2, 4, 1).unwrap();
1181 let r2 = reshard::<f32>(&dir, 2, 4, 2).unwrap();
1182 let r3 = reshard::<f32>(&dir, 2, 4, 3).unwrap();
1183
1184 assert_eq!(r0["w"].shape(), &[1, 3]);
1185 assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0, 3.0]);
1186 assert_eq!(r1["w"].shape(), &[1, 3]);
1187 assert_eq!(r1["w"].data().unwrap(), &[4.0, 5.0, 6.0]);
1188 assert_eq!(r2["w"].shape(), &[1, 3]);
1189 assert_eq!(r2["w"].data().unwrap(), &[7.0, 8.0, 9.0]);
1190 assert_eq!(r3["w"].shape(), &[1, 3]);
1191 assert_eq!(r3["w"].data().unwrap(), &[10.0, 11.0, 12.0]);
1192
1193 cleanup(&dir);
1194 }
1195
1196 #[test]
1197 fn test_reshard_dim1() {
1198 let dir = temp_dir("reshard_dim1");
1201 cleanup(&dir);
1202
1203 let mut state0: HashMap<String, Tensor<f32>> = HashMap::new();
1204 state0.insert(
1205 "w".into(),
1206 make_tensor(vec![1.0, 2.0, 5.0, 6.0], vec![2, 2]),
1207 );
1208
1209 let mut state1: HashMap<String, Tensor<f32>> = HashMap::new();
1210 state1.insert(
1211 "w".into(),
1212 make_tensor(vec![3.0, 4.0, 7.0, 8.0], vec![2, 2]),
1213 );
1214
1215 let spec = ShardMetadata {
1216 num_ranks: 2,
1217 tensor_specs: {
1218 let mut m = HashMap::new();
1219 m.insert(
1220 "w".into(),
1221 TensorShardSpec {
1222 full_shape: vec![2, 4],
1223 shard_dim: 1,
1224 shard_sizes: vec![2, 2],
1225 },
1226 );
1227 m
1228 },
1229 };
1230
1231 save_distributed(&state0, &dir, 0, 2, &spec).unwrap();
1232 save_distributed(&state1, &dir, 1, 2, &spec).unwrap();
1233
1234 let r0 = reshard::<f32>(&dir, 2, 4, 0).unwrap();
1236 let r1 = reshard::<f32>(&dir, 2, 4, 1).unwrap();
1237 let r2 = reshard::<f32>(&dir, 2, 4, 2).unwrap();
1238 let r3 = reshard::<f32>(&dir, 2, 4, 3).unwrap();
1239
1240 assert_eq!(r0["w"].shape(), &[2, 1]);
1241 assert_eq!(r0["w"].data().unwrap(), &[1.0, 5.0]);
1242 assert_eq!(r1["w"].shape(), &[2, 1]);
1243 assert_eq!(r1["w"].data().unwrap(), &[2.0, 6.0]);
1244 assert_eq!(r2["w"].shape(), &[2, 1]);
1245 assert_eq!(r2["w"].data().unwrap(), &[3.0, 7.0]);
1246 assert_eq!(r3["w"].shape(), &[2, 1]);
1247 assert_eq!(r3["w"].data().unwrap(), &[4.0, 8.0]);
1248
1249 cleanup(&dir);
1250 }
1251
1252 #[test]
1253 fn test_reshard_3_to_2_uneven() {
1254 let dir = temp_dir("reshard_3_to_2");
1257 cleanup(&dir);
1258
1259 let spec = ShardMetadata {
1260 num_ranks: 3,
1261 tensor_specs: {
1262 let mut m = HashMap::new();
1263 m.insert(
1264 "w".into(),
1265 TensorShardSpec {
1266 full_shape: vec![9],
1267 shard_dim: 0,
1268 shard_sizes: vec![3, 3, 3],
1269 },
1270 );
1271 m
1272 },
1273 };
1274
1275 for rank in 0..3usize {
1276 let start = rank as f32 * 3.0 + 1.0;
1277 let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1278 state.insert(
1279 "w".into(),
1280 make_tensor(vec![start, start + 1.0, start + 2.0], vec![3]),
1281 );
1282 save_distributed(&state, &dir, rank, 3, &spec).unwrap();
1283 }
1284
1285 let r0 = reshard::<f32>(&dir, 3, 2, 0).unwrap();
1287 let r1 = reshard::<f32>(&dir, 3, 2, 1).unwrap();
1288
1289 assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
1290 assert_eq!(r1["w"].data().unwrap(), &[6.0, 7.0, 8.0, 9.0]);
1291
1292 cleanup(&dir);
1293 }
1294
1295 #[test]
1298 fn test_load_distributed_reshards_when_world_size_differs() {
1299 let dir = temp_dir("load_reshard");
1300 cleanup(&dir);
1301
1302 let spec = ShardMetadata {
1304 num_ranks: 2,
1305 tensor_specs: {
1306 let mut m = HashMap::new();
1307 m.insert(
1308 "w".into(),
1309 TensorShardSpec {
1310 full_shape: vec![4],
1311 shard_dim: 0,
1312 shard_sizes: vec![2, 2],
1313 },
1314 );
1315 m
1316 },
1317 };
1318
1319 let mut s0: HashMap<String, Tensor<f32>> = HashMap::new();
1320 s0.insert("w".into(), make_tensor(vec![1.0, 2.0], vec![2]));
1321 save_distributed(&s0, &dir, 0, 2, &spec).unwrap();
1322
1323 let mut s1: HashMap<String, Tensor<f32>> = HashMap::new();
1324 s1.insert("w".into(), make_tensor(vec![3.0, 4.0], vec![2]));
1325 save_distributed(&s1, &dir, 1, 2, &spec).unwrap();
1326
1327 let r0 = load_distributed::<f32>(&dir, 0, 4).unwrap();
1329 let r1 = load_distributed::<f32>(&dir, 1, 4).unwrap();
1330 let r2 = load_distributed::<f32>(&dir, 2, 4).unwrap();
1331 let r3 = load_distributed::<f32>(&dir, 3, 4).unwrap();
1332
1333 assert_eq!(r0["w"].data().unwrap(), &[1.0]);
1334 assert_eq!(r1["w"].data().unwrap(), &[2.0]);
1335 assert_eq!(r2["w"].data().unwrap(), &[3.0]);
1336 assert_eq!(r3["w"].data().unwrap(), &[4.0]);
1337
1338 cleanup(&dir);
1339 }
1340
1341 #[test]
1344 fn test_metadata_roundtrip() {
1345 let spec = ShardMetadata {
1346 num_ranks: 4,
1347 tensor_specs: {
1348 let mut m = HashMap::new();
1349 m.insert(
1350 "layer.weight".into(),
1351 TensorShardSpec {
1352 full_shape: vec![256, 512],
1353 shard_dim: 0,
1354 shard_sizes: vec![64, 64, 64, 64],
1355 },
1356 );
1357 m.insert(
1358 "layer.bias".into(),
1359 TensorShardSpec {
1360 full_shape: vec![256],
1361 shard_dim: 0,
1362 shard_sizes: vec![64, 64, 64, 64],
1363 },
1364 );
1365 m
1366 },
1367 };
1368
1369 let json = serde_json::to_string_pretty(&spec).unwrap();
1370 let loaded: ShardMetadata = serde_json::from_str(&json).unwrap();
1371
1372 assert_eq!(loaded.num_ranks, 4);
1373 assert_eq!(loaded.tensor_specs.len(), 2);
1374 assert_eq!(
1375 loaded.tensor_specs["layer.weight"].full_shape,
1376 vec![256, 512]
1377 );
1378 assert_eq!(loaded.tensor_specs["layer.weight"].shard_dim, 0);
1379 assert_eq!(
1380 loaded.tensor_specs["layer.weight"].shard_sizes,
1381 vec![64, 64, 64, 64]
1382 );
1383 }
1384
1385 #[test]
1388 fn test_compute_shard_sizes_even() {
1389 assert_eq!(compute_shard_sizes(8, 4), vec![2, 2, 2, 2]);
1390 assert_eq!(compute_shard_sizes(12, 3), vec![4, 4, 4]);
1391 }
1392
1393 #[test]
1394 fn test_compute_shard_sizes_uneven() {
1395 assert_eq!(compute_shard_sizes(9, 2), vec![5, 4]);
1397 assert_eq!(compute_shard_sizes(10, 3), vec![4, 3, 3]);
1399 assert_eq!(compute_shard_sizes(7, 4), vec![2, 2, 2, 1]);
1401 }
1402
1403 #[test]
1406 fn test_concat_1d() {
1407 let data0 = vec![1.0f32, 2.0];
1408 let data1 = vec![3.0f32, 4.0, 5.0];
1409 let full_shape = vec![5];
1410
1411 let result =
1412 concat_along_dim(&[data0, data1], &[vec![2], vec![3]], 0, &full_shape).unwrap();
1413
1414 assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1415 }
1416
1417 #[test]
1418 fn test_concat_2d_dim0() {
1419 let data0 = vec![1.0f32, 2.0, 3.0];
1421 let data1 = vec![4.0f32, 5.0, 6.0];
1422 let full_shape = vec![2, 3];
1423
1424 let result =
1425 concat_along_dim(&[data0, data1], &[vec![1, 3], vec![1, 3]], 0, &full_shape).unwrap();
1426
1427 assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1428 }
1429
1430 #[test]
1431 fn test_concat_2d_dim1() {
1432 let data0 = vec![1.0f32, 3.0];
1435 let data1 = vec![2.0f32, 4.0];
1436 let full_shape = vec![2, 2];
1437
1438 let result =
1439 concat_along_dim(&[data0, data1], &[vec![2, 1], vec![2, 1]], 1, &full_shape).unwrap();
1440
1441 assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0]);
1442 }
1443
1444 #[test]
1447 fn test_slice_1d() {
1448 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
1449 let shape = vec![5];
1450
1451 let s0 = slice_along_dim(&data, &shape, 0, 0, 2);
1452 assert_eq!(s0, vec![1.0, 2.0]);
1453
1454 let s1 = slice_along_dim(&data, &shape, 0, 2, 3);
1455 assert_eq!(s1, vec![3.0, 4.0, 5.0]);
1456 }
1457
1458 #[test]
1459 fn test_slice_2d_dim0() {
1460 let data = vec![
1462 1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1463 ];
1464 let shape = vec![4, 3];
1465
1466 let s = slice_along_dim(&data, &shape, 0, 1, 2);
1467 assert_eq!(s, vec![4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
1468 }
1469
1470 #[test]
1471 fn test_slice_2d_dim1() {
1472 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1474 let shape = vec![2, 4];
1475
1476 let s = slice_along_dim(&data, &shape, 1, 1, 2);
1477 assert_eq!(s, vec![2.0, 3.0, 6.0, 7.0]);
1478 }
1479
1480 #[test]
1483 fn test_flat_shard_metadata() {
1484 let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1485 state.insert("w".into(), make_tensor(vec![1.0, 2.0, 3.0], vec![3]));
1486
1487 let meta = flat_shard_metadata(&state, 4);
1488 assert_eq!(meta.num_ranks, 4);
1489
1490 let spec = &meta.tensor_specs["w"];
1491 assert_eq!(spec.full_shape, vec![12]); assert_eq!(spec.shard_dim, 0);
1493 assert_eq!(spec.shard_sizes, vec![3, 3, 3, 3]);
1494 }
1495
1496 #[test]
1499 fn test_async_checkpoint_basic() {
1500 let dir = temp_dir("async_basic");
1501 cleanup(&dir);
1502
1503 let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1504 state.insert("w".into(), make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]));
1505
1506 let spec = ShardMetadata {
1507 num_ranks: 1,
1508 tensor_specs: {
1509 let mut m = HashMap::new();
1510 m.insert(
1511 "w".into(),
1512 TensorShardSpec {
1513 full_shape: vec![4],
1514 shard_dim: 0,
1515 shard_sizes: vec![4],
1516 },
1517 );
1518 m
1519 },
1520 };
1521
1522 let ckpt = AsyncCheckpointer::new(dir.clone(), 0, 1, spec);
1523 let mut future = ckpt.save_async(&state).unwrap();
1524 future.wait().unwrap();
1525
1526 let loaded = load_distributed::<f32>(&dir, 0, 1).unwrap();
1528 assert_eq!(loaded["w"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
1529
1530 cleanup(&dir);
1531 }
1532
1533 #[test]
1534 fn test_async_checkpoint_wait_idempotent() {
1535 let dir = temp_dir("async_idempotent");
1536 cleanup(&dir);
1537
1538 let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1539 state.insert("x".into(), make_tensor(vec![42.0], vec![1]));
1540
1541 let spec = ShardMetadata {
1542 num_ranks: 1,
1543 tensor_specs: {
1544 let mut m = HashMap::new();
1545 m.insert(
1546 "x".into(),
1547 TensorShardSpec {
1548 full_shape: vec![1],
1549 shard_dim: 0,
1550 shard_sizes: vec![1],
1551 },
1552 );
1553 m
1554 },
1555 };
1556
1557 let ckpt = AsyncCheckpointer::new(dir.clone(), 0, 1, spec);
1558 let mut future = ckpt.save_async(&state).unwrap();
1559
1560 future.wait().unwrap();
1562 future.wait().unwrap();
1563
1564 cleanup(&dir);
1565 }
1566
1567 #[test]
1568 fn test_async_checkpoint_is_done() {
1569 let dir = temp_dir("async_is_done");
1570 cleanup(&dir);
1571
1572 let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1573 state.insert("x".into(), make_tensor(vec![1.0], vec![1]));
1574
1575 let spec = ShardMetadata {
1576 num_ranks: 1,
1577 tensor_specs: {
1578 let mut m = HashMap::new();
1579 m.insert(
1580 "x".into(),
1581 TensorShardSpec {
1582 full_shape: vec![1],
1583 shard_dim: 0,
1584 shard_sizes: vec![1],
1585 },
1586 );
1587 m
1588 },
1589 };
1590
1591 let ckpt = AsyncCheckpointer::new(dir.clone(), 0, 1, spec);
1592 let mut future = ckpt.save_async(&state).unwrap();
1593 future.wait().unwrap();
1594 assert!(future.is_done());
1595
1596 cleanup(&dir);
1597 }
1598
1599 #[test]
1602 fn test_save_invalid_rank() {
1603 let dir = temp_dir("invalid_rank");
1604 let state: HashMap<String, Tensor<f32>> = HashMap::new();
1605 let spec = ShardMetadata {
1606 num_ranks: 2,
1607 tensor_specs: HashMap::new(),
1608 };
1609
1610 let result = save_distributed(&state, &dir, 5, 2, &spec);
1611 assert!(result.is_err());
1612 }
1613
1614 #[test]
1615 fn test_load_missing_metadata() {
1616 let dir = temp_dir("missing_meta");
1617 cleanup(&dir);
1618 std::fs::create_dir_all(&dir).unwrap();
1619
1620 let result = load_distributed::<f32>(&dir, 0, 1);
1621 assert!(result.is_err());
1622
1623 cleanup(&dir);
1624 }
1625
1626 #[test]
1627 fn test_load_missing_shard() {
1628 let dir = temp_dir("missing_shard");
1629 cleanup(&dir);
1630 std::fs::create_dir_all(&dir).unwrap();
1631
1632 let spec = ShardMetadata {
1634 num_ranks: 1,
1635 tensor_specs: HashMap::new(),
1636 };
1637 let json = serde_json::to_string_pretty(&spec).unwrap();
1638 std::fs::write(metadata_path(&dir), json).unwrap();
1639
1640 let loaded = load_distributed::<f32>(&dir, 0, 1).unwrap();
1644 assert!(loaded.is_empty());
1645
1646 cleanup(&dir);
1647 }
1648
1649 #[test]
1652 fn test_reshard_multiple_tensors() {
1653 let dir = temp_dir("reshard_multi");
1654 cleanup(&dir);
1655
1656 let spec = ShardMetadata {
1657 num_ranks: 2,
1658 tensor_specs: {
1659 let mut m = HashMap::new();
1660 m.insert(
1661 "weight".into(),
1662 TensorShardSpec {
1663 full_shape: vec![4],
1664 shard_dim: 0,
1665 shard_sizes: vec![2, 2],
1666 },
1667 );
1668 m.insert(
1669 "bias".into(),
1670 TensorShardSpec {
1671 full_shape: vec![6],
1672 shard_dim: 0,
1673 shard_sizes: vec![3, 3],
1674 },
1675 );
1676 m
1677 },
1678 };
1679
1680 let mut s0: HashMap<String, Tensor<f32>> = HashMap::new();
1681 s0.insert("weight".into(), make_tensor(vec![1.0, 2.0], vec![2]));
1682 s0.insert("bias".into(), make_tensor(vec![10.0, 20.0, 30.0], vec![3]));
1683
1684 let mut s1: HashMap<String, Tensor<f32>> = HashMap::new();
1685 s1.insert("weight".into(), make_tensor(vec![3.0, 4.0], vec![2]));
1686 s1.insert("bias".into(), make_tensor(vec![40.0, 50.0, 60.0], vec![3]));
1687
1688 save_distributed(&s0, &dir, 0, 2, &spec).unwrap();
1689 save_distributed(&s1, &dir, 1, 2, &spec).unwrap();
1690
1691 let r0 = load_distributed::<f32>(&dir, 0, 1).unwrap();
1693
1694 assert_eq!(r0["weight"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
1695 assert_eq!(
1696 r0["bias"].data().unwrap(),
1697 &[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]
1698 );
1699
1700 cleanup(&dir);
1701 }
1702
1703 #[test]
1706 fn test_distributed_checkpoint_struct() {
1707 let ckpt = DistributedCheckpoint {
1708 checkpoint_dir: PathBuf::from("/tmp/test"),
1709 shard_metadata: ShardMetadata {
1710 num_ranks: 2,
1711 tensor_specs: HashMap::new(),
1712 },
1713 };
1714 assert_eq!(ckpt.checkpoint_dir, PathBuf::from("/tmp/test"));
1715 assert_eq!(ckpt.shard_metadata.num_ranks, 2);
1716 }
1717
1718 #[test]
1720 fn test_reshard_same_world_size() {
1721 let dir = temp_dir("reshard_same");
1722 cleanup(&dir);
1723
1724 let spec = ShardMetadata {
1725 num_ranks: 2,
1726 tensor_specs: {
1727 let mut m = HashMap::new();
1728 m.insert(
1729 "w".into(),
1730 TensorShardSpec {
1731 full_shape: vec![4],
1732 shard_dim: 0,
1733 shard_sizes: vec![2, 2],
1734 },
1735 );
1736 m
1737 },
1738 };
1739
1740 let mut s0: HashMap<String, Tensor<f32>> = HashMap::new();
1741 s0.insert("w".into(), make_tensor(vec![1.0, 2.0], vec![2]));
1742 save_distributed(&s0, &dir, 0, 2, &spec).unwrap();
1743
1744 let mut s1: HashMap<String, Tensor<f32>> = HashMap::new();
1745 s1.insert("w".into(), make_tensor(vec![3.0, 4.0], vec![2]));
1746 save_distributed(&s1, &dir, 1, 2, &spec).unwrap();
1747
1748 let r0 = reshard::<f32>(&dir, 2, 2, 0).unwrap();
1750 let r1 = reshard::<f32>(&dir, 2, 2, 1).unwrap();
1751
1752 assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0]);
1753 assert_eq!(r1["w"].data().unwrap(), &[3.0, 4.0]);
1754
1755 cleanup(&dir);
1756 }
1757}