border_core/
lib.rs

1#![warn(missing_docs)]
2//! Core components for reinforcement learning.
3//!
4//! # Observation and Action
5//!
6//! The [`Obs`] and [`Act`] traits provide abstractions for observations and actions in environments.
7//!
8//! # Environment
9//!
10//! The [`Env`] trait serves as the fundamental abstraction for environments. It defines four associated types:
11//! `Config`, `Obs`, `Act`, and `Info`. The `Obs` and `Act` types represent concrete implementations of
12//! environment observations and actions, respectively. These types must implement the [`Obs`] and [`Act`] traits.
13//! Environments implementing [`Env`] generate [`Step<E: Env>`] objects at each interaction step through the
14//! [`Env::step()`] method. The [`Info`] type stores additional information from each agent-environment interaction,
15//! which may be empty (implemented as a zero-sized struct). The `Config` type represents environment configurations
16//! and is used during environment construction.
17//!
18//! # Policy
19//!
20//! The [`Policy<E: Env>`] trait represents a decision-making policy. The [`Policy::sample()`] method takes an
21//! `E::Obs` and generates an `E::Act`. Policies can be either probabilistic or deterministic, depending on the
22//! implementation.
23//!
24//! # Agent
25//!
26//! In this crate, an [`Agent<E: Env, R: ReplayBufferBase>`] is defined as a trainable [`Policy<E: Env>`].
27//! Agents operate in either training or evaluation mode. During training, the agent's policy may be probabilistic
28//! to facilitate exploration, while in evaluation mode, it typically becomes deterministic.
29//!
30//! The [`Agent::opt()`] method executes a single optimization step. The specific implementation of an optimization
31//! step varies between agents and may include multiple stochastic gradient descent steps. Training samples are
32//! obtained from the [`ReplayBufferBase`].
33//!
34//! This trait also provides methods for saving and loading trained policy parameters to and from a directory.
35//!
36//! # Batch
37//!
38//! The [`TransitionBatch`] trait represents a batch of transitions in the form `(o_t, r_t, a_t, o_t+1)`.
39//! This trait is used for training [`Agent`]s using reinforcement learning algorithms.
40//!
41//! # Replay Buffer and Experience Buffer
42//!
43//! The [`ReplayBufferBase`] trait provides an abstraction for replay buffers. Its associated type
44//! [`ReplayBufferBase::Batch`] represents samples used for training [`Agent`]s. Agents must implement the
45//! [`Agent::opt()`] method, where [`ReplayBufferBase::Batch`] must have appropriate type or trait bounds
46//! for training the agent.
47//!
48//! While [`ReplayBufferBase`] focuses on generating training batches, the [`ExperienceBufferBase`] trait
49//! handles sample storage. The [`ExperienceBufferBase::push()`] method stores samples of type
50//! [`ExperienceBufferBase::Item`], typically obtained through environment interactions.
51//!
52//! ## Reference Implementation
53//!
54//! [`SimpleReplayBuffer<O, A>`] implements both [`ReplayBufferBase`] and [`ExperienceBufferBase`].
55//! This type takes two parameters, `O` and `A`, representing observation and action types in the replay buffer.
56//! Both `O` and `A` must implement [`BatchBase`], which provides sample storage functionality similar to `Vec<T>`.
57//! The associated types `Item` and `Batch` are both [`GenericTransitionBatch`], representing sets of
58//! `(o_t, r_t, a_t, o_t+1)` transitions.
59//!
60//! # Step Processor
61//!
62//! The [`StepProcessor`] trait plays a crucial role in the training pipeline by transforming environment
63//! interactions into training samples. It processes [`Step<E: Env>`] objects, which contain the current
64//! observation, action, reward, and next observation, into a format suitable for the replay buffer.
65//!
66//! The [`SimpleStepProcessor<E, O, A>`] is a concrete implementation that:
67//! 1. Maintains the previous observation to construct complete transitions
68//! 2. Converts environment-specific observations and actions (`E::Obs` and `E::Act`) into batch-compatible
69//!    types (`O` and `A`) using the `From` trait
70//! 3. Generates [`GenericTransitionBatch`] objects containing the complete transition
71//!    `(o_t, a_t, o_t+1, r_t, is_terminated, is_truncated)`
72//! 4. Handles episode termination by properly resetting the previous observation
73//!
74//! This processor is essential for implementing temporal difference learning algorithms, as it ensures
75//! that transitions are properly formatted and stored in the replay buffer for training.
76//!
77//! [`SimpleStepProcessor<E, O, A>`] can be used with [`SimpleReplayBuffer<O, A>`]. It converts `E::Obs` and
78//! `E::Act` into their respective [`BatchBase`] types and generates [`GenericTransitionBatch`]. This conversion
79//! relies on the trait bounds `O: From<E::Obs>` and `A: From<E::Act>`.
80//!
81//! # Trainer
82//!
83//! The [`Trainer`] manages the training loop and related objects. A [`Trainer`] instance is configured with
84//! training parameters such as the maximum number of optimization steps and the directory for saving agent
85//! parameters during training. The [`Trainer::train`] method executes online training of an agent in an environment.
86//! During the training loop, the agent interacts with the environment to collect samples and perform optimization
87//! steps, while simultaneously recording various metrics.
88//!
89//! # Evaluator
90//!
91//! The [`Evaluator<E, P>`] trait is used to evaluate a policy's (`P`) performance in an environment (`E`).
92//! An instance of this type is provided to the [`Trainer`] for policy evaluation during training.
93//! [`DefaultEvaluator<E, P>`] serves as the default implementation of [`Evaluator<E, P>`]. This evaluator
94//! runs the policy in the environment for a specified number of episodes. At the start of each episode,
95//! the environment is reset using [`Env::reset_with_index()`] to control specific evaluation conditions.
96//!
97//! [`SimpleReplayBuffer`]: generic_replay_buffer::SimpleReplayBuffer
98//! [`SimpleReplayBuffer<O, A>`]: generic_replay_buffer::SimpleReplayBuffer
99//! [`BatchBase`]: generic_replay_buffer::BatchBase
100//! [`GenericTransitionBatch`]: generic_replay_buffer::GenericTransitionBatch
101//! [`SimpleStepProcessor`]: generic_replay_buffer::SimpleStepProcessor
102//! [`SimpleStepProcessor<E, O, A>`]: generic_replay_buffer::SimpleStepProcessor
103pub mod dummy;
104pub mod error;
105mod evaluator;
106pub mod generic_replay_buffer;
107pub mod record;
108
109mod base;
110pub use base::{
111    Act, Agent, Configurable, Env, ExperienceBufferBase, Info, NullReplayBuffer, Obs, Policy,
112    ReplayBufferBase, Step, StepProcessor, TransitionBatch,
113};
114
115mod trainer;
116pub use evaluator::{DefaultEvaluator, Evaluator};
117pub use trainer::{Sampler, Trainer, TrainerConfig};
118
119// TODO: Consider to compile this module only for tests.
120/// Agent and Env for testing.
121pub mod test {
122    use serde::{Deserialize, Serialize};
123
124    /// Obs for testing.
125    #[derive(Clone, Debug)]
126    pub struct TestObs {
127        obs: usize,
128    }
129
130    impl crate::Obs for TestObs {
131        fn len(&self) -> usize {
132            1
133        }
134    }
135
136    /// Batch of obs for testing.
137    pub struct TestObsBatch {
138        obs: Vec<usize>,
139    }
140
141    impl crate::generic_replay_buffer::BatchBase for TestObsBatch {
142        fn new(capacity: usize) -> Self {
143            Self {
144                obs: vec![0; capacity],
145            }
146        }
147
148        fn push(&mut self, i: usize, data: Self) {
149            self.obs[i] = data.obs[0];
150        }
151
152        fn sample(&self, ixs: &Vec<usize>) -> Self {
153            let obs = ixs.iter().map(|ix| self.obs[*ix]).collect();
154            Self { obs }
155        }
156    }
157
158    impl From<TestObs> for TestObsBatch {
159        fn from(obs: TestObs) -> Self {
160            Self { obs: vec![obs.obs] }
161        }
162    }
163
164    /// Act for testing.
165    #[derive(Clone, Debug)]
166    pub struct TestAct {
167        act: usize,
168    }
169
170    impl crate::Act for TestAct {}
171
172    /// Batch of act for testing.
173    pub struct TestActBatch {
174        act: Vec<usize>,
175    }
176
177    impl From<TestAct> for TestActBatch {
178        fn from(act: TestAct) -> Self {
179            Self { act: vec![act.act] }
180        }
181    }
182
183    impl crate::generic_replay_buffer::BatchBase for TestActBatch {
184        fn new(capacity: usize) -> Self {
185            Self {
186                act: vec![0; capacity],
187            }
188        }
189
190        fn push(&mut self, i: usize, data: Self) {
191            self.act[i] = data.act[0];
192        }
193
194        fn sample(&self, ixs: &Vec<usize>) -> Self {
195            let act = ixs.iter().map(|ix| self.act[*ix]).collect();
196            Self { act }
197        }
198    }
199
200    /// Info for testing.
201    pub struct TestInfo {}
202
203    impl crate::Info for TestInfo {}
204
205    /// Environment for testing.
206    pub struct TestEnv {
207        state_init: usize,
208        state: usize,
209    }
210
211    impl crate::Env for TestEnv {
212        type Config = usize;
213        type Obs = TestObs;
214        type Act = TestAct;
215        type Info = TestInfo;
216
217        fn reset(&mut self, _is_done: Option<&Vec<i8>>) -> anyhow::Result<Self::Obs> {
218            self.state = self.state_init;
219            Ok(TestObs { obs: self.state })
220        }
221
222        fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result<Self::Obs> {
223            self.state = self.state_init;
224            Ok(TestObs { obs: self.state })
225        }
226
227        fn step_with_reset(&mut self, a: &Self::Act) -> (crate::Step<Self>, crate::record::Record)
228        where
229            Self: Sized,
230        {
231            self.state = self.state + a.act;
232            let step = crate::Step {
233                obs: TestObs { obs: self.state },
234                act: a.clone(),
235                reward: vec![0.0],
236                is_terminated: vec![0],
237                is_truncated: vec![0],
238                info: TestInfo {},
239                init_obs: Some(TestObs {
240                    obs: self.state_init,
241                }),
242            };
243            return (step, crate::record::Record::empty());
244        }
245
246        fn step(&mut self, a: &Self::Act) -> (crate::Step<Self>, crate::record::Record)
247        where
248            Self: Sized,
249        {
250            self.state = self.state + a.act;
251            let step = crate::Step {
252                obs: TestObs { obs: self.state },
253                act: a.clone(),
254                reward: vec![0.0],
255                is_terminated: vec![0],
256                is_truncated: vec![0],
257                info: TestInfo {},
258                init_obs: Some(TestObs {
259                    obs: self.state_init,
260                }),
261            };
262            return (step, crate::record::Record::empty());
263        }
264
265        fn build(config: &Self::Config, _seed: i64) -> anyhow::Result<Self>
266        where
267            Self: Sized,
268        {
269            Ok(Self {
270                state_init: *config,
271                state: 0,
272            })
273        }
274    }
275
276    type ReplayBuffer =
277        crate::generic_replay_buffer::SimpleReplayBuffer<TestObsBatch, TestActBatch>;
278
279    /// Agent for testing.
280    pub struct TestAgent {}
281
282    #[derive(Clone, Deserialize, Serialize)]
283    /// Config of agent for testing.
284    pub struct TestAgentConfig;
285
286    impl crate::Agent<TestEnv, ReplayBuffer> for TestAgent {
287        fn train(&mut self) {}
288
289        fn is_train(&self) -> bool {
290            false
291        }
292
293        fn eval(&mut self) {}
294
295        fn opt_with_record(&mut self, _buffer: &mut ReplayBuffer) -> crate::record::Record {
296            crate::record::Record::empty()
297        }
298
299        fn save_params(&self, _path: &std::path::Path) -> anyhow::Result<Vec<std::path::PathBuf>> {
300            Ok(vec![])
301        }
302
303        fn load_params(&mut self, _path: &std::path::Path) -> anyhow::Result<()> {
304            Ok(())
305        }
306
307        fn as_any_ref(&self) -> &dyn std::any::Any {
308            self
309        }
310
311        fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
312            self
313        }
314    }
315
316    impl crate::Policy<TestEnv> for TestAgent {
317        fn sample(&mut self, _obs: &TestObs) -> TestAct {
318            TestAct { act: 1 }
319        }
320    }
321
322    impl crate::Configurable for TestAgent {
323        type Config = TestAgentConfig;
324
325        fn build(_config: Self::Config) -> Self {
326            Self {}
327        }
328    }
329}