Skip to main content

burn_store/pytorch/
lazy_data.rs

1//! Lazy data loading support for PyTorch files.
2//!
3//! This module provides abstractions for lazy loading of tensor data from PyTorch files,
4//! avoiding the need to load all data into memory upfront.
5
6use alloc::string::String;
7use alloc::vec::Vec;
8use std::collections::HashMap;
9use std::fs::File;
10use std::io::{BufReader, Read, Seek};
11use std::path::{Path, PathBuf};
12use std::sync::{Arc, Mutex, RwLock};
13use zip::ZipArchive;
14
15/// A data source that can lazily load tensor data.
16#[derive(Clone)]
17pub enum LazyDataSource {
18    /// ZIP archive with lazy loading
19    Zip(Arc<Mutex<ZipSource>>),
20    /// Legacy format with multiple storages in single blob
21    LegacyMultiStorage(Arc<Mutex<LegacyMultiStorageSource>>),
22}
23
24/// ZIP archive source for lazy loading
25pub struct ZipSource {
26    path: PathBuf,
27    // Cache the file list to avoid reopening archive repeatedly
28    file_list: Vec<(String, u64, u64)>, // (name, offset, compressed_size)
29}
30
31/// Legacy multi-storage source for old PyTorch format (pre-1.6)
32///
33/// ## Format Analysis
34///
35/// Based on research into PyTorch's serialization.py and the legacy TAR format:
36///
37/// 1. **Storage Layout**: PyTorch legacy format (0.1.10-1.5) stores data as:
38///    - Pickle metadata containing tensor definitions
39///    - A list of storage keys in order
40///    - Raw binary data with all storages concatenated
41///
42/// 2. **Boundary Detection Challenge**: After extensive research, I found that:
43///    - PyTorch does NOT store explicit storage boundaries in the file
44///    - Storages are concatenated in the order specified by the storage keys list
45///    - Each tensor references its storage by key and specifies offset/size
46///
47/// 3. **Why True Lazy Loading is Difficult**:
48///    - To determine storage boundaries, we would need to:
49///      a. Parse ALL tensor metadata to find which storage each uses
50///      b. Track the maximum extent of each storage based on tensor usage
51///      c. Infer boundaries from the gaps between storages
52///    - However, the TensorSnapshot abstraction hides storage keys in closures
53///    - This would require deep modifications to the pickle parsing logic
54///
55/// ## Current Implementation
56///
57/// This implementation provides a best-effort approach:
58/// - Supports setting a storage map if boundaries can be determined externally
59/// - Falls back to loading the entire blob if boundaries are unknown
60pub struct LegacyMultiStorageSource {
61    path: PathBuf,
62    data_offset: u64,
63    #[allow(dead_code)]
64    data_size: u64,
65    // Map of storage_key -> (offset_in_blob, size)
66    storage_map: RwLock<Option<HashMap<String, (u64, u64)>>>,
67    // Storage keys in order (for boundary calculation)
68    storage_keys: RwLock<Option<Vec<String>>>,
69    // Track storage usage as tensors are accessed
70    storage_usage: RwLock<HashMap<String, usize>>, // key -> max_bytes_needed
71}
72
73impl ZipSource {
74    /// Create a new ZIP source
75    pub fn new(path: PathBuf) -> std::io::Result<Self> {
76        let file = File::open(&path)?;
77        let reader = BufReader::new(file);
78        let mut archive = ZipArchive::new(reader)?;
79
80        // Cache file metadata
81        let mut file_list = Vec::new();
82        for i in 0..archive.len() {
83            let file = archive.by_index(i)?;
84            let name = file.name().to_string();
85            let offset = file.data_start();
86            let compressed_size = file.compressed_size();
87            file_list.push((name, offset, compressed_size));
88        }
89
90        Ok(Self { path, file_list })
91    }
92
93    /// Check if a file exists in the archive
94    pub fn contains(&self, name: &str) -> bool {
95        self.file_list.iter().any(|(n, _, _)| n == name)
96    }
97
98    /// Get list of data files (excluding pickle files)
99    pub fn data_files(&self) -> Vec<String> {
100        self.file_list
101            .iter()
102            .filter(|(name, _, _)| name.starts_with("data/") || name.contains("/data/"))
103            .filter(|(name, _, _)| !name.ends_with(".pkl") && !name.ends_with("/"))
104            .map(|(name, _, _)| name.clone())
105            .collect()
106    }
107
108    /// Read a specific file from the archive
109    pub fn read_file(&self, name: &str) -> std::io::Result<Vec<u8>> {
110        let file = File::open(&self.path)?;
111        let reader = BufReader::new(file);
112        let mut archive = ZipArchive::new(reader)?;
113
114        let mut file = archive.by_name(name)?;
115        let mut contents = Vec::with_capacity(file.size() as usize);
116        file.read_to_end(&mut contents)?;
117        Ok(contents)
118    }
119
120    /// Read a portion of a file
121    pub fn read_file_range(
122        &self,
123        name: &str,
124        offset: usize,
125        length: usize,
126    ) -> std::io::Result<Vec<u8>> {
127        let file = File::open(&self.path)?;
128        let reader = BufReader::new(file);
129        let mut archive = ZipArchive::new(reader)?;
130
131        let mut file = archive.by_name(name)?;
132        let mut buffer = vec![0u8; length];
133
134        // Skip to offset
135        let mut skip_buffer = vec![0u8; offset.min(8192)];
136        let mut skipped = 0;
137        while skipped < offset {
138            let to_skip = (offset - skipped).min(skip_buffer.len());
139            file.read_exact(&mut skip_buffer[..to_skip])?;
140            skipped += to_skip;
141        }
142
143        // Read the requested data
144        file.read_exact(&mut buffer)?;
145        Ok(buffer)
146    }
147}
148
149impl LegacyMultiStorageSource {
150    /// Create a new legacy multi-storage source
151    pub fn new(path: PathBuf, data_offset: u64, data_size: u64) -> Self {
152        Self {
153            path,
154            data_offset,
155            data_size,
156            storage_map: RwLock::new(None),
157            storage_keys: RwLock::new(None),
158            storage_usage: RwLock::new(HashMap::new()),
159        }
160    }
161
162    /// Set the ordered storage keys from the pickle
163    pub fn set_storage_keys(&self, keys: Vec<String>) {
164        let mut storage_keys = self
165            .storage_keys
166            .write()
167            .unwrap_or_else(|poisoned| poisoned.into_inner());
168        *storage_keys = Some(keys);
169    }
170
171    /// Track storage usage from tensor access
172    /// This is called from within tensor loading closures
173    pub fn track_storage_usage(&self, storage_key: &str, offset: usize, size: usize) {
174        let mut usage = self
175            .storage_usage
176            .write()
177            .unwrap_or_else(|poisoned| poisoned.into_inner());
178        let max_extent = offset + size;
179        usage
180            .entry(storage_key.to_string())
181            .and_modify(|current| *current = (*current).max(max_extent))
182            .or_insert(max_extent);
183
184        // Try to build storage map if we have enough information
185        self.try_build_storage_map();
186    }
187
188    /// Try to build the storage map from tracked usage
189    fn try_build_storage_map(&self) {
190        // Only build if we don't already have a map
191        if self
192            .storage_map
193            .read()
194            .unwrap_or_else(|poisoned| poisoned.into_inner())
195            .is_some()
196        {
197            return;
198        }
199
200        // Check if we have storage keys
201        let keys_guard = self
202            .storage_keys
203            .read()
204            .unwrap_or_else(|poisoned| poisoned.into_inner());
205        if let Some(ref keys) = *keys_guard {
206            let usage = self
207                .storage_usage
208                .read()
209                .unwrap_or_else(|poisoned| poisoned.into_inner());
210
211            // Only build if we have usage info for all storages
212            if keys.iter().all(|k| usage.contains_key(k)) {
213                let mut map = HashMap::new();
214                let mut current_offset = 0u64;
215
216                for key in keys {
217                    if let Some(&size) = usage.get(key) {
218                        map.insert(key.clone(), (current_offset, size as u64));
219                        current_offset += size as u64;
220                    }
221                }
222
223                // Set the storage map
224                drop(keys_guard);
225                drop(usage);
226                let mut storage_map = self
227                    .storage_map
228                    .write()
229                    .unwrap_or_else(|poisoned| poisoned.into_inner());
230                *storage_map = Some(map);
231            }
232        }
233    }
234
235    /// Read data for a specific storage key
236    /// Only loads the specific storage portion, never the entire blob
237    pub fn read(&self, key: &str) -> std::io::Result<Vec<u8>> {
238        // Extract numeric key from paths like "data/0" or just "0"
239        let storage_key = key.split('/').next_back().unwrap_or(key);
240
241        // Get storage map - must be available for lazy loading to work
242        let storage_map = self
243            .storage_map
244            .read()
245            .unwrap_or_else(|poisoned| poisoned.into_inner());
246        if let Some(ref map) = *storage_map
247            && let Some(&(offset, size)) = map.get(storage_key)
248        {
249            // Load only this specific storage
250            let mut file = File::open(&self.path)?;
251            file.seek(std::io::SeekFrom::Start(self.data_offset + offset))?;
252
253            let mut buffer = vec![0u8; size as usize];
254            file.read_exact(&mut buffer)?;
255            return Ok(buffer);
256        }
257
258        // NO FALLBACK! If we don't have storage boundaries, we cannot load data lazily
259        // The storage map MUST be built from tensor metadata for lazy loading to work
260        Err(std::io::Error::new(
261            std::io::ErrorKind::InvalidData,
262            format!(
263                "Storage boundaries not available for key '{}'. Cannot perform lazy loading.",
264                storage_key
265            ),
266        ))
267    }
268}
269
270impl LazyDataSource {
271    /// Create from a ZIP file
272    pub fn from_zip(path: impl AsRef<Path>) -> std::io::Result<Self> {
273        Ok(Self::Zip(Arc::new(Mutex::new(ZipSource::new(
274            path.as_ref().to_path_buf(),
275        )?))))
276    }
277
278    /// Create from a legacy multi-storage file
279    pub fn from_legacy_multi_storage(
280        path: impl AsRef<Path>,
281        data_offset: u64,
282        data_size: u64,
283    ) -> Self {
284        Self::LegacyMultiStorage(Arc::new(Mutex::new(LegacyMultiStorageSource::new(
285            path.as_ref().to_path_buf(),
286            data_offset,
287            data_size,
288        ))))
289    }
290
291    /// Read data for a specific key
292    pub fn read(&self, key: &str) -> std::io::Result<Vec<u8>> {
293        match self {
294            Self::Zip(source) => {
295                let source = source
296                    .lock()
297                    .unwrap_or_else(|poisoned| poisoned.into_inner());
298                source.read_file(key)
299            }
300            Self::LegacyMultiStorage(source) => {
301                let source = source
302                    .lock()
303                    .unwrap_or_else(|poisoned| poisoned.into_inner());
304                source.read(key)
305            }
306        }
307    }
308
309    /// Read a portion of data for a specific key
310    pub fn read_range(&self, key: &str, offset: usize, length: usize) -> std::io::Result<Vec<u8>> {
311        match self {
312            Self::Zip(source) => {
313                let source = source
314                    .lock()
315                    .unwrap_or_else(|poisoned| poisoned.into_inner());
316                source.read_file_range(key, offset, length)
317            }
318            Self::LegacyMultiStorage(source) => {
319                // For legacy format, read only the requested range
320                let storage_key = key.split('/').next_back().unwrap_or(key);
321                let source = source
322                    .lock()
323                    .unwrap_or_else(|poisoned| poisoned.into_inner());
324
325                // Get storage boundaries
326                let storage_map = source
327                    .storage_map
328                    .read()
329                    .unwrap_or_else(|poisoned| poisoned.into_inner());
330                if let Some(ref map) = *storage_map
331                    && let Some(&(storage_offset, storage_size)) = map.get(storage_key)
332                {
333                    // Calculate actual file position
334                    let file_offset = source.data_offset + storage_offset + offset as u64;
335                    let read_length = length.min((storage_size as usize).saturating_sub(offset));
336
337                    // Read only the requested range
338                    let mut file = File::open(&source.path)?;
339                    file.seek(std::io::SeekFrom::Start(file_offset))?;
340
341                    let mut buffer = vec![0u8; read_length];
342                    file.read_exact(&mut buffer)?;
343                    Ok(buffer)
344                } else {
345                    Err(std::io::Error::new(
346                        std::io::ErrorKind::InvalidData,
347                        format!(
348                            "Storage boundaries not available for key '{}'. Cannot perform lazy loading.",
349                            storage_key
350                        ),
351                    ))
352                }
353            }
354        }
355    }
356
357    /// Check if a key exists
358    pub fn contains(&self, key: &str) -> bool {
359        match self {
360            Self::Zip(source) => {
361                let source = source
362                    .lock()
363                    .unwrap_or_else(|poisoned| poisoned.into_inner());
364                source.contains(key)
365            }
366            Self::LegacyMultiStorage(_) => true, // Legacy format has all data
367        }
368    }
369
370    /// Get list of available keys (for ZIP sources)
371    pub fn keys(&self) -> Vec<String> {
372        match self {
373            Self::Zip(source) => {
374                let source = source
375                    .lock()
376                    .unwrap_or_else(|poisoned| poisoned.into_inner());
377                source.data_files()
378            }
379            Self::LegacyMultiStorage(_) => vec![], // Legacy format doesn't have distinct keys
380        }
381    }
382}