Trait border::core::base::Agent [−][src]
Represents a trainable policy on an environment.
Required methods
fn train(&mut self)
[src]
Set the policy to training mode.
fn eval(&mut self)
[src]
Set the policy to evaluation mode.
fn is_train(&self) -> bool
[src]
Return if it is in training mode.
fn observe(&mut self, step: Step<E>) -> Option<Record>
[src]
Observe a crate::core::base::Step object. The agent is expected to do training its policy based on the observation.
If an optimization step was performed, it returns Some(crate::core::record::Record)
,
otherwise None
.
fn push_obs(&self, obs: &E::Obs)
[src]
Push observation to the agent. This method is used when resetting the environment.
fn save<T: AsRef<Path>>(&self, path: T) -> Result<(), Box<dyn Error>>
[src]
Save the agent in the given directory.
This method commonly creates a number of files consisting the agent
into the given directory. For example, crate::agent::tch::dqn::DQN
agent saves
two Q-networks corresponding to the original and target networks.
fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<(), Box<dyn Error>>
[src]
Load the agent from the given directory.
Implementors
impl<E, M, O, A> Agent<E> for DQN<E, M, O, A> where
E: Env,
M: Model1<Input = Tensor, Output = Tensor> + Clone,
E::Obs: Into<M::Input>,
E::Act: From<Tensor>,
O: TchBuffer<Item = E::Obs, SubBatch = M::Input>,
A: TchBuffer<Item = E::Act, SubBatch = Tensor>,
[src]
E: Env,
M: Model1<Input = Tensor, Output = Tensor> + Clone,
E::Obs: Into<M::Input>,
E::Act: From<Tensor>,
O: TchBuffer<Item = E::Obs, SubBatch = M::Input>,
A: TchBuffer<Item = E::Act, SubBatch = Tensor>,
fn train(&mut self)
[src]
fn eval(&mut self)
[src]
fn is_train(&self) -> bool
[src]
fn push_obs(&self, obs: &E::Obs)
[src]
fn observe(&mut self, step: Step<E>) -> Option<Record>
[src]
Update model parameters.
When the return value is Some(Record)
, it includes:
loss_critic
: Loss of critic
fn save<T: AsRef<Path>>(&self, path: T) -> Result<(), Box<dyn Error>>
[src]
fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<(), Box<dyn Error>>
[src]
impl<E, M, O, A> Agent<E> for PGDiscrete<E, M, O, A> where
E: Env,
M: Model1<Input = Tensor, Output = Tensor>,
E::Obs: Into<M::Input> + Clone,
E::Act: From<Tensor>,
O: TchBuffer<Item = E::Obs, SubBatch = M::Input>,
A: TchBuffer<Item = E::Act, SubBatch = Tensor>,
[src]
E: Env,
M: Model1<Input = Tensor, Output = Tensor>,
E::Obs: Into<M::Input> + Clone,
E::Act: From<Tensor>,
O: TchBuffer<Item = E::Obs, SubBatch = M::Input>,
A: TchBuffer<Item = E::Act, SubBatch = Tensor>,
fn train(&mut self)
[src]
fn eval(&mut self)
[src]
fn is_train(&self) -> bool
[src]
fn push_obs(&self, obs: &E::Obs)
[src]
fn observe(&mut self, step: Step<E>) -> Option<Record>
[src]
Update model parameters.
When the return value is Some(Record)
, it includes:
loss
: Loss for poligy gradient
fn save<T: AsRef<Path>>(&self, path: T) -> Result<(), Box<dyn Error>>
[src]
fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<(), Box<dyn Error>>
[src]
impl<E, M, O, A> Agent<E> for PPODiscrete<E, M, O, A> where
E: Env,
M: Model1<Input = Tensor, Output = (Tensor, Tensor)>,
E::Obs: Into<M::Input> + Clone,
E::Act: From<Tensor>,
O: TchBuffer<Item = E::Obs, SubBatch = M::Input>,
A: TchBuffer<Item = E::Act, SubBatch = Tensor>,
[src]
E: Env,
M: Model1<Input = Tensor, Output = (Tensor, Tensor)>,
E::Obs: Into<M::Input> + Clone,
E::Act: From<Tensor>,
O: TchBuffer<Item = E::Obs, SubBatch = M::Input>,
A: TchBuffer<Item = E::Act, SubBatch = Tensor>,
fn train(&mut self)
[src]
fn eval(&mut self)
[src]
fn is_train(&self) -> bool
[src]
fn push_obs(&self, obs: &E::Obs)
[src]
fn observe(&mut self, step: Step<E>) -> Option<Record>
[src]
Update model parameters.
When the return value is Some(Record)
, it includes:
loss_critic
: Loss of criticloss_actor
: Loss of actor
fn save<T: AsRef<Path>>(&self, path: T) -> Result<(), Box<dyn Error>>
[src]
fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<(), Box<dyn Error>>
[src]
impl<E, Q, P, O, A> Agent<E> for DDPG<E, Q, P, O, A> where
E: Env,
Q: Model2<Input1 = O::SubBatch, Input2 = A::SubBatch, Output = Tensor> + Clone,
P: Model1<Output = A::SubBatch> + Clone,
E::Obs: Into<O::SubBatch>,
E::Act: From<Tensor>,
O: TchBuffer<Item = E::Obs, SubBatch = P::Input>,
A: TchBuffer<Item = E::Act, SubBatch = Tensor>,
[src]
E: Env,
Q: Model2<Input1 = O::SubBatch, Input2 = A::SubBatch, Output = Tensor> + Clone,
P: Model1<Output = A::SubBatch> + Clone,
E::Obs: Into<O::SubBatch>,
E::Act: From<Tensor>,
O: TchBuffer<Item = E::Obs, SubBatch = P::Input>,
A: TchBuffer<Item = E::Act, SubBatch = Tensor>,
fn train(&mut self)
[src]
fn eval(&mut self)
[src]
fn is_train(&self) -> bool
[src]
fn push_obs(&self, obs: &E::Obs)
[src]
fn observe(&mut self, step: Step<E>) -> Option<Record>
[src]
Update model parameters.
When the return value is Some(Record)
, it includes:
loss_critic
: Loss of criticloss_actor
: Loss of actor
fn save<T: AsRef<Path>>(&self, path: T) -> Result<(), Box<dyn Error>>
[src]
fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<(), Box<dyn Error>>
[src]
impl<E, Q, P, O, A> Agent<E> for SAC<E, Q, P, O, A> where
E: Env,
Q: Model2<Input1 = O::SubBatch, Input2 = A::SubBatch, Output = Tensor> + Clone,
P: Model1<Input = Tensor, Output = (Tensor, Tensor)> + Clone,
E::Obs: Into<O::SubBatch>,
E::Act: From<Tensor>,
O: TchBuffer<Item = E::Obs, SubBatch = P::Input>,
A: TchBuffer<Item = E::Act, SubBatch = Tensor>,
[src]
E: Env,
Q: Model2<Input1 = O::SubBatch, Input2 = A::SubBatch, Output = Tensor> + Clone,
P: Model1<Input = Tensor, Output = (Tensor, Tensor)> + Clone,
E::Obs: Into<O::SubBatch>,
E::Act: From<Tensor>,
O: TchBuffer<Item = E::Obs, SubBatch = P::Input>,
A: TchBuffer<Item = E::Act, SubBatch = Tensor>,
fn train(&mut self)
[src]
fn eval(&mut self)
[src]
fn is_train(&self) -> bool
[src]
fn push_obs(&self, obs: &E::Obs)
[src]
fn observe(&mut self, step: Step<E>) -> Option<Record>
[src]
Update model parameters.
When the return value is Some(Record)
, it includes:
loss_critic
: Loss of criticloss_actor
: Loss of actor