relearn/simulation/
train.rs

1use super::{OnlineStepsSummary, Simulation, Steps, StepsSummary};
2use crate::agents::{buffers::HistoryDataBound, ActorMode, Agent, BatchUpdate, WriteExperience};
3use crate::envs::{EnvStructure, Environment, StructuredEnvironment};
4use crate::feedback::{Feedback, Summary};
5use crate::logging::{Loggable, StatsLogger};
6use crate::spaces::{LogElementSpace, Space};
7use crate::Prng;
8use log::warn;
9use rand::SeedableRng;
10use serde::{Deserialize, Serialize};
11use std::iter;
12use std::time::Instant;
13
14/// Train a batch learning agent in this thread.
15pub fn train_serial<T, E>(
16    agent: &mut T,
17    environment: &E,
18    num_periods: usize,
19    rng_env: &mut Prng,
20    rng_agent: &mut Prng,
21    logger: &mut dyn StatsLogger,
22) where
23    T: Agent<E::Observation, E::Action>
24        + BatchUpdate<E::Observation, E::Action, Feedback = E::Feedback>
25        + ?Sized,
26    E: StructuredEnvironment + ?Sized,
27    E::ObservationSpace: LogElementSpace,
28    E::ActionSpace: LogElementSpace,
29    E::Feedback: Feedback,
30{
31    let mut buffer = agent.buffer();
32    for _ in 0..num_periods {
33        let update_size = agent.min_update_size();
34        buffer
35            .write_experience(
36                update_size
37                    .take(Steps::new(
38                        environment,
39                        agent.actor(ActorMode::Training),
40                        &mut *rng_env,
41                        &mut *rng_agent,
42                        &mut *logger,
43                    ))
44                    .log(),
45            )
46            .unwrap_or_else(|err| warn!("error filling buffer: {}", err));
47        agent.batch_update(iter::once(&mut buffer), logger);
48    }
49}
50
51/// Configuration for [`train_parallel`].
52#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
53pub struct TrainParallelConfig {
54    /// Number of data-collection-batch-update loops.
55    pub num_periods: usize,
56    /// Number of simulation threads.
57    pub num_threads: usize,
58    /// Minimum step capacity of each worker buffer.
59    pub min_worker_steps: usize,
60}
61
62/// Train a batch learning agent in parallel across several threads.
63///
64/// The logger is used by the main thread for agent updates
65/// as well as by one of the worker threads for action and step logs.
66// False positive in the ::StepSummary:Send bound
67#[allow(clippy::trait_duplication_in_bounds)]
68pub fn train_parallel<T, E>(
69    agent: &mut T,
70    environment: &E,
71    config: &TrainParallelConfig,
72    rng_env: &mut Prng,
73    rng_agent: &mut Prng,
74    logger: &mut dyn StatsLogger,
75) where
76    // TODO: Why does the simpler bound work for train_serial but not here?
77    E: EnvStructure
78        + Environment<
79            Observation = <E::ObservationSpace as Space>::Element,
80            Action = <E::ActionSpace as Space>::Element,
81            Feedback = <E::FeedbackSpace as Space>::Element,
82        > + Sync
83        + ?Sized,
84    T: Agent<<E::ObservationSpace as Space>::Element, <E::ActionSpace as Space>::Element>
85        + BatchUpdate<
86            <E::ObservationSpace as Space>::Element,
87            <E::ActionSpace as Space>::Element,
88            Feedback = <E::FeedbackSpace as Space>::Element,
89        > + ?Sized,
90    T::Actor: Send,
91    T::HistoryBuffer: Send,
92    E::ObservationSpace: LogElementSpace,
93    E::ActionSpace: LogElementSpace,
94    <E::FeedbackSpace as Space>::Element: Feedback,
95    <<E::FeedbackSpace as Space>::Element as Feedback>::StepSummary: Send,
96    <<E::FeedbackSpace as Space>::Element as Feedback>::EpisodeSummary: Send,
97{
98    let mut buffers: Vec<_> = (0..config.num_threads).map(|_| agent.buffer()).collect();
99    let mut thread_rngs: Vec<_> = (0..config.num_threads)
100        .map(|_| {
101            (
102                Prng::from_rng(&mut *rng_env).expect("Prng should be infallible"),
103                Prng::from_rng(&mut *rng_agent).expect("Prng should be infallible"),
104            )
105        })
106        .collect();
107
108    for _ in 0..config.num_periods {
109        let collect_start = Instant::now();
110
111        let worker_update_size =
112            agent
113                .min_update_size()
114                .divide(config.num_threads)
115                .max(HistoryDataBound {
116                    min_steps: config.min_worker_steps,
117                    slack_steps: 0,
118                });
119
120        // Send the logger to the first thread
121        let mut worker0_logger = logger.with_scope("worker0");
122        let mut send_logger = Some(&mut worker0_logger as &mut dyn StatsLogger);
123
124        let summary = crossbeam::scope(|scope| {
125            let mut threads = Vec::new();
126
127            for (buffer, rngs) in buffers.iter_mut().zip(&mut thread_rngs) {
128                let actor = agent.actor(ActorMode::Training);
129                let thread_logger = send_logger.take();
130                threads.push(scope.spawn(move |_scope| {
131                    let mut summary = OnlineStepsSummary::default();
132                    buffer
133                        .write_experience(
134                            worker_update_size
135                                .take(Steps::new(
136                                    environment,
137                                    actor,
138                                    &mut rngs.0,
139                                    &mut rngs.1,
140                                    thread_logger.unwrap_or(&mut ()),
141                                ))
142                                .log()
143                                .map(|step| {
144                                    summary.push(&step);
145                                    step
146                                }),
147                        )
148                        .unwrap_or_else(|err| warn!("error filling buffer: {}", err));
149                    StepsSummary::from(summary)
150                }));
151            }
152
153            threads
154                .into_iter()
155                .map(|t| t.join().unwrap())
156                .sum::<StepsSummary<_>>()
157        })
158        .unwrap();
159
160        let mut sim_logger = logger.with_scope("sim").group();
161        let mut episode_logger = (&mut sim_logger).with_scope("ep");
162        let num_episodes = summary.episode_length.count();
163        if num_episodes > 0 {
164            summary
165                .episode_feedback
166                .log("fbk", &mut episode_logger)
167                .unwrap();
168            episode_logger.log_scalar("length_mean", summary.episode_length.mean().unwrap());
169            episode_logger.log_scalar("length_stddev", summary.episode_length.stddev().unwrap());
170        }
171        episode_logger.log_counter_increment("count", num_episodes);
172
173        let mut step_logger = (&mut sim_logger).with_scope("step");
174        summary.step_feedback.log("fbk", &mut step_logger).unwrap();
175        step_logger.log_counter_increment("count", summary.step_feedback.size());
176        let update_start = Instant::now();
177        sim_logger.log_duration("time", update_start - collect_start);
178        drop(sim_logger);
179
180        agent.batch_update(&mut buffers, &mut *logger);
181
182        let mut agent_logger = logger.with_scope("agent_update").group();
183        agent_logger.log_duration("time", update_start.elapsed());
184        agent_logger.log_counter_increment("count", 1);
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use crate::agents::{testing, BuildAgent, TabularQLearningAgentConfig};
192    use crate::envs::DeterministicBandit;
193
194    #[test]
195    fn train_parallel_tabular_q_bandit() {
196        let config = TrainParallelConfig {
197            num_periods: 10,
198            num_threads: 4,
199            min_worker_steps: 100,
200        };
201        let mut rng_env = Prng::seed_from_u64(0);
202        let mut rng_actor = Prng::seed_from_u64(1);
203
204        let env = DeterministicBandit::from_values(vec![0.0, 1.0]);
205        let mut agent = TabularQLearningAgentConfig::default()
206            .build_agent(&env, &mut rng_actor)
207            .unwrap();
208
209        train_parallel(
210            &mut agent,
211            &env,
212            &config,
213            &mut rng_env,
214            &mut rng_actor,
215            &mut (),
216        );
217
218        testing::eval_deterministic_bandit(agent.actor(ActorMode::Evaluation), &env, 0.9);
219    }
220}