1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
//! Core functionalities. use crate::core::record::Record; use anyhow::Result; use std::{fmt::Debug, path::Path}; /// Represents an observation of the environment. pub trait Obs: Clone + Debug { /// Returns a dummy observation. /// /// The observation created with this method is ignored. fn dummy(n_procs: usize) -> Self; /// Replace elements of observation where `is_done[i] == 1.0`. /// This method assumes that `is_done.len() == n_procs`. fn merge(self, obs_reset: Self, is_done: &[i8]) -> Self; /// Returns the number of processes that created this observation; /// it assumes a synchronous vectorized environment. /// /// TODO: consider to remove this, replace with `batch_size()`. fn n_procs(&self) -> usize; /// Returns the batch size. fn batch_size(&self) -> usize; } /// Represents an action of the environment. pub trait Act: Clone + Debug {} /// Represents additional information to `Obs` and `Act`. pub trait Info {} /// Represents all information given at every step of agent-envieronment interaction. /// `reward` and `is_done` have the same length, the number of processes (environments). pub struct Step<E: Env> { /// Action. pub act: E::Act, /// Observation. pub obs: E::Obs, /// Reward. pub reward: Vec<f32>, /// Flag denoting if episode is done. pub is_done: Vec<i8>, /// Information defined by user. pub info: E::Info, } impl<E: Env> Step<E> { /// Constructs a [Step] object. pub fn new( obs: E::Obs, act: E::Act, reward: Vec<f32>, is_done: Vec<i8>, info: E::Info, ) -> Self { Step { act, obs, reward, is_done, info, } } } /// Represents an environment, typically an MDP. pub trait Env { /// Observation of the environment. type Obs: Obs; /// Action of the environment. type Act: Act; /// Information in the [self::Step] object. type Info: Info; /// Performes an interaction step. fn step(&mut self, a: &Self::Act) -> (Step<Self>, Record) where Self: Sized; /// Reset the i-th environment if `is_done[i]==1.0`. /// Thei-th return value should be ignored if `is_done[i]==0.0`. fn reset(&mut self, is_done: Option<&Vec<i8>>) -> Result<Self::Obs>; } /// Represents a policy. on an environment. It is based on a mapping from an observation /// to an action. The mapping can be either of deterministic or stochastic. pub trait Policy<E: Env> { /// Sample an action given an observation. fn sample(&mut self, obs: &E::Obs) -> E::Act; } /// Represents a trainable policy on an environment. pub trait Agent<E: Env>: Policy<E> { /// Set the policy to training mode. fn train(&mut self); /// Set the policy to evaluation mode. fn eval(&mut self); /// Return if it is in training mode. fn is_train(&self) -> bool; /// Observe a [crate::core::base::Step] object. /// The agent is expected to do training its policy based on the observation. /// /// If an optimization step was performed, it returns `Some(crate::core::record::Record)`, /// otherwise `None`. fn observe(&mut self, step: Step<E>) -> Option<Record>; /// Push observation to the agent. /// This method is used when resetting the environment. fn push_obs(&self, obs: &E::Obs); /// Save the agent in the given directory. /// This method commonly creates a number of files consisting the agent /// into the given directory. For example, DQN agent in `border_tch_agent` crate saves /// two Q-networks corresponding to the original and target networks. fn save<T: AsRef<Path>>(&self, path: T) -> Result<()>; /// Load the agent from the given directory. fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()>; }