burn_store/pytorch/
store.rs1use crate::{
4 ApplyResult, KeyRemapper, ModuleSnapshot, ModuleStore, PathFilter, PyTorchToBurnAdapter,
5 TensorSnapshot,
6};
7
8use alloc::format;
9use alloc::string::{String, ToString};
10use alloc::vec::Vec;
11use burn_tensor::backend::Backend;
12use core::fmt;
13use std::path::PathBuf;
14
15use super::reader::{PytorchError as ReaderError, PytorchReader};
16
17#[derive(Debug)]
19pub enum PytorchStoreError {
20 Reader(ReaderError),
22
23 Io(std::io::Error),
25
26 TensorNotFound(String),
28
29 ValidationFailed(String),
31
32 Other(String),
34}
35
36impl fmt::Display for PytorchStoreError {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 match self {
39 Self::Reader(e) => write!(f, "PyTorch reader error: {}", e),
40 Self::Io(e) => write!(f, "I/O error: {}", e),
41 Self::TensorNotFound(name) => write!(f, "Tensor not found: {}", name),
42 Self::ValidationFailed(msg) => write!(f, "Validation failed: {}", msg),
43 Self::Other(msg) => write!(f, "{}", msg),
44 }
45 }
46}
47
48impl std::error::Error for PytorchStoreError {}
49
50impl From<ReaderError> for PytorchStoreError {
51 fn from(e: ReaderError) -> Self {
52 PytorchStoreError::Reader(e)
53 }
54}
55
56impl From<std::io::Error> for PytorchStoreError {
57 fn from(e: std::io::Error) -> Self {
58 PytorchStoreError::Io(e)
59 }
60}
61
62pub struct PytorchStore {
71 pub(crate) path: PathBuf,
72 pub(crate) filter: PathFilter,
73 pub(crate) remapper: KeyRemapper,
74 pub(crate) validate: bool,
75 pub(crate) allow_partial: bool,
76 pub(crate) top_level_key: Option<String>,
77}
78
79impl PytorchStore {
80 pub fn from_file(path: impl Into<PathBuf>) -> Self {
92 Self {
93 path: path.into(),
94 filter: PathFilter::new(),
95 remapper: KeyRemapper::new(),
96 validate: true,
97 allow_partial: false,
98 top_level_key: None,
99 }
100 }
101
102 pub fn with_top_level_key(mut self, key: impl Into<String>) -> Self {
114 self.top_level_key = Some(key.into());
115 self
116 }
117
118 pub fn filter(mut self, filter: PathFilter) -> Self {
120 self.filter = filter;
121 self
122 }
123
124 pub fn with_regex<S: AsRef<str>>(mut self, pattern: S) -> Self {
136 self.filter = self.filter.with_regex(pattern);
137 self
138 }
139
140 pub fn with_regexes<I, S>(mut self, patterns: I) -> Self
142 where
143 I: IntoIterator<Item = S>,
144 S: AsRef<str>,
145 {
146 self.filter = self.filter.with_regexes(patterns);
147 self
148 }
149
150 pub fn with_full_path<S: Into<String>>(mut self, path: S) -> Self {
160 self.filter = self.filter.with_full_path(path);
161 self
162 }
163
164 pub fn with_full_paths<I, S>(mut self, paths: I) -> Self
166 where
167 I: IntoIterator<Item = S>,
168 S: Into<String>,
169 {
170 self.filter = self.filter.with_full_paths(paths);
171 self
172 }
173
174 pub fn with_predicate(mut self, predicate: fn(&str, &str) -> bool) -> Self {
185 self.filter = self.filter.with_predicate(predicate);
186 self
187 }
188
189 pub fn with_predicates<I>(mut self, predicates: I) -> Self
191 where
192 I: IntoIterator<Item = fn(&str, &str) -> bool>,
193 {
194 self.filter = self.filter.with_predicates(predicates);
195 self
196 }
197
198 pub fn match_all(mut self) -> Self {
200 self.filter = self.filter.match_all();
201 self
202 }
203
204 pub fn remap(mut self, remapper: KeyRemapper) -> Self {
206 self.remapper = remapper;
207 self
208 }
209
210 pub fn with_key_remapping(
220 mut self,
221 from_pattern: impl AsRef<str>,
222 to_pattern: impl Into<String>,
223 ) -> Self {
224 self.remapper = self
225 .remapper
226 .add_pattern(from_pattern, to_pattern)
227 .expect("Invalid regex pattern");
228 self
229 }
230
231 pub fn validate(mut self, validate: bool) -> Self {
233 self.validate = validate;
234 self
235 }
236
237 pub fn allow_partial(mut self, allow: bool) -> Self {
239 self.allow_partial = allow;
240 self
241 }
242
243 fn apply_filter(&self, mut snapshots: Vec<TensorSnapshot>) -> Vec<TensorSnapshot> {
245 if self.filter.is_empty() {
246 return snapshots;
247 }
248
249 snapshots.retain(|snapshot| {
250 let path = snapshot.full_path();
251 self.filter.matches(&path)
252 });
253
254 snapshots
255 }
256
257 fn apply_remapping(&self, snapshots: Vec<TensorSnapshot>) -> Vec<TensorSnapshot> {
259 if self.remapper.is_empty() {
260 return snapshots;
261 }
262
263 let (remapped, _) = self.remapper.remap(snapshots);
264 remapped
265 }
266}
267
268impl ModuleStore for PytorchStore {
269 type Error = PytorchStoreError;
270
271 fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
272 &mut self,
273 _module: &M,
274 ) -> Result<(), Self::Error> {
275 Err(PytorchStoreError::Other(
277 "Saving to PyTorch format is not yet supported. Use other formats for saving."
278 .to_string(),
279 ))
280 }
281
282 fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
283 &mut self,
284 module: &mut M,
285 ) -> Result<ApplyResult, Self::Error> {
286 let reader = if let Some(ref key) = self.top_level_key {
288 PytorchReader::with_top_level_key(&self.path, key)?
289 } else {
290 PytorchReader::new(&self.path)?
291 };
292
293 let mut snapshots: Vec<TensorSnapshot> = reader
295 .into_tensors()
296 .into_iter()
297 .map(|(key, mut snapshot)| {
298 let path_parts: Vec<String> = key.split('.').map(|s| s.to_string()).collect();
300
301 snapshot.path_stack = Some(path_parts);
304 snapshot.container_stack = None;
305 snapshot.tensor_id = None;
306
307 snapshot
308 })
309 .collect();
310
311 snapshots = self.apply_filter(snapshots);
313
314 snapshots = self.apply_remapping(snapshots);
316
317 let result = module.apply(snapshots, None, Some(Box::new(PyTorchToBurnAdapter)));
322
323 if self.validate && !result.errors.is_empty() {
325 return Err(PytorchStoreError::ValidationFailed(format!(
326 "Import errors: {:?}",
327 result.errors
328 )));
329 }
330
331 if !self.allow_partial && !result.missing.is_empty() {
332 return Err(PytorchStoreError::TensorNotFound(format!(
333 "Missing tensors: {:?}",
334 result.missing
335 )));
336 }
337
338 Ok(result)
339 }
340}