1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
//! Core functionalities.
use crate::core::record::Record;
use anyhow::Result;
use std::{fmt::Debug, path::Path};

/// Represents an observation of the environment.
pub trait Obs: Clone + Debug {
    /// Returns a dummy observation.
    ///
    /// The observation created with this method is ignored.
    fn dummy(n_procs: usize) -> Self;

    /// Replace elements of observation where `is_done[i] == 1.0`.
    /// This method assumes that `is_done.len() == n_procs`.
    fn merge(self, obs_reset: Self, is_done: &[i8]) -> Self;

    /// Returns the number of processes that created this observation;
    /// it assumes a synchronous vectorized environment.
    ///
    /// TODO: consider to remove this, replace with `batch_size()`.
    fn n_procs(&self) -> usize;

    /// Returns the batch size.
    fn batch_size(&self) -> usize;
}

/// Represents an action of the environment.
pub trait Act: Clone + Debug {}

/// Represents additional information to `Obs` and `Act`.
pub trait Info {}

/// Represents all information given at every step of agent-envieronment interaction.
/// `reward` and `is_done` have the same length, the number of processes (environments).
pub struct Step<E: Env> {
    /// Action.
    pub act: E::Act,
    /// Observation.
    pub obs: E::Obs,
    /// Reward.
    pub reward: Vec<f32>,
    /// Flag denoting if episode is done.
    pub is_done: Vec<i8>,
    /// Information defined by user.
    pub info: E::Info,
}

impl<E: Env> Step<E> {
    /// Constructs a [Step] object.
    pub fn new(
        obs: E::Obs,
        act: E::Act,
        reward: Vec<f32>,
        is_done: Vec<i8>,
        info: E::Info,
    ) -> Self {
        Step {
            act,
            obs,
            reward,
            is_done,
            info,
        }
    }
}

/// Represents an environment, typically an MDP.
pub trait Env {
    /// Observation of the environment.
    type Obs: Obs;
    /// Action of the environment.
    type Act: Act;
    /// Information in the [self::Step] object.
    type Info: Info;

    /// Performes an interaction step.
    fn step(&mut self, a: &Self::Act) -> (Step<Self>, Record)
    where
        Self: Sized;

    /// Reset the i-th environment if `is_done[i]==1.0`.
    /// Thei-th return value should be ignored if `is_done[i]==0.0`.
    fn reset(&mut self, is_done: Option<&Vec<i8>>) -> Result<Self::Obs>;
}

/// Represents a policy. on an environment. It is based on a mapping from an observation
/// to an action. The mapping can be either of deterministic or stochastic.
pub trait Policy<E: Env> {
    /// Sample an action given an observation.
    fn sample(&mut self, obs: &E::Obs) -> E::Act;
}

/// Represents a trainable policy on an environment.
pub trait Agent<E: Env>: Policy<E> {
    /// Set the policy to training mode.
    fn train(&mut self);

    /// Set the policy to evaluation mode.
    fn eval(&mut self);

    /// Return if it is in training mode.
    fn is_train(&self) -> bool;

    /// Observe a [crate::core::base::Step] object.
    /// The agent is expected to do training its policy based on the observation.
    ///
    /// If an optimization step was performed, it returns `Some(crate::core::record::Record)`,
    /// otherwise `None`.
    fn observe(&mut self, step: Step<E>) -> Option<Record>;

    /// Push observation to the agent.
    /// This method is used when resetting the environment.
    fn push_obs(&self, obs: &E::Obs);

    /// Save the agent in the given directory.
    /// This method commonly creates a number of files consisting the agent
    /// into the given directory. For example, DQN agent in `border_tch_agent` crate saves
    /// two Q-networks corresponding to the original and target networks.
    fn save<T: AsRef<Path>>(&self, path: T) -> Result<()>;

    /// Load the agent from the given directory.
    fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()>;
}