border_async_trainer/lib.rs
1//! Asynchronous trainer with parallel sampling processes.
2//!
3//! The code might look like below.
4//!
5//! ```
6//! # use serde::{Deserialize, Serialize};
7//! # use border_core::test::{
8//! # TestAgent, TestAgentConfig, TestEnv, TestObs, TestObsBatch,
9//! # TestAct, TestActBatch
10//! # };
11//! # use border_core::Env as _;
12//! # use border_async_trainer::{
13//! # //test::{TestAgent, TestAgentConfig, TestEnv},
14//! # ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig,
15//! # };
16//! # use border_core::{
17//! # generic_replay_buffer::{
18//! # SimpleReplayBuffer, SimpleReplayBufferConfig,
19//! # SimpleStepProcessorConfig, SimpleStepProcessor
20//! # },
21//! # record::{Recorder, NullRecorder}, DefaultEvaluator,
22//! # };
23//! #
24//! # use std::path::{Path, PathBuf};
25//! #
26//! # fn agent_config() -> TestAgentConfig {
27//! # TestAgentConfig
28//! # }
29//! #
30//! # fn env_config() -> usize {
31//! # 0
32//! # }
33//!
34//! type Env = TestEnv;
35//! type ObsBatch = TestObsBatch;
36//! type ActBatch = TestActBatch;
37//! type ReplayBuffer = SimpleReplayBuffer<ObsBatch, ActBatch>;
38//! type StepProcessor = SimpleStepProcessor<Env, ObsBatch, ActBatch>;
39//!
40//! // Create a new agent by wrapping the existing agent in order to implement SyncModel.
41//! struct TestAgent2(TestAgent);
42//!
43//! impl border_core::Configurable for TestAgent2 {
44//! type Config = TestAgentConfig;
45//!
46//! fn build(config: Self::Config) -> Self {
47//! Self(TestAgent::build(config))
48//! }
49//! }
50//!
51//! impl border_core::Agent<Env, ReplayBuffer> for TestAgent2 {
52//! // Boilerplate code to delegate the method calls to the inner agent.
53//! fn train(&mut self) {
54//! self.0.train();
55//! }
56//!
57//! // For other methods ...
58//! # fn is_train(&self) -> bool {
59//! # self.0.is_train()
60//! # }
61//! #
62//! # fn eval(&mut self) {
63//! # self.0.eval();
64//! # }
65//! #
66//! # fn opt_with_record(&mut self, buffer: &mut ReplayBuffer) -> border_core::record::Record {
67//! # self.0.opt_with_record(buffer)
68//! # }
69//! #
70//! # fn save_params(&self, path: &Path) -> anyhow::Result<Vec<PathBuf>> {
71//! # self.0.save_params(path)
72//! # }
73//! #
74//! # fn load_params(&mut self, path: &Path) -> anyhow::Result<()> {
75//! # self.0.load_params(path)
76//! # }
77//! #
78//! # fn opt(&mut self, buffer: &mut ReplayBuffer) {
79//! # self.0.opt_with_record(buffer);
80//! # }
81//! #
82//! # fn as_any_ref(&self) -> &dyn std::any::Any {
83//! # self
84//! # }
85//! #
86//! # fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
87//! # self
88//! # }
89//! }
90//!
91//! impl border_core::Policy<Env> for TestAgent2 {
92//! // Boilerplate code to delegate the method calls to the inner agent.
93//! // ...
94//! # fn sample(&mut self, obs: &TestObs) -> TestAct {
95//! # self.0.sample(obs)
96//! # }
97//! }
98//!
99//! impl border_async_trainer::SyncModel for TestAgent2{
100//! // Self::ModelInfo shold include the model parameters.
101//! type ModelInfo = usize;
102//!
103//!
104//! fn model_info(&self) -> (usize, Self::ModelInfo) {
105//! // Extracts the model parameters and returns them as Self::ModelInfo.
106//! // The first element of the tuple is the number of optimization steps.
107//! (0, 0)
108//! }
109//!
110//! fn sync_model(&mut self, _model_info: &Self::ModelInfo) {
111//! // implements synchronization of the model based on the _model_info
112//! }
113//! }
114//!
115//! let agent_configs: Vec<_> = vec![agent_config()];
116//! let env_config_train = env_config();
117//! let env_config_eval = env_config();
118//! let replay_buffer_config = SimpleReplayBufferConfig::default();
119//! let step_proc_config = SimpleStepProcessorConfig::default();
120//! let actor_man_config = ActorManagerConfig::default();
121//! let async_trainer_config = AsyncTrainerConfig::default();
122//! let mut recorder: Box<dyn Recorder<_, _>> = Box::new(NullRecorder::new());
123//! let mut evaluator = DefaultEvaluator::<TestEnv>::new(&env_config_eval, 0, 1).unwrap();
124//!
125//! border_async_trainer::util::train_async::<TestAgent2, _, _, StepProcessor>(
126//! &agent_config(),
127//! &agent_configs,
128//! &env_config_train,
129//! &env_config_eval,
130//! &step_proc_config,
131//! &replay_buffer_config,
132//! &actor_man_config,
133//! &async_trainer_config,
134//! &mut recorder,
135//! &mut evaluator,
136//! );
137//! ```
138//!
139//! Training process consists of the following two components:
140//!
141//! * [`ActorManager`] manages [`Actor`]s, each of which runs a thread for interacting
142//! [`Agent`] and [`Env`] and taking samples. Those samples will be sent to
143//! the replay buffer in [`AsyncTrainer`].
144//! * [`AsyncTrainer`] is responsible for training of an agent. It also runs a thread
145//! for pushing samples from [`ActorManager`] into a replay buffer.
146//!
147//! The `Agent` must implement [`SyncModel`] trait in order to synchronize the model of
148//! the agent in [`Actor`] with the trained agent in [`AsyncTrainer`]. The trait has
149//! the ability to import and export the information of the model as
150//! [`SyncModel`]`::ModelInfo`.
151//!
152//! The `Agent` in [`AsyncTrainer`] is responsible for training, typically with a GPU,
153//! while the `Agent`s in [`Actor`]s in [`ActorManager`] is responsible for sampling
154//! using CPU.
155//!
156//! Both [`AsyncTrainer`] and [`ActorManager`] are running in the same machine and
157//! communicate by channels.
158//!
159//! [`Agent`]: border_core::Agent
160//! [`Env`]: border_core::Env
161mod actor;
162mod actor_manager;
163mod async_trainer;
164mod error;
165mod messages;
166mod replay_buffer_proxy;
167mod sync_model;
168pub mod util;
169
170pub use actor::{actor_stats_fmt, Actor, ActorStat};
171pub use actor_manager::{ActorManager, ActorManagerConfig};
172pub use async_trainer::{AsyncTrainStat, AsyncTrainer, AsyncTrainerConfig};
173pub use error::BorderAsyncTrainerError;
174pub use messages::PushedItemMessage;
175pub use replay_buffer_proxy::{ReplayBufferProxy, ReplayBufferProxyConfig};
176pub use sync_model::SyncModel;
177
178/// Agent and Env for testing.
179#[cfg(test)]
180pub mod test {
181 use serde::{Deserialize, Serialize};
182 use std::path::{Path, PathBuf};
183
184 /// Obs for testing.
185 #[derive(Clone, Debug)]
186 pub struct TestObs {
187 obs: usize,
188 }
189
190 impl border_core::Obs for TestObs {
191 fn len(&self) -> usize {
192 1
193 }
194 }
195
196 /// Batch of obs for testing.
197 pub struct TestObsBatch {
198 obs: Vec<usize>,
199 }
200
201 impl border_core::generic_replay_buffer::BatchBase for TestObsBatch {
202 fn new(capacity: usize) -> Self {
203 Self {
204 obs: vec![0; capacity],
205 }
206 }
207
208 fn push(&mut self, i: usize, data: Self) {
209 self.obs[i] = data.obs[0];
210 }
211
212 fn sample(&self, ixs: &Vec<usize>) -> Self {
213 let obs = ixs.iter().map(|ix| self.obs[*ix]).collect();
214 Self { obs }
215 }
216 }
217
218 impl From<TestObs> for TestObsBatch {
219 fn from(obs: TestObs) -> Self {
220 Self { obs: vec![obs.obs] }
221 }
222 }
223
224 /// Act for testing.
225 #[derive(Clone, Debug)]
226 pub struct TestAct {
227 act: usize,
228 }
229
230 impl border_core::Act for TestAct {}
231
232 /// Batch of act for testing.
233 pub struct TestActBatch {
234 act: Vec<usize>,
235 }
236
237 impl From<TestAct> for TestActBatch {
238 fn from(act: TestAct) -> Self {
239 Self { act: vec![act.act] }
240 }
241 }
242
243 impl border_core::generic_replay_buffer::BatchBase for TestActBatch {
244 fn new(capacity: usize) -> Self {
245 Self {
246 act: vec![0; capacity],
247 }
248 }
249
250 fn push(&mut self, i: usize, data: Self) {
251 self.act[i] = data.act[0];
252 }
253
254 fn sample(&self, ixs: &Vec<usize>) -> Self {
255 let act = ixs.iter().map(|ix| self.act[*ix]).collect();
256 Self { act }
257 }
258 }
259
260 /// Info for testing.
261 pub struct TestInfo {}
262
263 impl border_core::Info for TestInfo {}
264
265 /// Environment for testing.
266 pub struct TestEnv {
267 state_init: usize,
268 state: usize,
269 }
270
271 impl border_core::Env for TestEnv {
272 type Config = usize;
273 type Obs = TestObs;
274 type Act = TestAct;
275 type Info = TestInfo;
276
277 fn reset(&mut self, _is_done: Option<&Vec<i8>>) -> anyhow::Result<Self::Obs> {
278 self.state = self.state_init;
279 Ok(TestObs { obs: self.state })
280 }
281
282 fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result<Self::Obs> {
283 self.state = self.state_init;
284 Ok(TestObs { obs: self.state })
285 }
286
287 fn step_with_reset(
288 &mut self,
289 a: &Self::Act,
290 ) -> (border_core::Step<Self>, border_core::record::Record)
291 where
292 Self: Sized,
293 {
294 self.state = self.state + a.act;
295 let step = border_core::Step {
296 obs: TestObs { obs: self.state },
297 act: a.clone(),
298 reward: vec![0.0],
299 is_terminated: vec![0],
300 is_truncated: vec![0],
301 info: TestInfo {},
302 init_obs: Some(TestObs {
303 obs: self.state_init,
304 }),
305 };
306 return (step, border_core::record::Record::empty());
307 }
308
309 fn step(&mut self, a: &Self::Act) -> (border_core::Step<Self>, border_core::record::Record)
310 where
311 Self: Sized,
312 {
313 self.state = self.state + a.act;
314 let step = border_core::Step {
315 obs: TestObs { obs: self.state },
316 act: a.clone(),
317 reward: vec![0.0],
318 is_terminated: vec![0],
319 is_truncated: vec![0],
320 info: TestInfo {},
321 init_obs: Some(TestObs {
322 obs: self.state_init,
323 }),
324 };
325 return (step, border_core::record::Record::empty());
326 }
327
328 fn build(config: &Self::Config, _seed: i64) -> anyhow::Result<Self>
329 where
330 Self: Sized,
331 {
332 Ok(Self {
333 state_init: *config,
334 state: 0,
335 })
336 }
337 }
338
339 type ReplayBuffer =
340 border_core::generic_replay_buffer::SimpleReplayBuffer<TestObsBatch, TestActBatch>;
341
342 /// Agent for testing.
343 pub struct TestAgent {}
344
345 #[derive(Clone, Deserialize, Serialize)]
346 /// Config of agent for testing.
347 pub struct TestAgentConfig;
348
349 impl border_core::Agent<TestEnv, ReplayBuffer> for TestAgent {
350 fn train(&mut self) {}
351
352 fn is_train(&self) -> bool {
353 false
354 }
355
356 fn eval(&mut self) {}
357
358 fn opt_with_record(&mut self, _buffer: &mut ReplayBuffer) -> border_core::record::Record {
359 border_core::record::Record::empty()
360 }
361
362 fn save_params(&self, _path: &Path) -> anyhow::Result<Vec<PathBuf>> {
363 Ok(vec![])
364 }
365
366 fn load_params(&mut self, _path: &Path) -> anyhow::Result<()> {
367 Ok(())
368 }
369
370 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
371 self
372 }
373
374 fn as_any_ref(&self) -> &dyn std::any::Any {
375 self
376 }
377 }
378
379 impl border_core::Policy<TestEnv> for TestAgent {
380 fn sample(&mut self, _obs: &TestObs) -> TestAct {
381 TestAct { act: 1 }
382 }
383 }
384
385 impl border_core::Configurable for TestAgent {
386 type Config = TestAgentConfig;
387
388 fn build(_config: Self::Config) -> Self {
389 Self {}
390 }
391 }
392
393 impl crate::SyncModel for TestAgent {
394 type ModelInfo = usize;
395
396 fn model_info(&self) -> (usize, Self::ModelInfo) {
397 (0, 0)
398 }
399
400 fn sync_model(&mut self, _model_info: &Self::ModelInfo) {
401 // nothing to do
402 }
403 }
404}