1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
//! Core functionalities. use std::{fmt::Debug, path::Path, error::Error}; use crate::core::record::Record; /// Represents an observation of the environment. pub trait Obs: Clone + Debug { /// Returns a dummy observation. /// /// The observation created with this method is ignored. fn dummy(n_procs: usize) -> Self; /// Replace elements of observation where `is_done[i] == 1.0`. /// This method assumes that `is_done.len() == n_procs`. fn merge(self, obs_reset: Self, is_done: &[i8]) -> Self; } /// Represents an action of the environment. pub trait Act: Clone + Debug {} /// Represents additional information to `Obs` and `Act`. pub trait Info {} /// Represents all information given at every step of agent-envieronment interaction. /// `reward` and `is_done` have the same length, the number of processes (environments). pub struct Step<E: Env> { /// Action. pub act: E::Act, /// Observation. pub obs: E::Obs, /// Reward. pub reward: Vec<f32>, /// Flag denoting if episode is done. pub is_done: Vec<i8>, /// Information defined by user. pub info: E::Info, } impl<E: Env> Step<E> { /// Constructs a [Step] object. pub fn new(obs: E::Obs, act: E::Act, reward: Vec<f32>, is_done: Vec<i8>, info: E::Info) -> Self { Step { act, obs, reward, is_done, info, } } } /// Represents an environment, typically an MDP. pub trait Env { /// Observation of the environment. type Obs: Obs; /// Action of the environment. type Act: Act; /// Information in the [self::Step] object. type Info: Info; /// Performes an interaction step. fn step(&mut self, a: &Self::Act) -> (Step<Self>, Record) where Self: Sized; /// Reset the i-th environment if `is_done[i]==1.0`. /// Thei-th return value should be ignored if `is_done[i]==0.0`. fn reset(&mut self, is_done: Option<&Vec<i8>>) -> Result<Self::Obs, Box<dyn Error>>; } /// Represents a policy. on an environment. It is based on a mapping from an observation /// to an action. The mapping can be either of deterministic or stochastic. pub trait Policy<E: Env> { /// Sample an action given an observation. fn sample(&mut self, obs: &E::Obs) -> E::Act; } /// Represents a trainable policy on an environment. pub trait Agent<E: Env>: Policy<E> { /// Set the policy to training mode. fn train(&mut self); /// Set the policy to evaluation mode. fn eval(&mut self); /// Return if it is in training mode. fn is_train(&self) -> bool; /// Observe a [crate::core::base::Step] object. /// The agent is expected to do training its policy based on the observation. /// /// If an optimization step was performed, it returns `Some(crate::core::record::Record)`, /// otherwise `None`. fn observe(&mut self, step: Step<E>) -> Option<Record>; /// Push observation to the agent. /// This method is used when resetting the environment. fn push_obs(&self, obs: &E::Obs); /// Save the agent in the given directory. /// This method commonly creates a number of files consisting the agent /// into the given directory. For example, [`crate::agent::tch::dqn::DQN`] agent saves /// two Q-networks corresponding to the original and target networks. fn save<T: AsRef<Path>>(&self, path: T) -> Result<(), Box<dyn Error>>; /// Load the agent from the given directory. fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<(), Box<dyn Error>>; }