border_core/
lib.rs

1#![warn(missing_docs)]
2//! Core components for reinforcement learning.
3//!
4//! # Observation and action
5//!
6//! [`Obs`] and [`Act`] traits are abstractions of observation and action in environments.
7//! These traits can handle two or more samples for implementing vectorized environments,
8//! although there is currently no implementation of vectorized environment.
9//!
10//! # Environment
11//!
12//! [`Env`] trait is an abstraction of environments. It has four associated types:
13//! `Config`, `Obs`, `Act` and `Info`. `Obs` and `Act` are concrete types of
14//! observation and action of the environment.
15//! These types must implement [`Obs`] and [`Act`] traits, respectively.
16//! The environment that implements [`Env`] generates [`Step<E: Env>`] object
17//! at every environment interaction step with [`Env::step()`] method.
18//! [`Info`] stores some information at every step of interactions of an agent and
19//! the environment. It could be empty (zero-sized struct). `Config` represents
20//! configurations of the environment and is used to build.
21//!
22//! # Policy
23//!
24//! [`Policy<E: Env>`] represents a policy. [`Policy::sample()`] takes `E::Obs` and
25//! generates `E::Act`. It could be probabilistic or deterministic.
26//!
27//! # Agent
28//!
29//! In this crate, [`Agent<E: Env, R: ReplayBufferBase>`] is defined as trainable
30//! [`Policy<E: Env>`]. It is in either training or evaluation mode. In training mode,
31//! the agent's policy might be probabilistic for exploration, while in evaluation mode,
32//! the policy might be deterministic.
33//!
34//! The [`Agent::opt()`] method performs a single optimization step. The definition of an
35//! optimization step varies for each agent. It might be multiple stochastic gradient
36//! steps in an optimization step. Samples for training are taken from
37//! [`R: ReplayBufferBase`][`ReplayBufferBase`].
38//!
39//! This trait also has methods for saving/loading parameters of the trained policy
40//! in a directory.
41//!
42//! # Batch
43//!
44//! [`TransitionBatch`] is a trait of a batch of transitions `(o_t, r_t, a_t, o_t+1)`.
45//! This trait is used to train [`Agent`]s using an RL algorithm.
46//!
47//! # Replay buffer and experience buffer
48//!
49//! [`ReplayBufferBase`] trait is an abstraction of replay buffers.
50//! One of the associated type [`ReplayBufferBase::Batch`] represents samples taken from
51//! the buffer for training [`Agent`]s. Agents must implements [`Agent::opt()`] method,
52//! where [`ReplayBufferBase::Batch`] has an appropriate type or trait bound(s) to train
53//! the agent.
54//!
55//! As explained above, [`ReplayBufferBase`] trait has an ability to generates batches
56//! of samples with which agents are trained. On the other hand, [`ExperienceBufferBase`]
57//! trait has an ability to store samples. [`ExperienceBufferBase::push()`] is used to push
58//! samples of type [`ExperienceBufferBase::Item`], which might be obtained via interaction
59//! steps with an environment.
60//!
61//! ## A reference implementation
62//!
63//! [`SimpleReplayBuffer<O, A>`] implementats both [`ReplayBufferBase`] and [`ExperienceBufferBase`].
64//! This type has two parameters `O` and `A`, which are representation of
65//! observation and action in the replay buffer. `O` and `A` must implement
66//! [`BatchBase`], which has the functionality of storing samples, like `Vec<T>`,
67//! for observation and action. The associated types `Item` and `Batch`
68//! are the same type, [`GenericTransitionBatch`], representing sets of `(o_t, r_t, a_t, o_t+1)`.
69//!
70//! [`SimpleStepProcessor<E, O, A>`] might be used with [`SimpleReplayBuffer<O, A>`].
71//! It converts `E::Obs` and `E::Act` into [`BatchBase`]s of respective types
72//! and generates [`GenericTransitionBatch`]. The conversion process relies on trait bounds,
73//! `O: From<E::Obs>` and `A: From<E::Act>`.
74//!
75//! # Trainer
76//!
77//! [`Trainer`] manages training loop and related objects. The [`Trainer`] object is
78//! built with configurations of training parameters such as the maximum number of 
79//! optimization steps, model directory to save parameters of the agent during training, etc.
80//! [`Trainer::train`] method executes online training of an agent on an environment.
81//! In the training loop of this method, the agent interacts with the environment to
82//! take samples and perform optimization steps. Some metrices are recorded at the same time.
83//! 
84//! # Evaluator
85//! 
86//! [`Evaluator<E, P>`] is used to evaluate the policy's (`P`) performance in the environment (`E`).
87//! The object of this type is given to the [`Trainer`] object to evaluate the policy during training.
88//! [`DefaultEvaluator<E, P>`] is a default implementation of [`Evaluator<E, P>`].
89//! This evaluator runs the policy in the environment for a certain number of episodes.
90//! At the start of each episode, the environment is reset using [`Env::reset_with_index()`]
91//! to control specific conditions for evaluation.
92//! 
93//! [`SimpleReplayBuffer`]: replay_buffer::SimpleReplayBuffer
94//! [`SimpleReplayBuffer<O, A>`]: generic_replay_buffer::SimpleReplayBuffer
95//! [`BatchBase`]: generic_replay_buffer::BatchBase
96//! [`GenericTransitionBatch`]: generic_replay_buffer::GenericTransitionBatch
97//! [`SimpleStepProcessor`]: replay_buffer::SimpleStepProcessor
98//! [`SimpleStepProcessor<E, O, A>`]: generic_replay_buffer::SimpleStepProcessor
99pub mod error;
100mod evaluator;
101pub mod generic_replay_buffer;
102pub mod record;
103
104mod base;
105pub use base::{
106    Act, Agent, Configurable, Env, ExperienceBufferBase, Info, Obs, Policy, ReplayBufferBase, Step,
107    StepProcessor, TransitionBatch,
108};
109
110mod trainer;
111pub use evaluator::{DefaultEvaluator, Evaluator};
112pub use trainer::{Sampler, Trainer, TrainerConfig};
113
114// TODO: Consider to compile this module only for tests.
115/// Agent and Env for testing.
116pub mod test {
117    use serde::{Deserialize, Serialize};
118
119    /// Obs for testing.
120    #[derive(Clone, Debug)]
121    pub struct TestObs {
122        obs: usize,
123    }
124
125    impl crate::Obs for TestObs {
126        fn dummy(_n: usize) -> Self {
127            Self { obs: 0 }
128        }
129
130        fn len(&self) -> usize {
131            1
132        }
133    }
134
135    /// Batch of obs for testing.
136    pub struct TestObsBatch {
137        obs: Vec<usize>,
138    }
139
140    impl crate::generic_replay_buffer::BatchBase for TestObsBatch {
141        fn new(capacity: usize) -> Self {
142            Self {
143                obs: vec![0; capacity],
144            }
145        }
146
147        fn push(&mut self, i: usize, data: Self) {
148            self.obs[i] = data.obs[0];
149        }
150
151        fn sample(&self, ixs: &Vec<usize>) -> Self {
152            let obs = ixs.iter().map(|ix| self.obs[*ix]).collect();
153            Self { obs }
154        }
155    }
156
157    impl From<TestObs> for TestObsBatch {
158        fn from(obs: TestObs) -> Self {
159            Self { obs: vec![obs.obs] }
160        }
161    }
162
163    /// Act for testing.
164    #[derive(Clone, Debug)]
165    pub struct TestAct {
166        act: usize,
167    }
168
169    impl crate::Act for TestAct {}
170
171    /// Batch of act for testing.
172    pub struct TestActBatch {
173        act: Vec<usize>,
174    }
175
176    impl From<TestAct> for TestActBatch {
177        fn from(act: TestAct) -> Self {
178            Self { act: vec![act.act] }
179        }
180    }
181
182    impl crate::generic_replay_buffer::BatchBase for TestActBatch {
183        fn new(capacity: usize) -> Self {
184            Self {
185                act: vec![0; capacity],
186            }
187        }
188
189        fn push(&mut self, i: usize, data: Self) {
190            self.act[i] = data.act[0];
191        }
192
193        fn sample(&self, ixs: &Vec<usize>) -> Self {
194            let act = ixs.iter().map(|ix| self.act[*ix]).collect();
195            Self { act }
196        }
197    }
198
199    /// Info for testing.
200    pub struct TestInfo {}
201
202    impl crate::Info for TestInfo {}
203
204    /// Environment for testing.
205    pub struct TestEnv {
206        state_init: usize,
207        state: usize,
208    }
209
210    impl crate::Env for TestEnv {
211        type Config = usize;
212        type Obs = TestObs;
213        type Act = TestAct;
214        type Info = TestInfo;
215
216        fn reset(&mut self, _is_done: Option<&Vec<i8>>) -> anyhow::Result<Self::Obs> {
217            self.state = self.state_init;
218            Ok(TestObs { obs: self.state })
219        }
220
221        fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result<Self::Obs> {
222            self.state = self.state_init;
223            Ok(TestObs { obs: self.state })
224        }
225
226        fn step_with_reset(
227            &mut self,
228            a: &Self::Act,
229        ) -> (crate::Step<Self>, crate::record::Record)
230        where
231            Self: Sized,
232        {
233            self.state = self.state + a.act;
234            let step = crate::Step {
235                obs: TestObs { obs: self.state },
236                act: a.clone(),
237                reward: vec![0.0],
238                is_terminated: vec![0],
239                is_truncated: vec![0],
240                info: TestInfo {},
241                init_obs: TestObs {
242                    obs: self.state_init,
243                },
244            };
245            return (step, crate::record::Record::empty());
246        }
247
248        fn step(&mut self, a: &Self::Act) -> (crate::Step<Self>, crate::record::Record)
249        where
250            Self: Sized,
251        {
252            self.state = self.state + a.act;
253            let step = crate::Step {
254                obs: TestObs { obs: self.state },
255                act: a.clone(),
256                reward: vec![0.0],
257                is_terminated: vec![0],
258                is_truncated: vec![0],
259                info: TestInfo {},
260                init_obs: TestObs {
261                    obs: self.state_init,
262                },
263            };
264            return (step, crate::record::Record::empty());
265        }
266
267        fn build(config: &Self::Config, _seed: i64) -> anyhow::Result<Self>
268        where
269            Self: Sized,
270        {
271            Ok(Self {
272                state_init: *config,
273                state: 0,
274            })
275        }
276    }
277
278    type ReplayBuffer =
279        crate::generic_replay_buffer::SimpleReplayBuffer<TestObsBatch, TestActBatch>;
280
281    /// Agent for testing.
282    pub struct TestAgent {}
283
284    #[derive(Clone, Deserialize, Serialize)]
285    /// Config of agent for testing.
286    pub struct TestAgentConfig;
287
288    impl crate::Agent<TestEnv, ReplayBuffer> for TestAgent {
289        fn train(&mut self) {}
290
291        fn is_train(&self) -> bool {
292            false
293        }
294
295        fn eval(&mut self) {}
296
297        fn opt_with_record(&mut self, _buffer: &mut ReplayBuffer) -> crate::record::Record {
298            crate::record::Record::empty()
299        }
300
301        fn save_params<T: AsRef<std::path::Path>>(&self, _path: T) -> anyhow::Result<()> {
302            Ok(())
303        }
304
305        fn load_params<T: AsRef<std::path::Path>>(&mut self, _path: T) -> anyhow::Result<()> {
306            Ok(())
307        }
308    }
309
310    impl crate::Policy<TestEnv> for TestAgent {
311        fn sample(&mut self, _obs: &TestObs) -> TestAct {
312            TestAct { act: 1 }
313        }
314    }
315
316    impl crate::Configurable<TestEnv> for TestAgent {
317        type Config = TestAgentConfig;
318
319        fn build(_config: Self::Config) -> Self {
320            Self {}
321        }
322    }
323}