Struct border_tch_agent::iqn::IqnModel
source · pub struct IqnModel<F, M>where
F: SubModel<Output = Tensor>,
M: SubModel<Input = Tensor, Output = Tensor>,
F::Config: DeserializeOwned + Serialize,
M::Config: DeserializeOwned + Serialize,{ /* private fields */ }Expand description
Constructs IQN output layer, which takes input features and percent points. It returns action-value quantiles.
Implementations§
source§impl<F, M> IqnModel<F, M>where
F: SubModel<Output = Tensor>,
M: SubModel<Input = Tensor, Output = Tensor>,
F::Config: DeserializeOwned + Serialize,
M::Config: DeserializeOwned + Serialize + OutDim,
impl<F, M> IqnModel<F, M>where F: SubModel<Output = Tensor>, M: SubModel<Input = Tensor, Output = Tensor>, F::Config: DeserializeOwned + Serialize, M::Config: DeserializeOwned + Serialize + OutDim,
sourcepub fn build(
config: IqnModelConfig<F::Config, M::Config>,
device: Device
) -> Result<IqnModel<F, M>>
pub fn build( config: IqnModelConfig<F::Config, M::Config>, device: Device ) -> Result<IqnModel<F, M>>
Constructs IqnModel.
sourcepub fn build_with_submodel_configs(
config: IqnModelConfig<F::Config, M::Config>,
f_config: F::Config,
m_config: M::Config,
device: Device
) -> IqnModel<F, M>
pub fn build_with_submodel_configs( config: IqnModelConfig<F::Config, M::Config>, f_config: F::Config, m_config: M::Config, device: Device ) -> IqnModel<F, M>
Constructs IqnModel with the given configurations of sub models.
sourcepub fn forward(&self, x: &F::Input, tau: &Tensor) -> Tensor
pub fn forward(&self, x: &F::Input, tau: &Tensor) -> Tensor
Returns the tensor of action-value quantiles.
- The shape of
psi(x)(feature vector) is [batch_size, feature_dim]. - The shape of
tauis [batch_size, n_percent_points]. - The shape of the output is [batch_size, n_percent_points, self.out_dim].
Trait Implementations§
source§impl<F, M> Clone for IqnModel<F, M>where
F: SubModel<Output = Tensor>,
M: SubModel<Input = Tensor, Output = Tensor>,
F::Config: DeserializeOwned + Serialize,
M::Config: DeserializeOwned + Serialize + OutDim,
impl<F, M> Clone for IqnModel<F, M>where F: SubModel<Output = Tensor>, M: SubModel<Input = Tensor, Output = Tensor>, F::Config: DeserializeOwned + Serialize, M::Config: DeserializeOwned + Serialize + OutDim,
source§impl<F, M> ModelBase for IqnModel<F, M>where
F: SubModel<Output = Tensor>,
M: SubModel<Input = Tensor, Output = Tensor>,
F::Config: DeserializeOwned + Serialize,
M::Config: DeserializeOwned + Serialize,
impl<F, M> ModelBase for IqnModel<F, M>where F: SubModel<Output = Tensor>, M: SubModel<Input = Tensor, Output = Tensor>, F::Config: DeserializeOwned + Serialize, M::Config: DeserializeOwned + Serialize,
source§fn backward_step(&mut self, loss: &Tensor)
fn backward_step(&mut self, loss: &Tensor)
Trains the network given a loss.
source§fn get_var_store(&self) -> &VarStore
fn get_var_store(&self) -> &VarStore
Returns
var_store.source§fn get_var_store_mut(&mut self) -> &mut VarStore
fn get_var_store_mut(&mut self) -> &mut VarStore
Returns
var_store as mutable reference.Auto Trait Implementations§
impl<F, M> !RefUnwindSafe for IqnModel<F, M>
impl<F, M> Send for IqnModel<F, M>where F: Send, M: Send,
impl<F, M> !Sync for IqnModel<F, M>
impl<F, M> Unpin for IqnModel<F, M>where F: Unpin, M: Unpin,
impl<F, M> !UnwindSafe for IqnModel<F, M>
Blanket Implementations§
source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere T: ?Sized,
source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more