burn_store/pytorch/
lazy_data.rs1use alloc::string::String;
7use alloc::vec::Vec;
8use std::collections::HashMap;
9use std::fs::File;
10use std::io::{BufReader, Read, Seek};
11use std::path::{Path, PathBuf};
12use std::sync::{Arc, Mutex, RwLock};
13use zip::ZipArchive;
14
15#[derive(Clone)]
17pub enum LazyDataSource {
18 Zip(Arc<Mutex<ZipSource>>),
20 LegacyMultiStorage(Arc<Mutex<LegacyMultiStorageSource>>),
22}
23
24pub struct ZipSource {
26 path: PathBuf,
27 file_list: Vec<(String, u64, u64)>, }
30
31pub struct LegacyMultiStorageSource {
61 path: PathBuf,
62 data_offset: u64,
63 #[allow(dead_code)]
64 data_size: u64,
65 storage_map: RwLock<Option<HashMap<String, (u64, u64)>>>,
67 storage_keys: RwLock<Option<Vec<String>>>,
69 storage_usage: RwLock<HashMap<String, usize>>, }
72
73impl ZipSource {
74 pub fn new(path: PathBuf) -> std::io::Result<Self> {
76 let file = File::open(&path)?;
77 let reader = BufReader::new(file);
78 let mut archive = ZipArchive::new(reader)?;
79
80 let mut file_list = Vec::new();
82 for i in 0..archive.len() {
83 let file = archive.by_index(i)?;
84 let name = file.name().to_string();
85 let offset = file.data_start();
86 let compressed_size = file.compressed_size();
87 file_list.push((name, offset, compressed_size));
88 }
89
90 Ok(Self { path, file_list })
91 }
92
93 pub fn contains(&self, name: &str) -> bool {
95 self.file_list.iter().any(|(n, _, _)| n == name)
96 }
97
98 pub fn data_files(&self) -> Vec<String> {
100 self.file_list
101 .iter()
102 .filter(|(name, _, _)| name.starts_with("data/") || name.contains("/data/"))
103 .filter(|(name, _, _)| !name.ends_with(".pkl") && !name.ends_with("/"))
104 .map(|(name, _, _)| name.clone())
105 .collect()
106 }
107
108 pub fn read_file(&self, name: &str) -> std::io::Result<Vec<u8>> {
110 let file = File::open(&self.path)?;
111 let reader = BufReader::new(file);
112 let mut archive = ZipArchive::new(reader)?;
113
114 let mut file = archive.by_name(name)?;
115 let mut contents = Vec::with_capacity(file.size() as usize);
116 file.read_to_end(&mut contents)?;
117 Ok(contents)
118 }
119
120 pub fn read_file_range(
122 &self,
123 name: &str,
124 offset: usize,
125 length: usize,
126 ) -> std::io::Result<Vec<u8>> {
127 let file = File::open(&self.path)?;
128 let reader = BufReader::new(file);
129 let mut archive = ZipArchive::new(reader)?;
130
131 let mut file = archive.by_name(name)?;
132 let mut buffer = vec![0u8; length];
133
134 let mut skip_buffer = vec![0u8; offset.min(8192)];
136 let mut skipped = 0;
137 while skipped < offset {
138 let to_skip = (offset - skipped).min(skip_buffer.len());
139 file.read_exact(&mut skip_buffer[..to_skip])?;
140 skipped += to_skip;
141 }
142
143 file.read_exact(&mut buffer)?;
145 Ok(buffer)
146 }
147}
148
149impl LegacyMultiStorageSource {
150 pub fn new(path: PathBuf, data_offset: u64, data_size: u64) -> Self {
152 Self {
153 path,
154 data_offset,
155 data_size,
156 storage_map: RwLock::new(None),
157 storage_keys: RwLock::new(None),
158 storage_usage: RwLock::new(HashMap::new()),
159 }
160 }
161
162 pub fn set_storage_keys(&self, keys: Vec<String>) {
164 let mut storage_keys = self
165 .storage_keys
166 .write()
167 .unwrap_or_else(|poisoned| poisoned.into_inner());
168 *storage_keys = Some(keys);
169 }
170
171 pub fn track_storage_usage(&self, storage_key: &str, offset: usize, size: usize) {
174 let mut usage = self
175 .storage_usage
176 .write()
177 .unwrap_or_else(|poisoned| poisoned.into_inner());
178 let max_extent = offset + size;
179 usage
180 .entry(storage_key.to_string())
181 .and_modify(|current| *current = (*current).max(max_extent))
182 .or_insert(max_extent);
183
184 self.try_build_storage_map();
186 }
187
188 fn try_build_storage_map(&self) {
190 if self
192 .storage_map
193 .read()
194 .unwrap_or_else(|poisoned| poisoned.into_inner())
195 .is_some()
196 {
197 return;
198 }
199
200 let keys_guard = self
202 .storage_keys
203 .read()
204 .unwrap_or_else(|poisoned| poisoned.into_inner());
205 if let Some(ref keys) = *keys_guard {
206 let usage = self
207 .storage_usage
208 .read()
209 .unwrap_or_else(|poisoned| poisoned.into_inner());
210
211 if keys.iter().all(|k| usage.contains_key(k)) {
213 let mut map = HashMap::new();
214 let mut current_offset = 0u64;
215
216 for key in keys {
217 if let Some(&size) = usage.get(key) {
218 map.insert(key.clone(), (current_offset, size as u64));
219 current_offset += size as u64;
220 }
221 }
222
223 drop(keys_guard);
225 drop(usage);
226 let mut storage_map = self
227 .storage_map
228 .write()
229 .unwrap_or_else(|poisoned| poisoned.into_inner());
230 *storage_map = Some(map);
231 }
232 }
233 }
234
235 pub fn read(&self, key: &str) -> std::io::Result<Vec<u8>> {
238 let storage_key = key.split('/').next_back().unwrap_or(key);
240
241 let storage_map = self
243 .storage_map
244 .read()
245 .unwrap_or_else(|poisoned| poisoned.into_inner());
246 if let Some(ref map) = *storage_map
247 && let Some(&(offset, size)) = map.get(storage_key)
248 {
249 let mut file = File::open(&self.path)?;
251 file.seek(std::io::SeekFrom::Start(self.data_offset + offset))?;
252
253 let mut buffer = vec![0u8; size as usize];
254 file.read_exact(&mut buffer)?;
255 return Ok(buffer);
256 }
257
258 Err(std::io::Error::new(
261 std::io::ErrorKind::InvalidData,
262 format!(
263 "Storage boundaries not available for key '{}'. Cannot perform lazy loading.",
264 storage_key
265 ),
266 ))
267 }
268}
269
270impl LazyDataSource {
271 pub fn from_zip(path: impl AsRef<Path>) -> std::io::Result<Self> {
273 Ok(Self::Zip(Arc::new(Mutex::new(ZipSource::new(
274 path.as_ref().to_path_buf(),
275 )?))))
276 }
277
278 pub fn from_legacy_multi_storage(
280 path: impl AsRef<Path>,
281 data_offset: u64,
282 data_size: u64,
283 ) -> Self {
284 Self::LegacyMultiStorage(Arc::new(Mutex::new(LegacyMultiStorageSource::new(
285 path.as_ref().to_path_buf(),
286 data_offset,
287 data_size,
288 ))))
289 }
290
291 pub fn read(&self, key: &str) -> std::io::Result<Vec<u8>> {
293 match self {
294 Self::Zip(source) => {
295 let source = source
296 .lock()
297 .unwrap_or_else(|poisoned| poisoned.into_inner());
298 source.read_file(key)
299 }
300 Self::LegacyMultiStorage(source) => {
301 let source = source
302 .lock()
303 .unwrap_or_else(|poisoned| poisoned.into_inner());
304 source.read(key)
305 }
306 }
307 }
308
309 pub fn read_range(&self, key: &str, offset: usize, length: usize) -> std::io::Result<Vec<u8>> {
311 match self {
312 Self::Zip(source) => {
313 let source = source
314 .lock()
315 .unwrap_or_else(|poisoned| poisoned.into_inner());
316 source.read_file_range(key, offset, length)
317 }
318 Self::LegacyMultiStorage(source) => {
319 let storage_key = key.split('/').next_back().unwrap_or(key);
321 let source = source
322 .lock()
323 .unwrap_or_else(|poisoned| poisoned.into_inner());
324
325 let storage_map = source
327 .storage_map
328 .read()
329 .unwrap_or_else(|poisoned| poisoned.into_inner());
330 if let Some(ref map) = *storage_map
331 && let Some(&(storage_offset, storage_size)) = map.get(storage_key)
332 {
333 let file_offset = source.data_offset + storage_offset + offset as u64;
335 let read_length = length.min((storage_size as usize).saturating_sub(offset));
336
337 let mut file = File::open(&source.path)?;
339 file.seek(std::io::SeekFrom::Start(file_offset))?;
340
341 let mut buffer = vec![0u8; read_length];
342 file.read_exact(&mut buffer)?;
343 Ok(buffer)
344 } else {
345 Err(std::io::Error::new(
346 std::io::ErrorKind::InvalidData,
347 format!(
348 "Storage boundaries not available for key '{}'. Cannot perform lazy loading.",
349 storage_key
350 ),
351 ))
352 }
353 }
354 }
355 }
356
357 pub fn contains(&self, key: &str) -> bool {
359 match self {
360 Self::Zip(source) => {
361 let source = source
362 .lock()
363 .unwrap_or_else(|poisoned| poisoned.into_inner());
364 source.contains(key)
365 }
366 Self::LegacyMultiStorage(_) => true, }
368 }
369
370 pub fn keys(&self) -> Vec<String> {
372 match self {
373 Self::Zip(source) => {
374 let source = source
375 .lock()
376 .unwrap_or_else(|poisoned| poisoned.into_inner());
377 source.data_files()
378 }
379 Self::LegacyMultiStorage(_) => vec![], }
381 }
382}