burn-core 0.18.0

Flexible and Comprehensive Deep Learning Framework in Rust
Documentation
use super::GradientsParams;
use crate::module::{AutodiffModule, ModuleVisitor, ParamId};
use burn_tensor::{Tensor, backend::AutodiffBackend};
use core::marker::PhantomData;

#[cfg(not(feature = "std"))]
use alloc::vec::Vec;

#[derive(new)]
pub struct GradientsParamsConverter<'a, M: AutodiffModule<B>, B: AutodiffBackend> {
    grads: &'a mut B::Gradients,
    grads_params: &'a mut GradientsParams,
    phatom: PhantomData<M>,
    filter: Option<Vec<ParamId>>,
}

#[derive(new)]
pub struct GradientsParamsChangeDevice<'a, M: AutodiffModule<B>, B: AutodiffBackend> {
    device: &'a B::Device,
    grads: &'a mut GradientsParams,
    phatom: PhantomData<M>,
}

impl<B, M> ModuleVisitor<B> for GradientsParamsConverter<'_, M, B>
where
    B: AutodiffBackend,
    M: AutodiffModule<B>,
{
    fn visit_float<const D: usize>(&mut self, id: ParamId, tensor: &Tensor<B, D>) {
        if let Some(filter) = self.filter.as_ref() {
            if !filter.contains(&id) {
                return;
            }
        }
        let Some(grad) = tensor.grad_remove(self.grads) else {
            return;
        };

        self.grads_params.register::<B::InnerBackend, D>(id, grad);
    }
}

impl<B, M> ModuleVisitor<B> for GradientsParamsChangeDevice<'_, M, B>
where
    B: AutodiffBackend,
    M: AutodiffModule<B>,
{
    fn visit_float<const D: usize>(&mut self, id: ParamId, _tensor: &Tensor<B, D>) {
        let Some(grad) = self.grads.remove::<B::InnerBackend, D>(id) else {
            return;
        };

        self.grads
            .register::<B::InnerBackend, D>(id, grad.to_device(self.device));
    }
}