burn_core/module/param/
visitor.rs

1use super::{Param, ParamId};
2use crate::module::{Module, ModuleVisitor};
3use alloc::vec::Vec;
4use burn_tensor::{Bool, Int, Tensor, backend::Backend};
5use core::marker::PhantomData;
6
7struct ParamIdCollector<'a, M> {
8    param_ids: &'a mut Vec<ParamId>,
9    phantom: PhantomData<M>,
10}
11
12impl<B, M> ModuleVisitor<B> for ParamIdCollector<'_, M>
13where
14    B: Backend,
15    M: Module<B>,
16{
17    fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
18        self.param_ids.push(param.id);
19    }
20    fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {
21        self.param_ids.push(param.id);
22    }
23    fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {
24        self.param_ids.push(param.id);
25    }
26}
27
28/// List all the parameter ids in a module.
29pub fn list_param_ids<M: Module<B>, B: Backend>(module: &M) -> Vec<ParamId> {
30    let mut params_ids = Vec::new();
31    let mut visitor = ParamIdCollector {
32        param_ids: &mut params_ids,
33        phantom: PhantomData::<M>,
34    };
35    module.visit(&mut visitor);
36
37    params_ids
38}