border_async_trainer/
lib.rs

1//! Asynchronous trainer with parallel sampling processes.
2//!
3//! The code might look like below.
4//!
5//! ```
6//! # use serde::{Deserialize, Serialize};
7//! # use border_core::test::{
8//! #     TestAgent, TestAgentConfig, TestEnv, TestObs, TestObsBatch,
9//! #     TestAct, TestActBatch
10//! # };
11//! # use border_core::Env as _;
12//! # use border_async_trainer::{
13//! #     //test::{TestAgent, TestAgentConfig, TestEnv},
14//! #     ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig,
15//! # };
16//! # use border_core::{
17//! #     generic_replay_buffer::{
18//! #         SimpleReplayBuffer, SimpleReplayBufferConfig,
19//! #         SimpleStepProcessorConfig, SimpleStepProcessor
20//! #     },
21//! #     record::{Recorder, NullRecorder}, DefaultEvaluator,
22//! # };
23//! #
24//! # use std::path::{Path, PathBuf};
25//! #
26//! # fn agent_config() -> TestAgentConfig {
27//! #     TestAgentConfig
28//! # }
29//! #
30//! # fn env_config() -> usize {
31//! #     0
32//! # }
33//!
34//! type Env = TestEnv;
35//! type ObsBatch = TestObsBatch;
36//! type ActBatch = TestActBatch;
37//! type ReplayBuffer = SimpleReplayBuffer<ObsBatch, ActBatch>;
38//! type StepProcessor = SimpleStepProcessor<Env, ObsBatch, ActBatch>;
39//!
40//! // Create a new agent by wrapping the existing agent in order to implement SyncModel.
41//! struct TestAgent2(TestAgent);
42//!
43//! impl border_core::Configurable for TestAgent2 {
44//!     type Config = TestAgentConfig;
45//!
46//!     fn build(config: Self::Config) -> Self {
47//!         Self(TestAgent::build(config))
48//!     }
49//! }
50//!
51//! impl border_core::Agent<Env, ReplayBuffer> for TestAgent2 {
52//!     // Boilerplate code to delegate the method calls to the inner agent.
53//!     fn train(&mut self) {
54//!         self.0.train();
55//!      }
56//!
57//!      // For other methods ...
58//! #     fn is_train(&self) -> bool {
59//! #         self.0.is_train()
60//! #     }
61//! #
62//! #     fn eval(&mut self) {
63//! #         self.0.eval();
64//! #     }
65//! #
66//! #     fn opt_with_record(&mut self, buffer: &mut ReplayBuffer) -> border_core::record::Record {
67//! #         self.0.opt_with_record(buffer)
68//! #     }
69//! #
70//! #     fn save_params(&self, path: &Path) -> anyhow::Result<Vec<PathBuf>> {
71//! #         self.0.save_params(path)
72//! #     }
73//! #
74//! #     fn load_params(&mut self, path: &Path) -> anyhow::Result<()> {
75//! #         self.0.load_params(path)
76//! #     }
77//! #
78//! #     fn opt(&mut self, buffer: &mut ReplayBuffer) {
79//! #         self.0.opt_with_record(buffer);
80//! #     }
81//! #
82//! #     fn as_any_ref(&self) -> &dyn std::any::Any {
83//! #         self
84//! #     }
85//! #
86//! #     fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
87//! #         self
88//! #     }
89//! }
90//!
91//! impl border_core::Policy<Env> for TestAgent2 {
92//!       // Boilerplate code to delegate the method calls to the inner agent.
93//!       // ...
94//! #     fn sample(&mut self, obs: &TestObs) -> TestAct {
95//! #         self.0.sample(obs)
96//! #     }
97//! }
98//!
99//! impl border_async_trainer::SyncModel for TestAgent2{
100//!     // Self::ModelInfo shold include the model parameters.
101//!     type ModelInfo = usize;
102//!
103//!
104//!     fn model_info(&self) -> (usize, Self::ModelInfo) {
105//!         // Extracts the model parameters and returns them as Self::ModelInfo.
106//!         // The first element of the tuple is the number of optimization steps.
107//!         (0, 0)
108//!     }
109//!
110//!     fn sync_model(&mut self, _model_info: &Self::ModelInfo) {
111//!         // implements synchronization of the model based on the _model_info
112//!     }
113//! }
114//!
115//! let agent_configs: Vec<_> = vec![agent_config()];
116//! let env_config_train = env_config();
117//! let env_config_eval = env_config();
118//! let replay_buffer_config = SimpleReplayBufferConfig::default();
119//! let step_proc_config = SimpleStepProcessorConfig::default();
120//! let actor_man_config = ActorManagerConfig::default();
121//! let async_trainer_config = AsyncTrainerConfig::default();
122//! let mut recorder: Box<dyn Recorder<_, _>> = Box::new(NullRecorder::new());
123//! let mut evaluator = DefaultEvaluator::<TestEnv>::new(&env_config_eval, 0, 1).unwrap();
124//!
125//! border_async_trainer::util::train_async::<TestAgent2, _, _, StepProcessor>(
126//!     &agent_config(),
127//!     &agent_configs,
128//!     &env_config_train,
129//!     &env_config_eval,
130//!     &step_proc_config,
131//!     &replay_buffer_config,
132//!     &actor_man_config,
133//!     &async_trainer_config,
134//!     &mut recorder,
135//!     &mut evaluator,
136//! );
137//! ```
138//!
139//! Training process consists of the following two components:
140//!
141//! * [`ActorManager`] manages [`Actor`]s, each of which runs a thread for interacting
142//!   [`Agent`] and [`Env`] and taking samples. Those samples will be sent to
143//!   the replay buffer in [`AsyncTrainer`].
144//! * [`AsyncTrainer`] is responsible for training of an agent. It also runs a thread
145//!   for pushing samples from [`ActorManager`] into a replay buffer.
146//!
147//! The `Agent` must implement [`SyncModel`] trait in order to synchronize the model of
148//! the agent in [`Actor`] with the trained agent in [`AsyncTrainer`]. The trait has
149//! the ability to import and export the information of the model as
150//! [`SyncModel`]`::ModelInfo`.
151//!
152//! The `Agent` in [`AsyncTrainer`] is responsible for training, typically with a GPU,
153//! while the `Agent`s in [`Actor`]s in [`ActorManager`] is responsible for sampling
154//! using CPU.
155//!
156//! Both [`AsyncTrainer`] and [`ActorManager`] are running in the same machine and
157//! communicate by channels.
158//!
159//! [`Agent`]: border_core::Agent
160//! [`Env`]: border_core::Env
161mod actor;
162mod actor_manager;
163mod async_trainer;
164mod error;
165mod messages;
166mod replay_buffer_proxy;
167mod sync_model;
168pub mod util;
169
170pub use actor::{actor_stats_fmt, Actor, ActorStat};
171pub use actor_manager::{ActorManager, ActorManagerConfig};
172pub use async_trainer::{AsyncTrainStat, AsyncTrainer, AsyncTrainerConfig};
173pub use error::BorderAsyncTrainerError;
174pub use messages::PushedItemMessage;
175pub use replay_buffer_proxy::{ReplayBufferProxy, ReplayBufferProxyConfig};
176pub use sync_model::SyncModel;
177
178/// Agent and Env for testing.
179#[cfg(test)]
180pub mod test {
181    use serde::{Deserialize, Serialize};
182    use std::path::{Path, PathBuf};
183
184    /// Obs for testing.
185    #[derive(Clone, Debug)]
186    pub struct TestObs {
187        obs: usize,
188    }
189
190    impl border_core::Obs for TestObs {
191        fn len(&self) -> usize {
192            1
193        }
194    }
195
196    /// Batch of obs for testing.
197    pub struct TestObsBatch {
198        obs: Vec<usize>,
199    }
200
201    impl border_core::generic_replay_buffer::BatchBase for TestObsBatch {
202        fn new(capacity: usize) -> Self {
203            Self {
204                obs: vec![0; capacity],
205            }
206        }
207
208        fn push(&mut self, i: usize, data: Self) {
209            self.obs[i] = data.obs[0];
210        }
211
212        fn sample(&self, ixs: &Vec<usize>) -> Self {
213            let obs = ixs.iter().map(|ix| self.obs[*ix]).collect();
214            Self { obs }
215        }
216    }
217
218    impl From<TestObs> for TestObsBatch {
219        fn from(obs: TestObs) -> Self {
220            Self { obs: vec![obs.obs] }
221        }
222    }
223
224    /// Act for testing.
225    #[derive(Clone, Debug)]
226    pub struct TestAct {
227        act: usize,
228    }
229
230    impl border_core::Act for TestAct {}
231
232    /// Batch of act for testing.
233    pub struct TestActBatch {
234        act: Vec<usize>,
235    }
236
237    impl From<TestAct> for TestActBatch {
238        fn from(act: TestAct) -> Self {
239            Self { act: vec![act.act] }
240        }
241    }
242
243    impl border_core::generic_replay_buffer::BatchBase for TestActBatch {
244        fn new(capacity: usize) -> Self {
245            Self {
246                act: vec![0; capacity],
247            }
248        }
249
250        fn push(&mut self, i: usize, data: Self) {
251            self.act[i] = data.act[0];
252        }
253
254        fn sample(&self, ixs: &Vec<usize>) -> Self {
255            let act = ixs.iter().map(|ix| self.act[*ix]).collect();
256            Self { act }
257        }
258    }
259
260    /// Info for testing.
261    pub struct TestInfo {}
262
263    impl border_core::Info for TestInfo {}
264
265    /// Environment for testing.
266    pub struct TestEnv {
267        state_init: usize,
268        state: usize,
269    }
270
271    impl border_core::Env for TestEnv {
272        type Config = usize;
273        type Obs = TestObs;
274        type Act = TestAct;
275        type Info = TestInfo;
276
277        fn reset(&mut self, _is_done: Option<&Vec<i8>>) -> anyhow::Result<Self::Obs> {
278            self.state = self.state_init;
279            Ok(TestObs { obs: self.state })
280        }
281
282        fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result<Self::Obs> {
283            self.state = self.state_init;
284            Ok(TestObs { obs: self.state })
285        }
286
287        fn step_with_reset(
288            &mut self,
289            a: &Self::Act,
290        ) -> (border_core::Step<Self>, border_core::record::Record)
291        where
292            Self: Sized,
293        {
294            self.state = self.state + a.act;
295            let step = border_core::Step {
296                obs: TestObs { obs: self.state },
297                act: a.clone(),
298                reward: vec![0.0],
299                is_terminated: vec![0],
300                is_truncated: vec![0],
301                info: TestInfo {},
302                init_obs: Some(TestObs {
303                    obs: self.state_init,
304                }),
305            };
306            return (step, border_core::record::Record::empty());
307        }
308
309        fn step(&mut self, a: &Self::Act) -> (border_core::Step<Self>, border_core::record::Record)
310        where
311            Self: Sized,
312        {
313            self.state = self.state + a.act;
314            let step = border_core::Step {
315                obs: TestObs { obs: self.state },
316                act: a.clone(),
317                reward: vec![0.0],
318                is_terminated: vec![0],
319                is_truncated: vec![0],
320                info: TestInfo {},
321                init_obs: Some(TestObs {
322                    obs: self.state_init,
323                }),
324            };
325            return (step, border_core::record::Record::empty());
326        }
327
328        fn build(config: &Self::Config, _seed: i64) -> anyhow::Result<Self>
329        where
330            Self: Sized,
331        {
332            Ok(Self {
333                state_init: *config,
334                state: 0,
335            })
336        }
337    }
338
339    type ReplayBuffer =
340        border_core::generic_replay_buffer::SimpleReplayBuffer<TestObsBatch, TestActBatch>;
341
342    /// Agent for testing.
343    pub struct TestAgent {}
344
345    #[derive(Clone, Deserialize, Serialize)]
346    /// Config of agent for testing.
347    pub struct TestAgentConfig;
348
349    impl border_core::Agent<TestEnv, ReplayBuffer> for TestAgent {
350        fn train(&mut self) {}
351
352        fn is_train(&self) -> bool {
353            false
354        }
355
356        fn eval(&mut self) {}
357
358        fn opt_with_record(&mut self, _buffer: &mut ReplayBuffer) -> border_core::record::Record {
359            border_core::record::Record::empty()
360        }
361
362        fn save_params(&self, _path: &Path) -> anyhow::Result<Vec<PathBuf>> {
363            Ok(vec![])
364        }
365
366        fn load_params(&mut self, _path: &Path) -> anyhow::Result<()> {
367            Ok(())
368        }
369
370        fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
371            self
372        }
373
374        fn as_any_ref(&self) -> &dyn std::any::Any {
375            self
376        }
377    }
378
379    impl border_core::Policy<TestEnv> for TestAgent {
380        fn sample(&mut self, _obs: &TestObs) -> TestAct {
381            TestAct { act: 1 }
382        }
383    }
384
385    impl border_core::Configurable for TestAgent {
386        type Config = TestAgentConfig;
387
388        fn build(_config: Self::Config) -> Self {
389            Self {}
390        }
391    }
392
393    impl crate::SyncModel for TestAgent {
394        type ModelInfo = usize;
395
396        fn model_info(&self) -> (usize, Self::ModelInfo) {
397            (0, 0)
398        }
399
400        fn sync_model(&mut self, _model_info: &Self::ModelInfo) {
401            // nothing to do
402        }
403    }
404}