border_core/base/agent.rs
1//! A trainable policy that can interact with an environment and learn from experience.
2//!
3//! The [`Agent`] trait extends [`Policy`] with training capabilities, allowing the policy to
4//! learn from interactions with the environment. It provides methods for training, evaluation,
5//! parameter optimization, and model persistence.
6use super::{Env, Policy, ReplayBufferBase};
7use crate::record::Record;
8use anyhow::Result;
9use std::path::{Path, PathBuf};
10
11/// A trainable policy that can learn from environment interactions.
12///
13/// This trait extends [`Policy`] with training capabilities, allowing the policy to:
14/// - Switch between training and evaluation modes
15/// - Perform optimization steps using experience from a replay buffer
16/// - Save and load model parameters
17///
18/// The agent operates in two distinct modes:
19/// - Training mode: The policy may be stochastic to facilitate exploration
20/// - Evaluation mode: The policy is typically deterministic for consistent performance
21///
22/// During training, the agent uses a replay buffer to store and sample experiences,
23/// which are then used to update the policy's parameters through optimization steps.
24pub trait Agent<E: Env, R: ReplayBufferBase>: Policy<E> {
25 /// Switches the agent to training mode.
26 ///
27 /// In training mode, the policy may become stochastic to facilitate exploration.
28 /// This is typically implemented by enabling noise or randomness in the action selection process.
29 fn train(&mut self) {
30 unimplemented!();
31 }
32
33 /// Switches the agent to evaluation mode.
34 ///
35 /// In evaluation mode, the policy typically becomes deterministic to ensure
36 /// consistent performance. This is often implemented by disabling noise or
37 /// using the mean action instead of sampling from a distribution.
38 fn eval(&mut self) {
39 unimplemented!();
40 }
41
42 /// Returns whether the agent is currently in training mode.
43 ///
44 /// This method is used to determine the agent's current mode and can be used
45 /// to conditionally enable or disable certain behaviors.
46 fn is_train(&self) -> bool {
47 unimplemented!();
48 }
49
50 /// Performs a single optimization step using experiences from the replay buffer.
51 ///
52 /// This method updates the agent's parameters using a batch of transitions
53 /// sampled from the provided replay buffer. The specific optimization algorithm
54 /// (e.g., Q-learning, policy gradient) is determined by the agent's implementation.
55 ///
56 /// # Arguments
57 ///
58 /// * `buffer` - The replay buffer containing experiences used for training
59 fn opt(&mut self, buffer: &mut R) {
60 let _ = self.opt_with_record(buffer);
61 }
62
63 /// Performs an optimization step and returns training metrics.
64 ///
65 /// Similar to [`opt`], but also returns a [`Record`] containing training metrics
66 /// such as loss values, gradients, or other relevant statistics.
67 ///
68 /// # Arguments
69 ///
70 /// * `buffer` - The replay buffer containing experiences used for training
71 ///
72 /// # Returns
73 ///
74 /// A [`Record`] containing training metrics and statistics
75 ///
76 /// [`opt`]: Agent::opt
77 /// [`Record`]: crate::record::Record
78 #[allow(unused_variables)]
79 fn opt_with_record(&mut self, buffer: &mut R) -> Record {
80 unimplemented!();
81 }
82
83 /// Saves the agent's parameters to the specified directory.
84 ///
85 /// This method serializes the agent's current state (e.g., neural network weights,
86 /// policy parameters) to files in the given directory. The specific format and
87 /// number of files created depends on the agent's implementation.
88 ///
89 /// # Arguments
90 ///
91 /// * `path` - The directory where parameters will be saved
92 ///
93 /// # Returns
94 ///
95 /// A vector of paths to the saved parameter files
96 ///
97 /// # Examples
98 ///
99 /// For example, a DQN agent might save two Q-networks (original and target)
100 /// in separate files within the specified directory.
101 #[allow(unused_variables)]
102 fn save_params(&self, path: &Path) -> Result<Vec<PathBuf>> {
103 unimplemented!();
104 }
105
106 /// Loads the agent's parameters from the specified directory.
107 ///
108 /// This method deserializes the agent's state from files in the given directory,
109 /// restoring the agent to a previously saved state.
110 ///
111 /// # Arguments
112 ///
113 /// * `path` - The directory containing the saved parameter files
114 #[allow(unused_variables)]
115 fn load_params(&mut self, path: &Path) -> Result<()> {
116 unimplemented!();
117 }
118
119 /// Returns a reference to the agent as a type-erased `Any` value.
120 ///
121 /// This method is required for asynchronous training, allowing the agent to be
122 /// stored in a type-erased container. The returned reference can be downcast
123 /// to the concrete agent type when needed.
124 fn as_any_ref(&self) -> &dyn std::any::Any {
125 unimplemented!("as_any_ref() must be implemented for train_async()");
126 }
127
128 /// Returns a mutable reference to the agent as a type-erased `Any` value.
129 ///
130 /// This method is required for asynchronous training, allowing the agent to be
131 /// stored in a type-erased container. The returned reference can be downcast
132 /// to the concrete agent type when needed.
133 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
134 unimplemented!("as_any_mut() must be implemented for train_async()");
135 }
136}