Trait rai_core::nn::Module

source ·
pub trait Module {
    type Input;
    type Output;

    // Required methods
    fn forward(&self, x: &Self::Input) -> Self::Output;
    fn gather_params(&self, params: &mut HashMap<usize, Tensor>);
    fn update_params(&self, params: &mut HashMap<usize, Tensor>);
    fn gather_named_params(
        &self,
        prefix: &str,
        params: &mut HashMap<String, Tensor>
    );
    fn update_named_params(
        &self,
        prefix: &str,
        params: &mut HashMap<String, Tensor>
    );

    // Provided methods
    fn params(&self) -> HashMap<usize, Tensor> { ... }
    fn named_params(&self, prefix: &str) -> HashMap<String, Tensor> { ... }
    fn to_safetensors<P: AsRef<Path>>(&self, filename: P)
       where Self: Sized { ... }
    fn update_by_safetensors<P: AsRef<Path>>(
        &self,
        filenames: &[P],
        device: impl AsDevice
    ) { ... }
}

Required Associated Types§

Required Methods§

source

fn forward(&self, x: &Self::Input) -> Self::Output

source

fn gather_params(&self, params: &mut HashMap<usize, Tensor>)

source

fn update_params(&self, params: &mut HashMap<usize, Tensor>)

source

fn gather_named_params( &self, prefix: &str, params: &mut HashMap<String, Tensor> )

source

fn update_named_params( &self, prefix: &str, params: &mut HashMap<String, Tensor> )

Provided Methods§

source

fn params(&self) -> HashMap<usize, Tensor>

source

fn named_params(&self, prefix: &str) -> HashMap<String, Tensor>

source

fn to_safetensors<P: AsRef<Path>>(&self, filename: P)
where Self: Sized,

source

fn update_by_safetensors<P: AsRef<Path>>( &self, filenames: &[P], device: impl AsDevice )

Object Safety§

This trait is not object safe.

Implementations on Foreign Types§

source§

impl<'a, T> Module for &'a T
where T: Module,

§

type Input = <T as Module>::Input

§

type Output = <T as Module>::Output

source§

fn forward(&self, x: &Self::Input) -> Self::Output

source§

fn gather_params(&self, params: &mut HashMap<usize, Tensor>)

source§

fn update_params(&self, params: &mut HashMap<usize, Tensor>)

source§

fn gather_named_params( &self, prefix: &str, params: &mut HashMap<String, Tensor> )

source§

fn update_named_params( &self, prefix: &str, params: &mut HashMap<String, Tensor> )

Implementors§