burn_store/burnpack/
store.rs1#[cfg(feature = "std")]
2use std::path::PathBuf;
3
4use super::reader::BurnpackReader;
5use super::writer::BurnpackWriter;
6#[cfg(feature = "std")]
7use crate::KeyRemapper;
8use crate::burnpack::base::BurnpackError;
9use crate::{ModuleSnapshot, ModuleStore, PathFilter};
10use alloc::collections::BTreeMap;
11use alloc::format;
12use alloc::string::String;
13use burn_core::prelude::Backend;
14use burn_tensor::Bytes;
15
16enum StoreMode {
18 #[cfg(feature = "std")]
19 File(PathBuf),
20 Bytes(Option<Bytes>),
21}
22
23pub struct BurnpackStore {
25 mode: StoreMode,
27 filter: Option<PathFilter>,
29 metadata: BTreeMap<String, String>,
31 allow_partial: bool,
33 validate: bool,
35 overwrite: bool,
37 #[cfg(feature = "std")]
39 auto_extension: bool,
40 #[cfg(feature = "std")]
42 remapper: KeyRemapper,
43 writer: Option<BurnpackWriter>,
45 reader: Option<BurnpackReader>,
47}
48
49impl BurnpackStore {
50 pub fn default_metadata() -> BTreeMap<String, String> {
59 let mut metadata = BTreeMap::new();
60 metadata.insert("format".into(), "burnpack".into());
61 metadata.insert("producer".into(), "burn".into());
62 metadata.insert("version".into(), env!("CARGO_PKG_VERSION").into());
63 metadata
64 }
65 #[cfg(feature = "std")]
85 pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> Self {
86 Self {
87 mode: StoreMode::File(path.as_ref().to_path_buf()),
88 filter: None,
89 metadata: Self::default_metadata(),
90 allow_partial: false,
91 validate: true,
92 overwrite: false,
93 #[cfg(feature = "std")]
94 auto_extension: true,
95 #[cfg(feature = "std")]
96 remapper: KeyRemapper::new(),
97 writer: None,
98 reader: None,
99 }
100 }
101
102 pub fn from_bytes(bytes: Option<Bytes>) -> Self {
104 Self {
105 mode: StoreMode::Bytes(bytes),
106 filter: None,
107 metadata: Self::default_metadata(),
108 allow_partial: false,
109 validate: true,
110 overwrite: false,
111 #[cfg(feature = "std")]
112 auto_extension: false, #[cfg(feature = "std")]
114 remapper: KeyRemapper::new(),
115 writer: None,
116 reader: None,
117 }
118 }
119
120 pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
122 self.metadata.insert(key.into(), value.into());
123 self
124 }
125
126 pub fn clear_metadata(mut self) -> Self {
131 self.metadata.clear();
132 self
133 }
134
135 pub fn allow_partial(mut self, allow: bool) -> Self {
142 self.allow_partial = allow;
143 self
144 }
145
146 pub fn validate(mut self, validate: bool) -> Self {
154 self.validate = validate;
155 self
156 }
157
158 pub fn overwrite(mut self, overwrite: bool) -> Self {
165 self.overwrite = overwrite;
166 self
167 }
168
169 #[cfg(feature = "std")]
190 pub fn auto_extension(mut self, enable: bool) -> Self {
191 self.auto_extension = enable;
192 self
193 }
194
195 pub fn with_filter(mut self, filter: PathFilter) -> Self {
197 self.filter = Some(filter);
198 self
199 }
200
201 #[cfg(feature = "std")]
203 pub fn with_regex(mut self, pattern: &str) -> Self {
204 let filter = self.filter.unwrap_or_default();
205 self.filter = Some(filter.with_regex(pattern));
206 self
207 }
208
209 pub fn with_full_path(mut self, path: impl Into<String>) -> Self {
211 let filter = self.filter.unwrap_or_default();
212 self.filter = Some(filter.with_full_path(path));
213 self
214 }
215
216 pub fn match_all(mut self) -> Self {
218 self.filter = Some(PathFilter::new().match_all());
219 self
220 }
221
222 #[cfg(feature = "std")]
224 pub fn remap(mut self, remapper: KeyRemapper) -> Self {
225 self.remapper = remapper;
226 self
227 }
228
229 #[cfg(feature = "std")]
231 pub fn with_remap_pattern<S1, S2>(mut self, from: S1, to: S2) -> Self
232 where
233 S1: AsRef<str>,
234 S2: Into<String>,
235 {
236 self.remapper = self
237 .remapper
238 .add_pattern(from.as_ref(), to.into())
239 .expect("Invalid regex pattern");
240 self
241 }
242
243 pub fn filter(mut self, filter: PathFilter) -> Self {
245 self.filter = Some(filter);
246 self
247 }
248
249 pub fn get_bytes(&self) -> Result<Bytes, BurnpackError> {
251 if let Some(writer) = &self.writer {
252 return writer.to_bytes();
253 }
254
255 match &self.mode {
256 StoreMode::Bytes(Some(bytes)) => Ok(bytes.clone()),
257 _ => Err(BurnpackError::IoError("No bytes available".into())),
258 }
259 }
260
261 #[cfg(feature = "std")]
263 fn process_path(&self, path: &std::path::Path) -> PathBuf {
264 if !self.auto_extension {
265 return path.to_path_buf();
266 }
267
268 if path.extension().is_some() {
270 return path.to_path_buf();
272 }
273
274 let mut new_path = path.to_path_buf();
276 new_path.set_extension("bpk");
277 new_path
278 }
279}
280
281impl ModuleStore for BurnpackStore {
282 type Error = BurnpackError;
283
284 fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
285 &mut self,
286 module: &M,
287 ) -> Result<(), Self::Error> {
288 let snapshots = module.collect(self.filter.clone(), None);
290
291 let mut writer = BurnpackWriter::new(snapshots);
293
294 for (key, value) in &self.metadata {
296 writer = writer.with_metadata(key.as_str(), value.as_str());
297 }
298
299 self.writer = Some(writer);
301
302 if let Some(writer) = &self.writer {
304 match &self.mode {
305 #[cfg(feature = "std")]
306 StoreMode::File(path) => {
307 let final_path = self.process_path(path);
309
310 if final_path.exists() && !self.overwrite {
312 return Err(BurnpackError::IoError(format!(
313 "File already exists: {}. Use .overwrite(true) to overwrite.",
314 final_path.display()
315 )));
316 }
317 writer.write_to_file(&final_path)?;
318 }
319 StoreMode::Bytes(_) => {
320 let bytes_data = writer.to_bytes()?;
322 #[cfg_attr(not(feature = "std"), allow(irrefutable_let_patterns))]
324 let StoreMode::Bytes(bytes_ref) = &mut self.mode else {
325 unreachable!("We just matched Bytes variant");
326 };
327 *bytes_ref = Some(bytes_data);
328 }
329 }
330 }
331
332 Ok(())
333 }
334
335 fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
336 &mut self,
337 module: &mut M,
338 ) -> Result<crate::ApplyResult, Self::Error> {
339 if self.reader.is_none() {
341 let reader = match &self.mode {
342 #[cfg(feature = "std")]
343 StoreMode::File(path) => {
344 let final_path = self.process_path(path);
346 BurnpackReader::from_file(&final_path)?
347 }
348 StoreMode::Bytes(Some(bytes)) => BurnpackReader::from_bytes(bytes.clone())?,
349 StoreMode::Bytes(None) => {
350 return Err(BurnpackError::IoError("No bytes to read from".into()));
351 }
352 };
353 self.reader = Some(reader);
354 }
355
356 let reader = self
357 .reader
358 .as_ref()
359 .ok_or_else(|| BurnpackError::IoError("Reader not initialized".into()))?;
360
361 #[cfg(feature = "std")]
363 let snapshots = if !self.remapper.patterns.is_empty() {
364 let (remapped, _remapped_names) = self.remapper.remap(reader.get_snapshots()?);
365 remapped
367 } else {
368 reader.get_snapshots()?
369 };
370
371 #[cfg(not(feature = "std"))]
372 let snapshots = reader.get_snapshots()?;
373
374 let result = module.apply(snapshots, self.filter.clone(), None);
376
377 if self.validate && !result.errors.is_empty() {
379 return Err(BurnpackError::ValidationError(format!(
380 "Import errors: {:?}",
381 result.errors
382 )));
383 }
384
385 if !self.allow_partial && !result.missing.is_empty() {
387 return Err(BurnpackError::ValidationError(format!(
388 "Missing tensors: {:?}",
389 result.missing
390 )));
391 }
392
393 Ok(result)
394 }
395}