pub trait ModelBase {
    // Required methods
    fn backward_step(&mut self, loss: &Tensor);
    fn get_var_store_mut(&mut self) -> &mut VarStore;
    fn get_var_store(&self) -> &VarStore;
    fn save<T: AsRef<Path>>(&self, path: T) -> Result<()>;
    fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()>;
}
Expand description

Base interface.

Required Methods§

source

fn backward_step(&mut self, loss: &Tensor)

Trains the network given a loss.

source

fn get_var_store_mut(&mut self) -> &mut VarStore

Returns var_store as mutable reference.

source

fn get_var_store(&self) -> &VarStore

Returns var_store.

source

fn save<T: AsRef<Path>>(&self, path: T) -> Result<()>

Save parameters of the neural network.

source

fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()>

Load parameters of the neural network.

Implementors§

source§

impl<F, M> ModelBase for IqnModel<F, M>where F: SubModel<Output = Tensor>, M: SubModel<Input = Tensor, Output = Tensor>, F::Config: DeserializeOwned + Serialize, M::Config: DeserializeOwned + Serialize,

source§

impl<P> ModelBase for Actor<P>where P: SubModel<Output = (Tensor, Tensor)>, P::Config: DeserializeOwned + Serialize + OutDim,

source§

impl<Q> ModelBase for DqnModel<Q>where Q: SubModel<Output = Tensor>, Q::Config: DeserializeOwned + Serialize + OutDim,

source§

impl<Q> ModelBase for Critic<Q>where Q: SubModel2<Output = Tensor>, Q::Config: DeserializeOwned + Serialize,