1use super::{OnlineStepsSummary, Simulation, Steps, StepsSummary};
2use crate::agents::{buffers::HistoryDataBound, ActorMode, Agent, BatchUpdate, WriteExperience};
3use crate::envs::{EnvStructure, Environment, StructuredEnvironment};
4use crate::feedback::{Feedback, Summary};
5use crate::logging::{Loggable, StatsLogger};
6use crate::spaces::{LogElementSpace, Space};
7use crate::Prng;
8use log::warn;
9use rand::SeedableRng;
10use serde::{Deserialize, Serialize};
11use std::iter;
12use std::time::Instant;
13
14pub fn train_serial<T, E>(
16 agent: &mut T,
17 environment: &E,
18 num_periods: usize,
19 rng_env: &mut Prng,
20 rng_agent: &mut Prng,
21 logger: &mut dyn StatsLogger,
22) where
23 T: Agent<E::Observation, E::Action>
24 + BatchUpdate<E::Observation, E::Action, Feedback = E::Feedback>
25 + ?Sized,
26 E: StructuredEnvironment + ?Sized,
27 E::ObservationSpace: LogElementSpace,
28 E::ActionSpace: LogElementSpace,
29 E::Feedback: Feedback,
30{
31 let mut buffer = agent.buffer();
32 for _ in 0..num_periods {
33 let update_size = agent.min_update_size();
34 buffer
35 .write_experience(
36 update_size
37 .take(Steps::new(
38 environment,
39 agent.actor(ActorMode::Training),
40 &mut *rng_env,
41 &mut *rng_agent,
42 &mut *logger,
43 ))
44 .log(),
45 )
46 .unwrap_or_else(|err| warn!("error filling buffer: {}", err));
47 agent.batch_update(iter::once(&mut buffer), logger);
48 }
49}
50
51#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
53pub struct TrainParallelConfig {
54 pub num_periods: usize,
56 pub num_threads: usize,
58 pub min_worker_steps: usize,
60}
61
62#[allow(clippy::trait_duplication_in_bounds)]
68pub fn train_parallel<T, E>(
69 agent: &mut T,
70 environment: &E,
71 config: &TrainParallelConfig,
72 rng_env: &mut Prng,
73 rng_agent: &mut Prng,
74 logger: &mut dyn StatsLogger,
75) where
76 E: EnvStructure
78 + Environment<
79 Observation = <E::ObservationSpace as Space>::Element,
80 Action = <E::ActionSpace as Space>::Element,
81 Feedback = <E::FeedbackSpace as Space>::Element,
82 > + Sync
83 + ?Sized,
84 T: Agent<<E::ObservationSpace as Space>::Element, <E::ActionSpace as Space>::Element>
85 + BatchUpdate<
86 <E::ObservationSpace as Space>::Element,
87 <E::ActionSpace as Space>::Element,
88 Feedback = <E::FeedbackSpace as Space>::Element,
89 > + ?Sized,
90 T::Actor: Send,
91 T::HistoryBuffer: Send,
92 E::ObservationSpace: LogElementSpace,
93 E::ActionSpace: LogElementSpace,
94 <E::FeedbackSpace as Space>::Element: Feedback,
95 <<E::FeedbackSpace as Space>::Element as Feedback>::StepSummary: Send,
96 <<E::FeedbackSpace as Space>::Element as Feedback>::EpisodeSummary: Send,
97{
98 let mut buffers: Vec<_> = (0..config.num_threads).map(|_| agent.buffer()).collect();
99 let mut thread_rngs: Vec<_> = (0..config.num_threads)
100 .map(|_| {
101 (
102 Prng::from_rng(&mut *rng_env).expect("Prng should be infallible"),
103 Prng::from_rng(&mut *rng_agent).expect("Prng should be infallible"),
104 )
105 })
106 .collect();
107
108 for _ in 0..config.num_periods {
109 let collect_start = Instant::now();
110
111 let worker_update_size =
112 agent
113 .min_update_size()
114 .divide(config.num_threads)
115 .max(HistoryDataBound {
116 min_steps: config.min_worker_steps,
117 slack_steps: 0,
118 });
119
120 let mut worker0_logger = logger.with_scope("worker0");
122 let mut send_logger = Some(&mut worker0_logger as &mut dyn StatsLogger);
123
124 let summary = crossbeam::scope(|scope| {
125 let mut threads = Vec::new();
126
127 for (buffer, rngs) in buffers.iter_mut().zip(&mut thread_rngs) {
128 let actor = agent.actor(ActorMode::Training);
129 let thread_logger = send_logger.take();
130 threads.push(scope.spawn(move |_scope| {
131 let mut summary = OnlineStepsSummary::default();
132 buffer
133 .write_experience(
134 worker_update_size
135 .take(Steps::new(
136 environment,
137 actor,
138 &mut rngs.0,
139 &mut rngs.1,
140 thread_logger.unwrap_or(&mut ()),
141 ))
142 .log()
143 .map(|step| {
144 summary.push(&step);
145 step
146 }),
147 )
148 .unwrap_or_else(|err| warn!("error filling buffer: {}", err));
149 StepsSummary::from(summary)
150 }));
151 }
152
153 threads
154 .into_iter()
155 .map(|t| t.join().unwrap())
156 .sum::<StepsSummary<_>>()
157 })
158 .unwrap();
159
160 let mut sim_logger = logger.with_scope("sim").group();
161 let mut episode_logger = (&mut sim_logger).with_scope("ep");
162 let num_episodes = summary.episode_length.count();
163 if num_episodes > 0 {
164 summary
165 .episode_feedback
166 .log("fbk", &mut episode_logger)
167 .unwrap();
168 episode_logger.log_scalar("length_mean", summary.episode_length.mean().unwrap());
169 episode_logger.log_scalar("length_stddev", summary.episode_length.stddev().unwrap());
170 }
171 episode_logger.log_counter_increment("count", num_episodes);
172
173 let mut step_logger = (&mut sim_logger).with_scope("step");
174 summary.step_feedback.log("fbk", &mut step_logger).unwrap();
175 step_logger.log_counter_increment("count", summary.step_feedback.size());
176 let update_start = Instant::now();
177 sim_logger.log_duration("time", update_start - collect_start);
178 drop(sim_logger);
179
180 agent.batch_update(&mut buffers, &mut *logger);
181
182 let mut agent_logger = logger.with_scope("agent_update").group();
183 agent_logger.log_duration("time", update_start.elapsed());
184 agent_logger.log_counter_increment("count", 1);
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::agents::{testing, BuildAgent, TabularQLearningAgentConfig};
192 use crate::envs::DeterministicBandit;
193
194 #[test]
195 fn train_parallel_tabular_q_bandit() {
196 let config = TrainParallelConfig {
197 num_periods: 10,
198 num_threads: 4,
199 min_worker_steps: 100,
200 };
201 let mut rng_env = Prng::seed_from_u64(0);
202 let mut rng_actor = Prng::seed_from_u64(1);
203
204 let env = DeterministicBandit::from_values(vec![0.0, 1.0]);
205 let mut agent = TabularQLearningAgentConfig::default()
206 .build_agent(&env, &mut rng_actor)
207 .unwrap();
208
209 train_parallel(
210 &mut agent,
211 &env,
212 &config,
213 &mut rng_env,
214 &mut rng_actor,
215 &mut (),
216 );
217
218 testing::eval_deterministic_bandit(agent.actor(ActorMode::Evaluation), &env, 0.9);
219 }
220}