border_core/base/
agent.rs

1//! A trainable policy that can interact with an environment and learn from experience.
2//!
3//! The [`Agent`] trait extends [`Policy`] with training capabilities, allowing the policy to
4//! learn from interactions with the environment. It provides methods for training, evaluation,
5//! parameter optimization, and model persistence.
6use super::{Env, Policy, ReplayBufferBase};
7use crate::record::Record;
8use anyhow::Result;
9use std::path::{Path, PathBuf};
10
11/// A trainable policy that can learn from environment interactions.
12///
13/// This trait extends [`Policy`] with training capabilities, allowing the policy to:
14/// - Switch between training and evaluation modes
15/// - Perform optimization steps using experience from a replay buffer
16/// - Save and load model parameters
17///
18/// The agent operates in two distinct modes:
19/// - Training mode: The policy may be stochastic to facilitate exploration
20/// - Evaluation mode: The policy is typically deterministic for consistent performance
21///
22/// During training, the agent uses a replay buffer to store and sample experiences,
23/// which are then used to update the policy's parameters through optimization steps.
24pub trait Agent<E: Env, R: ReplayBufferBase>: Policy<E> {
25    /// Switches the agent to training mode.
26    ///
27    /// In training mode, the policy may become stochastic to facilitate exploration.
28    /// This is typically implemented by enabling noise or randomness in the action selection process.
29    fn train(&mut self) {
30        unimplemented!();
31    }
32
33    /// Switches the agent to evaluation mode.
34    ///
35    /// In evaluation mode, the policy typically becomes deterministic to ensure
36    /// consistent performance. This is often implemented by disabling noise or
37    /// using the mean action instead of sampling from a distribution.
38    fn eval(&mut self) {
39        unimplemented!();
40    }
41
42    /// Returns whether the agent is currently in training mode.
43    ///
44    /// This method is used to determine the agent's current mode and can be used
45    /// to conditionally enable or disable certain behaviors.
46    fn is_train(&self) -> bool {
47        unimplemented!();
48    }
49
50    /// Performs a single optimization step using experiences from the replay buffer.
51    ///
52    /// This method updates the agent's parameters using a batch of transitions
53    /// sampled from the provided replay buffer. The specific optimization algorithm
54    /// (e.g., Q-learning, policy gradient) is determined by the agent's implementation.
55    ///
56    /// # Arguments
57    ///
58    /// * `buffer` - The replay buffer containing experiences used for training
59    fn opt(&mut self, buffer: &mut R) {
60        let _ = self.opt_with_record(buffer);
61    }
62
63    /// Performs an optimization step and returns training metrics.
64    ///
65    /// Similar to [`opt`], but also returns a [`Record`] containing training metrics
66    /// such as loss values, gradients, or other relevant statistics.
67    ///
68    /// # Arguments
69    ///
70    /// * `buffer` - The replay buffer containing experiences used for training
71    ///
72    /// # Returns
73    ///
74    /// A [`Record`] containing training metrics and statistics
75    ///
76    /// [`opt`]: Agent::opt
77    /// [`Record`]: crate::record::Record
78    #[allow(unused_variables)]
79    fn opt_with_record(&mut self, buffer: &mut R) -> Record {
80        unimplemented!();
81    }
82
83    /// Saves the agent's parameters to the specified directory.
84    ///
85    /// This method serializes the agent's current state (e.g., neural network weights,
86    /// policy parameters) to files in the given directory. The specific format and
87    /// number of files created depends on the agent's implementation.
88    ///
89    /// # Arguments
90    ///
91    /// * `path` - The directory where parameters will be saved
92    ///
93    /// # Returns
94    ///
95    /// A vector of paths to the saved parameter files
96    ///
97    /// # Examples
98    ///
99    /// For example, a DQN agent might save two Q-networks (original and target)
100    /// in separate files within the specified directory.
101    #[allow(unused_variables)]
102    fn save_params(&self, path: &Path) -> Result<Vec<PathBuf>> {
103        unimplemented!();
104    }
105
106    /// Loads the agent's parameters from the specified directory.
107    ///
108    /// This method deserializes the agent's state from files in the given directory,
109    /// restoring the agent to a previously saved state.
110    ///
111    /// # Arguments
112    ///
113    /// * `path` - The directory containing the saved parameter files
114    #[allow(unused_variables)]
115    fn load_params(&mut self, path: &Path) -> Result<()> {
116        unimplemented!();
117    }
118
119    /// Returns a reference to the agent as a type-erased `Any` value.
120    ///
121    /// This method is required for asynchronous training, allowing the agent to be
122    /// stored in a type-erased container. The returned reference can be downcast
123    /// to the concrete agent type when needed.
124    fn as_any_ref(&self) -> &dyn std::any::Any {
125        unimplemented!("as_any_ref() must be implemented for train_async()");
126    }
127
128    /// Returns a mutable reference to the agent as a type-erased `Any` value.
129    ///
130    /// This method is required for asynchronous training, allowing the agent to be
131    /// stored in a type-erased container. The returned reference can be downcast
132    /// to the concrete agent type when needed.
133    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
134        unimplemented!("as_any_mut() must be implemented for train_async()");
135    }
136}