ModelBase

Trait ModelBase 

Source
pub trait ModelBase {
    // Required methods
    fn backward_step(&mut self, loss: &Tensor);
    fn get_var_store_mut(&mut self) -> &mut VarStore;
    fn get_var_store(&self) -> &VarStore;
    fn save<T: AsRef<Path>>(&self, path: T) -> Result<()>;
    fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()>;
}
Expand description

Base interface.

Required Methods§

Source

fn backward_step(&mut self, loss: &Tensor)

Trains the network given a loss.

Source

fn get_var_store_mut(&mut self) -> &mut VarStore

Returns var_store as mutable reference.

Source

fn get_var_store(&self) -> &VarStore

Returns var_store.

Source

fn save<T: AsRef<Path>>(&self, path: T) -> Result<()>

Save parameters of the neural network.

Source

fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()>

Load parameters of the neural network.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§

Source§

impl<F, M> ModelBase for IqnModel<F, M>
where F: SubModel<Output = Tensor>, M: SubModel<Input = Tensor, Output = Tensor>, F::Config: DeserializeOwned + Serialize, M::Config: DeserializeOwned + Serialize,

Source§

impl<P> ModelBase for Actor<P>
where P: SubModel<Output = (Tensor, Tensor)>, P::Config: DeserializeOwned + Serialize + OutDim,

Source§

impl<Q> ModelBase for DqnModel<Q>
where Q: SubModel<Output = Tensor>, Q::Config: DeserializeOwned + Serialize + OutDim,

Source§

impl<Q> ModelBase for Critic<Q>
where Q: SubModel2<Output = Tensor>, Q::Config: DeserializeOwned + Serialize,