Skip to main content

ferrotorch_distributed/
checkpoint.rs

1//! Distributed checkpointing with per-rank shard saving, loading, and resharding.
2//!
3//! Each rank saves its own shard as a SafeTensors file under a shared checkpoint
4//! directory. Rank 0 additionally writes a `metadata.json` file describing how
5//! tensors are distributed across ranks. When loading, if the world size has
6//! changed, the resharding logic automatically splits or merges shards so that
7//! each new rank receives the correct slice of each tensor.
8//!
9//! # Async checkpointing
10//!
11//! [`AsyncCheckpointer`] copies tensor data to CPU memory in a background thread
12//! and writes to disk without blocking the training loop. Call
13//! [`save_async`](AsyncCheckpointer::save_async) to start a checkpoint, and
14//! [`CheckpointFuture::wait`] when you need to ensure the write completed.
15//!
16//! ## REQ status (per `.design/ferrotorch-distributed/checkpoint.md`)
17//!
18//! Full evidence rows (impl + non-test production consumer + upstream
19//! cites) live in the design doc; this synopsis is a one-line summary per
20//! REQ.
21//!
22//! | REQ | Status | Evidence |
23//! |---|---|---|
24//! | REQ-1 (`DistCheckpointError` + `From` bridges) | SHIPPED | `pub enum DistCheckpointError` + `impl From<DistCheckpointError> for FerrotorchError` + `impl From<std::io::Error> for DistCheckpointError` in `checkpoint.rs` mirror `class CheckpointException` in `torch/distributed/checkpoint/api.py`; consumer `pub use checkpoint::{...DistCheckpointError...}` in `lib.rs` |
25//! | REQ-2 (`TensorShardSpec`) | SHIPPED | `pub struct TensorShardSpec` in `checkpoint.rs`; consumer field of `ShardMetadata::tensor_specs` plus `lib.rs` re-export |
26//! | REQ-3 (`ShardMetadata`) | SHIPPED | `pub struct ShardMetadata` in `checkpoint.rs`; consumer of `save_distributed` / `load_distributed` / `reshard` / `flat_shard_metadata` / `AsyncCheckpointer` plus `lib.rs` re-export |
27//! | REQ-4 (`DistributedCheckpoint` handle) | SHIPPED | `pub struct DistributedCheckpoint` in `checkpoint.rs`; consumer `pub use checkpoint::DistributedCheckpoint` in `lib.rs` |
28//! | REQ-5 (`save_distributed`) | SHIPPED | `pub fn save_distributed<T>` in `checkpoint.rs` mirrors `torch.distributed.checkpoint.save` in `torch/distributed/checkpoint/__init__.py`; consumer of `fn save_tensors_to_file` (same file) plus `lib.rs` re-export |
29//! | REQ-6 (`load_distributed`) | SHIPPED | `pub fn load_distributed<T>` in `checkpoint.rs` mirrors `torch.distributed.checkpoint.load` in `torch/distributed/checkpoint/__init__.py`; consumer of `pub fn reshard` on world-size mismatch (same file) plus `lib.rs` re-export |
30//! | REQ-7 (`reshard`) | SHIPPED | `pub fn reshard<T>` in `checkpoint.rs`; consumer `pub fn load_distributed` (same file) on world-size mismatch plus `lib.rs` re-export |
31//! | REQ-8 (SafeTensors interop helpers) | SHIPPED | `fn st_dtype` / `fn as_le_bytes` / `fn save_tensors_to_file` / `fn load_tensors_from_file` in `checkpoint.rs`; consumer `pub fn save_distributed` / `pub fn reshard` / `AsyncCheckpointer::save_async` (same file) |
32//! | REQ-9 (`concat_along_dim`) | SHIPPED | `fn concat_along_dim` in `checkpoint.rs`; consumer `pub fn reshard` (same file) |
33//! | REQ-10 (`slice_along_dim`) | SHIPPED | `fn slice_along_dim` in `checkpoint.rs`; consumer `pub fn reshard` (same file) |
34//! | REQ-11 (`AsyncCheckpointer` + `CheckpointFuture`) | SHIPPED | `pub struct CheckpointFuture` + `pub struct AsyncCheckpointer` + `pub fn save_async` / `wait` / `is_done` in `checkpoint.rs` mirror `torch.distributed.checkpoint.async_save` in `torch/distributed/checkpoint/__init__.py`; consumer `lib.rs` re-export of `AsyncCheckpointer` and `CheckpointFuture` |
35//! | REQ-12 (`flat_shard_metadata` convenience) | SHIPPED | `pub fn flat_shard_metadata` in `checkpoint.rs`; consumer `pub use checkpoint::flat_shard_metadata` in `lib.rs` |
36
37use std::collections::HashMap;
38use std::path::{Path, PathBuf};
39use std::sync::{Arc, Mutex};
40use std::thread;
41
42use safetensors::serialize_to_file;
43use safetensors::tensor::{Dtype, SafeTensors, TensorView};
44use serde::{Deserialize, Serialize};
45
46use ferrotorch_core::storage::TensorStorage;
47use ferrotorch_core::{FerrotorchError, Float, Tensor};
48
49// ---------------------------------------------------------------------------
50// Error type
51// ---------------------------------------------------------------------------
52
53/// Errors specific to distributed checkpointing.
54#[derive(Debug, thiserror::Error)]
55#[non_exhaustive]
56pub enum DistCheckpointError {
57    #[error("I/O error: {message}")]
58    Io { message: String },
59
60    #[error("serialization error: {message}")]
61    Serialization { message: String },
62
63    #[error("metadata error: {message}")]
64    Metadata { message: String },
65
66    #[error("shard file missing: {path}")]
67    MissingShard { path: String },
68
69    #[error("tensor error: {message}")]
70    Tensor { message: String },
71
72    #[error("invalid argument: {message}")]
73    InvalidArgument { message: String },
74
75    #[error("async checkpoint failed: {message}")]
76    AsyncFailed { message: String },
77}
78
79impl From<DistCheckpointError> for FerrotorchError {
80    fn from(e: DistCheckpointError) -> Self {
81        FerrotorchError::InvalidArgument {
82            message: e.to_string(),
83        }
84    }
85}
86
87impl From<std::io::Error> for DistCheckpointError {
88    fn from(e: std::io::Error) -> Self {
89        DistCheckpointError::Io {
90            message: e.to_string(),
91        }
92    }
93}
94
95// ---------------------------------------------------------------------------
96// Metadata types
97// ---------------------------------------------------------------------------
98
99/// Describes how a single tensor is sharded across ranks.
100///
101/// Marked `#[non_exhaustive]` so future fields (e.g., dtype, ordering,
102/// per-rank ranges for non-uniform shards) can be added without a major
103/// version bump. Construct via in-crate helpers (e.g.
104/// [`flat_shard_metadata`]) and access fields by name.
105#[derive(Debug, Clone, Serialize, Deserialize)]
106#[non_exhaustive]
107pub struct TensorShardSpec {
108    /// Full (unsharded) shape of the tensor.
109    pub full_shape: Vec<usize>,
110    /// Which dimension is split across ranks.
111    pub shard_dim: usize,
112    /// Size along `shard_dim` for each rank. The sum must equal
113    /// `full_shape[shard_dim]`.
114    pub shard_sizes: Vec<usize>,
115}
116
117/// Metadata for a distributed checkpoint: how many ranks participated and
118/// how each tensor is sharded.
119///
120/// Marked `#[non_exhaustive]` to allow forward-compatible additions
121/// (versioning, dtype catalog, optimizer-state metadata).
122#[derive(Debug, Clone, Serialize, Deserialize)]
123#[non_exhaustive]
124pub struct ShardMetadata {
125    /// Number of ranks that produced this checkpoint.
126    pub num_ranks: usize,
127    /// Per-tensor sharding specification, keyed by tensor name.
128    pub tensor_specs: HashMap<String, TensorShardSpec>,
129}
130
131/// Handle for a distributed checkpoint directory.
132///
133/// Wraps the directory path and the shard metadata. Typically created by
134/// [`save_distributed`] (which writes files) or by reading an existing
135/// checkpoint's `metadata.json`.
136///
137/// Marked `#[non_exhaustive]` for forward-compatible additions
138/// (e.g., async-write handles, optimizer-state cache).
139#[non_exhaustive]
140pub struct DistributedCheckpoint {
141    /// Directory where shard files are stored.
142    pub checkpoint_dir: PathBuf,
143    /// Metadata about how tensors are sharded across ranks.
144    pub shard_metadata: ShardMetadata,
145}
146
147// ---------------------------------------------------------------------------
148// Helpers
149// ---------------------------------------------------------------------------
150
151/// SafeTensors dtype for a given `Float` type.
152fn st_dtype<T: Float>() -> Result<Dtype, DistCheckpointError> {
153    match std::mem::size_of::<T>() {
154        4 => Ok(Dtype::F32),
155        8 => Ok(Dtype::F64),
156        other => Err(DistCheckpointError::InvalidArgument {
157            message: format!("unsupported element size {other} for safetensors serialization"),
158        }),
159    }
160}
161
162/// Reinterpret a `&[T]` as a byte slice (little-endian platforms).
163fn as_le_bytes<T: Float>(data: &[T]) -> &[u8] {
164    // SAFETY: byte-reinterpret of a `&[T]` of `Float` (f32/f64/bf16/f16) into
165    // `&[u8]`.
166    //
167    // - VALIDITY: every byte pattern is a valid `u8`, so reading the
168    //   underlying bytes never produces an invalid value (unlike e.g. `bool`
169    //   or `char`).
170    // - LIFETIME: the resulting `&[u8]` borrows the same allocation as
171    //   `data` and is returned tied to `data`'s lifetime — the borrow
172    //   checker therefore prevents `data` from being dropped or mutated
173    //   while the byte slice is live.
174    // - LENGTH: `mem::size_of_val(data) == data.len() * mem::size_of::<T>()`,
175    //   so the resulting slice covers exactly the same memory as the input.
176    // - ALIGNMENT: `*const u8` has alignment 1, which is always satisfied
177    //   by an allocation aligned to `T` (alignment of `T` is a multiple of
178    //   1 for every primitive `Float` type).
179    // - ALIASING: this returns a shared borrow; no `&mut [u8]` aliasing is
180    //   possible while the borrow lives.
181    // - ENDIANNESS: ferrotorch targets little-endian platforms (the
182    //   workspace's MSRV-supported targets are all LE); on a hypothetical
183    //   big-endian build the byte order would not match SafeTensors'
184    //   on-disk LE convention, but no such target is currently supported.
185    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data)) }
186}
187
188/// Build a shard file path for a given rank inside the checkpoint directory.
189fn shard_path(dir: &Path, rank: usize) -> PathBuf {
190    dir.join(format!("rank_{rank}.safetensors"))
191}
192
193/// Path to the metadata JSON file.
194fn metadata_path(dir: &Path) -> PathBuf {
195    dir.join("metadata.json")
196}
197
198/// Save a `HashMap<String, Tensor<T>>` as a single SafeTensors file.
199fn save_tensors_to_file<T: Float>(
200    tensors: &HashMap<String, Tensor<T>>,
201    path: &Path,
202) -> Result<(), DistCheckpointError> {
203    let dtype = st_dtype::<T>()?;
204
205    let mut keys: Vec<&str> = tensors.keys().map(String::as_str).collect();
206    keys.sort_unstable();
207
208    struct Entry<'a> {
209        name: String,
210        shape: Vec<usize>,
211        data: &'a [u8],
212    }
213
214    let mut entries: Vec<Entry<'_>> = Vec::with_capacity(keys.len());
215    for key in &keys {
216        let tensor = &tensors[*key];
217        let data_slice = tensor.data().map_err(|e| DistCheckpointError::Tensor {
218            message: format!("failed to read tensor \"{key}\": {e}"),
219        })?;
220        entries.push(Entry {
221            name: (*key).to_owned(),
222            shape: tensor.shape().to_vec(),
223            data: as_le_bytes(data_slice),
224        });
225    }
226
227    let views: Vec<(String, TensorView<'_>)> = entries
228        .iter()
229        .map(|entry| {
230            TensorView::new(dtype, entry.shape.clone(), entry.data)
231                .map(|v| (entry.name.clone(), v))
232                .map_err(|e| DistCheckpointError::Serialization {
233                    message: format!("TensorView for \"{}\": {e}", entry.name),
234                })
235        })
236        .collect::<Result<Vec<_>, _>>()?;
237
238    serialize_to_file(views, &None, path).map_err(|e| DistCheckpointError::Serialization {
239        message: format!("safetensors write to {}: {e}", path.display()),
240    })?;
241
242    Ok(())
243}
244
245/// Load a SafeTensors file into `HashMap<String, Tensor<T>>`.
246fn load_tensors_from_file<T: Float>(
247    path: &Path,
248) -> Result<HashMap<String, Tensor<T>>, DistCheckpointError> {
249    let elem_size = std::mem::size_of::<T>();
250    let expected = st_dtype::<T>()?;
251
252    let file_data = std::fs::read(path).map_err(|e| DistCheckpointError::Io {
253        message: format!("reading {}: {e}", path.display()),
254    })?;
255
256    let st =
257        SafeTensors::deserialize(&file_data).map_err(|e| DistCheckpointError::Serialization {
258            message: format!("parsing {}: {e}", path.display()),
259        })?;
260
261    let tensor_list = st.tensors();
262    let mut result: HashMap<String, Tensor<T>> = HashMap::with_capacity(tensor_list.len());
263
264    for (name, view) in &tensor_list {
265        if view.dtype() != expected {
266            return Err(DistCheckpointError::Tensor {
267                message: format!(
268                    "tensor \"{name}\" has dtype {:?}, expected {:?}",
269                    view.dtype(),
270                    expected
271                ),
272            });
273        }
274
275        let shape = view.shape().to_vec();
276        let byte_data = view.data();
277        let numel: usize = if shape.is_empty() {
278            1
279        } else {
280            shape.iter().product()
281        };
282        let expected_bytes = numel * elem_size;
283
284        if byte_data.len() != expected_bytes {
285            return Err(DistCheckpointError::Tensor {
286                message: format!(
287                    "tensor \"{name}\" has {} bytes but shape {shape:?} requires {expected_bytes}",
288                    byte_data.len()
289                ),
290            });
291        }
292
293        let data: Vec<T> = byte_data
294            .chunks_exact(elem_size)
295            .map(|chunk| {
296                let mut bytes = [0u8; 8];
297                bytes[..elem_size].copy_from_slice(chunk);
298                // SAFETY: `read_unaligned` of `T` from a stack buffer.
299                //
300                // - VALIDITY: `T: Float` is constrained to f32/f64/bf16/f16,
301                //   for which every 4- or 8-byte pattern is a valid value
302                //   (NaN/inf included). No invalid bit patterns exist for
303                //   IEEE-754 floats — unlike `bool` (must be 0/1) or
304                //   `NonZero<...>`.
305                // - LENGTH: `elem_size == size_of::<T>()` (verified at the
306                //   call site by `chunks_exact(elem_size)`); since
307                //   `size_of::<T>() <= 8` for every supported float and
308                //   `bytes` is a `[u8; 8]`, the read stays within bounds.
309                // - ALIGNMENT: `read_unaligned` does not require
310                //   `bytes.as_ptr()` to be aligned to `T` — that's the
311                //   whole point of the function. Using `read` instead
312                //   would be UB here because `[u8; 8]` is 1-aligned.
313                // - LIFETIME: `bytes` is a stack-local array; the read
314                //   produces an owned `T` by value, so no dangling
315                //   reference can outlive the closure.
316                // - PROVENANCE: `bytes.as_ptr()` is derived directly from
317                //   the live stack array `bytes`; the cast to `*const T`
318                //   preserves provenance under the strict-provenance model.
319                // - ENDIANNESS: SafeTensors stores tensors little-endian;
320                //   ferrotorch's supported targets are LE, so the byte
321                //   pattern read here matches the on-disk pattern. On a
322                //   hypothetical BE host the values would be byte-swapped,
323                //   but no such target is currently supported.
324                unsafe { std::ptr::read_unaligned(bytes.as_ptr() as *const T) }
325            })
326            .collect();
327
328        let storage = TensorStorage::cpu(data);
329        let tensor = Tensor::from_storage(storage, shape, false).map_err(|e| {
330            DistCheckpointError::Tensor {
331                message: format!("creating tensor \"{name}\": {e}"),
332            }
333        })?;
334        result.insert(name.clone(), tensor);
335    }
336
337    Ok(result)
338}
339
340// ---------------------------------------------------------------------------
341// Save / Load
342// ---------------------------------------------------------------------------
343
344/// Save a distributed checkpoint.
345///
346/// Each rank saves its own shard of the state dict to
347/// `dir/rank_<rank>.safetensors`. Rank 0 additionally writes
348/// `dir/metadata.json` containing the [`ShardMetadata`] so that future loads
349/// (potentially with a different world size) know how the data is distributed.
350///
351/// The `state_dict` should contain this rank's shard of each tensor.
352/// `shard_spec` describes the full shapes and how they are sharded.
353///
354/// # Errors
355///
356/// Returns an error if the directory cannot be created, if tensor data cannot
357/// be read (e.g., GPU tensor without prior `.cpu()`), or if serialization
358/// fails.
359pub fn save_distributed<T: Float>(
360    state_dict: &HashMap<String, Tensor<T>>,
361    dir: &Path,
362    rank: usize,
363    world_size: usize,
364    shard_spec: &ShardMetadata,
365) -> Result<(), DistCheckpointError> {
366    // Validate basic inputs.
367    if world_size == 0 {
368        return Err(DistCheckpointError::InvalidArgument {
369            message: "world_size must be >= 1".into(),
370        });
371    }
372    if rank >= world_size {
373        return Err(DistCheckpointError::InvalidArgument {
374            message: format!("rank {rank} >= world_size {world_size}"),
375        });
376    }
377
378    // Create the checkpoint directory if it doesn't exist.
379    std::fs::create_dir_all(dir)?;
380
381    // Save this rank's shard.
382    let path = shard_path(dir, rank);
383    save_tensors_to_file(state_dict, &path)?;
384
385    // Rank 0 writes the metadata file.
386    if rank == 0 {
387        let json = serde_json::to_string_pretty(shard_spec).map_err(|e| {
388            DistCheckpointError::Serialization {
389                message: format!("serializing metadata: {e}"),
390            }
391        })?;
392        std::fs::write(metadata_path(dir), json)?;
393    }
394
395    Ok(())
396}
397
398/// Load a distributed checkpoint for a specific rank.
399///
400/// Reads `dir/metadata.json` to discover the original sharding layout. If the
401/// current `world_size` matches the saved metadata, each rank simply loads its
402/// own shard file. If world sizes differ, automatic resharding is performed
403/// via [`reshard`].
404///
405/// # Errors
406///
407/// Returns an error if metadata or shard files are missing, if tensors have
408/// unexpected dtypes, or if resharding fails.
409pub fn load_distributed<T: Float>(
410    dir: &Path,
411    rank: usize,
412    world_size: usize,
413) -> Result<HashMap<String, Tensor<T>>, DistCheckpointError> {
414    if world_size == 0 {
415        return Err(DistCheckpointError::InvalidArgument {
416            message: "world_size must be >= 1".into(),
417        });
418    }
419    if rank >= world_size {
420        return Err(DistCheckpointError::InvalidArgument {
421            message: format!("rank {rank} >= world_size {world_size}"),
422        });
423    }
424
425    // Read metadata.
426    let meta_path = metadata_path(dir);
427    let meta_json = std::fs::read_to_string(&meta_path).map_err(|e| DistCheckpointError::Io {
428        message: format!("reading {}: {e}", meta_path.display()),
429    })?;
430    let metadata: ShardMetadata =
431        serde_json::from_str(&meta_json).map_err(|e| DistCheckpointError::Serialization {
432            message: format!("parsing metadata: {e}"),
433        })?;
434
435    let old_world_size = metadata.num_ranks;
436
437    if old_world_size == world_size {
438        // Same world size — just load this rank's shard directly.
439        let path = shard_path(dir, rank);
440        if !path.exists() {
441            // If there are no tensor specs, this is a legitimately empty
442            // checkpoint — return an empty map.
443            if metadata.tensor_specs.is_empty() {
444                return Ok(HashMap::new());
445            }
446            return Err(DistCheckpointError::MissingShard {
447                path: path.display().to_string(),
448            });
449        }
450        load_tensors_from_file(&path)
451    } else {
452        // Different world size — need to reshard.
453        reshard(dir, old_world_size, world_size, rank)
454    }
455}
456
457// ---------------------------------------------------------------------------
458// Resharding
459// ---------------------------------------------------------------------------
460
461/// Reshard a checkpoint from `old_world_size` ranks to `new_world_size` ranks,
462/// returning the data for `new_rank`.
463///
464/// For each tensor described in the checkpoint metadata, this function:
465///
466/// 1. Reconstructs the full (unsharded) tensor by loading and concatenating
467///    all old shard files along the shard dimension.
468/// 2. Splits the full tensor along the shard dimension into `new_world_size`
469///    pieces and returns the piece for `new_rank`.
470///
471/// This handles both scale-up (e.g., 4 GPUs to 8) and scale-down (e.g., 8
472/// GPUs to 4) as well as arbitrary remappings.
473///
474/// # Errors
475///
476/// Returns an error if shard files are missing, if tensors have unexpected
477/// shapes, or if the full tensor cannot be evenly divided for resharding.
478pub fn reshard<T: Float>(
479    dir: &Path,
480    old_world_size: usize,
481    new_world_size: usize,
482    new_rank: usize,
483) -> Result<HashMap<String, Tensor<T>>, DistCheckpointError> {
484    if new_world_size == 0 {
485        return Err(DistCheckpointError::InvalidArgument {
486            message: "new_world_size must be >= 1".into(),
487        });
488    }
489    if new_rank >= new_world_size {
490        return Err(DistCheckpointError::InvalidArgument {
491            message: format!("new_rank {new_rank} >= new_world_size {new_world_size}"),
492        });
493    }
494    if old_world_size == 0 {
495        return Err(DistCheckpointError::InvalidArgument {
496            message: "old_world_size must be >= 1".into(),
497        });
498    }
499
500    // Read metadata.
501    let meta_path = metadata_path(dir);
502    let meta_json = std::fs::read_to_string(&meta_path).map_err(|e| DistCheckpointError::Io {
503        message: format!("reading {}: {e}", meta_path.display()),
504    })?;
505    let metadata: ShardMetadata =
506        serde_json::from_str(&meta_json).map_err(|e| DistCheckpointError::Serialization {
507            message: format!("parsing metadata: {e}"),
508        })?;
509
510    // Load all old shards into memory.
511    let mut old_shards: Vec<HashMap<String, Tensor<T>>> = Vec::with_capacity(old_world_size);
512    for old_rank in 0..old_world_size {
513        let path = shard_path(dir, old_rank);
514        if !path.exists() {
515            return Err(DistCheckpointError::MissingShard {
516                path: path.display().to_string(),
517            });
518        }
519        old_shards.push(load_tensors_from_file(&path)?);
520    }
521
522    // For each tensor in the metadata, reconstruct full tensor then re-split.
523    let mut result: HashMap<String, Tensor<T>> = HashMap::new();
524
525    for (name, spec) in &metadata.tensor_specs {
526        let shard_dim = spec.shard_dim;
527        let full_shape = &spec.full_shape;
528
529        // Collect shard data from each old rank for this tensor.
530        let mut shard_datas: Vec<Vec<T>> = Vec::with_capacity(old_world_size);
531        let mut shard_shapes: Vec<Vec<usize>> = Vec::with_capacity(old_world_size);
532
533        for (old_rank, shard) in old_shards.iter().enumerate().take(old_world_size) {
534            let tensor = shard.get(name).ok_or_else(|| DistCheckpointError::Tensor {
535                message: format!("tensor \"{name}\" missing from rank {old_rank} shard"),
536            })?;
537            shard_datas.push(tensor.data_vec().map_err(|e| DistCheckpointError::Tensor {
538                message: format!("reading tensor \"{name}\" from rank {old_rank}: {e}"),
539            })?);
540            shard_shapes.push(tensor.shape().to_vec());
541        }
542
543        // Reconstruct the full tensor by concatenating along shard_dim.
544        let full_data = concat_along_dim(&shard_datas, &shard_shapes, shard_dim, full_shape)?;
545
546        // Now split the full tensor along shard_dim for the new world size.
547        let full_dim_size = full_shape[shard_dim];
548        let new_shard_sizes = compute_shard_sizes(full_dim_size, new_world_size);
549        let new_offset: usize = new_shard_sizes[..new_rank].iter().sum();
550        let new_size = new_shard_sizes[new_rank];
551
552        // Extract this rank's slice from the full data.
553        let mut new_shape = full_shape.clone();
554        new_shape[shard_dim] = new_size;
555
556        let new_data = slice_along_dim(&full_data, full_shape, shard_dim, new_offset, new_size);
557
558        let tensor =
559            Tensor::from_storage(TensorStorage::cpu(new_data), new_shape, false).map_err(|e| {
560                DistCheckpointError::Tensor {
561                    message: format!("creating resharded tensor \"{name}\": {e}"),
562                }
563            })?;
564
565        result.insert(name.clone(), tensor);
566    }
567
568    Ok(result)
569}
570
571/// Compute how to divide `total` elements across `num_parts` as evenly as
572/// possible. The first `total % num_parts` parts get one extra element.
573fn compute_shard_sizes(total: usize, num_parts: usize) -> Vec<usize> {
574    let base = total / num_parts;
575    let remainder = total % num_parts;
576    (0..num_parts)
577        .map(|i| if i < remainder { base + 1 } else { base })
578        .collect()
579}
580
581/// Concatenate tensor data from multiple shards along a given dimension.
582///
583/// Each entry in `shard_datas` is a flat (row-major) data array for that shard.
584/// `shard_shapes` gives the shape of each shard. `full_shape` is the expected
585/// output shape after concatenation.
586fn concat_along_dim<T: Float>(
587    shard_datas: &[Vec<T>],
588    shard_shapes: &[Vec<usize>],
589    dim: usize,
590    full_shape: &[usize],
591) -> Result<Vec<T>, DistCheckpointError> {
592    let ndim = full_shape.len();
593    if dim >= ndim {
594        return Err(DistCheckpointError::InvalidArgument {
595            message: format!("shard_dim {dim} >= ndim {ndim}"),
596        });
597    }
598
599    let full_numel: usize = full_shape.iter().product();
600    let mut full_data = vec![<T as num_traits::Zero>::zero(); full_numel];
601
602    // For concatenation along `dim`, we think of the tensor as having three
603    // logical parts:
604    //   outer = product of dims before `dim`
605    //   middle = dim (varies per shard)
606    //   inner = product of dims after `dim`
607    let outer: usize = full_shape[..dim].iter().product();
608    let inner: usize = full_shape[dim + 1..].iter().product();
609    let full_middle = full_shape[dim];
610
611    // Walk through shards and copy their data into the right offsets.
612    let mut dim_offset = 0;
613    for (shard_idx, shard_data) in shard_datas.iter().enumerate() {
614        let shard_middle = shard_shapes[shard_idx][dim];
615
616        // Validate shard shape dimensions except along the shard dim.
617        for d in 0..ndim {
618            if d != dim && shard_shapes[shard_idx][d] != full_shape[d] {
619                return Err(DistCheckpointError::Tensor {
620                    message: format!(
621                        "shard {shard_idx} has shape {:?} but expected dim {d} to be {} (full shape {full_shape:?})",
622                        shard_shapes[shard_idx], full_shape[d]
623                    ),
624                });
625            }
626        }
627
628        for o in 0..outer {
629            let src_start = o * shard_middle * inner;
630            let dst_start = o * full_middle * inner + dim_offset * inner;
631            let count = shard_middle * inner;
632
633            full_data[dst_start..dst_start + count]
634                .copy_from_slice(&shard_data[src_start..src_start + count]);
635        }
636
637        dim_offset += shard_middle;
638    }
639
640    if dim_offset != full_middle {
641        return Err(DistCheckpointError::Tensor {
642            message: format!(
643                "shard sizes along dim {dim} sum to {dim_offset}, expected {full_middle}"
644            ),
645        });
646    }
647
648    Ok(full_data)
649}
650
651/// Extract a contiguous slice of a tensor along a given dimension.
652///
653/// Returns the flat data for `shape` with `shape[dim]` replaced by `size`,
654/// starting at offset `offset` along that dimension.
655fn slice_along_dim<T: Float>(
656    data: &[T],
657    shape: &[usize],
658    dim: usize,
659    offset: usize,
660    size: usize,
661) -> Vec<T> {
662    let outer: usize = shape[..dim].iter().product();
663    let full_middle = shape[dim];
664    let inner: usize = shape[dim + 1..].iter().product();
665
666    let out_numel = outer * size * inner;
667    let mut result = Vec::with_capacity(out_numel);
668
669    for o in 0..outer {
670        let src_start = o * full_middle * inner + offset * inner;
671        let count = size * inner;
672        result.extend_from_slice(&data[src_start..src_start + count]);
673    }
674
675    result
676}
677
678// ---------------------------------------------------------------------------
679// Async checkpointing
680// ---------------------------------------------------------------------------
681
682/// Result of an asynchronous checkpoint operation.
683///
684/// The checkpoint is being written in a background thread. Call [`wait`](Self::wait)
685/// to block until completion and retrieve any error.
686pub struct CheckpointFuture {
687    handle: Option<thread::JoinHandle<Result<(), DistCheckpointError>>>,
688    /// Cached result after the thread has been joined.
689    result: Option<Result<(), DistCheckpointError>>,
690}
691
692impl CheckpointFuture {
693    /// Block until the checkpoint write completes.
694    ///
695    /// Returns `Ok(())` if the write succeeded, or the error encountered
696    /// during serialization/I/O.
697    ///
698    /// Calling `wait()` multiple times is safe — subsequent calls return
699    /// the cached result.
700    pub fn wait(&mut self) -> Result<(), DistCheckpointError> {
701        if let Some(handle) = self.handle.take() {
702            let res = handle
703                .join()
704                .map_err(|_| DistCheckpointError::AsyncFailed {
705                    message: "background checkpoint thread panicked".into(),
706                })?;
707            self.result = Some(res);
708        }
709
710        match &self.result {
711            Some(Ok(())) => Ok(()),
712            Some(Err(e)) => Err(DistCheckpointError::AsyncFailed {
713                message: format!("{e}"),
714            }),
715            None => Err(DistCheckpointError::AsyncFailed {
716                message: "no checkpoint was started".into(),
717            }),
718        }
719    }
720
721    /// Returns `true` if the background write has completed (success or failure).
722    pub fn is_done(&self) -> bool {
723        if self.result.is_some() {
724            return true;
725        }
726        match &self.handle {
727            Some(h) => h.is_finished(),
728            None => true,
729        }
730    }
731}
732
733/// Asynchronous checkpointer that writes shards to disk in a background thread.
734///
735/// Before writing, tensor data is staged (copied) into CPU memory so that
736/// training can continue on GPU without blocking. The actual file I/O happens
737/// in a separate thread.
738///
739/// # Usage
740///
741/// ```ignore
742/// let ckpt = AsyncCheckpointer::new(
743///     checkpoint_dir.clone(),
744///     rank,
745///     world_size,
746///     shard_metadata.clone(),
747/// );
748///
749/// // Non-blocking: training continues immediately.
750/// let mut future = ckpt.save_async(&state_dict)?;
751///
752/// // ... continue training ...
753///
754/// // When you need to be sure it finished:
755/// future.wait()?;
756/// ```
757pub struct AsyncCheckpointer {
758    dir: PathBuf,
759    rank: usize,
760    world_size: usize,
761    shard_spec: ShardMetadata,
762    /// Guard against concurrent saves: only one checkpoint at a time.
763    in_flight: Arc<Mutex<bool>>,
764}
765
766impl AsyncCheckpointer {
767    /// Create a new async checkpointer.
768    ///
769    /// - `dir`: directory where shard files will be written.
770    /// - `rank`: this process's rank.
771    /// - `world_size`: total number of ranks.
772    /// - `shard_spec`: metadata describing how tensors are sharded.
773    pub fn new(dir: PathBuf, rank: usize, world_size: usize, shard_spec: ShardMetadata) -> Self {
774        Self {
775            dir,
776            rank,
777            world_size,
778            shard_spec,
779            in_flight: Arc::new(Mutex::new(false)),
780        }
781    }
782
783    /// The checkpoint directory.
784    pub fn dir(&self) -> &Path {
785        &self.dir
786    }
787
788    /// This process's rank.
789    pub fn rank(&self) -> usize {
790        self.rank
791    }
792
793    /// Total number of ranks.
794    pub fn world_size(&self) -> usize {
795        self.world_size
796    }
797
798    /// Start an asynchronous checkpoint.
799    ///
800    /// This immediately copies all tensor data to CPU-owned `Vec`s (staging),
801    /// then spawns a background thread that serializes and writes the shard
802    /// file. For GPU tensors, the staging step transfers data to host memory.
803    ///
804    /// Returns a [`CheckpointFuture`] that can be polled or waited on.
805    ///
806    /// # Errors
807    ///
808    /// Returns an error immediately if another async save is already in flight,
809    /// or if staging (GPU-to-CPU copy) fails.
810    pub fn save_async(
811        &self,
812        state_dict: &HashMap<String, Tensor<f32>>,
813    ) -> Result<CheckpointFuture, DistCheckpointError> {
814        // Check that no other save is in progress.
815        {
816            let mut guard =
817                self.in_flight
818                    .lock()
819                    .map_err(|e| DistCheckpointError::AsyncFailed {
820                        message: format!("lock poisoned: {e}"),
821                    })?;
822            if *guard {
823                return Err(DistCheckpointError::AsyncFailed {
824                    message: "another async checkpoint is already in flight".into(),
825                });
826            }
827            *guard = true;
828        }
829
830        // Stage: copy all tensor data to CPU-owned Vecs. This is the only part
831        // that touches GPU memory and must happen on the calling thread.
832        let mut staged: HashMap<String, (Vec<f32>, Vec<usize>)> =
833            HashMap::with_capacity(state_dict.len());
834
835        for (name, tensor) in state_dict {
836            let data = tensor.data_vec().map_err(|e| {
837                // Release the lock on error.
838                if let Ok(mut g) = self.in_flight.lock() {
839                    *g = false;
840                }
841                DistCheckpointError::Tensor {
842                    message: format!("staging tensor \"{name}\": {e}"),
843                }
844            })?;
845            let shape = tensor.shape().to_vec();
846            staged.insert(name.clone(), (data, shape));
847        }
848
849        // Capture everything the background thread needs.
850        let dir = self.dir.clone();
851        let rank = self.rank;
852        let shard_spec = self.shard_spec.clone();
853        let in_flight = Arc::clone(&self.in_flight);
854
855        let handle = thread::spawn(move || {
856            let result = (|| -> Result<(), DistCheckpointError> {
857                // Rebuild tensors from staged data.
858                let mut tensors: HashMap<String, Tensor<f32>> =
859                    HashMap::with_capacity(staged.len());
860                for (name, (data, shape)) in staged {
861                    let tensor = Tensor::from_storage(TensorStorage::cpu(data), shape, false)
862                        .map_err(|e| DistCheckpointError::Tensor {
863                            message: format!("rebuilding tensor \"{name}\": {e}"),
864                        })?;
865                    tensors.insert(name, tensor);
866                }
867
868                // Write to disk.
869                std::fs::create_dir_all(&dir)?;
870                let path = shard_path(&dir, rank);
871                save_tensors_to_file(&tensors, &path)?;
872
873                // Rank 0 writes metadata.
874                if rank == 0 {
875                    let json = serde_json::to_string_pretty(&shard_spec).map_err(|e| {
876                        DistCheckpointError::Serialization {
877                            message: format!("serializing metadata: {e}"),
878                        }
879                    })?;
880                    std::fs::write(metadata_path(&dir), json)?;
881                }
882
883                Ok(())
884            })();
885
886            // Release the in-flight guard.
887            if let Ok(mut g) = in_flight.lock() {
888                *g = false;
889            }
890
891            result
892        });
893
894        Ok(CheckpointFuture {
895            handle: Some(handle),
896            result: None,
897        })
898    }
899}
900
901// ---------------------------------------------------------------------------
902// Convenience: auto-generate ShardMetadata from a flat state dict
903// ---------------------------------------------------------------------------
904
905/// Build a [`ShardMetadata`] where every tensor is sharded along dimension 0
906/// with equal shard sizes.
907///
908/// This is the common case for FSDP where parameters are simply chunked along
909/// the flattened (dim-0) axis.
910pub fn flat_shard_metadata(
911    state_dict: &HashMap<String, Tensor<f32>>,
912    world_size: usize,
913) -> ShardMetadata {
914    let mut tensor_specs = HashMap::new();
915    for (name, tensor) in state_dict {
916        let shape = tensor.shape();
917        // For FSDP-style flat sharding, the shard tensor is 1-D with shape
918        // [chunk_size]. The full tensor has shape [chunk_size * world_size].
919        let shard_numel = shape.iter().product::<usize>();
920        let full_numel = shard_numel * world_size;
921        let shard_sizes = vec![shard_numel; world_size];
922        tensor_specs.insert(
923            name.clone(),
924            TensorShardSpec {
925                full_shape: vec![full_numel],
926                shard_dim: 0,
927                shard_sizes,
928            },
929        );
930    }
931    ShardMetadata {
932        num_ranks: world_size,
933        tensor_specs,
934    }
935}
936
937// ---------------------------------------------------------------------------
938// Tests
939// ---------------------------------------------------------------------------
940
941#[cfg(test)]
942mod tests {
943    use super::*;
944    use ferrotorch_core::Tensor;
945    use ferrotorch_core::storage::TensorStorage;
946    use std::collections::HashMap;
947
948    fn make_tensor(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
949        Tensor::from_storage(TensorStorage::cpu(data), shape, false).unwrap()
950    }
951
952    fn temp_dir(name: &str) -> PathBuf {
953        std::env::temp_dir()
954            .join("ferrotorch_test_dist_ckpt")
955            .join(name)
956    }
957
958    fn cleanup(dir: &Path) {
959        let _ = std::fs::remove_dir_all(dir);
960    }
961
962    // --- save_distributed / load_distributed roundtrip ---
963
964    #[test]
965    fn test_save_load_single_rank() {
966        let dir = temp_dir("single_rank");
967        cleanup(&dir);
968
969        let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
970        state.insert(
971            "weight".into(),
972            make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]),
973        );
974        state.insert("bias".into(), make_tensor(vec![0.1, 0.2], vec![2]));
975
976        let spec = ShardMetadata {
977            num_ranks: 1,
978            tensor_specs: {
979                let mut m = HashMap::new();
980                m.insert(
981                    "weight".into(),
982                    TensorShardSpec {
983                        full_shape: vec![4],
984                        shard_dim: 0,
985                        shard_sizes: vec![4],
986                    },
987                );
988                m.insert(
989                    "bias".into(),
990                    TensorShardSpec {
991                        full_shape: vec![2],
992                        shard_dim: 0,
993                        shard_sizes: vec![2],
994                    },
995                );
996                m
997            },
998        };
999
1000        save_distributed(&state, &dir, 0, 1, &spec).unwrap();
1001        let loaded = load_distributed::<f32>(&dir, 0, 1).unwrap();
1002
1003        assert_eq!(loaded.len(), 2);
1004        assert_eq!(loaded["weight"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
1005        assert_eq!(loaded["bias"].data().unwrap(), &[0.1, 0.2]);
1006
1007        cleanup(&dir);
1008    }
1009
1010    #[test]
1011    fn test_save_load_two_ranks() {
1012        let dir = temp_dir("two_ranks");
1013        cleanup(&dir);
1014
1015        // Rank 0 shard: first half of weight.
1016        let mut state0: HashMap<String, Tensor<f32>> = HashMap::new();
1017        state0.insert("weight".into(), make_tensor(vec![1.0, 2.0], vec![2]));
1018
1019        // Rank 1 shard: second half of weight.
1020        let mut state1: HashMap<String, Tensor<f32>> = HashMap::new();
1021        state1.insert("weight".into(), make_tensor(vec![3.0, 4.0], vec![2]));
1022
1023        let spec = ShardMetadata {
1024            num_ranks: 2,
1025            tensor_specs: {
1026                let mut m = HashMap::new();
1027                m.insert(
1028                    "weight".into(),
1029                    TensorShardSpec {
1030                        full_shape: vec![4],
1031                        shard_dim: 0,
1032                        shard_sizes: vec![2, 2],
1033                    },
1034                );
1035                m
1036            },
1037        };
1038
1039        save_distributed(&state0, &dir, 0, 2, &spec).unwrap();
1040        save_distributed(&state1, &dir, 1, 2, &spec).unwrap();
1041
1042        // Load each rank's shard (same world size).
1043        let loaded0 = load_distributed::<f32>(&dir, 0, 2).unwrap();
1044        let loaded1 = load_distributed::<f32>(&dir, 1, 2).unwrap();
1045
1046        assert_eq!(loaded0["weight"].data().unwrap(), &[1.0, 2.0]);
1047        assert_eq!(loaded1["weight"].data().unwrap(), &[3.0, 4.0]);
1048
1049        cleanup(&dir);
1050    }
1051
1052    // --- Resharding tests ---
1053
1054    #[test]
1055    fn test_reshard_2_to_4() {
1056        // Saved with 2 ranks, load with 4 ranks.
1057        let dir = temp_dir("reshard_2_to_4");
1058        cleanup(&dir);
1059
1060        // Full tensor: [1, 2, 3, 4, 5, 6, 7, 8], shape [8].
1061        // 2 ranks: rank 0 = [1,2,3,4], rank 1 = [5,6,7,8].
1062        let mut state0: HashMap<String, Tensor<f32>> = HashMap::new();
1063        state0.insert("w".into(), make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]));
1064
1065        let mut state1: HashMap<String, Tensor<f32>> = HashMap::new();
1066        state1.insert("w".into(), make_tensor(vec![5.0, 6.0, 7.0, 8.0], vec![4]));
1067
1068        let spec = ShardMetadata {
1069            num_ranks: 2,
1070            tensor_specs: {
1071                let mut m = HashMap::new();
1072                m.insert(
1073                    "w".into(),
1074                    TensorShardSpec {
1075                        full_shape: vec![8],
1076                        shard_dim: 0,
1077                        shard_sizes: vec![4, 4],
1078                    },
1079                );
1080                m
1081            },
1082        };
1083
1084        save_distributed(&state0, &dir, 0, 2, &spec).unwrap();
1085        save_distributed(&state1, &dir, 1, 2, &spec).unwrap();
1086
1087        // Reshard to 4 ranks. Each rank gets 2 elements.
1088        let r0 = reshard::<f32>(&dir, 2, 4, 0).unwrap();
1089        let r1 = reshard::<f32>(&dir, 2, 4, 1).unwrap();
1090        let r2 = reshard::<f32>(&dir, 2, 4, 2).unwrap();
1091        let r3 = reshard::<f32>(&dir, 2, 4, 3).unwrap();
1092
1093        assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0]);
1094        assert_eq!(r1["w"].data().unwrap(), &[3.0, 4.0]);
1095        assert_eq!(r2["w"].data().unwrap(), &[5.0, 6.0]);
1096        assert_eq!(r3["w"].data().unwrap(), &[7.0, 8.0]);
1097
1098        cleanup(&dir);
1099    }
1100
1101    #[test]
1102    fn test_reshard_4_to_2() {
1103        // Saved with 4 ranks, load with 2.
1104        let dir = temp_dir("reshard_4_to_2");
1105        cleanup(&dir);
1106
1107        let spec = ShardMetadata {
1108            num_ranks: 4,
1109            tensor_specs: {
1110                let mut m = HashMap::new();
1111                m.insert(
1112                    "w".into(),
1113                    TensorShardSpec {
1114                        full_shape: vec![8],
1115                        shard_dim: 0,
1116                        shard_sizes: vec![2, 2, 2, 2],
1117                    },
1118                );
1119                m
1120            },
1121        };
1122
1123        for rank in 0..4 {
1124            let start = rank as f32 * 2.0 + 1.0;
1125            let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1126            state.insert("w".into(), make_tensor(vec![start, start + 1.0], vec![2]));
1127            save_distributed(&state, &dir, rank, 4, &spec).unwrap();
1128        }
1129
1130        // Reshard to 2 ranks. Each rank gets 4 elements.
1131        let r0 = reshard::<f32>(&dir, 4, 2, 0).unwrap();
1132        let r1 = reshard::<f32>(&dir, 4, 2, 1).unwrap();
1133
1134        assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
1135        assert_eq!(r1["w"].data().unwrap(), &[5.0, 6.0, 7.0, 8.0]);
1136
1137        cleanup(&dir);
1138    }
1139
1140    #[test]
1141    fn test_reshard_2d_tensor() {
1142        // Full tensor: [[1,2,3],[4,5,6],[7,8,9],[10,11,12]], shape [4, 3].
1143        // Shard along dim 0: rank 0 gets rows 0-1, rank 1 gets rows 2-3.
1144        let dir = temp_dir("reshard_2d");
1145        cleanup(&dir);
1146
1147        let mut state0: HashMap<String, Tensor<f32>> = HashMap::new();
1148        state0.insert(
1149            "w".into(),
1150            make_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]),
1151        );
1152
1153        let mut state1: HashMap<String, Tensor<f32>> = HashMap::new();
1154        state1.insert(
1155            "w".into(),
1156            make_tensor(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], vec![2, 3]),
1157        );
1158
1159        let spec = ShardMetadata {
1160            num_ranks: 2,
1161            tensor_specs: {
1162                let mut m = HashMap::new();
1163                m.insert(
1164                    "w".into(),
1165                    TensorShardSpec {
1166                        full_shape: vec![4, 3],
1167                        shard_dim: 0,
1168                        shard_sizes: vec![2, 2],
1169                    },
1170                );
1171                m
1172            },
1173        };
1174
1175        save_distributed(&state0, &dir, 0, 2, &spec).unwrap();
1176        save_distributed(&state1, &dir, 1, 2, &spec).unwrap();
1177
1178        // Reshard to 4 ranks: each gets 1 row.
1179        let r0 = reshard::<f32>(&dir, 2, 4, 0).unwrap();
1180        let r1 = reshard::<f32>(&dir, 2, 4, 1).unwrap();
1181        let r2 = reshard::<f32>(&dir, 2, 4, 2).unwrap();
1182        let r3 = reshard::<f32>(&dir, 2, 4, 3).unwrap();
1183
1184        assert_eq!(r0["w"].shape(), &[1, 3]);
1185        assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0, 3.0]);
1186        assert_eq!(r1["w"].shape(), &[1, 3]);
1187        assert_eq!(r1["w"].data().unwrap(), &[4.0, 5.0, 6.0]);
1188        assert_eq!(r2["w"].shape(), &[1, 3]);
1189        assert_eq!(r2["w"].data().unwrap(), &[7.0, 8.0, 9.0]);
1190        assert_eq!(r3["w"].shape(), &[1, 3]);
1191        assert_eq!(r3["w"].data().unwrap(), &[10.0, 11.0, 12.0]);
1192
1193        cleanup(&dir);
1194    }
1195
1196    #[test]
1197    fn test_reshard_dim1() {
1198        // Full tensor: [[1,2,3,4],[5,6,7,8]], shape [2, 4].
1199        // Shard along dim 1: rank 0 gets cols 0-1, rank 1 gets cols 2-3.
1200        let dir = temp_dir("reshard_dim1");
1201        cleanup(&dir);
1202
1203        let mut state0: HashMap<String, Tensor<f32>> = HashMap::new();
1204        state0.insert(
1205            "w".into(),
1206            make_tensor(vec![1.0, 2.0, 5.0, 6.0], vec![2, 2]),
1207        );
1208
1209        let mut state1: HashMap<String, Tensor<f32>> = HashMap::new();
1210        state1.insert(
1211            "w".into(),
1212            make_tensor(vec![3.0, 4.0, 7.0, 8.0], vec![2, 2]),
1213        );
1214
1215        let spec = ShardMetadata {
1216            num_ranks: 2,
1217            tensor_specs: {
1218                let mut m = HashMap::new();
1219                m.insert(
1220                    "w".into(),
1221                    TensorShardSpec {
1222                        full_shape: vec![2, 4],
1223                        shard_dim: 1,
1224                        shard_sizes: vec![2, 2],
1225                    },
1226                );
1227                m
1228            },
1229        };
1230
1231        save_distributed(&state0, &dir, 0, 2, &spec).unwrap();
1232        save_distributed(&state1, &dir, 1, 2, &spec).unwrap();
1233
1234        // Reshard to 4 ranks: each gets 1 column.
1235        let r0 = reshard::<f32>(&dir, 2, 4, 0).unwrap();
1236        let r1 = reshard::<f32>(&dir, 2, 4, 1).unwrap();
1237        let r2 = reshard::<f32>(&dir, 2, 4, 2).unwrap();
1238        let r3 = reshard::<f32>(&dir, 2, 4, 3).unwrap();
1239
1240        assert_eq!(r0["w"].shape(), &[2, 1]);
1241        assert_eq!(r0["w"].data().unwrap(), &[1.0, 5.0]);
1242        assert_eq!(r1["w"].shape(), &[2, 1]);
1243        assert_eq!(r1["w"].data().unwrap(), &[2.0, 6.0]);
1244        assert_eq!(r2["w"].shape(), &[2, 1]);
1245        assert_eq!(r2["w"].data().unwrap(), &[3.0, 7.0]);
1246        assert_eq!(r3["w"].shape(), &[2, 1]);
1247        assert_eq!(r3["w"].data().unwrap(), &[4.0, 8.0]);
1248
1249        cleanup(&dir);
1250    }
1251
1252    #[test]
1253    fn test_reshard_3_to_2_uneven() {
1254        // Full tensor: [1,2,3,4,5,6,7,8,9], shape [9].
1255        // 3 ranks with equal shards of 3 each.
1256        let dir = temp_dir("reshard_3_to_2");
1257        cleanup(&dir);
1258
1259        let spec = ShardMetadata {
1260            num_ranks: 3,
1261            tensor_specs: {
1262                let mut m = HashMap::new();
1263                m.insert(
1264                    "w".into(),
1265                    TensorShardSpec {
1266                        full_shape: vec![9],
1267                        shard_dim: 0,
1268                        shard_sizes: vec![3, 3, 3],
1269                    },
1270                );
1271                m
1272            },
1273        };
1274
1275        for rank in 0..3usize {
1276            let start = rank as f32 * 3.0 + 1.0;
1277            let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1278            state.insert(
1279                "w".into(),
1280                make_tensor(vec![start, start + 1.0, start + 2.0], vec![3]),
1281            );
1282            save_distributed(&state, &dir, rank, 3, &spec).unwrap();
1283        }
1284
1285        // Reshard to 2 ranks: 9 / 2 = 4 rem 1, so rank 0 gets 5, rank 1 gets 4.
1286        let r0 = reshard::<f32>(&dir, 3, 2, 0).unwrap();
1287        let r1 = reshard::<f32>(&dir, 3, 2, 1).unwrap();
1288
1289        assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
1290        assert_eq!(r1["w"].data().unwrap(), &[6.0, 7.0, 8.0, 9.0]);
1291
1292        cleanup(&dir);
1293    }
1294
1295    // --- load_distributed triggers resharding ---
1296
1297    #[test]
1298    fn test_load_distributed_reshards_when_world_size_differs() {
1299        let dir = temp_dir("load_reshard");
1300        cleanup(&dir);
1301
1302        // Save with 2 ranks.
1303        let spec = ShardMetadata {
1304            num_ranks: 2,
1305            tensor_specs: {
1306                let mut m = HashMap::new();
1307                m.insert(
1308                    "w".into(),
1309                    TensorShardSpec {
1310                        full_shape: vec![4],
1311                        shard_dim: 0,
1312                        shard_sizes: vec![2, 2],
1313                    },
1314                );
1315                m
1316            },
1317        };
1318
1319        let mut s0: HashMap<String, Tensor<f32>> = HashMap::new();
1320        s0.insert("w".into(), make_tensor(vec![1.0, 2.0], vec![2]));
1321        save_distributed(&s0, &dir, 0, 2, &spec).unwrap();
1322
1323        let mut s1: HashMap<String, Tensor<f32>> = HashMap::new();
1324        s1.insert("w".into(), make_tensor(vec![3.0, 4.0], vec![2]));
1325        save_distributed(&s1, &dir, 1, 2, &spec).unwrap();
1326
1327        // Load with 4 ranks — triggers resharding.
1328        let r0 = load_distributed::<f32>(&dir, 0, 4).unwrap();
1329        let r1 = load_distributed::<f32>(&dir, 1, 4).unwrap();
1330        let r2 = load_distributed::<f32>(&dir, 2, 4).unwrap();
1331        let r3 = load_distributed::<f32>(&dir, 3, 4).unwrap();
1332
1333        assert_eq!(r0["w"].data().unwrap(), &[1.0]);
1334        assert_eq!(r1["w"].data().unwrap(), &[2.0]);
1335        assert_eq!(r2["w"].data().unwrap(), &[3.0]);
1336        assert_eq!(r3["w"].data().unwrap(), &[4.0]);
1337
1338        cleanup(&dir);
1339    }
1340
1341    // --- Metadata serialization ---
1342
1343    #[test]
1344    fn test_metadata_roundtrip() {
1345        let spec = ShardMetadata {
1346            num_ranks: 4,
1347            tensor_specs: {
1348                let mut m = HashMap::new();
1349                m.insert(
1350                    "layer.weight".into(),
1351                    TensorShardSpec {
1352                        full_shape: vec![256, 512],
1353                        shard_dim: 0,
1354                        shard_sizes: vec![64, 64, 64, 64],
1355                    },
1356                );
1357                m.insert(
1358                    "layer.bias".into(),
1359                    TensorShardSpec {
1360                        full_shape: vec![256],
1361                        shard_dim: 0,
1362                        shard_sizes: vec![64, 64, 64, 64],
1363                    },
1364                );
1365                m
1366            },
1367        };
1368
1369        let json = serde_json::to_string_pretty(&spec).unwrap();
1370        let loaded: ShardMetadata = serde_json::from_str(&json).unwrap();
1371
1372        assert_eq!(loaded.num_ranks, 4);
1373        assert_eq!(loaded.tensor_specs.len(), 2);
1374        assert_eq!(
1375            loaded.tensor_specs["layer.weight"].full_shape,
1376            vec![256, 512]
1377        );
1378        assert_eq!(loaded.tensor_specs["layer.weight"].shard_dim, 0);
1379        assert_eq!(
1380            loaded.tensor_specs["layer.weight"].shard_sizes,
1381            vec![64, 64, 64, 64]
1382        );
1383    }
1384
1385    // --- compute_shard_sizes ---
1386
1387    #[test]
1388    fn test_compute_shard_sizes_even() {
1389        assert_eq!(compute_shard_sizes(8, 4), vec![2, 2, 2, 2]);
1390        assert_eq!(compute_shard_sizes(12, 3), vec![4, 4, 4]);
1391    }
1392
1393    #[test]
1394    fn test_compute_shard_sizes_uneven() {
1395        // 9 / 2 = 4 rem 1 -> [5, 4]
1396        assert_eq!(compute_shard_sizes(9, 2), vec![5, 4]);
1397        // 10 / 3 = 3 rem 1 -> [4, 3, 3]
1398        assert_eq!(compute_shard_sizes(10, 3), vec![4, 3, 3]);
1399        // 7 / 4 = 1 rem 3 -> [2, 2, 2, 1]
1400        assert_eq!(compute_shard_sizes(7, 4), vec![2, 2, 2, 1]);
1401    }
1402
1403    // --- concat_along_dim ---
1404
1405    #[test]
1406    fn test_concat_1d() {
1407        let data0 = vec![1.0f32, 2.0];
1408        let data1 = vec![3.0f32, 4.0, 5.0];
1409        let full_shape = vec![5];
1410
1411        let result =
1412            concat_along_dim(&[data0, data1], &[vec![2], vec![3]], 0, &full_shape).unwrap();
1413
1414        assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1415    }
1416
1417    #[test]
1418    fn test_concat_2d_dim0() {
1419        // Two 1x3 matrices -> 2x3 matrix.
1420        let data0 = vec![1.0f32, 2.0, 3.0];
1421        let data1 = vec![4.0f32, 5.0, 6.0];
1422        let full_shape = vec![2, 3];
1423
1424        let result =
1425            concat_along_dim(&[data0, data1], &[vec![1, 3], vec![1, 3]], 0, &full_shape).unwrap();
1426
1427        assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1428    }
1429
1430    #[test]
1431    fn test_concat_2d_dim1() {
1432        // Two 2x1 matrices -> 2x2 matrix.
1433        // data0 = [[1],[3]], data1 = [[2],[4]]
1434        let data0 = vec![1.0f32, 3.0];
1435        let data1 = vec![2.0f32, 4.0];
1436        let full_shape = vec![2, 2];
1437
1438        let result =
1439            concat_along_dim(&[data0, data1], &[vec![2, 1], vec![2, 1]], 1, &full_shape).unwrap();
1440
1441        assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0]);
1442    }
1443
1444    // --- slice_along_dim ---
1445
1446    #[test]
1447    fn test_slice_1d() {
1448        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
1449        let shape = vec![5];
1450
1451        let s0 = slice_along_dim(&data, &shape, 0, 0, 2);
1452        assert_eq!(s0, vec![1.0, 2.0]);
1453
1454        let s1 = slice_along_dim(&data, &shape, 0, 2, 3);
1455        assert_eq!(s1, vec![3.0, 4.0, 5.0]);
1456    }
1457
1458    #[test]
1459    fn test_slice_2d_dim0() {
1460        // [[1,2,3],[4,5,6],[7,8,9],[10,11,12]], shape [4,3]
1461        let data = vec![
1462            1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1463        ];
1464        let shape = vec![4, 3];
1465
1466        let s = slice_along_dim(&data, &shape, 0, 1, 2);
1467        assert_eq!(s, vec![4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
1468    }
1469
1470    #[test]
1471    fn test_slice_2d_dim1() {
1472        // [[1,2,3,4],[5,6,7,8]], shape [2,4]
1473        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1474        let shape = vec![2, 4];
1475
1476        let s = slice_along_dim(&data, &shape, 1, 1, 2);
1477        assert_eq!(s, vec![2.0, 3.0, 6.0, 7.0]);
1478    }
1479
1480    // --- flat_shard_metadata ---
1481
1482    #[test]
1483    fn test_flat_shard_metadata() {
1484        let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1485        state.insert("w".into(), make_tensor(vec![1.0, 2.0, 3.0], vec![3]));
1486
1487        let meta = flat_shard_metadata(&state, 4);
1488        assert_eq!(meta.num_ranks, 4);
1489
1490        let spec = &meta.tensor_specs["w"];
1491        assert_eq!(spec.full_shape, vec![12]); // 3 * 4
1492        assert_eq!(spec.shard_dim, 0);
1493        assert_eq!(spec.shard_sizes, vec![3, 3, 3, 3]);
1494    }
1495
1496    // --- AsyncCheckpointer ---
1497
1498    #[test]
1499    fn test_async_checkpoint_basic() {
1500        let dir = temp_dir("async_basic");
1501        cleanup(&dir);
1502
1503        let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1504        state.insert("w".into(), make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]));
1505
1506        let spec = ShardMetadata {
1507            num_ranks: 1,
1508            tensor_specs: {
1509                let mut m = HashMap::new();
1510                m.insert(
1511                    "w".into(),
1512                    TensorShardSpec {
1513                        full_shape: vec![4],
1514                        shard_dim: 0,
1515                        shard_sizes: vec![4],
1516                    },
1517                );
1518                m
1519            },
1520        };
1521
1522        let ckpt = AsyncCheckpointer::new(dir.clone(), 0, 1, spec);
1523        let mut future = ckpt.save_async(&state).unwrap();
1524        future.wait().unwrap();
1525
1526        // Verify the file was written correctly.
1527        let loaded = load_distributed::<f32>(&dir, 0, 1).unwrap();
1528        assert_eq!(loaded["w"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
1529
1530        cleanup(&dir);
1531    }
1532
1533    #[test]
1534    fn test_async_checkpoint_wait_idempotent() {
1535        let dir = temp_dir("async_idempotent");
1536        cleanup(&dir);
1537
1538        let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1539        state.insert("x".into(), make_tensor(vec![42.0], vec![1]));
1540
1541        let spec = ShardMetadata {
1542            num_ranks: 1,
1543            tensor_specs: {
1544                let mut m = HashMap::new();
1545                m.insert(
1546                    "x".into(),
1547                    TensorShardSpec {
1548                        full_shape: vec![1],
1549                        shard_dim: 0,
1550                        shard_sizes: vec![1],
1551                    },
1552                );
1553                m
1554            },
1555        };
1556
1557        let ckpt = AsyncCheckpointer::new(dir.clone(), 0, 1, spec);
1558        let mut future = ckpt.save_async(&state).unwrap();
1559
1560        // Wait twice — should not panic.
1561        future.wait().unwrap();
1562        future.wait().unwrap();
1563
1564        cleanup(&dir);
1565    }
1566
1567    #[test]
1568    fn test_async_checkpoint_is_done() {
1569        let dir = temp_dir("async_is_done");
1570        cleanup(&dir);
1571
1572        let mut state: HashMap<String, Tensor<f32>> = HashMap::new();
1573        state.insert("x".into(), make_tensor(vec![1.0], vec![1]));
1574
1575        let spec = ShardMetadata {
1576            num_ranks: 1,
1577            tensor_specs: {
1578                let mut m = HashMap::new();
1579                m.insert(
1580                    "x".into(),
1581                    TensorShardSpec {
1582                        full_shape: vec![1],
1583                        shard_dim: 0,
1584                        shard_sizes: vec![1],
1585                    },
1586                );
1587                m
1588            },
1589        };
1590
1591        let ckpt = AsyncCheckpointer::new(dir.clone(), 0, 1, spec);
1592        let mut future = ckpt.save_async(&state).unwrap();
1593        future.wait().unwrap();
1594        assert!(future.is_done());
1595
1596        cleanup(&dir);
1597    }
1598
1599    // --- Error cases ---
1600
1601    #[test]
1602    fn test_save_invalid_rank() {
1603        let dir = temp_dir("invalid_rank");
1604        let state: HashMap<String, Tensor<f32>> = HashMap::new();
1605        let spec = ShardMetadata {
1606            num_ranks: 2,
1607            tensor_specs: HashMap::new(),
1608        };
1609
1610        let result = save_distributed(&state, &dir, 5, 2, &spec);
1611        assert!(result.is_err());
1612    }
1613
1614    #[test]
1615    fn test_load_missing_metadata() {
1616        let dir = temp_dir("missing_meta");
1617        cleanup(&dir);
1618        std::fs::create_dir_all(&dir).unwrap();
1619
1620        let result = load_distributed::<f32>(&dir, 0, 1);
1621        assert!(result.is_err());
1622
1623        cleanup(&dir);
1624    }
1625
1626    #[test]
1627    fn test_load_missing_shard() {
1628        let dir = temp_dir("missing_shard");
1629        cleanup(&dir);
1630        std::fs::create_dir_all(&dir).unwrap();
1631
1632        // Write metadata but no shard file.
1633        let spec = ShardMetadata {
1634            num_ranks: 1,
1635            tensor_specs: HashMap::new(),
1636        };
1637        let json = serde_json::to_string_pretty(&spec).unwrap();
1638        std::fs::write(metadata_path(&dir), json).unwrap();
1639
1640        // This should work (empty tensor_specs, no shard needed for the direct load).
1641        // But if we try to load with a different world_size triggering reshard,
1642        // it needs the shard files.
1643        let loaded = load_distributed::<f32>(&dir, 0, 1).unwrap();
1644        assert!(loaded.is_empty());
1645
1646        cleanup(&dir);
1647    }
1648
1649    // --- Multiple tensors ---
1650
1651    #[test]
1652    fn test_reshard_multiple_tensors() {
1653        let dir = temp_dir("reshard_multi");
1654        cleanup(&dir);
1655
1656        let spec = ShardMetadata {
1657            num_ranks: 2,
1658            tensor_specs: {
1659                let mut m = HashMap::new();
1660                m.insert(
1661                    "weight".into(),
1662                    TensorShardSpec {
1663                        full_shape: vec![4],
1664                        shard_dim: 0,
1665                        shard_sizes: vec![2, 2],
1666                    },
1667                );
1668                m.insert(
1669                    "bias".into(),
1670                    TensorShardSpec {
1671                        full_shape: vec![6],
1672                        shard_dim: 0,
1673                        shard_sizes: vec![3, 3],
1674                    },
1675                );
1676                m
1677            },
1678        };
1679
1680        let mut s0: HashMap<String, Tensor<f32>> = HashMap::new();
1681        s0.insert("weight".into(), make_tensor(vec![1.0, 2.0], vec![2]));
1682        s0.insert("bias".into(), make_tensor(vec![10.0, 20.0, 30.0], vec![3]));
1683
1684        let mut s1: HashMap<String, Tensor<f32>> = HashMap::new();
1685        s1.insert("weight".into(), make_tensor(vec![3.0, 4.0], vec![2]));
1686        s1.insert("bias".into(), make_tensor(vec![40.0, 50.0, 60.0], vec![3]));
1687
1688        save_distributed(&s0, &dir, 0, 2, &spec).unwrap();
1689        save_distributed(&s1, &dir, 1, 2, &spec).unwrap();
1690
1691        // Reshard to 1 rank (consolidate).
1692        let r0 = load_distributed::<f32>(&dir, 0, 1).unwrap();
1693
1694        assert_eq!(r0["weight"].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
1695        assert_eq!(
1696            r0["bias"].data().unwrap(),
1697            &[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]
1698        );
1699
1700        cleanup(&dir);
1701    }
1702
1703    // --- DistributedCheckpoint struct ---
1704
1705    #[test]
1706    fn test_distributed_checkpoint_struct() {
1707        let ckpt = DistributedCheckpoint {
1708            checkpoint_dir: PathBuf::from("/tmp/test"),
1709            shard_metadata: ShardMetadata {
1710                num_ranks: 2,
1711                tensor_specs: HashMap::new(),
1712            },
1713        };
1714        assert_eq!(ckpt.checkpoint_dir, PathBuf::from("/tmp/test"));
1715        assert_eq!(ckpt.shard_metadata.num_ranks, 2);
1716    }
1717
1718    // Edge case: reshard to same world_size (no-op path through reshard fn)
1719    #[test]
1720    fn test_reshard_same_world_size() {
1721        let dir = temp_dir("reshard_same");
1722        cleanup(&dir);
1723
1724        let spec = ShardMetadata {
1725            num_ranks: 2,
1726            tensor_specs: {
1727                let mut m = HashMap::new();
1728                m.insert(
1729                    "w".into(),
1730                    TensorShardSpec {
1731                        full_shape: vec![4],
1732                        shard_dim: 0,
1733                        shard_sizes: vec![2, 2],
1734                    },
1735                );
1736                m
1737            },
1738        };
1739
1740        let mut s0: HashMap<String, Tensor<f32>> = HashMap::new();
1741        s0.insert("w".into(), make_tensor(vec![1.0, 2.0], vec![2]));
1742        save_distributed(&s0, &dir, 0, 2, &spec).unwrap();
1743
1744        let mut s1: HashMap<String, Tensor<f32>> = HashMap::new();
1745        s1.insert("w".into(), make_tensor(vec![3.0, 4.0], vec![2]));
1746        save_distributed(&s1, &dir, 1, 2, &spec).unwrap();
1747
1748        // Reshard with same world_size — should produce identical results.
1749        let r0 = reshard::<f32>(&dir, 2, 2, 0).unwrap();
1750        let r1 = reshard::<f32>(&dir, 2, 2, 1).unwrap();
1751
1752        assert_eq!(r0["w"].data().unwrap(), &[1.0, 2.0]);
1753        assert_eq!(r1["w"].data().unwrap(), &[3.0, 4.0]);
1754
1755        cleanup(&dir);
1756    }
1757}