border_core/lib.rs
1#![warn(missing_docs)]
2//! Core components for reinforcement learning.
3//!
4//! # Observation and Action
5//!
6//! The [`Obs`] and [`Act`] traits provide abstractions for observations and actions in environments.
7//!
8//! # Environment
9//!
10//! The [`Env`] trait serves as the fundamental abstraction for environments. It defines four associated types:
11//! `Config`, `Obs`, `Act`, and `Info`. The `Obs` and `Act` types represent concrete implementations of
12//! environment observations and actions, respectively. These types must implement the [`Obs`] and [`Act`] traits.
13//! Environments implementing [`Env`] generate [`Step<E: Env>`] objects at each interaction step through the
14//! [`Env::step()`] method. The [`Info`] type stores additional information from each agent-environment interaction,
15//! which may be empty (implemented as a zero-sized struct). The `Config` type represents environment configurations
16//! and is used during environment construction.
17//!
18//! # Policy
19//!
20//! The [`Policy<E: Env>`] trait represents a decision-making policy. The [`Policy::sample()`] method takes an
21//! `E::Obs` and generates an `E::Act`. Policies can be either probabilistic or deterministic, depending on the
22//! implementation.
23//!
24//! # Agent
25//!
26//! In this crate, an [`Agent<E: Env, R: ReplayBufferBase>`] is defined as a trainable [`Policy<E: Env>`].
27//! Agents operate in either training or evaluation mode. During training, the agent's policy may be probabilistic
28//! to facilitate exploration, while in evaluation mode, it typically becomes deterministic.
29//!
30//! The [`Agent::opt()`] method executes a single optimization step. The specific implementation of an optimization
31//! step varies between agents and may include multiple stochastic gradient descent steps. Training samples are
32//! obtained from the [`ReplayBufferBase`].
33//!
34//! This trait also provides methods for saving and loading trained policy parameters to and from a directory.
35//!
36//! # Batch
37//!
38//! The [`TransitionBatch`] trait represents a batch of transitions in the form `(o_t, r_t, a_t, o_t+1)`.
39//! This trait is used for training [`Agent`]s using reinforcement learning algorithms.
40//!
41//! # Replay Buffer and Experience Buffer
42//!
43//! The [`ReplayBufferBase`] trait provides an abstraction for replay buffers. Its associated type
44//! [`ReplayBufferBase::Batch`] represents samples used for training [`Agent`]s. Agents must implement the
45//! [`Agent::opt()`] method, where [`ReplayBufferBase::Batch`] must have appropriate type or trait bounds
46//! for training the agent.
47//!
48//! While [`ReplayBufferBase`] focuses on generating training batches, the [`ExperienceBufferBase`] trait
49//! handles sample storage. The [`ExperienceBufferBase::push()`] method stores samples of type
50//! [`ExperienceBufferBase::Item`], typically obtained through environment interactions.
51//!
52//! ## Reference Implementation
53//!
54//! [`SimpleReplayBuffer<O, A>`] implements both [`ReplayBufferBase`] and [`ExperienceBufferBase`].
55//! This type takes two parameters, `O` and `A`, representing observation and action types in the replay buffer.
56//! Both `O` and `A` must implement [`BatchBase`], which provides sample storage functionality similar to `Vec<T>`.
57//! The associated types `Item` and `Batch` are both [`GenericTransitionBatch`], representing sets of
58//! `(o_t, r_t, a_t, o_t+1)` transitions.
59//!
60//! # Step Processor
61//!
62//! The [`StepProcessor`] trait plays a crucial role in the training pipeline by transforming environment
63//! interactions into training samples. It processes [`Step<E: Env>`] objects, which contain the current
64//! observation, action, reward, and next observation, into a format suitable for the replay buffer.
65//!
66//! The [`SimpleStepProcessor<E, O, A>`] is a concrete implementation that:
67//! 1. Maintains the previous observation to construct complete transitions
68//! 2. Converts environment-specific observations and actions (`E::Obs` and `E::Act`) into batch-compatible
69//! types (`O` and `A`) using the `From` trait
70//! 3. Generates [`GenericTransitionBatch`] objects containing the complete transition
71//! `(o_t, a_t, o_t+1, r_t, is_terminated, is_truncated)`
72//! 4. Handles episode termination by properly resetting the previous observation
73//!
74//! This processor is essential for implementing temporal difference learning algorithms, as it ensures
75//! that transitions are properly formatted and stored in the replay buffer for training.
76//!
77//! [`SimpleStepProcessor<E, O, A>`] can be used with [`SimpleReplayBuffer<O, A>`]. It converts `E::Obs` and
78//! `E::Act` into their respective [`BatchBase`] types and generates [`GenericTransitionBatch`]. This conversion
79//! relies on the trait bounds `O: From<E::Obs>` and `A: From<E::Act>`.
80//!
81//! # Trainer
82//!
83//! The [`Trainer`] manages the training loop and related objects. A [`Trainer`] instance is configured with
84//! training parameters such as the maximum number of optimization steps and the directory for saving agent
85//! parameters during training. The [`Trainer::train`] method executes online training of an agent in an environment.
86//! During the training loop, the agent interacts with the environment to collect samples and perform optimization
87//! steps, while simultaneously recording various metrics.
88//!
89//! # Evaluator
90//!
91//! The [`Evaluator<E, P>`] trait is used to evaluate a policy's (`P`) performance in an environment (`E`).
92//! An instance of this type is provided to the [`Trainer`] for policy evaluation during training.
93//! [`DefaultEvaluator<E, P>`] serves as the default implementation of [`Evaluator<E, P>`]. This evaluator
94//! runs the policy in the environment for a specified number of episodes. At the start of each episode,
95//! the environment is reset using [`Env::reset_with_index()`] to control specific evaluation conditions.
96//!
97//! [`SimpleReplayBuffer`]: generic_replay_buffer::SimpleReplayBuffer
98//! [`SimpleReplayBuffer<O, A>`]: generic_replay_buffer::SimpleReplayBuffer
99//! [`BatchBase`]: generic_replay_buffer::BatchBase
100//! [`GenericTransitionBatch`]: generic_replay_buffer::GenericTransitionBatch
101//! [`SimpleStepProcessor`]: generic_replay_buffer::SimpleStepProcessor
102//! [`SimpleStepProcessor<E, O, A>`]: generic_replay_buffer::SimpleStepProcessor
103pub mod dummy;
104pub mod error;
105mod evaluator;
106pub mod generic_replay_buffer;
107pub mod record;
108
109mod base;
110pub use base::{
111 Act, Agent, Configurable, Env, ExperienceBufferBase, Info, NullReplayBuffer, Obs, Policy,
112 ReplayBufferBase, Step, StepProcessor, TransitionBatch,
113};
114
115mod trainer;
116pub use evaluator::{DefaultEvaluator, Evaluator};
117pub use trainer::{Sampler, Trainer, TrainerConfig};
118
119// TODO: Consider to compile this module only for tests.
120/// Agent and Env for testing.
121pub mod test {
122 use serde::{Deserialize, Serialize};
123
124 /// Obs for testing.
125 #[derive(Clone, Debug)]
126 pub struct TestObs {
127 obs: usize,
128 }
129
130 impl crate::Obs for TestObs {
131 fn len(&self) -> usize {
132 1
133 }
134 }
135
136 /// Batch of obs for testing.
137 pub struct TestObsBatch {
138 obs: Vec<usize>,
139 }
140
141 impl crate::generic_replay_buffer::BatchBase for TestObsBatch {
142 fn new(capacity: usize) -> Self {
143 Self {
144 obs: vec![0; capacity],
145 }
146 }
147
148 fn push(&mut self, i: usize, data: Self) {
149 self.obs[i] = data.obs[0];
150 }
151
152 fn sample(&self, ixs: &Vec<usize>) -> Self {
153 let obs = ixs.iter().map(|ix| self.obs[*ix]).collect();
154 Self { obs }
155 }
156 }
157
158 impl From<TestObs> for TestObsBatch {
159 fn from(obs: TestObs) -> Self {
160 Self { obs: vec![obs.obs] }
161 }
162 }
163
164 /// Act for testing.
165 #[derive(Clone, Debug)]
166 pub struct TestAct {
167 act: usize,
168 }
169
170 impl crate::Act for TestAct {}
171
172 /// Batch of act for testing.
173 pub struct TestActBatch {
174 act: Vec<usize>,
175 }
176
177 impl From<TestAct> for TestActBatch {
178 fn from(act: TestAct) -> Self {
179 Self { act: vec![act.act] }
180 }
181 }
182
183 impl crate::generic_replay_buffer::BatchBase for TestActBatch {
184 fn new(capacity: usize) -> Self {
185 Self {
186 act: vec![0; capacity],
187 }
188 }
189
190 fn push(&mut self, i: usize, data: Self) {
191 self.act[i] = data.act[0];
192 }
193
194 fn sample(&self, ixs: &Vec<usize>) -> Self {
195 let act = ixs.iter().map(|ix| self.act[*ix]).collect();
196 Self { act }
197 }
198 }
199
200 /// Info for testing.
201 pub struct TestInfo {}
202
203 impl crate::Info for TestInfo {}
204
205 /// Environment for testing.
206 pub struct TestEnv {
207 state_init: usize,
208 state: usize,
209 }
210
211 impl crate::Env for TestEnv {
212 type Config = usize;
213 type Obs = TestObs;
214 type Act = TestAct;
215 type Info = TestInfo;
216
217 fn reset(&mut self, _is_done: Option<&Vec<i8>>) -> anyhow::Result<Self::Obs> {
218 self.state = self.state_init;
219 Ok(TestObs { obs: self.state })
220 }
221
222 fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result<Self::Obs> {
223 self.state = self.state_init;
224 Ok(TestObs { obs: self.state })
225 }
226
227 fn step_with_reset(&mut self, a: &Self::Act) -> (crate::Step<Self>, crate::record::Record)
228 where
229 Self: Sized,
230 {
231 self.state = self.state + a.act;
232 let step = crate::Step {
233 obs: TestObs { obs: self.state },
234 act: a.clone(),
235 reward: vec![0.0],
236 is_terminated: vec![0],
237 is_truncated: vec![0],
238 info: TestInfo {},
239 init_obs: Some(TestObs {
240 obs: self.state_init,
241 }),
242 };
243 return (step, crate::record::Record::empty());
244 }
245
246 fn step(&mut self, a: &Self::Act) -> (crate::Step<Self>, crate::record::Record)
247 where
248 Self: Sized,
249 {
250 self.state = self.state + a.act;
251 let step = crate::Step {
252 obs: TestObs { obs: self.state },
253 act: a.clone(),
254 reward: vec![0.0],
255 is_terminated: vec![0],
256 is_truncated: vec![0],
257 info: TestInfo {},
258 init_obs: Some(TestObs {
259 obs: self.state_init,
260 }),
261 };
262 return (step, crate::record::Record::empty());
263 }
264
265 fn build(config: &Self::Config, _seed: i64) -> anyhow::Result<Self>
266 where
267 Self: Sized,
268 {
269 Ok(Self {
270 state_init: *config,
271 state: 0,
272 })
273 }
274 }
275
276 type ReplayBuffer =
277 crate::generic_replay_buffer::SimpleReplayBuffer<TestObsBatch, TestActBatch>;
278
279 /// Agent for testing.
280 pub struct TestAgent {}
281
282 #[derive(Clone, Deserialize, Serialize)]
283 /// Config of agent for testing.
284 pub struct TestAgentConfig;
285
286 impl crate::Agent<TestEnv, ReplayBuffer> for TestAgent {
287 fn train(&mut self) {}
288
289 fn is_train(&self) -> bool {
290 false
291 }
292
293 fn eval(&mut self) {}
294
295 fn opt_with_record(&mut self, _buffer: &mut ReplayBuffer) -> crate::record::Record {
296 crate::record::Record::empty()
297 }
298
299 fn save_params(&self, _path: &std::path::Path) -> anyhow::Result<Vec<std::path::PathBuf>> {
300 Ok(vec![])
301 }
302
303 fn load_params(&mut self, _path: &std::path::Path) -> anyhow::Result<()> {
304 Ok(())
305 }
306
307 fn as_any_ref(&self) -> &dyn std::any::Any {
308 self
309 }
310
311 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
312 self
313 }
314 }
315
316 impl crate::Policy<TestEnv> for TestAgent {
317 fn sample(&mut self, _obs: &TestObs) -> TestAct {
318 TestAct { act: 1 }
319 }
320 }
321
322 impl crate::Configurable for TestAgent {
323 type Config = TestAgentConfig;
324
325 fn build(_config: Self::Config) -> Self {
326 Self {}
327 }
328 }
329}