border_core/lib.rs
1#![warn(missing_docs)]
2//! Core components for reinforcement learning.
3//!
4//! # Observation and action
5//!
6//! [`Obs`] and [`Act`] traits are abstractions of observation and action in environments.
7//! These traits can handle two or more samples for implementing vectorized environments,
8//! although there is currently no implementation of vectorized environment.
9//!
10//! # Environment
11//!
12//! [`Env`] trait is an abstraction of environments. It has four associated types:
13//! `Config`, `Obs`, `Act` and `Info`. `Obs` and `Act` are concrete types of
14//! observation and action of the environment.
15//! These types must implement [`Obs`] and [`Act`] traits, respectively.
16//! The environment that implements [`Env`] generates [`Step<E: Env>`] object
17//! at every environment interaction step with [`Env::step()`] method.
18//! [`Info`] stores some information at every step of interactions of an agent and
19//! the environment. It could be empty (zero-sized struct). `Config` represents
20//! configurations of the environment and is used to build.
21//!
22//! # Policy
23//!
24//! [`Policy<E: Env>`] represents a policy. [`Policy::sample()`] takes `E::Obs` and
25//! generates `E::Act`. It could be probabilistic or deterministic.
26//!
27//! # Agent
28//!
29//! In this crate, [`Agent<E: Env, R: ReplayBufferBase>`] is defined as trainable
30//! [`Policy<E: Env>`]. It is in either training or evaluation mode. In training mode,
31//! the agent's policy might be probabilistic for exploration, while in evaluation mode,
32//! the policy might be deterministic.
33//!
34//! The [`Agent::opt()`] method performs a single optimization step. The definition of an
35//! optimization step varies for each agent. It might be multiple stochastic gradient
36//! steps in an optimization step. Samples for training are taken from
37//! [`R: ReplayBufferBase`][`ReplayBufferBase`].
38//!
39//! This trait also has methods for saving/loading parameters of the trained policy
40//! in a directory.
41//!
42//! # Batch
43//!
44//! [`TransitionBatch`] is a trait of a batch of transitions `(o_t, r_t, a_t, o_t+1)`.
45//! This trait is used to train [`Agent`]s using an RL algorithm.
46//!
47//! # Replay buffer and experience buffer
48//!
49//! [`ReplayBufferBase`] trait is an abstraction of replay buffers.
50//! One of the associated type [`ReplayBufferBase::Batch`] represents samples taken from
51//! the buffer for training [`Agent`]s. Agents must implements [`Agent::opt()`] method,
52//! where [`ReplayBufferBase::Batch`] has an appropriate type or trait bound(s) to train
53//! the agent.
54//!
55//! As explained above, [`ReplayBufferBase`] trait has an ability to generates batches
56//! of samples with which agents are trained. On the other hand, [`ExperienceBufferBase`]
57//! trait has an ability to store samples. [`ExperienceBufferBase::push()`] is used to push
58//! samples of type [`ExperienceBufferBase::Item`], which might be obtained via interaction
59//! steps with an environment.
60//!
61//! ## A reference implementation
62//!
63//! [`SimpleReplayBuffer<O, A>`] implementats both [`ReplayBufferBase`] and [`ExperienceBufferBase`].
64//! This type has two parameters `O` and `A`, which are representation of
65//! observation and action in the replay buffer. `O` and `A` must implement
66//! [`BatchBase`], which has the functionality of storing samples, like `Vec<T>`,
67//! for observation and action. The associated types `Item` and `Batch`
68//! are the same type, [`GenericTransitionBatch`], representing sets of `(o_t, r_t, a_t, o_t+1)`.
69//!
70//! [`SimpleStepProcessor<E, O, A>`] might be used with [`SimpleReplayBuffer<O, A>`].
71//! It converts `E::Obs` and `E::Act` into [`BatchBase`]s of respective types
72//! and generates [`GenericTransitionBatch`]. The conversion process relies on trait bounds,
73//! `O: From<E::Obs>` and `A: From<E::Act>`.
74//!
75//! # Trainer
76//!
77//! [`Trainer`] manages training loop and related objects. The [`Trainer`] object is
78//! built with configurations of training parameters such as the maximum number of
79//! optimization steps, model directory to save parameters of the agent during training, etc.
80//! [`Trainer::train`] method executes online training of an agent on an environment.
81//! In the training loop of this method, the agent interacts with the environment to
82//! take samples and perform optimization steps. Some metrices are recorded at the same time.
83//!
84//! # Evaluator
85//!
86//! [`Evaluator<E, P>`] is used to evaluate the policy's (`P`) performance in the environment (`E`).
87//! The object of this type is given to the [`Trainer`] object to evaluate the policy during training.
88//! [`DefaultEvaluator<E, P>`] is a default implementation of [`Evaluator<E, P>`].
89//! This evaluator runs the policy in the environment for a certain number of episodes.
90//! At the start of each episode, the environment is reset using [`Env::reset_with_index()`]
91//! to control specific conditions for evaluation.
92//!
93//! [`SimpleReplayBuffer`]: replay_buffer::SimpleReplayBuffer
94//! [`SimpleReplayBuffer<O, A>`]: generic_replay_buffer::SimpleReplayBuffer
95//! [`BatchBase`]: generic_replay_buffer::BatchBase
96//! [`GenericTransitionBatch`]: generic_replay_buffer::GenericTransitionBatch
97//! [`SimpleStepProcessor`]: replay_buffer::SimpleStepProcessor
98//! [`SimpleStepProcessor<E, O, A>`]: generic_replay_buffer::SimpleStepProcessor
99pub mod error;
100mod evaluator;
101pub mod generic_replay_buffer;
102pub mod record;
103
104mod base;
105pub use base::{
106 Act, Agent, Configurable, Env, ExperienceBufferBase, Info, Obs, Policy, ReplayBufferBase, Step,
107 StepProcessor, TransitionBatch,
108};
109
110mod trainer;
111pub use evaluator::{DefaultEvaluator, Evaluator};
112pub use trainer::{Sampler, Trainer, TrainerConfig};
113
114// TODO: Consider to compile this module only for tests.
115/// Agent and Env for testing.
116pub mod test {
117 use serde::{Deserialize, Serialize};
118
119 /// Obs for testing.
120 #[derive(Clone, Debug)]
121 pub struct TestObs {
122 obs: usize,
123 }
124
125 impl crate::Obs for TestObs {
126 fn dummy(_n: usize) -> Self {
127 Self { obs: 0 }
128 }
129
130 fn len(&self) -> usize {
131 1
132 }
133 }
134
135 /// Batch of obs for testing.
136 pub struct TestObsBatch {
137 obs: Vec<usize>,
138 }
139
140 impl crate::generic_replay_buffer::BatchBase for TestObsBatch {
141 fn new(capacity: usize) -> Self {
142 Self {
143 obs: vec![0; capacity],
144 }
145 }
146
147 fn push(&mut self, i: usize, data: Self) {
148 self.obs[i] = data.obs[0];
149 }
150
151 fn sample(&self, ixs: &Vec<usize>) -> Self {
152 let obs = ixs.iter().map(|ix| self.obs[*ix]).collect();
153 Self { obs }
154 }
155 }
156
157 impl From<TestObs> for TestObsBatch {
158 fn from(obs: TestObs) -> Self {
159 Self { obs: vec![obs.obs] }
160 }
161 }
162
163 /// Act for testing.
164 #[derive(Clone, Debug)]
165 pub struct TestAct {
166 act: usize,
167 }
168
169 impl crate::Act for TestAct {}
170
171 /// Batch of act for testing.
172 pub struct TestActBatch {
173 act: Vec<usize>,
174 }
175
176 impl From<TestAct> for TestActBatch {
177 fn from(act: TestAct) -> Self {
178 Self { act: vec![act.act] }
179 }
180 }
181
182 impl crate::generic_replay_buffer::BatchBase for TestActBatch {
183 fn new(capacity: usize) -> Self {
184 Self {
185 act: vec![0; capacity],
186 }
187 }
188
189 fn push(&mut self, i: usize, data: Self) {
190 self.act[i] = data.act[0];
191 }
192
193 fn sample(&self, ixs: &Vec<usize>) -> Self {
194 let act = ixs.iter().map(|ix| self.act[*ix]).collect();
195 Self { act }
196 }
197 }
198
199 /// Info for testing.
200 pub struct TestInfo {}
201
202 impl crate::Info for TestInfo {}
203
204 /// Environment for testing.
205 pub struct TestEnv {
206 state_init: usize,
207 state: usize,
208 }
209
210 impl crate::Env for TestEnv {
211 type Config = usize;
212 type Obs = TestObs;
213 type Act = TestAct;
214 type Info = TestInfo;
215
216 fn reset(&mut self, _is_done: Option<&Vec<i8>>) -> anyhow::Result<Self::Obs> {
217 self.state = self.state_init;
218 Ok(TestObs { obs: self.state })
219 }
220
221 fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result<Self::Obs> {
222 self.state = self.state_init;
223 Ok(TestObs { obs: self.state })
224 }
225
226 fn step_with_reset(
227 &mut self,
228 a: &Self::Act,
229 ) -> (crate::Step<Self>, crate::record::Record)
230 where
231 Self: Sized,
232 {
233 self.state = self.state + a.act;
234 let step = crate::Step {
235 obs: TestObs { obs: self.state },
236 act: a.clone(),
237 reward: vec![0.0],
238 is_terminated: vec![0],
239 is_truncated: vec![0],
240 info: TestInfo {},
241 init_obs: TestObs {
242 obs: self.state_init,
243 },
244 };
245 return (step, crate::record::Record::empty());
246 }
247
248 fn step(&mut self, a: &Self::Act) -> (crate::Step<Self>, crate::record::Record)
249 where
250 Self: Sized,
251 {
252 self.state = self.state + a.act;
253 let step = crate::Step {
254 obs: TestObs { obs: self.state },
255 act: a.clone(),
256 reward: vec![0.0],
257 is_terminated: vec![0],
258 is_truncated: vec![0],
259 info: TestInfo {},
260 init_obs: TestObs {
261 obs: self.state_init,
262 },
263 };
264 return (step, crate::record::Record::empty());
265 }
266
267 fn build(config: &Self::Config, _seed: i64) -> anyhow::Result<Self>
268 where
269 Self: Sized,
270 {
271 Ok(Self {
272 state_init: *config,
273 state: 0,
274 })
275 }
276 }
277
278 type ReplayBuffer =
279 crate::generic_replay_buffer::SimpleReplayBuffer<TestObsBatch, TestActBatch>;
280
281 /// Agent for testing.
282 pub struct TestAgent {}
283
284 #[derive(Clone, Deserialize, Serialize)]
285 /// Config of agent for testing.
286 pub struct TestAgentConfig;
287
288 impl crate::Agent<TestEnv, ReplayBuffer> for TestAgent {
289 fn train(&mut self) {}
290
291 fn is_train(&self) -> bool {
292 false
293 }
294
295 fn eval(&mut self) {}
296
297 fn opt_with_record(&mut self, _buffer: &mut ReplayBuffer) -> crate::record::Record {
298 crate::record::Record::empty()
299 }
300
301 fn save_params<T: AsRef<std::path::Path>>(&self, _path: T) -> anyhow::Result<()> {
302 Ok(())
303 }
304
305 fn load_params<T: AsRef<std::path::Path>>(&mut self, _path: T) -> anyhow::Result<()> {
306 Ok(())
307 }
308 }
309
310 impl crate::Policy<TestEnv> for TestAgent {
311 fn sample(&mut self, _obs: &TestObs) -> TestAct {
312 TestAct { act: 1 }
313 }
314 }
315
316 impl crate::Configurable<TestEnv> for TestAgent {
317 type Config = TestAgentConfig;
318
319 fn build(_config: Self::Config) -> Self {
320 Self {}
321 }
322 }
323}