border_async_trainer/actor/
base.rs1use crate::{ActorStat, PushedItemMessage, ReplayBufferProxy, ReplayBufferProxyConfig, SyncModel};
2use border_core::{
3 Agent, Configurable, Env, ExperienceBufferBase, ReplayBufferBase, Sampler, StepProcessor,
4};
5use crossbeam_channel::Sender;
6use log::{debug, info};
7use std::{
8 marker::PhantomData,
9 ops::DerefMut,
10 sync::{Arc, Mutex},
11};
12
13#[cfg_attr(doc, aquamarine::aquamarine)]
14pub struct Actor<A, E, P, R>
40where
41 A: Agent<E, R> + Configurable + SyncModel + 'static,
42 E: Env,
43 P: StepProcessor<E>,
44 R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
45{
46 id: usize,
48 stop: Arc<Mutex<bool>>,
49 agent_config: A::Config,
50 env_config: E::Config,
51 step_proc_config: P::Config,
52 replay_buffer_config: ReplayBufferProxyConfig,
53 env_seed: i64,
54 stats: Arc<Mutex<Option<ActorStat>>>,
55 phantom: PhantomData<(A, E, P, R)>,
56}
57
58impl<A, E, P, R> Actor<A, E, P, R>
59where
60 A: Agent<E, R> + Configurable + SyncModel + 'static,
61 E: Env,
62 P: StepProcessor<E>,
63 R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
64{
65 pub fn build(
66 id: usize,
67 agent_config: A::Config,
68 env_config: E::Config,
69 step_proc_config: P::Config,
70 replay_buffer_config: ReplayBufferProxyConfig,
71 stop: Arc<Mutex<bool>>,
72 env_seed: i64,
73 stats: Arc<Mutex<Option<ActorStat>>>,
74 ) -> Self {
75 log::info!("Create actor {}", id);
76 Self {
77 id,
78 stop,
79 agent_config: agent_config.clone(),
80 env_config: env_config.clone(),
81 step_proc_config: step_proc_config.clone(),
82 replay_buffer_config: replay_buffer_config.clone(),
83 env_seed,
84 stats,
85 phantom: PhantomData,
86 }
87 }
88
89 fn sync_model_first(agent: &mut A, model_info: &Arc<Mutex<(usize, A::ModelInfo)>>, id: usize) {
90 let model_info = model_info.lock().unwrap();
91 agent.sync_model(&model_info.1);
92 info!("Received the initial model info in actor {}", id);
93 }
94
95 fn sync_model(
96 agent: &mut A,
97 n_opt_steps: &mut usize,
98 model_info: &Arc<Mutex<(usize, A::ModelInfo)>>,
99 id: usize,
100 ) {
101 let model_info = model_info.lock().unwrap();
102 if model_info.0 > *n_opt_steps {
103 *n_opt_steps = model_info.0;
104 agent.sync_model(&model_info.1);
105 debug!(
106 "Synchronized the model info of {} opt steps in actor {}",
107 n_opt_steps, id
108 );
109 }
110 }
111
112 #[inline]
113 fn downcast_mut(agent: &mut Box<dyn Agent<E, R>>) -> &mut A {
114 agent.deref_mut().as_any_mut().downcast_mut::<A>().unwrap()
115 }
116
117 pub fn run(
121 &mut self,
122 sender: Sender<PushedItemMessage<R::Item>>,
123 model_info: Arc<Mutex<(usize, A::ModelInfo)>>,
124 guard: Arc<Mutex<bool>>,
125 guard_init_model: Arc<Mutex<bool>>,
126 ) {
127 let mut agent: Box<dyn Agent<E, R>> = Box::new(A::build(self.agent_config.clone()));
128 let mut buffer =
129 ReplayBufferProxy::<R>::build_with_sender(self.id, &self.replay_buffer_config, sender);
130 let mut sampler = {
131 let mut tmp = guard.lock().unwrap();
132 let env = E::build(&self.env_config, self.env_seed).unwrap();
133 let step_proc = P::build(&self.step_proc_config);
134 *tmp = true;
135 Sampler::new(env, step_proc)
136 };
137 info!("Starts actor {:?}", self.id);
138
139 let mut env_steps = 0;
140 let mut n_opt_steps = 0;
141 let time = std::time::SystemTime::now();
142
143 {
145 let mut guard_init_model = guard_init_model.lock().unwrap();
146 Self::sync_model_first(Self::downcast_mut(&mut agent), &model_info, self.id);
147 *guard_init_model = true;
148 }
149
150 agent.train();
152
153 info!("Starts sampling loop in actor {}", self.id);
155 loop {
156 Self::sync_model(
158 Self::downcast_mut(&mut agent),
159 &mut n_opt_steps,
160 &model_info,
161 self.id,
162 );
163
164 let _record = sampler.sample_and_push(&mut agent, &mut buffer).unwrap();
166 env_steps += 1;
167
168 if *self.stop.lock().unwrap() {
170 *self.stats.lock().unwrap() = Some(ActorStat {
171 env_steps,
172 duration: time.elapsed().unwrap(),
173 });
174 break;
175 }
176 }
177 info!("Stopped thread for actor {}", self.id);
178 }
179}