1use alloc::boxed::Box;
4use alloc::format;
5use alloc::string::{String, ToString};
6use alloc::vec::Vec;
7
8use hashbrown::{HashMap, HashSet};
9
10use burn_core::module::{ModuleMapper, Param};
11use burn_tensor::{Bool, DType, Int, Shape, Tensor, backend::Backend};
12
13use crate::{ModuleAdapter, PathFilter, TensorSnapshot};
14
15#[derive(Debug, Clone)]
17pub enum ApplyError {
18 ShapeMismatch {
20 path: String,
22 expected: Vec<usize>,
24 found: Vec<usize>,
26 },
27 DTypeMismatch {
29 path: String,
31 expected: DType,
33 found: DType,
35 },
36 AdapterError {
38 path: String,
40 message: String,
42 },
43 LoadError {
45 path: String,
47 message: String,
49 },
50}
51
52impl core::fmt::Display for ApplyError {
53 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
54 match self {
55 Self::ShapeMismatch {
56 path,
57 expected,
58 found,
59 } => {
60 write!(
61 f,
62 "Shape mismatch for '{}': expected {:?}, found {:?}",
63 path, expected, found
64 )
65 }
66 Self::DTypeMismatch {
67 path,
68 expected,
69 found,
70 } => {
71 write!(
72 f,
73 "DType mismatch for '{}': expected {:?}, found {:?}",
74 path, expected, found
75 )
76 }
77 Self::AdapterError { path, message } => {
78 write!(f, "Adapter error for '{}': {}", path, message)
79 }
80 Self::LoadError { path, message } => {
81 write!(f, "Load error for '{}': {}", path, message)
82 }
83 }
84 }
85}
86
87impl core::error::Error for ApplyError {}
88
89#[derive(Debug, Clone)]
91pub struct ApplyResult {
92 pub applied: Vec<String>,
94 pub skipped: Vec<String>,
96 pub missing: Vec<String>,
98 pub unused: Vec<String>,
100 pub errors: Vec<ApplyError>,
102}
103
104impl ApplyResult {
105 pub fn is_success(&self) -> bool {
108 self.errors.is_empty()
109 }
110}
111
112pub struct Applier<B: Backend> {
115 snapshots: HashMap<String, TensorSnapshot>,
117 path_stack: Vec<String>,
119 container_stack: Vec<String>,
121 filter: Option<PathFilter>,
123 adapter: Option<Box<dyn ModuleAdapter>>,
125 applied: Vec<String>,
127 skipped: HashSet<String>,
129 errors: Vec<ApplyError>,
131 visited_paths: HashSet<String>,
133 _backend: core::marker::PhantomData<B>,
135}
136
137impl<B: Backend> Applier<B> {
138 pub fn new(
147 views: Vec<TensorSnapshot>,
148 filter: Option<PathFilter>,
149 adapter: Option<Box<dyn ModuleAdapter>>,
150 ) -> Self {
151 let views_map: HashMap<String, TensorSnapshot> = views
152 .into_iter()
153 .map(|view| (view.full_path(), view))
154 .collect();
155
156 Self {
157 snapshots: views_map,
158 path_stack: Vec::new(),
159 container_stack: Vec::new(),
160 filter,
161 adapter,
162 applied: Vec::new(),
163 skipped: HashSet::new(),
164 errors: Vec::new(),
165 visited_paths: HashSet::new(),
166 _backend: core::marker::PhantomData,
167 }
168 }
169
170 fn current_path(&self) -> String {
172 self.path_stack.join(".")
173 }
174
175 fn should_apply(&self) -> bool {
177 match &self.filter {
178 None => true,
179 Some(f) => f.matches_with_container_path(&self.path_stack, &self.container_stack),
180 }
181 }
182
183 fn adapt_snapshot(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
185 if let Some(ref adapter) = self.adapter {
186 let snapshot_with_context = TensorSnapshot::from_closure(
188 snapshot.clone_data_fn(),
189 snapshot.dtype,
190 snapshot.shape.clone(),
191 self.path_stack.clone(), self.container_stack.clone(), snapshot.tensor_id.unwrap_or_default(),
194 );
195
196 return adapter.adapt(&snapshot_with_context);
198 }
199 snapshot.clone()
200 }
201
202 pub fn into_result(self) -> ApplyResult {
204 let unused: Vec<String> = self
205 .snapshots
206 .keys()
207 .filter(|path| !self.visited_paths.contains(*path) && !self.skipped.contains(*path))
208 .cloned()
209 .collect();
210
211 let missing: Vec<String> = self
212 .visited_paths
213 .into_iter()
214 .filter(|p| !self.snapshots.contains_key(p) && !self.skipped.contains(p))
215 .collect();
216
217 ApplyResult {
218 applied: self.applied,
219 skipped: self.skipped.into_iter().collect(),
220 missing,
221 unused,
222 errors: self.errors,
223 }
224 }
225
226 fn apply_tensor<const D: usize, K>(
229 &mut self,
230 target_device: &B::Device,
231 target_shape: Shape,
232 ) -> Option<Tensor<B, D, K>>
233 where
234 K: burn_tensor::TensorKind<B>,
235 K: burn_tensor::BasicOps<B>,
236 {
237 let path = self.current_path();
238 self.visited_paths.insert(path.clone());
239
240 let snapshot = match self.snapshots.get(&path) {
242 Some(s) => s,
243 None => {
244 return None;
246 }
247 };
248
249 if !self.should_apply() {
251 self.skipped.insert(path.clone());
252 return None;
253 }
254
255 let adapted_snapshot = self.adapt_snapshot(snapshot);
257 let data = match adapted_snapshot.to_data() {
258 Ok(data) => data,
259 Err(e) => {
260 self.errors.push(ApplyError::LoadError {
261 path: path.clone(),
262 message: format!("Failed to load tensor data: {:?}", e),
263 });
264 return None; }
266 };
267
268 if data.shape != target_shape.dims {
270 self.errors.push(ApplyError::ShapeMismatch {
271 path: path.clone(),
272 expected: target_shape.dims,
273 found: data.shape.clone(),
274 });
275 return None; }
277
278 self.applied.push(path);
279 Some(Tensor::from_data(data, target_device))
280 }
281}
282
283impl<B: Backend> ModuleMapper<B> for Applier<B> {
284 fn enter_module(&mut self, name: &str, container_type: &str) {
285 self.path_stack.push(name.to_string());
286 self.container_stack.push(container_type.to_string());
287 }
288
289 fn exit_module(&mut self, _name: &str, _container_type: &str) {
290 self.path_stack.pop();
291 self.container_stack.pop();
292 }
293
294 fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
295 let param_id = param.id;
296 let target_device = param.lazy_device();
297 let target_shape = param.lazy_shape();
298
299 match self.apply_tensor(&target_device, target_shape) {
301 Some(tensor) => {
302 param.transform_for_load(tensor, param_id)
304 }
305 None => {
306 param
308 }
309 }
310 }
311
312 fn map_int<const D: usize>(
313 &mut self,
314 param: Param<Tensor<B, D, Int>>,
315 ) -> Param<Tensor<B, D, Int>> {
316 let param_id = param.id;
317 let target_device = param.lazy_device();
318 let target_shape = param.lazy_shape();
319
320 match self.apply_tensor(&target_device, target_shape) {
322 Some(tensor) => {
323 param.transform_for_load(tensor, param_id)
325 }
326 None => {
327 param
329 }
330 }
331 }
332
333 fn map_bool<const D: usize>(
334 &mut self,
335 param: Param<Tensor<B, D, Bool>>,
336 ) -> Param<Tensor<B, D, Bool>> {
337 let param_id = param.id;
338 let target_device = param.lazy_device();
339 let target_shape = param.lazy_shape();
340
341 match self.apply_tensor(&target_device, target_shape) {
343 Some(tensor) => {
344 param.transform_for_load(tensor, param_id)
346 }
347 None => {
348 param
350 }
351 }
352 }
353}
354
355#[cfg(all(test, feature = "std", target_has_atomic = "ptr"))]
356mod tests {
357 use super::*;
358 use burn_core::module::{ModuleMapper, Param, ParamId};
359 use burn_tensor::Tensor;
360
361 type TestBackend = burn_ndarray::NdArray;
362
363 #[test]
364 fn root_level_parameters() {
365 let device = Default::default();
366
367 let weight = Param::<Tensor<TestBackend, 2>>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
369 let bias = Param::<Tensor<TestBackend, 1>>::from_data([5.0, 6.0], &device);
370
371 let weight_snapshot = crate::TensorSnapshot::from_data(
373 weight.val().to_data(),
374 vec!["weight".to_string()], vec![], ParamId::new(),
377 );
378
379 let bias_snapshot = crate::TensorSnapshot::from_data(
380 bias.val().to_data(),
381 vec!["bias".to_string()], vec![], ParamId::new(),
384 );
385
386 let mut applier =
388 Applier::<TestBackend>::new(vec![weight_snapshot, bias_snapshot], None, None);
389
390 let weight_target = Param::initialized(
392 ParamId::new(),
393 Tensor::<TestBackend, 2>::zeros([2, 2], &device),
394 );
395 let bias_target = Param::initialized(
396 ParamId::new(),
397 Tensor::<TestBackend, 1>::zeros([2], &device),
398 );
399
400 applier.enter_module("weight", "");
403 let weight_loaded = applier.map_float(weight_target);
404 applier.exit_module("weight", "");
405
406 applier.enter_module("bias", "");
408 let bias_loaded = applier.map_float(bias_target);
409 applier.exit_module("bias", "");
410
411 let weight_data = weight_loaded.val().to_data().to_vec::<f32>().unwrap();
413 let bias_data = bias_loaded.val().to_data().to_vec::<f32>().unwrap();
414
415 assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]);
416 assert_eq!(bias_data, vec![5.0, 6.0]);
417
418 let result = applier.into_result();
420 assert_eq!(result.applied.len(), 2);
421 assert_eq!(result.errors.len(), 0);
422 }
423}