border_async_trainer/actor/
base.rs

1use crate::{ActorStat, PushedItemMessage, ReplayBufferProxy, ReplayBufferProxyConfig, SyncModel};
2use border_core::{
3    Agent, Configurable, Env, ExperienceBufferBase, ReplayBufferBase, Sampler, StepProcessor,
4};
5use crossbeam_channel::Sender;
6use log::{debug, info};
7use std::{
8    marker::PhantomData,
9    ops::DerefMut,
10    sync::{Arc, Mutex},
11};
12
13#[cfg_attr(doc, aquamarine::aquamarine)]
14/// Generate transitions by running [`Agent`] in [`Env`].
15///
16/// ```mermaid
17/// flowchart TB
18///   E["Agent<br>(in AsyncTrainer)"]-->|SyncModel::ModelInfo|A
19///   subgraph D[Actor]
20///     A[Agent]-->|Env::Act|B[Env]
21///     B-->|Env::Obs|A
22///     B-->|Step&ltE: Env&gt|C[StepProcessor]
23///   end
24///   C-->|ReplayBufferBase::PushedItem|F[ReplayBufferProxy]
25/// ```
26///
27/// In [`Actor`], an [`Agent`] runs on an [`Env`] and generates [`Step`] objects.
28/// These objects are processed with [`StepProcessor`] and sent to [`ReplayBufferProxy`].
29/// The [`Agent`] in the [`Actor`] periodically synchronizes with the [`Agent`] in
30/// [`AsyncTrainer`] via [`SyncModel::ModelInfo`].
31///
32/// See also the diagram in [`AsyncTrainer`].
33///
34/// [`AsyncTrainer`]: crate::AsyncTrainer
35/// [`Agent`]: border_core::Agent
36/// [`Env`]: border_core::Env
37/// [`StepProcessor`]: border_core::StepProcessor
38/// [`Step`]: border_core::Step
39pub struct Actor<A, E, P, R>
40where
41    A: Agent<E, R> + Configurable + SyncModel + 'static,
42    E: Env,
43    P: StepProcessor<E>,
44    R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
45{
46    /// Stops sampling process if this field is set to `true`.
47    id: usize,
48    stop: Arc<Mutex<bool>>,
49    agent_config: A::Config,
50    env_config: E::Config,
51    step_proc_config: P::Config,
52    replay_buffer_config: ReplayBufferProxyConfig,
53    env_seed: i64,
54    stats: Arc<Mutex<Option<ActorStat>>>,
55    phantom: PhantomData<(A, E, P, R)>,
56}
57
58impl<A, E, P, R> Actor<A, E, P, R>
59where
60    A: Agent<E, R> + Configurable + SyncModel + 'static,
61    E: Env,
62    P: StepProcessor<E>,
63    R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
64{
65    pub fn build(
66        id: usize,
67        agent_config: A::Config,
68        env_config: E::Config,
69        step_proc_config: P::Config,
70        replay_buffer_config: ReplayBufferProxyConfig,
71        stop: Arc<Mutex<bool>>,
72        env_seed: i64,
73        stats: Arc<Mutex<Option<ActorStat>>>,
74    ) -> Self {
75        log::info!("Create actor {}", id);
76        Self {
77            id,
78            stop,
79            agent_config: agent_config.clone(),
80            env_config: env_config.clone(),
81            step_proc_config: step_proc_config.clone(),
82            replay_buffer_config: replay_buffer_config.clone(),
83            env_seed,
84            stats,
85            phantom: PhantomData,
86        }
87    }
88
89    fn sync_model_first(agent: &mut A, model_info: &Arc<Mutex<(usize, A::ModelInfo)>>, id: usize) {
90        let model_info = model_info.lock().unwrap();
91        agent.sync_model(&model_info.1);
92        info!("Received the initial model info in actor {}", id);
93    }
94
95    fn sync_model(
96        agent: &mut A,
97        n_opt_steps: &mut usize,
98        model_info: &Arc<Mutex<(usize, A::ModelInfo)>>,
99        id: usize,
100    ) {
101        let model_info = model_info.lock().unwrap();
102        if model_info.0 > *n_opt_steps {
103            *n_opt_steps = model_info.0;
104            agent.sync_model(&model_info.1);
105            debug!(
106                "Synchronized the model info of {} opt steps in actor {}",
107                n_opt_steps, id
108            );
109        }
110    }
111
112    #[inline]
113    fn downcast_mut(agent: &mut Box<dyn Agent<E, R>>) -> &mut A {
114        agent.deref_mut().as_any_mut().downcast_mut::<A>().unwrap()
115    }
116
117    /// Runs sampling loop until `self.stop` becomes `true`.
118    ///
119    /// When finishes, this method sets [ActorStat].
120    pub fn run(
121        &mut self,
122        sender: Sender<PushedItemMessage<R::Item>>,
123        model_info: Arc<Mutex<(usize, A::ModelInfo)>>,
124        guard: Arc<Mutex<bool>>,
125        guard_init_model: Arc<Mutex<bool>>,
126    ) {
127        let mut agent: Box<dyn Agent<E, R>> = Box::new(A::build(self.agent_config.clone()));
128        let mut buffer =
129            ReplayBufferProxy::<R>::build_with_sender(self.id, &self.replay_buffer_config, sender);
130        let mut sampler = {
131            let mut tmp = guard.lock().unwrap();
132            let env = E::build(&self.env_config, self.env_seed).unwrap();
133            let step_proc = P::build(&self.step_proc_config);
134            *tmp = true;
135            Sampler::new(env, step_proc)
136        };
137        info!("Starts actor {:?}", self.id);
138
139        let mut env_steps = 0;
140        let mut n_opt_steps = 0;
141        let time = std::time::SystemTime::now();
142
143        // Waits and syncs the initial model
144        {
145            let mut guard_init_model = guard_init_model.lock().unwrap();
146            Self::sync_model_first(Self::downcast_mut(&mut agent), &model_info, self.id);
147            *guard_init_model = true;
148        }
149
150        // Set agent training mode for exploration
151        agent.train();
152
153        // Sampling loop
154        info!("Starts sampling loop in actor {}", self.id);
155        loop {
156            // Check model update and synchronize
157            Self::sync_model(
158                Self::downcast_mut(&mut agent),
159                &mut n_opt_steps,
160                &model_info,
161                self.id,
162            );
163
164            // TODO: error handling
165            let _record = sampler.sample_and_push(&mut agent, &mut buffer).unwrap();
166            env_steps += 1;
167
168            // Stop sampling loop
169            if *self.stop.lock().unwrap() {
170                *self.stats.lock().unwrap() = Some(ActorStat {
171                    env_steps,
172                    duration: time.elapsed().unwrap(),
173                });
174                break;
175            }
176        }
177        info!("Stopped thread for actor {}", self.id);
178    }
179}