1use std::collections::HashMap;
17use std::path::{Path, PathBuf};
18use std::sync::{Arc, Mutex};
19use std::thread;
20
21use safetensors::serialize_to_file;
22use safetensors::tensor::{Dtype, SafeTensors, TensorView};
23use serde::{Deserialize, Serialize};
24
25use ferrotorch_core::storage::TensorStorage;
26use ferrotorch_core::{FerrotorchError, Float, Tensor};
27
28#[derive(Debug, thiserror::Error)]
34#[non_exhaustive]
35pub enum DistCheckpointError {
36 #[error("I/O error: {message}")]
37 Io { message: String },
38
39 #[error("serialization error: {message}")]
40 Serialization { message: String },
41
42 #[error("metadata error: {message}")]
43 Metadata { message: String },
44
45 #[error("shard file missing: {path}")]
46 MissingShard { path: String },
47
48 #[error("tensor error: {message}")]
49 Tensor { message: String },
50
51 #[error("invalid argument: {message}")]
52 InvalidArgument { message: String },
53
54 #[error("async checkpoint failed: {message}")]
55 AsyncFailed { message: String },
56}
57
58impl From<DistCheckpointError> for FerrotorchError {
59 fn from(e: DistCheckpointError) -> Self {
60 FerrotorchError::InvalidArgument {
61 message: e.to_string(),
62 }
63 }
64}
65
66impl From<std::io::Error> for DistCheckpointError {
67 fn from(e: std::io::Error) -> Self {
68 DistCheckpointError::Io {
69 message: e.to_string(),
70 }
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct TensorShardSpec {
81 pub full_shape: Vec<usize>,
83 pub shard_dim: usize,
85 pub shard_sizes: Vec<usize>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ShardMetadata {
94 pub num_ranks: usize,
96 pub tensor_specs: HashMap<String, TensorShardSpec>,
98}
99
100pub struct DistributedCheckpoint {
106 pub checkpoint_dir: PathBuf,
108 pub shard_metadata: ShardMetadata,
110}
111
112fn st_dtype<T: Float>() -> Result<Dtype, DistCheckpointError> {
118 match std::mem::size_of::<T>() {
119 4 => Ok(Dtype::F32),
120 8 => Ok(Dtype::F64),
121 other => Err(DistCheckpointError::InvalidArgument {
122 message: format!("unsupported element size {other} for safetensors serialization"),
123 }),
124 }
125}
126
127fn as_le_bytes<T: Float>(data: &[T]) -> &[u8] {
129 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data)) }
130}
131
132fn shard_path(dir: &Path, rank: usize) -> PathBuf {
134 dir.join(format!("rank_{rank}.safetensors"))
135}
136
137fn metadata_path(dir: &Path) -> PathBuf {
139 dir.join("metadata.json")
140}
141
142fn save_tensors_to_file<T: Float>(
144 tensors: &HashMap<String, Tensor<T>>,
145 path: &Path,
146) -> Result<(), DistCheckpointError> {
147 let dtype = st_dtype::<T>()?;
148
149 let mut keys: Vec<&String> = tensors.keys().collect();
150 keys.sort();
151
152 struct Entry<'a> {
153 name: String,
154 shape: Vec<usize>,
155 data: &'a [u8],
156 }
157
158 let mut entries: Vec<Entry<'_>> = Vec::with_capacity(keys.len());
159 for key in &keys {
160 let tensor = &tensors[*key];
161 let data_slice = tensor.data().map_err(|e| DistCheckpointError::Tensor {
162 message: format!("failed to read tensor \"{key}\": {e}"),
163 })?;
164 entries.push(Entry {
165 name: (*key).clone(),
166 shape: tensor.shape().to_vec(),
167 data: as_le_bytes(data_slice),
168 });
169 }
170
171 let views: Vec<(String, TensorView<'_>)> = entries
172 .iter()
173 .map(|entry| {
174 TensorView::new(dtype, entry.shape.clone(), entry.data)
175 .map(|v| (entry.name.clone(), v))
176 .map_err(|e| DistCheckpointError::Serialization {
177 message: format!("TensorView for \"{}\": {e}", entry.name),
178 })
179 })
180 .collect::<Result<Vec<_>, _>>()?;
181
182 serialize_to_file(views, &None, path).map_err(|e| DistCheckpointError::Serialization {
183 message: format!("safetensors write to {}: {e}", path.display()),
184 })?;
185
186 Ok(())
187}
188
189fn load_tensors_from_file<T: Float>(
191 path: &Path,
192) -> Result<HashMap<String, Tensor<T>>, DistCheckpointError> {
193 let elem_size = std::mem::size_of::<T>();
194 let expected = st_dtype::<T>()?;
195
196 let file_data = std::fs::read(path).map_err(|e| DistCheckpointError::Io {
197 message: format!("reading {}: {e}", path.display()),
198 })?;
199
200 let st =
201 SafeTensors::deserialize(&file_data).map_err(|e| DistCheckpointError::Serialization {
202 message: format!("parsing {}: {e}", path.display()),
203 })?;
204
205 let tensor_list = st.tensors();
206 let mut result: HashMap<String, Tensor<T>> = HashMap::with_capacity(tensor_list.len());
207
208 for (name, view) in &tensor_list {
209 if view.dtype() != expected {
210 return Err(DistCheckpointError::Tensor {
211 message: format!(
212 "tensor \"{name}\" has dtype {:?}, expected {:?}",
213 view.dtype(),
214 expected
215 ),
216 });
217 }
218
219 let shape = view.shape().to_vec();
220 let byte_data = view.data();
221 let numel: usize = if shape.is_empty() {
222 1
223 } else {
224 shape.iter().product()
225 };
226 let expected_bytes = numel * elem_size;
227
228 if byte_data.len() != expected_bytes {
229 return Err(DistCheckpointError::Tensor {
230 message: format!(
231 "tensor \"{name}\" has {} bytes but shape {shape:?} requires {expected_bytes}",
232 byte_data.len()
233 ),
234 });
235 }
236
237 let data: Vec<T> = byte_data
238 .chunks_exact(elem_size)
239 .map(|chunk| {
240 let mut bytes = [0u8; 8];
241 bytes[..elem_size].copy_from_slice(chunk);
242 unsafe { std::ptr::read_unaligned(bytes.as_ptr() as *const T) }
243 })
244 .collect();
245
246 let storage = TensorStorage::cpu(data);
247 let tensor = Tensor::from_storage(storage, shape, false).map_err(|e| {
248 DistCheckpointError::Tensor {
249 message: format!("creating tensor \"{name}\": {e}"),
250 }
251 })?;
252 result.insert(name.clone(), tensor);
253 }
254
255 Ok(result)
256}
257
258pub fn save_distributed<T: Float>(
278 state_dict: &HashMap<String, Tensor<T>>,
279 dir: &Path,
280 rank: usize,
281 world_size: usize,
282 shard_spec: &ShardMetadata,
283) -> Result<(), DistCheckpointError> {
284 if world_size == 0 {
286 return Err(DistCheckpointError::InvalidArgument {
287 message: "world_size must be >= 1".into(),
288 });
289 }
290 if rank >= world_size {
291 return Err(DistCheckpointError::InvalidArgument {
292 message: format!("rank {rank} >= world_size {world_size}"),
293 });
294 }
295
296 std::fs::create_dir_all(dir)?;
298
299 let path = shard_path(dir, rank);
301 save_tensors_to_file(state_dict, &path)?;
302
303 if rank == 0 {
305 let json = serde_json::to_string_pretty(shard_spec).map_err(|e| {
306 DistCheckpointError::Serialization {
307 message: format!("serializing metadata: {e}"),
308 }
309 })?;
310 std::fs::write(metadata_path(dir), json)?;
311 }
312
313 Ok(())
314}
315
316pub fn load_distributed<T: Float>(
328 dir: &Path,
329 rank: usize,
330 world_size: usize,
331) -> Result<HashMap<String, Tensor<T>>, DistCheckpointError> {
332 if world_size == 0 {
333 return Err(DistCheckpointError::InvalidArgument {
334 message: "world_size must be >= 1".into(),
335 });
336 }
337 if rank >= world_size {
338 return Err(DistCheckpointError::InvalidArgument {
339 message: format!("rank {rank} >= world_size {world_size}"),
340 });
341 }
342
343 let meta_path = metadata_path(dir);
345 let meta_json = std::fs::read_to_string(&meta_path).map_err(|e| DistCheckpointError::Io {
346 message: format!("reading {}: {e}", meta_path.display()),
347 })?;
348 let metadata: ShardMetadata =
349 serde_json::from_str(&meta_json).map_err(|e| DistCheckpointError::Serialization {
350 message: format!("parsing metadata: {e}"),
351 })?;
352
353 let old_world_size = metadata.num_ranks;
354
355 if old_world_size == world_size {
356 let path = shard_path(dir, rank);
358 if !path.exists() {
359 if metadata.tensor_specs.is_empty() {
362 return Ok(HashMap::new());
363 }
364 return Err(DistCheckpointError::MissingShard {
365 path: path.display().to_string(),
366 });
367 }
368 load_tensors_from_file(&path)
369 } else {
370 reshard(dir, old_world_size, world_size, rank)
372 }
373}
374
375pub fn reshard<T: Float>(
397 dir: &Path,
398 old_world_size: usize,
399 new_world_size: usize,
400 new_rank: usize,
401) -> Result<HashMap<String, Tensor<T>>, DistCheckpointError> {
402 if new_world_size == 0 {
403 return Err(DistCheckpointError::InvalidArgument {
404 message: "new_world_size must be >= 1".into(),
405 });
406 }
407 if new_rank >= new_world_size {
408 return Err(DistCheckpointError::InvalidArgument {
409 message: format!("new_rank {new_rank} >= new_world_size {new_world_size}"),
410 });
411 }
412 if old_world_size == 0 {
413 return Err(DistCheckpointError::InvalidArgument {
414 message: "old_world_size must be >= 1".into(),
415 });
416 }
417
418 let meta_path = metadata_path(dir);
420 let meta_json = std::fs::read_to_string(&meta_path).map_err(|e| DistCheckpointError::Io {
421 message: format!("reading {}: {e}", meta_path.display()),
422 })?;
423 let metadata: ShardMetadata =
424 serde_json::from_str(&meta_json).map_err(|e| DistCheckpointError::Serialization {
425 message: format!("parsing metadata: {e}"),
426 })?;
427
428 let mut old_shards: Vec<HashMap<String, Tensor<T>>> = Vec::with_capacity(old_world_size);
430 for old_rank in 0..old_world_size {
431 let path = shard_path(dir, old_rank);
432 if !path.exists() {
433 return Err(DistCheckpointError::MissingShard {
434 path: path.display().to_string(),
435 });
436 }
437 old_shards.push(load_tensors_from_file(&path)?);
438 }
439
440 let mut result: HashMap<String, Tensor<T>> = HashMap::new();
442
443 for (name, spec) in &metadata.tensor_specs {
444 let shard_dim = spec.shard_dim;
445 let full_shape = &spec.full_shape;
446
447 let mut shard_datas: Vec<Vec<T>> = Vec::with_capacity(old_world_size);
449 let mut shard_shapes: Vec<Vec<usize>> = Vec::with_capacity(old_world_size);
450
451 for (old_rank, shard) in old_shards.iter().enumerate().take(old_world_size) {
452 let tensor = shard.get(name).ok_or_else(|| DistCheckpointError::Tensor {
453 message: format!("tensor \"{name}\" missing from rank {old_rank} shard"),
454 })?;
455 shard_datas.push(tensor.data_vec().map_err(|e| DistCheckpointError::Tensor {
456 message: format!("reading tensor \"{name}\" from rank {old_rank}: {e}"),
457 })?);
458 shard_shapes.push(tensor.shape().to_vec());
459 }
460
461 let full_data = concat_along_dim(&shard_datas, &shard_shapes, shard_dim, full_shape)?;
463
464 let full_dim_size = full_shape[shard_dim];
466 let new_shard_sizes = compute_shard_sizes(full_dim_size, new_world_size);
467 let new_offset: usize = new_shard_sizes[..new_rank].iter().sum();
468 let new_size = new_shard_sizes[new_rank];
469
470 let mut new_shape = full_shape.clone();
472 new_shape[shard_dim] = new_size;
473
474 let new_data = slice_along_dim(&full_data, full_shape, shard_dim, new_offset, new_size);
475
476 let tensor =
477 Tensor::from_storage(TensorStorage::cpu(new_data), new_shape, false).map_err(|e| {
478 DistCheckpointError::Tensor {
479 message: format!("creating resharded tensor \"{name}\": {e}"),
480 }
481 })?;
482
483 result.insert(name.clone(), tensor);
484 }
485
486 Ok(result)
487}
488
489fn compute_shard_sizes(total: usize, num_parts: usize) -> Vec<usize> {
492 let base = total / num_parts;
493 let remainder = total % num_parts;
494 (0..num_parts)
495 .map(|i| if i < remainder { base + 1 } else { base })
496 .collect()
497}
498
499fn concat_along_dim<T: Float>(
505 shard_datas: &[Vec<T>],
506 shard_shapes: &[Vec<usize>],
507 dim: usize,
508 full_shape: &[usize],
509) -> Result<Vec<T>, DistCheckpointError> {
510 let ndim = full_shape.len();
511 if dim >= ndim {
512 return Err(DistCheckpointError::InvalidArgument {
513 message: format!("shard_dim {dim} >= ndim {ndim}"),
514 });
515 }
516
517 let full_numel: usize = full_shape.iter().product();
518 let mut full_data = vec![<T as num_traits::Zero>::zero(); full_numel];
519
520 let outer: usize = full_shape[..dim].iter().product();
526 let inner: usize = full_shape[dim + 1..].iter().product();
527 let full_middle = full_shape[dim];
528
529 let mut dim_offset = 0;
531 for (shard_idx, shard_data) in shard_datas.iter().enumerate() {
532 let shard_middle = shard_shapes[shard_idx][dim];
533
534 for d in 0..ndim {
536 if d != dim && shard_shapes[shard_idx][d] != full_shape[d] {
537 return Err(DistCheckpointError::Tensor {
538 message: format!(
539 "shard {shard_idx} has shape {:?} but expected dim {d} to be {} (full shape {full_shape:?})",
540 shard_shapes[shard_idx], full_shape[d]
541 ),
542 });
543 }
544 }
545
546 for o in 0..outer {
547 let src_start = o * shard_middle * inner;
548 let dst_start = o * full_middle * inner + dim_offset * inner;
549 let count = shard_middle * inner;
550
551 full_data[dst_start..dst_start + count]
552 .copy_from_slice(&shard_data[src_start..src_start + count]);
553 }
554
555 dim_offset += shard_middle;
556 }
557
558 if dim_offset != full_middle {
559 return Err(DistCheckpointError::Tensor {
560 message: format!(
561 "shard sizes along dim {dim} sum to {dim_offset}, expected {full_middle}"
562 ),
563 });
564 }
565
566 Ok(full_data)
567}
568
569fn slice_along_dim<T: Float>(
574 data: &[T],
575 shape: &[usize],
576 dim: usize,
577 offset: usize,
578 size: usize,
579) -> Vec<T> {
580 let outer: usize = shape[..dim].iter().product();
581 let full_middle = shape[dim];
582 let inner: usize = shape[dim + 1..].iter().product();
583
584 let out_numel = outer * size * inner;
585 let mut result = Vec::with_capacity(out_numel);
586
587 for o in 0..outer {
588 let src_start = o * full_middle * inner + offset * inner;
589 let count = size * inner;
590 result.extend_from_slice(&data[src_start..src_start + count]);
591 }
592
593 result
594}
595
596pub struct CheckpointFuture {
605 handle: Option<thread::JoinHandle<Result<(), DistCheckpointError>>>,
606 result: Option<Result<(), DistCheckpointError>>,
608}
609
610impl CheckpointFuture {
611 pub fn wait(&mut self) -> Result<(), DistCheckpointError> {
619 if let Some(handle) = self.handle.take() {
620 let res = handle
621 .join()
622 .map_err(|_| DistCheckpointError::AsyncFailed {
623 message: "background checkpoint thread panicked".into(),
624 })?;
625 self.result = Some(res);
626 }
627
628 match &self.result {
629 Some(Ok(())) => Ok(()),
630 Some(Err(e)) => Err(DistCheckpointError::AsyncFailed {
631 message: format!("{e}"),
632 }),
633 None => Err(DistCheckpointError::AsyncFailed {
634 message: "no checkpoint was started".into(),
635 }),
636 }
637 }
638
639 pub fn is_done(&self) -> bool {
641 if self.result.is_some() {
642 return true;
643 }
644 match &self.handle {
645 Some(h) => h.is_finished(),
646 None => true,
647 }
648 }
649}
650
651pub struct AsyncCheckpointer {
676 dir: PathBuf,
677 rank: usize,
678 world_size: usize,
679 shard_spec: ShardMetadata,
680 in_flight: Arc<Mutex<bool>>,
682}
683
684impl AsyncCheckpointer {
685 pub fn new(dir: PathBuf, rank: usize, world_size: usize, shard_spec: ShardMetadata) -> Self {
692 Self {
693 dir,
694 rank,
695 world_size,
696 shard_spec,
697 in_flight: Arc::new(Mutex::new(false)),
698 }
699 }
700
701 pub fn dir(&self) -> &Path {
703 &self.dir
704 }
705
706 pub fn rank(&self) -> usize {
708 self.rank
709 }
710
711 pub fn world_size(&self) -> usize {
713 self.world_size
714 }
715
716 pub fn save_async(
729 &self,
730 state_dict: &HashMap<String, Tensor<f32>>,
731 ) -> Result<CheckpointFuture, DistCheckpointError> {
732 {
734 let mut guard =
735 self.in_flight
736 .lock()
737 .map_err(|e| DistCheckpointError::AsyncFailed {
738 message: format!("lock poisoned: {e}"),
739 })?;
740 if *guard {
741 return Err(DistCheckpointError::AsyncFailed {
742 message: "another async checkpoint is already in flight".into(),
743 });
744 }
745 *guard = true;
746 }
747
748 let mut staged: HashMap<String, (Vec<f32>, Vec<usize>)> =
751 HashMap::with_capacity(state_dict.len());
752
753 for (name, tensor) in state_dict {
754 let data = tensor.data_vec().map_err(|e| {
755 if let Ok(mut g) = self.in_flight.lock() {
757 *g = false;
758 }
759 DistCheckpointError::Tensor {
760 message: format!("staging tensor \"{name}\": {e}"),
761 }
762 })?;
763 let shape = tensor.shape().to_vec();
764 staged.insert(name.clone(), (data, shape));
765 }
766
767 let dir = self.dir.clone();
769 let rank = self.rank;
770 let shard_spec = self.shard_spec.clone();
771 let in_flight = Arc::clone(&self.in_flight);
772
773 let handle = thread::spawn(move || {
774 let result = (|| -> Result<(), DistCheckpointError> {
775 let mut tensors: HashMap<String, Tensor<f32>> =
777 HashMap::with_capacity(staged.len());
778 for (name, (data, shape)) in staged {
779 let tensor = Tensor::from_storage(TensorStorage::cpu(data), shape, false)
780 .map_err(|e| DistCheckpointError::Tensor {
781 message: format!("rebuilding tensor \"{name}\": {e}"),
782 })?;
783 tensors.insert(name, tensor);
784 }
785
786 std::fs::create_dir_all(&dir)?;
788 let path = shard_path(&dir, rank);
789 save_tensors_to_file(&tensors, &path)?;
790
791 if rank == 0 {
793 let json = serde_json::to_string_pretty(&shard_spec).map_err(|e| {
794 DistCheckpointError::Serialization {
795 message: format!("serializing metadata: {e}"),
796 }
797 })?;
798 std::fs::write(metadata_path(&dir), json)?;
799 }
800
801 Ok(())
802 })();
803
804 if let Ok(mut g) = in_flight.lock() {
806 *g = false;
807 }
808
809 result
810 });
811
812 Ok(CheckpointFuture {
813 handle: Some(handle),
814 result: None,
815 })
816 }
817}
818
819pub fn flat_shard_metadata(
829 state_dict: &HashMap<String, Tensor<f32>>,
830 world_size: usize,
831) -> ShardMetadata {
832 let mut tensor_specs = HashMap::new();
833 for (name, tensor) in state_dict {
834 let shape = tensor.shape();
835 let shard_numel = shape.iter().product::<usize>();
838 let full_numel = shard_numel * world_size;
839 let shard_sizes = vec![shard_numel; world_size];
840 tensor_specs.insert(
841 name.clone(),
842 TensorShardSpec {
843 full_shape: vec![full_numel],
844 shard_dim: 0,
845 shard_sizes,
846 },
847 );
848 }
849 ShardMetadata {
850 num_ranks: world_size,
851 tensor_specs,
852 }
853}
854
855#[cfg(test)]
860mod tests {
861 use super::*;
862 use ferrotorch_core::Tensor;
863 use ferrotorch_core::storage::TensorStorage;
864 use std::collections::HashMap;
865
866 fn make_tensor(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
867 Tensor::from_storage(TensorStorage::cpu(data), shape, false).unwrap()
868 }
869
870 fn temp_dir(name: &str) -> PathBuf {
871 std::env::temp_dir()
872 .join("ferrotorch_test_dist_ckpt")
873 .join(name)
874 }
875
876 fn cleanup(dir: &Path) {
877 let _ = std::fs::remove_dir_all(dir);
878 }
879
880 #[test]
883 fn test_save_load_single_rank() {
884 let dir = temp_dir("single_rank");
885 cleanup(&dir);
886
887 let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
888 state.insert(
889 "weight".into(),
890 make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]),
891 );
892 state.insert("bias".into(), make_tensor(vec![0.1, 0.2], vec![2]));
893
894 let spec = ShardMetadata {
895 num_ranks: 1,
896 tensor_specs: {
897 let mut m = HashMap::new();
898 m.insert(
899 "weight".into(),
900 TensorShardSpec {
901 full_shape: vec![4],
902 shard_dim: 0,
903 shard_sizes: vec![4],
904 },
905 );
906 m.insert(
907 "bias".into(),
908 TensorShardSpec {
909 full_shape: vec![2],
910 shard_dim: 0,
911 shard_sizes: vec![2],
912 },
913 );
914 m
915 },
916 };
917
918 save_distributed(&state, &dir, 0, 1, &spec).unwrap();
919 let loaded = load_distributed::<f32>(&dir, 0, 1).unwrap();
920
921 assert_eq!(loaded.len(), 2);
922 assert_eq!(loaded["weight"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
923 assert_eq!(loaded["bias"].data().unwrap(), &[0.1, 0.2]);
924
925 cleanup(&dir);
926 }
927
928 #[test]
929 fn test_save_load_two_ranks() {
930 let dir = temp_dir("two_ranks");
931 cleanup(&dir);
932
933 let mut state0: HashMap<String, Tensor<f32>> = HashMap::new();
935 state0.insert("weight".into(), make_tensor(vec![1.0, 2.0], vec![2]));
936
937 let mut state1: HashMap<String, Tensor<f32>> = HashMap::new();
939 state1.insert("weight".into(), make_tensor(vec![3.0, 4.0], vec![2]));
940
941 let spec = ShardMetadata {
942 num_ranks: 2,
943 tensor_specs: {
944 let mut m = HashMap::new();
945 m.insert(
946 "weight".into(),
947 TensorShardSpec {
948 full_shape: vec![4],
949 shard_dim: 0,
950 shard_sizes: vec![2, 2],
951 },
952 );
953 m
954 },
955 };
956
957 save_distributed(&state0, &dir, 0, 2, &spec).unwrap();
958 save_distributed(&state1, &dir, 1, 2, &spec).unwrap();
959
960 let loaded0 = load_distributed::<f32>(&dir, 0, 2).unwrap();
962 let loaded1 = load_distributed::<f32>(&dir, 1, 2).unwrap();
963
964 assert_eq!(loaded0["weight"].data().unwrap(), &[1.0, 2.0]);
965 assert_eq!(loaded1["weight"].data().unwrap(), &[3.0, 4.0]);
966
967 cleanup(&dir);
968 }
969
970 #[test]
973 fn test_reshard_2_to_4() {
974 let dir = temp_dir("reshard_2_to_4");
976 cleanup(&dir);
977
978 let mut state0: HashMap<String, Tensor<f32>> = HashMap::new();
981 state0.insert("w".into(), make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]));
982
983 let mut state1: HashMap<String, Tensor<f32>> = HashMap::new();
984 state1.insert("w".into(), make_tensor(vec![5.0, 6.0, 7.0, 8.0], vec![4]));
985
986 let spec = ShardMetadata {
987 num_ranks: 2,
988 tensor_specs: {
989 let mut m = HashMap::new();
990 m.insert(
991 "w".into(),
992 TensorShardSpec {
993 full_shape: vec![8],
994 shard_dim: 0,
995 shard_sizes: vec![4, 4],
996 },
997 );
998 m
999 },
1000 };
1001
1002 save_distributed(&state0, &dir, 0, 2, &spec).unwrap();
1003 save_distributed(&state1, &dir, 1, 2, &spec).unwrap();
1004
1005 let r0 = reshard::<f32>(&dir, 2, 4, 0).unwrap();
1007 let r1 = reshard::<f32>(&dir, 2, 4, 1).unwrap();
1008 let r2 = reshard::<f32>(&dir, 2, 4, 2).unwrap();
1009 let r3 = reshard::<f32>(&dir, 2, 4, 3).unwrap();
1010
1011 assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0]);
1012 assert_eq!(r1["w"].data().unwrap(), &[3.0, 4.0]);
1013 assert_eq!(r2["w"].data().unwrap(), &[5.0, 6.0]);
1014 assert_eq!(r3["w"].data().unwrap(), &[7.0, 8.0]);
1015
1016 cleanup(&dir);
1017 }
1018
1019 #[test]
1020 fn test_reshard_4_to_2() {
1021 let dir = temp_dir("reshard_4_to_2");
1023 cleanup(&dir);
1024
1025 let spec = ShardMetadata {
1026 num_ranks: 4,
1027 tensor_specs: {
1028 let mut m = HashMap::new();
1029 m.insert(
1030 "w".into(),
1031 TensorShardSpec {
1032 full_shape: vec![8],
1033 shard_dim: 0,
1034 shard_sizes: vec![2, 2, 2, 2],
1035 },
1036 );
1037 m
1038 },
1039 };
1040
1041 for rank in 0..4 {
1042 let start = rank as f32 * 2.0 + 1.0;
1043 let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1044 state.insert("w".into(), make_tensor(vec![start, start + 1.0], vec![2]));
1045 save_distributed(&state, &dir, rank, 4, &spec).unwrap();
1046 }
1047
1048 let r0 = reshard::<f32>(&dir, 4, 2, 0).unwrap();
1050 let r1 = reshard::<f32>(&dir, 4, 2, 1).unwrap();
1051
1052 assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
1053 assert_eq!(r1["w"].data().unwrap(), &[5.0, 6.0, 7.0, 8.0]);
1054
1055 cleanup(&dir);
1056 }
1057
1058 #[test]
1059 fn test_reshard_2d_tensor() {
1060 let dir = temp_dir("reshard_2d");
1063 cleanup(&dir);
1064
1065 let mut state0: HashMap<String, Tensor<f32>> = HashMap::new();
1066 state0.insert(
1067 "w".into(),
1068 make_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]),
1069 );
1070
1071 let mut state1: HashMap<String, Tensor<f32>> = HashMap::new();
1072 state1.insert(
1073 "w".into(),
1074 make_tensor(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], vec![2, 3]),
1075 );
1076
1077 let spec = ShardMetadata {
1078 num_ranks: 2,
1079 tensor_specs: {
1080 let mut m = HashMap::new();
1081 m.insert(
1082 "w".into(),
1083 TensorShardSpec {
1084 full_shape: vec![4, 3],
1085 shard_dim: 0,
1086 shard_sizes: vec![2, 2],
1087 },
1088 );
1089 m
1090 },
1091 };
1092
1093 save_distributed(&state0, &dir, 0, 2, &spec).unwrap();
1094 save_distributed(&state1, &dir, 1, 2, &spec).unwrap();
1095
1096 let r0 = reshard::<f32>(&dir, 2, 4, 0).unwrap();
1098 let r1 = reshard::<f32>(&dir, 2, 4, 1).unwrap();
1099 let r2 = reshard::<f32>(&dir, 2, 4, 2).unwrap();
1100 let r3 = reshard::<f32>(&dir, 2, 4, 3).unwrap();
1101
1102 assert_eq!(r0["w"].shape(), &[1, 3]);
1103 assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0, 3.0]);
1104 assert_eq!(r1["w"].shape(), &[1, 3]);
1105 assert_eq!(r1["w"].data().unwrap(), &[4.0, 5.0, 6.0]);
1106 assert_eq!(r2["w"].shape(), &[1, 3]);
1107 assert_eq!(r2["w"].data().unwrap(), &[7.0, 8.0, 9.0]);
1108 assert_eq!(r3["w"].shape(), &[1, 3]);
1109 assert_eq!(r3["w"].data().unwrap(), &[10.0, 11.0, 12.0]);
1110
1111 cleanup(&dir);
1112 }
1113
1114 #[test]
1115 fn test_reshard_dim1() {
1116 let dir = temp_dir("reshard_dim1");
1119 cleanup(&dir);
1120
1121 let mut state0: HashMap<String, Tensor<f32>> = HashMap::new();
1122 state0.insert(
1123 "w".into(),
1124 make_tensor(vec![1.0, 2.0, 5.0, 6.0], vec![2, 2]),
1125 );
1126
1127 let mut state1: HashMap<String, Tensor<f32>> = HashMap::new();
1128 state1.insert(
1129 "w".into(),
1130 make_tensor(vec![3.0, 4.0, 7.0, 8.0], vec![2, 2]),
1131 );
1132
1133 let spec = ShardMetadata {
1134 num_ranks: 2,
1135 tensor_specs: {
1136 let mut m = HashMap::new();
1137 m.insert(
1138 "w".into(),
1139 TensorShardSpec {
1140 full_shape: vec![2, 4],
1141 shard_dim: 1,
1142 shard_sizes: vec![2, 2],
1143 },
1144 );
1145 m
1146 },
1147 };
1148
1149 save_distributed(&state0, &dir, 0, 2, &spec).unwrap();
1150 save_distributed(&state1, &dir, 1, 2, &spec).unwrap();
1151
1152 let r0 = reshard::<f32>(&dir, 2, 4, 0).unwrap();
1154 let r1 = reshard::<f32>(&dir, 2, 4, 1).unwrap();
1155 let r2 = reshard::<f32>(&dir, 2, 4, 2).unwrap();
1156 let r3 = reshard::<f32>(&dir, 2, 4, 3).unwrap();
1157
1158 assert_eq!(r0["w"].shape(), &[2, 1]);
1159 assert_eq!(r0["w"].data().unwrap(), &[1.0, 5.0]);
1160 assert_eq!(r1["w"].shape(), &[2, 1]);
1161 assert_eq!(r1["w"].data().unwrap(), &[2.0, 6.0]);
1162 assert_eq!(r2["w"].shape(), &[2, 1]);
1163 assert_eq!(r2["w"].data().unwrap(), &[3.0, 7.0]);
1164 assert_eq!(r3["w"].shape(), &[2, 1]);
1165 assert_eq!(r3["w"].data().unwrap(), &[4.0, 8.0]);
1166
1167 cleanup(&dir);
1168 }
1169
1170 #[test]
1171 fn test_reshard_3_to_2_uneven() {
1172 let dir = temp_dir("reshard_3_to_2");
1175 cleanup(&dir);
1176
1177 let spec = ShardMetadata {
1178 num_ranks: 3,
1179 tensor_specs: {
1180 let mut m = HashMap::new();
1181 m.insert(
1182 "w".into(),
1183 TensorShardSpec {
1184 full_shape: vec![9],
1185 shard_dim: 0,
1186 shard_sizes: vec![3, 3, 3],
1187 },
1188 );
1189 m
1190 },
1191 };
1192
1193 for rank in 0..3usize {
1194 let start = rank as f32 * 3.0 + 1.0;
1195 let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1196 state.insert(
1197 "w".into(),
1198 make_tensor(vec![start, start + 1.0, start + 2.0], vec![3]),
1199 );
1200 save_distributed(&state, &dir, rank, 3, &spec).unwrap();
1201 }
1202
1203 let r0 = reshard::<f32>(&dir, 3, 2, 0).unwrap();
1205 let r1 = reshard::<f32>(&dir, 3, 2, 1).unwrap();
1206
1207 assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
1208 assert_eq!(r1["w"].data().unwrap(), &[6.0, 7.0, 8.0, 9.0]);
1209
1210 cleanup(&dir);
1211 }
1212
1213 #[test]
1216 fn test_load_distributed_reshards_when_world_size_differs() {
1217 let dir = temp_dir("load_reshard");
1218 cleanup(&dir);
1219
1220 let spec = ShardMetadata {
1222 num_ranks: 2,
1223 tensor_specs: {
1224 let mut m = HashMap::new();
1225 m.insert(
1226 "w".into(),
1227 TensorShardSpec {
1228 full_shape: vec![4],
1229 shard_dim: 0,
1230 shard_sizes: vec![2, 2],
1231 },
1232 );
1233 m
1234 },
1235 };
1236
1237 let mut s0: HashMap<String, Tensor<f32>> = HashMap::new();
1238 s0.insert("w".into(), make_tensor(vec![1.0, 2.0], vec![2]));
1239 save_distributed(&s0, &dir, 0, 2, &spec).unwrap();
1240
1241 let mut s1: HashMap<String, Tensor<f32>> = HashMap::new();
1242 s1.insert("w".into(), make_tensor(vec![3.0, 4.0], vec![2]));
1243 save_distributed(&s1, &dir, 1, 2, &spec).unwrap();
1244
1245 let r0 = load_distributed::<f32>(&dir, 0, 4).unwrap();
1247 let r1 = load_distributed::<f32>(&dir, 1, 4).unwrap();
1248 let r2 = load_distributed::<f32>(&dir, 2, 4).unwrap();
1249 let r3 = load_distributed::<f32>(&dir, 3, 4).unwrap();
1250
1251 assert_eq!(r0["w"].data().unwrap(), &[1.0]);
1252 assert_eq!(r1["w"].data().unwrap(), &[2.0]);
1253 assert_eq!(r2["w"].data().unwrap(), &[3.0]);
1254 assert_eq!(r3["w"].data().unwrap(), &[4.0]);
1255
1256 cleanup(&dir);
1257 }
1258
1259 #[test]
1262 fn test_metadata_roundtrip() {
1263 let spec = ShardMetadata {
1264 num_ranks: 4,
1265 tensor_specs: {
1266 let mut m = HashMap::new();
1267 m.insert(
1268 "layer.weight".into(),
1269 TensorShardSpec {
1270 full_shape: vec![256, 512],
1271 shard_dim: 0,
1272 shard_sizes: vec![64, 64, 64, 64],
1273 },
1274 );
1275 m.insert(
1276 "layer.bias".into(),
1277 TensorShardSpec {
1278 full_shape: vec![256],
1279 shard_dim: 0,
1280 shard_sizes: vec![64, 64, 64, 64],
1281 },
1282 );
1283 m
1284 },
1285 };
1286
1287 let json = serde_json::to_string_pretty(&spec).unwrap();
1288 let loaded: ShardMetadata = serde_json::from_str(&json).unwrap();
1289
1290 assert_eq!(loaded.num_ranks, 4);
1291 assert_eq!(loaded.tensor_specs.len(), 2);
1292 assert_eq!(
1293 loaded.tensor_specs["layer.weight"].full_shape,
1294 vec![256, 512]
1295 );
1296 assert_eq!(loaded.tensor_specs["layer.weight"].shard_dim, 0);
1297 assert_eq!(
1298 loaded.tensor_specs["layer.weight"].shard_sizes,
1299 vec![64, 64, 64, 64]
1300 );
1301 }
1302
1303 #[test]
1306 fn test_compute_shard_sizes_even() {
1307 assert_eq!(compute_shard_sizes(8, 4), vec![2, 2, 2, 2]);
1308 assert_eq!(compute_shard_sizes(12, 3), vec![4, 4, 4]);
1309 }
1310
1311 #[test]
1312 fn test_compute_shard_sizes_uneven() {
1313 assert_eq!(compute_shard_sizes(9, 2), vec![5, 4]);
1315 assert_eq!(compute_shard_sizes(10, 3), vec![4, 3, 3]);
1317 assert_eq!(compute_shard_sizes(7, 4), vec![2, 2, 2, 1]);
1319 }
1320
1321 #[test]
1324 fn test_concat_1d() {
1325 let data0 = vec![1.0f32, 2.0];
1326 let data1 = vec![3.0f32, 4.0, 5.0];
1327 let full_shape = vec![5];
1328
1329 let result =
1330 concat_along_dim(&[data0, data1], &[vec![2], vec![3]], 0, &full_shape).unwrap();
1331
1332 assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1333 }
1334
1335 #[test]
1336 fn test_concat_2d_dim0() {
1337 let data0 = vec![1.0f32, 2.0, 3.0];
1339 let data1 = vec![4.0f32, 5.0, 6.0];
1340 let full_shape = vec![2, 3];
1341
1342 let result =
1343 concat_along_dim(&[data0, data1], &[vec![1, 3], vec![1, 3]], 0, &full_shape).unwrap();
1344
1345 assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1346 }
1347
1348 #[test]
1349 fn test_concat_2d_dim1() {
1350 let data0 = vec![1.0f32, 3.0];
1353 let data1 = vec![2.0f32, 4.0];
1354 let full_shape = vec![2, 2];
1355
1356 let result =
1357 concat_along_dim(&[data0, data1], &[vec![2, 1], vec![2, 1]], 1, &full_shape).unwrap();
1358
1359 assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0]);
1360 }
1361
1362 #[test]
1365 fn test_slice_1d() {
1366 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
1367 let shape = vec![5];
1368
1369 let s0 = slice_along_dim(&data, &shape, 0, 0, 2);
1370 assert_eq!(s0, vec![1.0, 2.0]);
1371
1372 let s1 = slice_along_dim(&data, &shape, 0, 2, 3);
1373 assert_eq!(s1, vec![3.0, 4.0, 5.0]);
1374 }
1375
1376 #[test]
1377 fn test_slice_2d_dim0() {
1378 let data = vec![
1380 1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1381 ];
1382 let shape = vec![4, 3];
1383
1384 let s = slice_along_dim(&data, &shape, 0, 1, 2);
1385 assert_eq!(s, vec![4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
1386 }
1387
1388 #[test]
1389 fn test_slice_2d_dim1() {
1390 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1392 let shape = vec![2, 4];
1393
1394 let s = slice_along_dim(&data, &shape, 1, 1, 2);
1395 assert_eq!(s, vec![2.0, 3.0, 6.0, 7.0]);
1396 }
1397
1398 #[test]
1401 fn test_flat_shard_metadata() {
1402 let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1403 state.insert("w".into(), make_tensor(vec![1.0, 2.0, 3.0], vec![3]));
1404
1405 let meta = flat_shard_metadata(&state, 4);
1406 assert_eq!(meta.num_ranks, 4);
1407
1408 let spec = &meta.tensor_specs["w"];
1409 assert_eq!(spec.full_shape, vec![12]); assert_eq!(spec.shard_dim, 0);
1411 assert_eq!(spec.shard_sizes, vec![3, 3, 3, 3]);
1412 }
1413
1414 #[test]
1417 fn test_async_checkpoint_basic() {
1418 let dir = temp_dir("async_basic");
1419 cleanup(&dir);
1420
1421 let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1422 state.insert("w".into(), make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]));
1423
1424 let spec = ShardMetadata {
1425 num_ranks: 1,
1426 tensor_specs: {
1427 let mut m = HashMap::new();
1428 m.insert(
1429 "w".into(),
1430 TensorShardSpec {
1431 full_shape: vec![4],
1432 shard_dim: 0,
1433 shard_sizes: vec![4],
1434 },
1435 );
1436 m
1437 },
1438 };
1439
1440 let ckpt = AsyncCheckpointer::new(dir.clone(), 0, 1, spec);
1441 let mut future = ckpt.save_async(&state).unwrap();
1442 future.wait().unwrap();
1443
1444 let loaded = load_distributed::<f32>(&dir, 0, 1).unwrap();
1446 assert_eq!(loaded["w"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
1447
1448 cleanup(&dir);
1449 }
1450
1451 #[test]
1452 fn test_async_checkpoint_wait_idempotent() {
1453 let dir = temp_dir("async_idempotent");
1454 cleanup(&dir);
1455
1456 let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1457 state.insert("x".into(), make_tensor(vec![42.0], vec![1]));
1458
1459 let spec = ShardMetadata {
1460 num_ranks: 1,
1461 tensor_specs: {
1462 let mut m = HashMap::new();
1463 m.insert(
1464 "x".into(),
1465 TensorShardSpec {
1466 full_shape: vec![1],
1467 shard_dim: 0,
1468 shard_sizes: vec![1],
1469 },
1470 );
1471 m
1472 },
1473 };
1474
1475 let ckpt = AsyncCheckpointer::new(dir.clone(), 0, 1, spec);
1476 let mut future = ckpt.save_async(&state).unwrap();
1477
1478 future.wait().unwrap();
1480 future.wait().unwrap();
1481
1482 cleanup(&dir);
1483 }
1484
1485 #[test]
1486 fn test_async_checkpoint_is_done() {
1487 let dir = temp_dir("async_is_done");
1488 cleanup(&dir);
1489
1490 let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1491 state.insert("x".into(), make_tensor(vec![1.0], vec![1]));
1492
1493 let spec = ShardMetadata {
1494 num_ranks: 1,
1495 tensor_specs: {
1496 let mut m = HashMap::new();
1497 m.insert(
1498 "x".into(),
1499 TensorShardSpec {
1500 full_shape: vec![1],
1501 shard_dim: 0,
1502 shard_sizes: vec![1],
1503 },
1504 );
1505 m
1506 },
1507 };
1508
1509 let ckpt = AsyncCheckpointer::new(dir.clone(), 0, 1, spec);
1510 let mut future = ckpt.save_async(&state).unwrap();
1511 future.wait().unwrap();
1512 assert!(future.is_done());
1513
1514 cleanup(&dir);
1515 }
1516
1517 #[test]
1520 fn test_save_invalid_rank() {
1521 let dir = temp_dir("invalid_rank");
1522 let state: HashMap<String, Tensor<f32>> = HashMap::new();
1523 let spec = ShardMetadata {
1524 num_ranks: 2,
1525 tensor_specs: HashMap::new(),
1526 };
1527
1528 let result = save_distributed(&state, &dir, 5, 2, &spec);
1529 assert!(result.is_err());
1530 }
1531
1532 #[test]
1533 fn test_load_missing_metadata() {
1534 let dir = temp_dir("missing_meta");
1535 cleanup(&dir);
1536 std::fs::create_dir_all(&dir).unwrap();
1537
1538 let result = load_distributed::<f32>(&dir, 0, 1);
1539 assert!(result.is_err());
1540
1541 cleanup(&dir);
1542 }
1543
1544 #[test]
1545 fn test_load_missing_shard() {
1546 let dir = temp_dir("missing_shard");
1547 cleanup(&dir);
1548 std::fs::create_dir_all(&dir).unwrap();
1549
1550 let spec = ShardMetadata {
1552 num_ranks: 1,
1553 tensor_specs: HashMap::new(),
1554 };
1555 let json = serde_json::to_string_pretty(&spec).unwrap();
1556 std::fs::write(metadata_path(&dir), json).unwrap();
1557
1558 let loaded = load_distributed::<f32>(&dir, 0, 1).unwrap();
1562 assert!(loaded.is_empty());
1563
1564 cleanup(&dir);
1565 }
1566
1567 #[test]
1570 fn test_reshard_multiple_tensors() {
1571 let dir = temp_dir("reshard_multi");
1572 cleanup(&dir);
1573
1574 let spec = ShardMetadata {
1575 num_ranks: 2,
1576 tensor_specs: {
1577 let mut m = HashMap::new();
1578 m.insert(
1579 "weight".into(),
1580 TensorShardSpec {
1581 full_shape: vec![4],
1582 shard_dim: 0,
1583 shard_sizes: vec![2, 2],
1584 },
1585 );
1586 m.insert(
1587 "bias".into(),
1588 TensorShardSpec {
1589 full_shape: vec![6],
1590 shard_dim: 0,
1591 shard_sizes: vec![3, 3],
1592 },
1593 );
1594 m
1595 },
1596 };
1597
1598 let mut s0: HashMap<String, Tensor<f32>> = HashMap::new();
1599 s0.insert("weight".into(), make_tensor(vec![1.0, 2.0], vec![2]));
1600 s0.insert("bias".into(), make_tensor(vec![10.0, 20.0, 30.0], vec![3]));
1601
1602 let mut s1: HashMap<String, Tensor<f32>> = HashMap::new();
1603 s1.insert("weight".into(), make_tensor(vec![3.0, 4.0], vec![2]));
1604 s1.insert("bias".into(), make_tensor(vec![40.0, 50.0, 60.0], vec![3]));
1605
1606 save_distributed(&s0, &dir, 0, 2, &spec).unwrap();
1607 save_distributed(&s1, &dir, 1, 2, &spec).unwrap();
1608
1609 let r0 = load_distributed::<f32>(&dir, 0, 1).unwrap();
1611
1612 assert_eq!(r0["weight"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
1613 assert_eq!(
1614 r0["bias"].data().unwrap(),
1615 &[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]
1616 );
1617
1618 cleanup(&dir);
1619 }
1620
1621 #[test]
1624 fn test_distributed_checkpoint_struct() {
1625 let ckpt = DistributedCheckpoint {
1626 checkpoint_dir: PathBuf::from("/tmp/test"),
1627 shard_metadata: ShardMetadata {
1628 num_ranks: 2,
1629 tensor_specs: HashMap::new(),
1630 },
1631 };
1632 assert_eq!(ckpt.checkpoint_dir, PathBuf::from("/tmp/test"));
1633 assert_eq!(ckpt.shard_metadata.num_ranks, 2);
1634 }
1635
1636 #[test]
1638 fn test_reshard_same_world_size() {
1639 let dir = temp_dir("reshard_same");
1640 cleanup(&dir);
1641
1642 let spec = ShardMetadata {
1643 num_ranks: 2,
1644 tensor_specs: {
1645 let mut m = HashMap::new();
1646 m.insert(
1647 "w".into(),
1648 TensorShardSpec {
1649 full_shape: vec![4],
1650 shard_dim: 0,
1651 shard_sizes: vec![2, 2],
1652 },
1653 );
1654 m
1655 },
1656 };
1657
1658 let mut s0: HashMap<String, Tensor<f32>> = HashMap::new();
1659 s0.insert("w".into(), make_tensor(vec![1.0, 2.0], vec![2]));
1660 save_distributed(&s0, &dir, 0, 2, &spec).unwrap();
1661
1662 let mut s1: HashMap<String, Tensor<f32>> = HashMap::new();
1663 s1.insert("w".into(), make_tensor(vec![3.0, 4.0], vec![2]));
1664 save_distributed(&s1, &dir, 1, 2, &spec).unwrap();
1665
1666 let r0 = reshard::<f32>(&dir, 2, 2, 0).unwrap();
1668 let r1 = reshard::<f32>(&dir, 2, 2, 1).unwrap();
1669
1670 assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0]);
1671 assert_eq!(r1["w"].data().unwrap(), &[3.0, 4.0]);
1672
1673 cleanup(&dir);
1674 }
1675}