burn_store/lib.rs
1#![cfg_attr(not(feature = "std"), no_std)]
2
3//! # Burn Store
4//!
5//! Advanced model storage and serialization infrastructure for the Burn deep learning framework.
6//!
7//! This crate provides comprehensive functionality for storing and loading Burn modules
8//! and their tensor data, with support for cross-framework interoperability, flexible filtering,
9//! and efficient memory management through lazy materialization.
10//!
11//! ## Key Features
12//!
13//! - **Burnpack Format**: Native Burn format with CBOR metadata, ParamId persistence for stateful training, and no-std support
14//! - **SafeTensors Format**: Industry-standard format for secure and efficient tensor serialization
15//! - **PyTorch Compatibility**: Load PyTorch models directly into Burn with automatic weight transformation
16//! - **Zero-Copy Loading**: Memory-mapped files and lazy tensor materialization for optimal performance
17//! - **Flexible Filtering**: Load/save specific model subsets using regex, exact paths, or custom predicates
18//! - **Tensor Remapping**: Rename tensors during load/save operations for framework compatibility
19//! - **No-std Support**: Core functionality available in embedded and WASM environments
20//!
21//! ## Quick Start
22//!
23//! ### Basic Save and Load
24//!
25//! ```rust,ignore
26//! use burn_store::{ModuleSnapshot, SafetensorsStore};
27//!
28//! // Save a model
29//! let mut store = SafetensorsStore::from_file("model.safetensors");
30//! model.save_into(&mut store)?;
31//!
32//! // Load a model
33//! let mut store = SafetensorsStore::from_file("model.safetensors");
34//! model.load_from(&mut store)?;
35//! ```
36//!
37//! ### Loading PyTorch Models
38//!
39//! ```rust,ignore
40//! use burn_store::PytorchStore;
41//!
42//! // Load PyTorch model (automatic weight transformation via PyTorchToBurnAdapter)
43//! let mut store = PytorchStore::from_file("pytorch_model.pth")
44//! .with_top_level_key("state_dict") // Access nested state dict if needed
45//! .allow_partial(true); // Skip unknown tensors
46//!
47//! model.load_from(&mut store)?;
48//! ```
49//!
50//! ### Filtering and Remapping
51//!
52//! ```rust,no_run
53//! # use burn_store::SafetensorsStore;
54//! // Save only specific layers with renaming
55//! let mut store = SafetensorsStore::from_file("encoder.safetensors")
56//! .with_regex(r"^encoder\..*") // Filter: only encoder layers
57//! .with_key_remapping(r"^encoder\.", "transformer.") // Rename: encoder.X -> transformer.X
58//! .metadata("subset", "encoder_only");
59//!
60//! // Use store with model.save_into(&mut store)?;
61//! ```
62//!
63//! ## Core Components
64//!
65//! - [`ModuleSnapshot`]: Extension trait for Burn modules providing `collect()` and `apply()` methods
66//! - [`BurnpackStore`]: Native Burn format with ParamId persistence for stateful training workflows
67//! - [`SafetensorsStore`]: Primary storage implementation supporting the SafeTensors format
68//! - [`PytorchStore`]: PyTorch model loader supporting .pth and .pt files
69//! - [`PathFilter`]: Flexible filtering system for selective tensor loading/saving
70//! - [`KeyRemapper`]: Advanced tensor name remapping with regex patterns
71//! - [`ModuleAdapter`]: Framework adapters for cross-framework compatibility
72//!
73//! ## Feature Flags
74//!
75//! - `std`: Enables file I/O and other std-only features (default)
76//! - `safetensors`: Enables SafeTensors format support (default)
77
78extern crate alloc;
79
80mod adapter;
81mod applier;
82mod apply_result;
83mod collector;
84mod filter;
85mod tensor_snapshot;
86mod traits;
87
88pub use adapter::{BurnToPyTorchAdapter, IdentityAdapter, ModuleAdapter, PyTorchToBurnAdapter};
89pub use applier::Applier;
90pub use apply_result::{ApplyError, ApplyResult};
91pub use collector::Collector;
92pub use filter::PathFilter;
93pub use tensor_snapshot::{TensorSnapshot, TensorSnapshotError};
94pub use traits::{ModuleSnapshot, ModuleStore};
95
96#[cfg(feature = "std")]
97mod keyremapper;
98#[cfg(feature = "std")]
99pub use keyremapper::{KeyRemapper, map_indices_contiguous};
100
101#[cfg(feature = "pytorch")]
102pub mod pytorch;
103#[cfg(feature = "pytorch")]
104pub use pytorch::{PytorchStore, PytorchStoreError};
105
106#[cfg(feature = "safetensors")]
107mod safetensors;
108#[cfg(feature = "safetensors")]
109pub use safetensors::{SafetensorsStore, SafetensorsStoreError};
110
111#[cfg(feature = "burnpack")]
112mod burnpack;
113#[cfg(feature = "burnpack")]
114pub use burnpack::store::BurnpackStore;
115#[cfg(feature = "burnpack")]
116pub use burnpack::writer::BurnpackWriter;