Skip to main content

ferrotorch_distributed/
checkpoint.rs

1//! Distributed checkpointing with per-rank shard saving, loading, and resharding.
2//!
3//! Each rank saves its own shard as a SafeTensors file under a shared checkpoint
4//! directory. Rank 0 additionally writes a `metadata.json` file describing how
5//! tensors are distributed across ranks. When loading, if the world size has
6//! changed, the resharding logic automatically splits or merges shards so that
7//! each new rank receives the correct slice of each tensor.
8//!
9//! # Async checkpointing
10//!
11//! [`AsyncCheckpointer`] copies tensor data to CPU memory in a background thread
12//! and writes to disk without blocking the training loop. Call
13//! [`save_async`](AsyncCheckpointer::save_async) to start a checkpoint, and
14//! [`CheckpointFuture::wait`] when you need to ensure the write completed.
15
16use std::collections::HashMap;
17use std::path::{Path, PathBuf};
18use std::sync::{Arc, Mutex};
19use std::thread;
20
21use safetensors::serialize_to_file;
22use safetensors::tensor::{Dtype, SafeTensors, TensorView};
23use serde::{Deserialize, Serialize};
24
25use ferrotorch_core::storage::TensorStorage;
26use ferrotorch_core::{FerrotorchError, Float, Tensor};
27
28// ---------------------------------------------------------------------------
29// Error type
30// ---------------------------------------------------------------------------
31
32/// Errors specific to distributed checkpointing.
33#[derive(Debug, thiserror::Error)]
34#[non_exhaustive]
35pub enum DistCheckpointError {
36    #[error("I/O error: {message}")]
37    Io { message: String },
38
39    #[error("serialization error: {message}")]
40    Serialization { message: String },
41
42    #[error("metadata error: {message}")]
43    Metadata { message: String },
44
45    #[error("shard file missing: {path}")]
46    MissingShard { path: String },
47
48    #[error("tensor error: {message}")]
49    Tensor { message: String },
50
51    #[error("invalid argument: {message}")]
52    InvalidArgument { message: String },
53
54    #[error("async checkpoint failed: {message}")]
55    AsyncFailed { message: String },
56}
57
58impl From<DistCheckpointError> for FerrotorchError {
59    fn from(e: DistCheckpointError) -> Self {
60        FerrotorchError::InvalidArgument {
61            message: e.to_string(),
62        }
63    }
64}
65
66impl From<std::io::Error> for DistCheckpointError {
67    fn from(e: std::io::Error) -> Self {
68        DistCheckpointError::Io {
69            message: e.to_string(),
70        }
71    }
72}
73
74// ---------------------------------------------------------------------------
75// Metadata types
76// ---------------------------------------------------------------------------
77
78/// Describes how a single tensor is sharded across ranks.
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct TensorShardSpec {
81    /// Full (unsharded) shape of the tensor.
82    pub full_shape: Vec<usize>,
83    /// Which dimension is split across ranks.
84    pub shard_dim: usize,
85    /// Size along `shard_dim` for each rank. The sum must equal
86    /// `full_shape[shard_dim]`.
87    pub shard_sizes: Vec<usize>,
88}
89
90/// Metadata for a distributed checkpoint: how many ranks participated and
91/// how each tensor is sharded.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ShardMetadata {
94    /// Number of ranks that produced this checkpoint.
95    pub num_ranks: usize,
96    /// Per-tensor sharding specification, keyed by tensor name.
97    pub tensor_specs: HashMap<String, TensorShardSpec>,
98}
99
100/// Handle for a distributed checkpoint directory.
101///
102/// Wraps the directory path and the shard metadata. Typically created by
103/// [`save_distributed`] (which writes files) or by reading an existing
104/// checkpoint's `metadata.json`.
105pub struct DistributedCheckpoint {
106    /// Directory where shard files are stored.
107    pub checkpoint_dir: PathBuf,
108    /// Metadata about how tensors are sharded across ranks.
109    pub shard_metadata: ShardMetadata,
110}
111
112// ---------------------------------------------------------------------------
113// Helpers
114// ---------------------------------------------------------------------------
115
116/// SafeTensors dtype for a given `Float` type.
117fn st_dtype<T: Float>() -> Result<Dtype, DistCheckpointError> {
118    match std::mem::size_of::<T>() {
119        4 => Ok(Dtype::F32),
120        8 => Ok(Dtype::F64),
121        other => Err(DistCheckpointError::InvalidArgument {
122            message: format!("unsupported element size {other} for safetensors serialization"),
123        }),
124    }
125}
126
127/// Reinterpret a `&[T]` as a byte slice (little-endian platforms).
128fn as_le_bytes<T: Float>(data: &[T]) -> &[u8] {
129    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data)) }
130}
131
132/// Build a shard file path for a given rank inside the checkpoint directory.
133fn shard_path(dir: &Path, rank: usize) -> PathBuf {
134    dir.join(format!("rank_{rank}.safetensors"))
135}
136
137/// Path to the metadata JSON file.
138fn metadata_path(dir: &Path) -> PathBuf {
139    dir.join("metadata.json")
140}
141
142/// Save a `HashMap<String, Tensor<T>>` as a single SafeTensors file.
143fn save_tensors_to_file<T: Float>(
144    tensors: &HashMap<String, Tensor<T>>,
145    path: &Path,
146) -> Result<(), DistCheckpointError> {
147    let dtype = st_dtype::<T>()?;
148
149    let mut keys: Vec<&String> = tensors.keys().collect();
150    keys.sort();
151
152    struct Entry<'a> {
153        name: String,
154        shape: Vec<usize>,
155        data: &'a [u8],
156    }
157
158    let mut entries: Vec<Entry<'_>> = Vec::with_capacity(keys.len());
159    for key in &keys {
160        let tensor = &tensors[*key];
161        let data_slice = tensor.data().map_err(|e| DistCheckpointError::Tensor {
162            message: format!("failed to read tensor \"{key}\": {e}"),
163        })?;
164        entries.push(Entry {
165            name: (*key).clone(),
166            shape: tensor.shape().to_vec(),
167            data: as_le_bytes(data_slice),
168        });
169    }
170
171    let views: Vec<(String, TensorView<'_>)> = entries
172        .iter()
173        .map(|entry| {
174            TensorView::new(dtype, entry.shape.clone(), entry.data)
175                .map(|v| (entry.name.clone(), v))
176                .map_err(|e| DistCheckpointError::Serialization {
177                    message: format!("TensorView for \"{}\": {e}", entry.name),
178                })
179        })
180        .collect::<Result<Vec<_>, _>>()?;
181
182    serialize_to_file(views, &None, path).map_err(|e| DistCheckpointError::Serialization {
183        message: format!("safetensors write to {}: {e}", path.display()),
184    })?;
185
186    Ok(())
187}
188
189/// Load a SafeTensors file into `HashMap<String, Tensor<T>>`.
190fn load_tensors_from_file<T: Float>(
191    path: &Path,
192) -> Result<HashMap<String, Tensor<T>>, DistCheckpointError> {
193    let elem_size = std::mem::size_of::<T>();
194    let expected = st_dtype::<T>()?;
195
196    let file_data = std::fs::read(path).map_err(|e| DistCheckpointError::Io {
197        message: format!("reading {}: {e}", path.display()),
198    })?;
199
200    let st =
201        SafeTensors::deserialize(&file_data).map_err(|e| DistCheckpointError::Serialization {
202            message: format!("parsing {}: {e}", path.display()),
203        })?;
204
205    let tensor_list = st.tensors();
206    let mut result: HashMap<String, Tensor<T>> = HashMap::with_capacity(tensor_list.len());
207
208    for (name, view) in &tensor_list {
209        if view.dtype() != expected {
210            return Err(DistCheckpointError::Tensor {
211                message: format!(
212                    "tensor \"{name}\" has dtype {:?}, expected {:?}",
213                    view.dtype(),
214                    expected
215                ),
216            });
217        }
218
219        let shape = view.shape().to_vec();
220        let byte_data = view.data();
221        let numel: usize = if shape.is_empty() {
222            1
223        } else {
224            shape.iter().product()
225        };
226        let expected_bytes = numel * elem_size;
227
228        if byte_data.len() != expected_bytes {
229            return Err(DistCheckpointError::Tensor {
230                message: format!(
231                    "tensor \"{name}\" has {} bytes but shape {shape:?} requires {expected_bytes}",
232                    byte_data.len()
233                ),
234            });
235        }
236
237        let data: Vec<T> = byte_data
238            .chunks_exact(elem_size)
239            .map(|chunk| {
240                let mut bytes = [0u8; 8];
241                bytes[..elem_size].copy_from_slice(chunk);
242                unsafe { std::ptr::read_unaligned(bytes.as_ptr() as *const T) }
243            })
244            .collect();
245
246        let storage = TensorStorage::cpu(data);
247        let tensor = Tensor::from_storage(storage, shape, false).map_err(|e| {
248            DistCheckpointError::Tensor {
249                message: format!("creating tensor \"{name}\": {e}"),
250            }
251        })?;
252        result.insert(name.clone(), tensor);
253    }
254
255    Ok(result)
256}
257
258// ---------------------------------------------------------------------------
259// Save / Load
260// ---------------------------------------------------------------------------
261
262/// Save a distributed checkpoint.
263///
264/// Each rank saves its own shard of the state dict to
265/// `dir/rank_<rank>.safetensors`. Rank 0 additionally writes
266/// `dir/metadata.json` containing the [`ShardMetadata`] so that future loads
267/// (potentially with a different world size) know how the data is distributed.
268///
269/// The `state_dict` should contain this rank's shard of each tensor.
270/// `shard_spec` describes the full shapes and how they are sharded.
271///
272/// # Errors
273///
274/// Returns an error if the directory cannot be created, if tensor data cannot
275/// be read (e.g., GPU tensor without prior `.cpu()`), or if serialization
276/// fails.
277pub fn save_distributed<T: Float>(
278    state_dict: &HashMap<String, Tensor<T>>,
279    dir: &Path,
280    rank: usize,
281    world_size: usize,
282    shard_spec: &ShardMetadata,
283) -> Result<(), DistCheckpointError> {
284    // Validate basic inputs.
285    if world_size == 0 {
286        return Err(DistCheckpointError::InvalidArgument {
287            message: "world_size must be >= 1".into(),
288        });
289    }
290    if rank >= world_size {
291        return Err(DistCheckpointError::InvalidArgument {
292            message: format!("rank {rank} >= world_size {world_size}"),
293        });
294    }
295
296    // Create the checkpoint directory if it doesn't exist.
297    std::fs::create_dir_all(dir)?;
298
299    // Save this rank's shard.
300    let path = shard_path(dir, rank);
301    save_tensors_to_file(state_dict, &path)?;
302
303    // Rank 0 writes the metadata file.
304    if rank == 0 {
305        let json = serde_json::to_string_pretty(shard_spec).map_err(|e| {
306            DistCheckpointError::Serialization {
307                message: format!("serializing metadata: {e}"),
308            }
309        })?;
310        std::fs::write(metadata_path(dir), json)?;
311    }
312
313    Ok(())
314}
315
316/// Load a distributed checkpoint for a specific rank.
317///
318/// Reads `dir/metadata.json` to discover the original sharding layout. If the
319/// current `world_size` matches the saved metadata, each rank simply loads its
320/// own shard file. If world sizes differ, automatic resharding is performed
321/// via [`reshard`].
322///
323/// # Errors
324///
325/// Returns an error if metadata or shard files are missing, if tensors have
326/// unexpected dtypes, or if resharding fails.
327pub fn load_distributed<T: Float>(
328    dir: &Path,
329    rank: usize,
330    world_size: usize,
331) -> Result<HashMap<String, Tensor<T>>, DistCheckpointError> {
332    if world_size == 0 {
333        return Err(DistCheckpointError::InvalidArgument {
334            message: "world_size must be >= 1".into(),
335        });
336    }
337    if rank >= world_size {
338        return Err(DistCheckpointError::InvalidArgument {
339            message: format!("rank {rank} >= world_size {world_size}"),
340        });
341    }
342
343    // Read metadata.
344    let meta_path = metadata_path(dir);
345    let meta_json = std::fs::read_to_string(&meta_path).map_err(|e| DistCheckpointError::Io {
346        message: format!("reading {}: {e}", meta_path.display()),
347    })?;
348    let metadata: ShardMetadata =
349        serde_json::from_str(&meta_json).map_err(|e| DistCheckpointError::Serialization {
350            message: format!("parsing metadata: {e}"),
351        })?;
352
353    let old_world_size = metadata.num_ranks;
354
355    if old_world_size == world_size {
356        // Same world size — just load this rank's shard directly.
357        let path = shard_path(dir, rank);
358        if !path.exists() {
359            // If there are no tensor specs, this is a legitimately empty
360            // checkpoint — return an empty map.
361            if metadata.tensor_specs.is_empty() {
362                return Ok(HashMap::new());
363            }
364            return Err(DistCheckpointError::MissingShard {
365                path: path.display().to_string(),
366            });
367        }
368        load_tensors_from_file(&path)
369    } else {
370        // Different world size — need to reshard.
371        reshard(dir, old_world_size, world_size, rank)
372    }
373}
374
375// ---------------------------------------------------------------------------
376// Resharding
377// ---------------------------------------------------------------------------
378
379/// Reshard a checkpoint from `old_world_size` ranks to `new_world_size` ranks,
380/// returning the data for `new_rank`.
381///
382/// For each tensor described in the checkpoint metadata, this function:
383///
384/// 1. Reconstructs the full (unsharded) tensor by loading and concatenating
385///    all old shard files along the shard dimension.
386/// 2. Splits the full tensor along the shard dimension into `new_world_size`
387///    pieces and returns the piece for `new_rank`.
388///
389/// This handles both scale-up (e.g., 4 GPUs to 8) and scale-down (e.g., 8
390/// GPUs to 4) as well as arbitrary remappings.
391///
392/// # Errors
393///
394/// Returns an error if shard files are missing, if tensors have unexpected
395/// shapes, or if the full tensor cannot be evenly divided for resharding.
396pub fn reshard<T: Float>(
397    dir: &Path,
398    old_world_size: usize,
399    new_world_size: usize,
400    new_rank: usize,
401) -> Result<HashMap<String, Tensor<T>>, DistCheckpointError> {
402    if new_world_size == 0 {
403        return Err(DistCheckpointError::InvalidArgument {
404            message: "new_world_size must be >= 1".into(),
405        });
406    }
407    if new_rank >= new_world_size {
408        return Err(DistCheckpointError::InvalidArgument {
409            message: format!("new_rank {new_rank} >= new_world_size {new_world_size}"),
410        });
411    }
412    if old_world_size == 0 {
413        return Err(DistCheckpointError::InvalidArgument {
414            message: "old_world_size must be >= 1".into(),
415        });
416    }
417
418    // Read metadata.
419    let meta_path = metadata_path(dir);
420    let meta_json = std::fs::read_to_string(&meta_path).map_err(|e| DistCheckpointError::Io {
421        message: format!("reading {}: {e}", meta_path.display()),
422    })?;
423    let metadata: ShardMetadata =
424        serde_json::from_str(&meta_json).map_err(|e| DistCheckpointError::Serialization {
425            message: format!("parsing metadata: {e}"),
426        })?;
427
428    // Load all old shards into memory.
429    let mut old_shards: Vec<HashMap<String, Tensor<T>>> = Vec::with_capacity(old_world_size);
430    for old_rank in 0..old_world_size {
431        let path = shard_path(dir, old_rank);
432        if !path.exists() {
433            return Err(DistCheckpointError::MissingShard {
434                path: path.display().to_string(),
435            });
436        }
437        old_shards.push(load_tensors_from_file(&path)?);
438    }
439
440    // For each tensor in the metadata, reconstruct full tensor then re-split.
441    let mut result: HashMap<String, Tensor<T>> = HashMap::new();
442
443    for (name, spec) in &metadata.tensor_specs {
444        let shard_dim = spec.shard_dim;
445        let full_shape = &spec.full_shape;
446
447        // Collect shard data from each old rank for this tensor.
448        let mut shard_datas: Vec<Vec<T>> = Vec::with_capacity(old_world_size);
449        let mut shard_shapes: Vec<Vec<usize>> = Vec::with_capacity(old_world_size);
450
451        for (old_rank, shard) in old_shards.iter().enumerate().take(old_world_size) {
452            let tensor = shard.get(name).ok_or_else(|| DistCheckpointError::Tensor {
453                message: format!("tensor \"{name}\" missing from rank {old_rank} shard"),
454            })?;
455            shard_datas.push(tensor.data_vec().map_err(|e| DistCheckpointError::Tensor {
456                message: format!("reading tensor \"{name}\" from rank {old_rank}: {e}"),
457            })?);
458            shard_shapes.push(tensor.shape().to_vec());
459        }
460
461        // Reconstruct the full tensor by concatenating along shard_dim.
462        let full_data = concat_along_dim(&shard_datas, &shard_shapes, shard_dim, full_shape)?;
463
464        // Now split the full tensor along shard_dim for the new world size.
465        let full_dim_size = full_shape[shard_dim];
466        let new_shard_sizes = compute_shard_sizes(full_dim_size, new_world_size);
467        let new_offset: usize = new_shard_sizes[..new_rank].iter().sum();
468        let new_size = new_shard_sizes[new_rank];
469
470        // Extract this rank's slice from the full data.
471        let mut new_shape = full_shape.clone();
472        new_shape[shard_dim] = new_size;
473
474        let new_data = slice_along_dim(&full_data, full_shape, shard_dim, new_offset, new_size);
475
476        let tensor =
477            Tensor::from_storage(TensorStorage::cpu(new_data), new_shape, false).map_err(|e| {
478                DistCheckpointError::Tensor {
479                    message: format!("creating resharded tensor \"{name}\": {e}"),
480                }
481            })?;
482
483        result.insert(name.clone(), tensor);
484    }
485
486    Ok(result)
487}
488
489/// Compute how to divide `total` elements across `num_parts` as evenly as
490/// possible. The first `total % num_parts` parts get one extra element.
491fn compute_shard_sizes(total: usize, num_parts: usize) -> Vec<usize> {
492    let base = total / num_parts;
493    let remainder = total % num_parts;
494    (0..num_parts)
495        .map(|i| if i < remainder { base + 1 } else { base })
496        .collect()
497}
498
499/// Concatenate tensor data from multiple shards along a given dimension.
500///
501/// Each entry in `shard_datas` is a flat (row-major) data array for that shard.
502/// `shard_shapes` gives the shape of each shard. `full_shape` is the expected
503/// output shape after concatenation.
504fn concat_along_dim<T: Float>(
505    shard_datas: &[Vec<T>],
506    shard_shapes: &[Vec<usize>],
507    dim: usize,
508    full_shape: &[usize],
509) -> Result<Vec<T>, DistCheckpointError> {
510    let ndim = full_shape.len();
511    if dim >= ndim {
512        return Err(DistCheckpointError::InvalidArgument {
513            message: format!("shard_dim {dim} >= ndim {ndim}"),
514        });
515    }
516
517    let full_numel: usize = full_shape.iter().product();
518    let mut full_data = vec![<T as num_traits::Zero>::zero(); full_numel];
519
520    // For concatenation along `dim`, we think of the tensor as having three
521    // logical parts:
522    //   outer = product of dims before `dim`
523    //   middle = dim (varies per shard)
524    //   inner = product of dims after `dim`
525    let outer: usize = full_shape[..dim].iter().product();
526    let inner: usize = full_shape[dim + 1..].iter().product();
527    let full_middle = full_shape[dim];
528
529    // Walk through shards and copy their data into the right offsets.
530    let mut dim_offset = 0;
531    for (shard_idx, shard_data) in shard_datas.iter().enumerate() {
532        let shard_middle = shard_shapes[shard_idx][dim];
533
534        // Validate shard shape dimensions except along the shard dim.
535        for d in 0..ndim {
536            if d != dim && shard_shapes[shard_idx][d] != full_shape[d] {
537                return Err(DistCheckpointError::Tensor {
538                    message: format!(
539                        "shard {shard_idx} has shape {:?} but expected dim {d} to be {} (full shape {full_shape:?})",
540                        shard_shapes[shard_idx], full_shape[d]
541                    ),
542                });
543            }
544        }
545
546        for o in 0..outer {
547            let src_start = o * shard_middle * inner;
548            let dst_start = o * full_middle * inner + dim_offset * inner;
549            let count = shard_middle * inner;
550
551            full_data[dst_start..dst_start + count]
552                .copy_from_slice(&shard_data[src_start..src_start + count]);
553        }
554
555        dim_offset += shard_middle;
556    }
557
558    if dim_offset != full_middle {
559        return Err(DistCheckpointError::Tensor {
560            message: format!(
561                "shard sizes along dim {dim} sum to {dim_offset}, expected {full_middle}"
562            ),
563        });
564    }
565
566    Ok(full_data)
567}
568
569/// Extract a contiguous slice of a tensor along a given dimension.
570///
571/// Returns the flat data for `shape` with `shape[dim]` replaced by `size`,
572/// starting at offset `offset` along that dimension.
573fn slice_along_dim<T: Float>(
574    data: &[T],
575    shape: &[usize],
576    dim: usize,
577    offset: usize,
578    size: usize,
579) -> Vec<T> {
580    let outer: usize = shape[..dim].iter().product();
581    let full_middle = shape[dim];
582    let inner: usize = shape[dim + 1..].iter().product();
583
584    let out_numel = outer * size * inner;
585    let mut result = Vec::with_capacity(out_numel);
586
587    for o in 0..outer {
588        let src_start = o * full_middle * inner + offset * inner;
589        let count = size * inner;
590        result.extend_from_slice(&data[src_start..src_start + count]);
591    }
592
593    result
594}
595
596// ---------------------------------------------------------------------------
597// Async checkpointing
598// ---------------------------------------------------------------------------
599
600/// Result of an asynchronous checkpoint operation.
601///
602/// The checkpoint is being written in a background thread. Call [`wait`](Self::wait)
603/// to block until completion and retrieve any error.
604pub struct CheckpointFuture {
605    handle: Option<thread::JoinHandle<Result<(), DistCheckpointError>>>,
606    /// Cached result after the thread has been joined.
607    result: Option<Result<(), DistCheckpointError>>,
608}
609
610impl CheckpointFuture {
611    /// Block until the checkpoint write completes.
612    ///
613    /// Returns `Ok(())` if the write succeeded, or the error encountered
614    /// during serialization/I/O.
615    ///
616    /// Calling `wait()` multiple times is safe — subsequent calls return
617    /// the cached result.
618    pub fn wait(&mut self) -> Result<(), DistCheckpointError> {
619        if let Some(handle) = self.handle.take() {
620            let res = handle
621                .join()
622                .map_err(|_| DistCheckpointError::AsyncFailed {
623                    message: "background checkpoint thread panicked".into(),
624                })?;
625            self.result = Some(res);
626        }
627
628        match &self.result {
629            Some(Ok(())) => Ok(()),
630            Some(Err(e)) => Err(DistCheckpointError::AsyncFailed {
631                message: format!("{e}"),
632            }),
633            None => Err(DistCheckpointError::AsyncFailed {
634                message: "no checkpoint was started".into(),
635            }),
636        }
637    }
638
639    /// Returns `true` if the background write has completed (success or failure).
640    pub fn is_done(&self) -> bool {
641        if self.result.is_some() {
642            return true;
643        }
644        match &self.handle {
645            Some(h) => h.is_finished(),
646            None => true,
647        }
648    }
649}
650
651/// Asynchronous checkpointer that writes shards to disk in a background thread.
652///
653/// Before writing, tensor data is staged (copied) into CPU memory so that
654/// training can continue on GPU without blocking. The actual file I/O happens
655/// in a separate thread.
656///
657/// # Usage
658///
659/// ```ignore
660/// let ckpt = AsyncCheckpointer::new(
661///     checkpoint_dir.clone(),
662///     rank,
663///     world_size,
664///     shard_metadata.clone(),
665/// );
666///
667/// // Non-blocking: training continues immediately.
668/// let mut future = ckpt.save_async(&state_dict)?;
669///
670/// // ... continue training ...
671///
672/// // When you need to be sure it finished:
673/// future.wait()?;
674/// ```
675pub struct AsyncCheckpointer {
676    dir: PathBuf,
677    rank: usize,
678    world_size: usize,
679    shard_spec: ShardMetadata,
680    /// Guard against concurrent saves: only one checkpoint at a time.
681    in_flight: Arc<Mutex<bool>>,
682}
683
684impl AsyncCheckpointer {
685    /// Create a new async checkpointer.
686    ///
687    /// - `dir`: directory where shard files will be written.
688    /// - `rank`: this process's rank.
689    /// - `world_size`: total number of ranks.
690    /// - `shard_spec`: metadata describing how tensors are sharded.
691    pub fn new(dir: PathBuf, rank: usize, world_size: usize, shard_spec: ShardMetadata) -> Self {
692        Self {
693            dir,
694            rank,
695            world_size,
696            shard_spec,
697            in_flight: Arc::new(Mutex::new(false)),
698        }
699    }
700
701    /// The checkpoint directory.
702    pub fn dir(&self) -> &Path {
703        &self.dir
704    }
705
706    /// This process's rank.
707    pub fn rank(&self) -> usize {
708        self.rank
709    }
710
711    /// Total number of ranks.
712    pub fn world_size(&self) -> usize {
713        self.world_size
714    }
715
716    /// Start an asynchronous checkpoint.
717    ///
718    /// This immediately copies all tensor data to CPU-owned `Vec`s (staging),
719    /// then spawns a background thread that serializes and writes the shard
720    /// file. For GPU tensors, the staging step transfers data to host memory.
721    ///
722    /// Returns a [`CheckpointFuture`] that can be polled or waited on.
723    ///
724    /// # Errors
725    ///
726    /// Returns an error immediately if another async save is already in flight,
727    /// or if staging (GPU-to-CPU copy) fails.
728    pub fn save_async(
729        &self,
730        state_dict: &HashMap<String, Tensor<f32>>,
731    ) -> Result<CheckpointFuture, DistCheckpointError> {
732        // Check that no other save is in progress.
733        {
734            let mut guard =
735                self.in_flight
736                    .lock()
737                    .map_err(|e| DistCheckpointError::AsyncFailed {
738                        message: format!("lock poisoned: {e}"),
739                    })?;
740            if *guard {
741                return Err(DistCheckpointError::AsyncFailed {
742                    message: "another async checkpoint is already in flight".into(),
743                });
744            }
745            *guard = true;
746        }
747
748        // Stage: copy all tensor data to CPU-owned Vecs. This is the only part
749        // that touches GPU memory and must happen on the calling thread.
750        let mut staged: HashMap<String, (Vec<f32>, Vec<usize>)> =
751            HashMap::with_capacity(state_dict.len());
752
753        for (name, tensor) in state_dict {
754            let data = tensor.data_vec().map_err(|e| {
755                // Release the lock on error.
756                if let Ok(mut g) = self.in_flight.lock() {
757                    *g = false;
758                }
759                DistCheckpointError::Tensor {
760                    message: format!("staging tensor \"{name}\": {e}"),
761                }
762            })?;
763            let shape = tensor.shape().to_vec();
764            staged.insert(name.clone(), (data, shape));
765        }
766
767        // Capture everything the background thread needs.
768        let dir = self.dir.clone();
769        let rank = self.rank;
770        let shard_spec = self.shard_spec.clone();
771        let in_flight = Arc::clone(&self.in_flight);
772
773        let handle = thread::spawn(move || {
774            let result = (|| -> Result<(), DistCheckpointError> {
775                // Rebuild tensors from staged data.
776                let mut tensors: HashMap<String, Tensor<f32>> =
777                    HashMap::with_capacity(staged.len());
778                for (name, (data, shape)) in staged {
779                    let tensor = Tensor::from_storage(TensorStorage::cpu(data), shape, false)
780                        .map_err(|e| DistCheckpointError::Tensor {
781                            message: format!("rebuilding tensor \"{name}\": {e}"),
782                        })?;
783                    tensors.insert(name, tensor);
784                }
785
786                // Write to disk.
787                std::fs::create_dir_all(&dir)?;
788                let path = shard_path(&dir, rank);
789                save_tensors_to_file(&tensors, &path)?;
790
791                // Rank 0 writes metadata.
792                if rank == 0 {
793                    let json = serde_json::to_string_pretty(&shard_spec).map_err(|e| {
794                        DistCheckpointError::Serialization {
795                            message: format!("serializing metadata: {e}"),
796                        }
797                    })?;
798                    std::fs::write(metadata_path(&dir), json)?;
799                }
800
801                Ok(())
802            })();
803
804            // Release the in-flight guard.
805            if let Ok(mut g) = in_flight.lock() {
806                *g = false;
807            }
808
809            result
810        });
811
812        Ok(CheckpointFuture {
813            handle: Some(handle),
814            result: None,
815        })
816    }
817}
818
819// ---------------------------------------------------------------------------
820// Convenience: auto-generate ShardMetadata from a flat state dict
821// ---------------------------------------------------------------------------
822
823/// Build a [`ShardMetadata`] where every tensor is sharded along dimension 0
824/// with equal shard sizes.
825///
826/// This is the common case for FSDP where parameters are simply chunked along
827/// the flattened (dim-0) axis.
828pub fn flat_shard_metadata(
829    state_dict: &HashMap<String, Tensor<f32>>,
830    world_size: usize,
831) -> ShardMetadata {
832    let mut tensor_specs = HashMap::new();
833    for (name, tensor) in state_dict {
834        let shape = tensor.shape();
835        // For FSDP-style flat sharding, the shard tensor is 1-D with shape
836        // [chunk_size]. The full tensor has shape [chunk_size * world_size].
837        let shard_numel = shape.iter().product::<usize>();
838        let full_numel = shard_numel * world_size;
839        let shard_sizes = vec![shard_numel; world_size];
840        tensor_specs.insert(
841            name.clone(),
842            TensorShardSpec {
843                full_shape: vec![full_numel],
844                shard_dim: 0,
845                shard_sizes,
846            },
847        );
848    }
849    ShardMetadata {
850        num_ranks: world_size,
851        tensor_specs,
852    }
853}
854
855// ---------------------------------------------------------------------------
856// Tests
857// ---------------------------------------------------------------------------
858
859#[cfg(test)]
860mod tests {
861    use super::*;
862    use ferrotorch_core::Tensor;
863    use ferrotorch_core::storage::TensorStorage;
864    use std::collections::HashMap;
865
866    fn make_tensor(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
867        Tensor::from_storage(TensorStorage::cpu(data), shape, false).unwrap()
868    }
869
870    fn temp_dir(name: &str) -> PathBuf {
871        std::env::temp_dir()
872            .join("ferrotorch_test_dist_ckpt")
873            .join(name)
874    }
875
876    fn cleanup(dir: &Path) {
877        let _ = std::fs::remove_dir_all(dir);
878    }
879
880    // --- save_distributed / load_distributed roundtrip ---
881
882    #[test]
883    fn test_save_load_single_rank() {
884        let dir = temp_dir("single_rank");
885        cleanup(&dir);
886
887        let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
888        state.insert(
889            "weight".into(),
890            make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]),
891        );
892        state.insert("bias".into(), make_tensor(vec![0.1, 0.2], vec![2]));
893
894        let spec = ShardMetadata {
895            num_ranks: 1,
896            tensor_specs: {
897                let mut m = HashMap::new();
898                m.insert(
899                    "weight".into(),
900                    TensorShardSpec {
901                        full_shape: vec![4],
902                        shard_dim: 0,
903                        shard_sizes: vec![4],
904                    },
905                );
906                m.insert(
907                    "bias".into(),
908                    TensorShardSpec {
909                        full_shape: vec![2],
910                        shard_dim: 0,
911                        shard_sizes: vec![2],
912                    },
913                );
914                m
915            },
916        };
917
918        save_distributed(&state, &dir, 0, 1, &spec).unwrap();
919        let loaded = load_distributed::<f32>(&dir, 0, 1).unwrap();
920
921        assert_eq!(loaded.len(), 2);
922        assert_eq!(loaded["weight"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
923        assert_eq!(loaded["bias"].data().unwrap(), &[0.1, 0.2]);
924
925        cleanup(&dir);
926    }
927
928    #[test]
929    fn test_save_load_two_ranks() {
930        let dir = temp_dir("two_ranks");
931        cleanup(&dir);
932
933        // Rank 0 shard: first half of weight.
934        let mut state0: HashMap<String, Tensor<f32>> = HashMap::new();
935        state0.insert("weight".into(), make_tensor(vec![1.0, 2.0], vec![2]));
936
937        // Rank 1 shard: second half of weight.
938        let mut state1: HashMap<String, Tensor<f32>> = HashMap::new();
939        state1.insert("weight".into(), make_tensor(vec![3.0, 4.0], vec![2]));
940
941        let spec = ShardMetadata {
942            num_ranks: 2,
943            tensor_specs: {
944                let mut m = HashMap::new();
945                m.insert(
946                    "weight".into(),
947                    TensorShardSpec {
948                        full_shape: vec![4],
949                        shard_dim: 0,
950                        shard_sizes: vec![2, 2],
951                    },
952                );
953                m
954            },
955        };
956
957        save_distributed(&state0, &dir, 0, 2, &spec).unwrap();
958        save_distributed(&state1, &dir, 1, 2, &spec).unwrap();
959
960        // Load each rank's shard (same world size).
961        let loaded0 = load_distributed::<f32>(&dir, 0, 2).unwrap();
962        let loaded1 = load_distributed::<f32>(&dir, 1, 2).unwrap();
963
964        assert_eq!(loaded0["weight"].data().unwrap(), &[1.0, 2.0]);
965        assert_eq!(loaded1["weight"].data().unwrap(), &[3.0, 4.0]);
966
967        cleanup(&dir);
968    }
969
970    // --- Resharding tests ---
971
972    #[test]
973    fn test_reshard_2_to_4() {
974        // Saved with 2 ranks, load with 4 ranks.
975        let dir = temp_dir("reshard_2_to_4");
976        cleanup(&dir);
977
978        // Full tensor: [1, 2, 3, 4, 5, 6, 7, 8], shape [8].
979        // 2 ranks: rank 0 = [1,2,3,4], rank 1 = [5,6,7,8].
980        let mut state0: HashMap<String, Tensor<f32>> = HashMap::new();
981        state0.insert("w".into(), make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]));
982
983        let mut state1: HashMap<String, Tensor<f32>> = HashMap::new();
984        state1.insert("w".into(), make_tensor(vec![5.0, 6.0, 7.0, 8.0], vec![4]));
985
986        let spec = ShardMetadata {
987            num_ranks: 2,
988            tensor_specs: {
989                let mut m = HashMap::new();
990                m.insert(
991                    "w".into(),
992                    TensorShardSpec {
993                        full_shape: vec![8],
994                        shard_dim: 0,
995                        shard_sizes: vec![4, 4],
996                    },
997                );
998                m
999            },
1000        };
1001
1002        save_distributed(&state0, &dir, 0, 2, &spec).unwrap();
1003        save_distributed(&state1, &dir, 1, 2, &spec).unwrap();
1004
1005        // Reshard to 4 ranks. Each rank gets 2 elements.
1006        let r0 = reshard::<f32>(&dir, 2, 4, 0).unwrap();
1007        let r1 = reshard::<f32>(&dir, 2, 4, 1).unwrap();
1008        let r2 = reshard::<f32>(&dir, 2, 4, 2).unwrap();
1009        let r3 = reshard::<f32>(&dir, 2, 4, 3).unwrap();
1010
1011        assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0]);
1012        assert_eq!(r1["w"].data().unwrap(), &[3.0, 4.0]);
1013        assert_eq!(r2["w"].data().unwrap(), &[5.0, 6.0]);
1014        assert_eq!(r3["w"].data().unwrap(), &[7.0, 8.0]);
1015
1016        cleanup(&dir);
1017    }
1018
1019    #[test]
1020    fn test_reshard_4_to_2() {
1021        // Saved with 4 ranks, load with 2.
1022        let dir = temp_dir("reshard_4_to_2");
1023        cleanup(&dir);
1024
1025        let spec = ShardMetadata {
1026            num_ranks: 4,
1027            tensor_specs: {
1028                let mut m = HashMap::new();
1029                m.insert(
1030                    "w".into(),
1031                    TensorShardSpec {
1032                        full_shape: vec![8],
1033                        shard_dim: 0,
1034                        shard_sizes: vec![2, 2, 2, 2],
1035                    },
1036                );
1037                m
1038            },
1039        };
1040
1041        for rank in 0..4 {
1042            let start = rank as f32 * 2.0 + 1.0;
1043            let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1044            state.insert("w".into(), make_tensor(vec![start, start + 1.0], vec![2]));
1045            save_distributed(&state, &dir, rank, 4, &spec).unwrap();
1046        }
1047
1048        // Reshard to 2 ranks. Each rank gets 4 elements.
1049        let r0 = reshard::<f32>(&dir, 4, 2, 0).unwrap();
1050        let r1 = reshard::<f32>(&dir, 4, 2, 1).unwrap();
1051
1052        assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
1053        assert_eq!(r1["w"].data().unwrap(), &[5.0, 6.0, 7.0, 8.0]);
1054
1055        cleanup(&dir);
1056    }
1057
1058    #[test]
1059    fn test_reshard_2d_tensor() {
1060        // Full tensor: [[1,2,3],[4,5,6],[7,8,9],[10,11,12]], shape [4, 3].
1061        // Shard along dim 0: rank 0 gets rows 0-1, rank 1 gets rows 2-3.
1062        let dir = temp_dir("reshard_2d");
1063        cleanup(&dir);
1064
1065        let mut state0: HashMap<String, Tensor<f32>> = HashMap::new();
1066        state0.insert(
1067            "w".into(),
1068            make_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]),
1069        );
1070
1071        let mut state1: HashMap<String, Tensor<f32>> = HashMap::new();
1072        state1.insert(
1073            "w".into(),
1074            make_tensor(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], vec![2, 3]),
1075        );
1076
1077        let spec = ShardMetadata {
1078            num_ranks: 2,
1079            tensor_specs: {
1080                let mut m = HashMap::new();
1081                m.insert(
1082                    "w".into(),
1083                    TensorShardSpec {
1084                        full_shape: vec![4, 3],
1085                        shard_dim: 0,
1086                        shard_sizes: vec![2, 2],
1087                    },
1088                );
1089                m
1090            },
1091        };
1092
1093        save_distributed(&state0, &dir, 0, 2, &spec).unwrap();
1094        save_distributed(&state1, &dir, 1, 2, &spec).unwrap();
1095
1096        // Reshard to 4 ranks: each gets 1 row.
1097        let r0 = reshard::<f32>(&dir, 2, 4, 0).unwrap();
1098        let r1 = reshard::<f32>(&dir, 2, 4, 1).unwrap();
1099        let r2 = reshard::<f32>(&dir, 2, 4, 2).unwrap();
1100        let r3 = reshard::<f32>(&dir, 2, 4, 3).unwrap();
1101
1102        assert_eq!(r0["w"].shape(), &[1, 3]);
1103        assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0, 3.0]);
1104        assert_eq!(r1["w"].shape(), &[1, 3]);
1105        assert_eq!(r1["w"].data().unwrap(), &[4.0, 5.0, 6.0]);
1106        assert_eq!(r2["w"].shape(), &[1, 3]);
1107        assert_eq!(r2["w"].data().unwrap(), &[7.0, 8.0, 9.0]);
1108        assert_eq!(r3["w"].shape(), &[1, 3]);
1109        assert_eq!(r3["w"].data().unwrap(), &[10.0, 11.0, 12.0]);
1110
1111        cleanup(&dir);
1112    }
1113
1114    #[test]
1115    fn test_reshard_dim1() {
1116        // Full tensor: [[1,2,3,4],[5,6,7,8]], shape [2, 4].
1117        // Shard along dim 1: rank 0 gets cols 0-1, rank 1 gets cols 2-3.
1118        let dir = temp_dir("reshard_dim1");
1119        cleanup(&dir);
1120
1121        let mut state0: HashMap<String, Tensor<f32>> = HashMap::new();
1122        state0.insert(
1123            "w".into(),
1124            make_tensor(vec![1.0, 2.0, 5.0, 6.0], vec![2, 2]),
1125        );
1126
1127        let mut state1: HashMap<String, Tensor<f32>> = HashMap::new();
1128        state1.insert(
1129            "w".into(),
1130            make_tensor(vec![3.0, 4.0, 7.0, 8.0], vec![2, 2]),
1131        );
1132
1133        let spec = ShardMetadata {
1134            num_ranks: 2,
1135            tensor_specs: {
1136                let mut m = HashMap::new();
1137                m.insert(
1138                    "w".into(),
1139                    TensorShardSpec {
1140                        full_shape: vec![2, 4],
1141                        shard_dim: 1,
1142                        shard_sizes: vec![2, 2],
1143                    },
1144                );
1145                m
1146            },
1147        };
1148
1149        save_distributed(&state0, &dir, 0, 2, &spec).unwrap();
1150        save_distributed(&state1, &dir, 1, 2, &spec).unwrap();
1151
1152        // Reshard to 4 ranks: each gets 1 column.
1153        let r0 = reshard::<f32>(&dir, 2, 4, 0).unwrap();
1154        let r1 = reshard::<f32>(&dir, 2, 4, 1).unwrap();
1155        let r2 = reshard::<f32>(&dir, 2, 4, 2).unwrap();
1156        let r3 = reshard::<f32>(&dir, 2, 4, 3).unwrap();
1157
1158        assert_eq!(r0["w"].shape(), &[2, 1]);
1159        assert_eq!(r0["w"].data().unwrap(), &[1.0, 5.0]);
1160        assert_eq!(r1["w"].shape(), &[2, 1]);
1161        assert_eq!(r1["w"].data().unwrap(), &[2.0, 6.0]);
1162        assert_eq!(r2["w"].shape(), &[2, 1]);
1163        assert_eq!(r2["w"].data().unwrap(), &[3.0, 7.0]);
1164        assert_eq!(r3["w"].shape(), &[2, 1]);
1165        assert_eq!(r3["w"].data().unwrap(), &[4.0, 8.0]);
1166
1167        cleanup(&dir);
1168    }
1169
1170    #[test]
1171    fn test_reshard_3_to_2_uneven() {
1172        // Full tensor: [1,2,3,4,5,6,7,8,9], shape [9].
1173        // 3 ranks with equal shards of 3 each.
1174        let dir = temp_dir("reshard_3_to_2");
1175        cleanup(&dir);
1176
1177        let spec = ShardMetadata {
1178            num_ranks: 3,
1179            tensor_specs: {
1180                let mut m = HashMap::new();
1181                m.insert(
1182                    "w".into(),
1183                    TensorShardSpec {
1184                        full_shape: vec![9],
1185                        shard_dim: 0,
1186                        shard_sizes: vec![3, 3, 3],
1187                    },
1188                );
1189                m
1190            },
1191        };
1192
1193        for rank in 0..3usize {
1194            let start = rank as f32 * 3.0 + 1.0;
1195            let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1196            state.insert(
1197                "w".into(),
1198                make_tensor(vec![start, start + 1.0, start + 2.0], vec![3]),
1199            );
1200            save_distributed(&state, &dir, rank, 3, &spec).unwrap();
1201        }
1202
1203        // Reshard to 2 ranks: 9 / 2 = 4 rem 1, so rank 0 gets 5, rank 1 gets 4.
1204        let r0 = reshard::<f32>(&dir, 3, 2, 0).unwrap();
1205        let r1 = reshard::<f32>(&dir, 3, 2, 1).unwrap();
1206
1207        assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
1208        assert_eq!(r1["w"].data().unwrap(), &[6.0, 7.0, 8.0, 9.0]);
1209
1210        cleanup(&dir);
1211    }
1212
1213    // --- load_distributed triggers resharding ---
1214
1215    #[test]
1216    fn test_load_distributed_reshards_when_world_size_differs() {
1217        let dir = temp_dir("load_reshard");
1218        cleanup(&dir);
1219
1220        // Save with 2 ranks.
1221        let spec = ShardMetadata {
1222            num_ranks: 2,
1223            tensor_specs: {
1224                let mut m = HashMap::new();
1225                m.insert(
1226                    "w".into(),
1227                    TensorShardSpec {
1228                        full_shape: vec![4],
1229                        shard_dim: 0,
1230                        shard_sizes: vec![2, 2],
1231                    },
1232                );
1233                m
1234            },
1235        };
1236
1237        let mut s0: HashMap<String, Tensor<f32>> = HashMap::new();
1238        s0.insert("w".into(), make_tensor(vec![1.0, 2.0], vec![2]));
1239        save_distributed(&s0, &dir, 0, 2, &spec).unwrap();
1240
1241        let mut s1: HashMap<String, Tensor<f32>> = HashMap::new();
1242        s1.insert("w".into(), make_tensor(vec![3.0, 4.0], vec![2]));
1243        save_distributed(&s1, &dir, 1, 2, &spec).unwrap();
1244
1245        // Load with 4 ranks — triggers resharding.
1246        let r0 = load_distributed::<f32>(&dir, 0, 4).unwrap();
1247        let r1 = load_distributed::<f32>(&dir, 1, 4).unwrap();
1248        let r2 = load_distributed::<f32>(&dir, 2, 4).unwrap();
1249        let r3 = load_distributed::<f32>(&dir, 3, 4).unwrap();
1250
1251        assert_eq!(r0["w"].data().unwrap(), &[1.0]);
1252        assert_eq!(r1["w"].data().unwrap(), &[2.0]);
1253        assert_eq!(r2["w"].data().unwrap(), &[3.0]);
1254        assert_eq!(r3["w"].data().unwrap(), &[4.0]);
1255
1256        cleanup(&dir);
1257    }
1258
1259    // --- Metadata serialization ---
1260
1261    #[test]
1262    fn test_metadata_roundtrip() {
1263        let spec = ShardMetadata {
1264            num_ranks: 4,
1265            tensor_specs: {
1266                let mut m = HashMap::new();
1267                m.insert(
1268                    "layer.weight".into(),
1269                    TensorShardSpec {
1270                        full_shape: vec![256, 512],
1271                        shard_dim: 0,
1272                        shard_sizes: vec![64, 64, 64, 64],
1273                    },
1274                );
1275                m.insert(
1276                    "layer.bias".into(),
1277                    TensorShardSpec {
1278                        full_shape: vec![256],
1279                        shard_dim: 0,
1280                        shard_sizes: vec![64, 64, 64, 64],
1281                    },
1282                );
1283                m
1284            },
1285        };
1286
1287        let json = serde_json::to_string_pretty(&spec).unwrap();
1288        let loaded: ShardMetadata = serde_json::from_str(&json).unwrap();
1289
1290        assert_eq!(loaded.num_ranks, 4);
1291        assert_eq!(loaded.tensor_specs.len(), 2);
1292        assert_eq!(
1293            loaded.tensor_specs["layer.weight"].full_shape,
1294            vec![256, 512]
1295        );
1296        assert_eq!(loaded.tensor_specs["layer.weight"].shard_dim, 0);
1297        assert_eq!(
1298            loaded.tensor_specs["layer.weight"].shard_sizes,
1299            vec![64, 64, 64, 64]
1300        );
1301    }
1302
1303    // --- compute_shard_sizes ---
1304
1305    #[test]
1306    fn test_compute_shard_sizes_even() {
1307        assert_eq!(compute_shard_sizes(8, 4), vec![2, 2, 2, 2]);
1308        assert_eq!(compute_shard_sizes(12, 3), vec![4, 4, 4]);
1309    }
1310
1311    #[test]
1312    fn test_compute_shard_sizes_uneven() {
1313        // 9 / 2 = 4 rem 1 -> [5, 4]
1314        assert_eq!(compute_shard_sizes(9, 2), vec![5, 4]);
1315        // 10 / 3 = 3 rem 1 -> [4, 3, 3]
1316        assert_eq!(compute_shard_sizes(10, 3), vec![4, 3, 3]);
1317        // 7 / 4 = 1 rem 3 -> [2, 2, 2, 1]
1318        assert_eq!(compute_shard_sizes(7, 4), vec![2, 2, 2, 1]);
1319    }
1320
1321    // --- concat_along_dim ---
1322
1323    #[test]
1324    fn test_concat_1d() {
1325        let data0 = vec![1.0f32, 2.0];
1326        let data1 = vec![3.0f32, 4.0, 5.0];
1327        let full_shape = vec![5];
1328
1329        let result =
1330            concat_along_dim(&[data0, data1], &[vec![2], vec![3]], 0, &full_shape).unwrap();
1331
1332        assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1333    }
1334
1335    #[test]
1336    fn test_concat_2d_dim0() {
1337        // Two 1x3 matrices -> 2x3 matrix.
1338        let data0 = vec![1.0f32, 2.0, 3.0];
1339        let data1 = vec![4.0f32, 5.0, 6.0];
1340        let full_shape = vec![2, 3];
1341
1342        let result =
1343            concat_along_dim(&[data0, data1], &[vec![1, 3], vec![1, 3]], 0, &full_shape).unwrap();
1344
1345        assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1346    }
1347
1348    #[test]
1349    fn test_concat_2d_dim1() {
1350        // Two 2x1 matrices -> 2x2 matrix.
1351        // data0 = [[1],[3]], data1 = [[2],[4]]
1352        let data0 = vec![1.0f32, 3.0];
1353        let data1 = vec![2.0f32, 4.0];
1354        let full_shape = vec![2, 2];
1355
1356        let result =
1357            concat_along_dim(&[data0, data1], &[vec![2, 1], vec![2, 1]], 1, &full_shape).unwrap();
1358
1359        assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0]);
1360    }
1361
1362    // --- slice_along_dim ---
1363
1364    #[test]
1365    fn test_slice_1d() {
1366        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
1367        let shape = vec![5];
1368
1369        let s0 = slice_along_dim(&data, &shape, 0, 0, 2);
1370        assert_eq!(s0, vec![1.0, 2.0]);
1371
1372        let s1 = slice_along_dim(&data, &shape, 0, 2, 3);
1373        assert_eq!(s1, vec![3.0, 4.0, 5.0]);
1374    }
1375
1376    #[test]
1377    fn test_slice_2d_dim0() {
1378        // [[1,2,3],[4,5,6],[7,8,9],[10,11,12]], shape [4,3]
1379        let data = vec![
1380            1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1381        ];
1382        let shape = vec![4, 3];
1383
1384        let s = slice_along_dim(&data, &shape, 0, 1, 2);
1385        assert_eq!(s, vec![4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
1386    }
1387
1388    #[test]
1389    fn test_slice_2d_dim1() {
1390        // [[1,2,3,4],[5,6,7,8]], shape [2,4]
1391        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1392        let shape = vec![2, 4];
1393
1394        let s = slice_along_dim(&data, &shape, 1, 1, 2);
1395        assert_eq!(s, vec![2.0, 3.0, 6.0, 7.0]);
1396    }
1397
1398    // --- flat_shard_metadata ---
1399
1400    #[test]
1401    fn test_flat_shard_metadata() {
1402        let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1403        state.insert("w".into(), make_tensor(vec![1.0, 2.0, 3.0], vec![3]));
1404
1405        let meta = flat_shard_metadata(&state, 4);
1406        assert_eq!(meta.num_ranks, 4);
1407
1408        let spec = &meta.tensor_specs["w"];
1409        assert_eq!(spec.full_shape, vec![12]); // 3 * 4
1410        assert_eq!(spec.shard_dim, 0);
1411        assert_eq!(spec.shard_sizes, vec![3, 3, 3, 3]);
1412    }
1413
1414    // --- AsyncCheckpointer ---
1415
1416    #[test]
1417    fn test_async_checkpoint_basic() {
1418        let dir = temp_dir("async_basic");
1419        cleanup(&dir);
1420
1421        let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1422        state.insert("w".into(), make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]));
1423
1424        let spec = ShardMetadata {
1425            num_ranks: 1,
1426            tensor_specs: {
1427                let mut m = HashMap::new();
1428                m.insert(
1429                    "w".into(),
1430                    TensorShardSpec {
1431                        full_shape: vec![4],
1432                        shard_dim: 0,
1433                        shard_sizes: vec![4],
1434                    },
1435                );
1436                m
1437            },
1438        };
1439
1440        let ckpt = AsyncCheckpointer::new(dir.clone(), 0, 1, spec);
1441        let mut future = ckpt.save_async(&state).unwrap();
1442        future.wait().unwrap();
1443
1444        // Verify the file was written correctly.
1445        let loaded = load_distributed::<f32>(&dir, 0, 1).unwrap();
1446        assert_eq!(loaded["w"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
1447
1448        cleanup(&dir);
1449    }
1450
1451    #[test]
1452    fn test_async_checkpoint_wait_idempotent() {
1453        let dir = temp_dir("async_idempotent");
1454        cleanup(&dir);
1455
1456        let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1457        state.insert("x".into(), make_tensor(vec![42.0], vec![1]));
1458
1459        let spec = ShardMetadata {
1460            num_ranks: 1,
1461            tensor_specs: {
1462                let mut m = HashMap::new();
1463                m.insert(
1464                    "x".into(),
1465                    TensorShardSpec {
1466                        full_shape: vec![1],
1467                        shard_dim: 0,
1468                        shard_sizes: vec![1],
1469                    },
1470                );
1471                m
1472            },
1473        };
1474
1475        let ckpt = AsyncCheckpointer::new(dir.clone(), 0, 1, spec);
1476        let mut future = ckpt.save_async(&state).unwrap();
1477
1478        // Wait twice — should not panic.
1479        future.wait().unwrap();
1480        future.wait().unwrap();
1481
1482        cleanup(&dir);
1483    }
1484
1485    #[test]
1486    fn test_async_checkpoint_is_done() {
1487        let dir = temp_dir("async_is_done");
1488        cleanup(&dir);
1489
1490        let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1491        state.insert("x".into(), make_tensor(vec![1.0], vec![1]));
1492
1493        let spec = ShardMetadata {
1494            num_ranks: 1,
1495            tensor_specs: {
1496                let mut m = HashMap::new();
1497                m.insert(
1498                    "x".into(),
1499                    TensorShardSpec {
1500                        full_shape: vec![1],
1501                        shard_dim: 0,
1502                        shard_sizes: vec![1],
1503                    },
1504                );
1505                m
1506            },
1507        };
1508
1509        let ckpt = AsyncCheckpointer::new(dir.clone(), 0, 1, spec);
1510        let mut future = ckpt.save_async(&state).unwrap();
1511        future.wait().unwrap();
1512        assert!(future.is_done());
1513
1514        cleanup(&dir);
1515    }
1516
1517    // --- Error cases ---
1518
1519    #[test]
1520    fn test_save_invalid_rank() {
1521        let dir = temp_dir("invalid_rank");
1522        let state: HashMap<String, Tensor<f32>> = HashMap::new();
1523        let spec = ShardMetadata {
1524            num_ranks: 2,
1525            tensor_specs: HashMap::new(),
1526        };
1527
1528        let result = save_distributed(&state, &dir, 5, 2, &spec);
1529        assert!(result.is_err());
1530    }
1531
1532    #[test]
1533    fn test_load_missing_metadata() {
1534        let dir = temp_dir("missing_meta");
1535        cleanup(&dir);
1536        std::fs::create_dir_all(&dir).unwrap();
1537
1538        let result = load_distributed::<f32>(&dir, 0, 1);
1539        assert!(result.is_err());
1540
1541        cleanup(&dir);
1542    }
1543
1544    #[test]
1545    fn test_load_missing_shard() {
1546        let dir = temp_dir("missing_shard");
1547        cleanup(&dir);
1548        std::fs::create_dir_all(&dir).unwrap();
1549
1550        // Write metadata but no shard file.
1551        let spec = ShardMetadata {
1552            num_ranks: 1,
1553            tensor_specs: HashMap::new(),
1554        };
1555        let json = serde_json::to_string_pretty(&spec).unwrap();
1556        std::fs::write(metadata_path(&dir), json).unwrap();
1557
1558        // This should work (empty tensor_specs, no shard needed for the direct load).
1559        // But if we try to load with a different world_size triggering reshard,
1560        // it needs the shard files.
1561        let loaded = load_distributed::<f32>(&dir, 0, 1).unwrap();
1562        assert!(loaded.is_empty());
1563
1564        cleanup(&dir);
1565    }
1566
1567    // --- Multiple tensors ---
1568
1569    #[test]
1570    fn test_reshard_multiple_tensors() {
1571        let dir = temp_dir("reshard_multi");
1572        cleanup(&dir);
1573
1574        let spec = ShardMetadata {
1575            num_ranks: 2,
1576            tensor_specs: {
1577                let mut m = HashMap::new();
1578                m.insert(
1579                    "weight".into(),
1580                    TensorShardSpec {
1581                        full_shape: vec![4],
1582                        shard_dim: 0,
1583                        shard_sizes: vec![2, 2],
1584                    },
1585                );
1586                m.insert(
1587                    "bias".into(),
1588                    TensorShardSpec {
1589                        full_shape: vec![6],
1590                        shard_dim: 0,
1591                        shard_sizes: vec![3, 3],
1592                    },
1593                );
1594                m
1595            },
1596        };
1597
1598        let mut s0: HashMap<String, Tensor<f32>> = HashMap::new();
1599        s0.insert("weight".into(), make_tensor(vec![1.0, 2.0], vec![2]));
1600        s0.insert("bias".into(), make_tensor(vec![10.0, 20.0, 30.0], vec![3]));
1601
1602        let mut s1: HashMap<String, Tensor<f32>> = HashMap::new();
1603        s1.insert("weight".into(), make_tensor(vec![3.0, 4.0], vec![2]));
1604        s1.insert("bias".into(), make_tensor(vec![40.0, 50.0, 60.0], vec![3]));
1605
1606        save_distributed(&s0, &dir, 0, 2, &spec).unwrap();
1607        save_distributed(&s1, &dir, 1, 2, &spec).unwrap();
1608
1609        // Reshard to 1 rank (consolidate).
1610        let r0 = load_distributed::<f32>(&dir, 0, 1).unwrap();
1611
1612        assert_eq!(r0["weight"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
1613        assert_eq!(
1614            r0["bias"].data().unwrap(),
1615            &[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]
1616        );
1617
1618        cleanup(&dir);
1619    }
1620
1621    // --- DistributedCheckpoint struct ---
1622
1623    #[test]
1624    fn test_distributed_checkpoint_struct() {
1625        let ckpt = DistributedCheckpoint {
1626            checkpoint_dir: PathBuf::from("/tmp/test"),
1627            shard_metadata: ShardMetadata {
1628                num_ranks: 2,
1629                tensor_specs: HashMap::new(),
1630            },
1631        };
1632        assert_eq!(ckpt.checkpoint_dir, PathBuf::from("/tmp/test"));
1633        assert_eq!(ckpt.shard_metadata.num_ranks, 2);
1634    }
1635
1636    // Edge case: reshard to same world_size (no-op path through reshard fn)
1637    #[test]
1638    fn test_reshard_same_world_size() {
1639        let dir = temp_dir("reshard_same");
1640        cleanup(&dir);
1641
1642        let spec = ShardMetadata {
1643            num_ranks: 2,
1644            tensor_specs: {
1645                let mut m = HashMap::new();
1646                m.insert(
1647                    "w".into(),
1648                    TensorShardSpec {
1649                        full_shape: vec![4],
1650                        shard_dim: 0,
1651                        shard_sizes: vec![2, 2],
1652                    },
1653                );
1654                m
1655            },
1656        };
1657
1658        let mut s0: HashMap<String, Tensor<f32>> = HashMap::new();
1659        s0.insert("w".into(), make_tensor(vec![1.0, 2.0], vec![2]));
1660        save_distributed(&s0, &dir, 0, 2, &spec).unwrap();
1661
1662        let mut s1: HashMap<String, Tensor<f32>> = HashMap::new();
1663        s1.insert("w".into(), make_tensor(vec![3.0, 4.0], vec![2]));
1664        save_distributed(&s1, &dir, 1, 2, &spec).unwrap();
1665
1666        // Reshard with same world_size — should produce identical results.
1667        let r0 = reshard::<f32>(&dir, 2, 2, 0).unwrap();
1668        let r1 = reshard::<f32>(&dir, 2, 2, 1).unwrap();
1669
1670        assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0]);
1671        assert_eq!(r1["w"].data().unwrap(), &[3.0, 4.0]);
1672
1673        cleanup(&dir);
1674    }
1675}