Trait border_tch_agent::model::ModelBase
source · pub trait ModelBase {
// Required methods
fn backward_step(&mut self, loss: &Tensor);
fn get_var_store_mut(&mut self) -> &mut VarStore;
fn get_var_store(&self) -> &VarStore;
fn save<T: AsRef<Path>>(&self, path: T) -> Result<()>;
fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()>;
}Expand description
Base interface.
Required Methods§
sourcefn backward_step(&mut self, loss: &Tensor)
fn backward_step(&mut self, loss: &Tensor)
Trains the network given a loss.
sourcefn get_var_store_mut(&mut self) -> &mut VarStore
fn get_var_store_mut(&mut self) -> &mut VarStore
Returns var_store as mutable reference.
sourcefn get_var_store(&self) -> &VarStore
fn get_var_store(&self) -> &VarStore
Returns var_store.