border_core/base/
agent.rs

1//! Agent.
2use super::{Env, Policy, ReplayBufferBase};
3use crate::record::Record;
4use anyhow::Result;
5use std::path::Path;
6
7/// Represents a trainable policy on an environment.
8pub trait Agent<E: Env, R: ReplayBufferBase>: Policy<E> {
9    /// Set the policy to training mode.
10    fn train(&mut self);
11
12    /// Set the policy to evaluation mode.
13    fn eval(&mut self);
14
15    /// Return if it is in training mode.
16    fn is_train(&self) -> bool;
17
18    /// Performs an optimization step.
19    ///
20    /// `buffer` is a replay buffer from which transitions will be taken
21    /// for updating model parameters.
22    fn opt(&mut self, buffer: &mut R) {
23        let _ = self.opt_with_record(buffer);
24    }
25
26    /// Performs an optimization step and returns some information.
27    fn opt_with_record(&mut self, buffer: &mut R) -> Record;
28
29    /// Save the parameters of the agent in the given directory.
30    /// This method commonly creates a number of files consisting the agent
31    /// in the directory. For example, the DQN agent in `border_tch_agent` crate saves
32    /// two Q-networks corresponding to the original and target networks.
33    fn save_params<T: AsRef<Path>>(&self, path: T) -> Result<()>;
34
35    /// Load the parameters of the agent from the given directory.
36    fn load_params<T: AsRef<Path>>(&mut self, path: T) -> Result<()>;
37}