border_core/base/agent.rs
1//! Agent.
2use super::{Env, Policy, ReplayBufferBase};
3use crate::record::Record;
4use anyhow::Result;
5use std::path::Path;
6
7/// Represents a trainable policy on an environment.
8pub trait Agent<E: Env, R: ReplayBufferBase>: Policy<E> {
9 /// Set the policy to training mode.
10 fn train(&mut self);
11
12 /// Set the policy to evaluation mode.
13 fn eval(&mut self);
14
15 /// Return if it is in training mode.
16 fn is_train(&self) -> bool;
17
18 /// Performs an optimization step.
19 ///
20 /// `buffer` is a replay buffer from which transitions will be taken
21 /// for updating model parameters.
22 fn opt(&mut self, buffer: &mut R) {
23 let _ = self.opt_with_record(buffer);
24 }
25
26 /// Performs an optimization step and returns some information.
27 fn opt_with_record(&mut self, buffer: &mut R) -> Record;
28
29 /// Save the parameters of the agent in the given directory.
30 /// This method commonly creates a number of files consisting the agent
31 /// in the directory. For example, the DQN agent in `border_tch_agent` crate saves
32 /// two Q-networks corresponding to the original and target networks.
33 fn save_params<T: AsRef<Path>>(&self, path: T) -> Result<()>;
34
35 /// Load the parameters of the agent from the given directory.
36 fn load_params<T: AsRef<Path>>(&mut self, path: T) -> Result<()>;
37}