1use crate::checkpoint::{
2 AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
3 KeepLastNCheckpoints, MetricCheckpointingStrategy,
4};
5use crate::learner::base::Interrupter;
6use crate::logger::{FileMetricLogger, MetricLogger};
7use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
8use crate::metric::{Adaptor, EpisodeLengthMetric, Metric, Numeric};
9use crate::renderer::{MetricsRenderer, default_renderer};
10use crate::{
11 ApplicationLoggerInstaller, AsyncProcessorTraining, FileApplicationLoggerInstaller, ItemLazy,
12 LearnerSummaryConfig, OffPolicyConfig, OffPolicyStrategy, RLAgentRecord, RLCheckpointer,
13 RLComponents, RLComponentsMarker, RLComponentsTypes, RLEventProcessor, RLMetrics,
14 RLPolicyRecord, RLStrategy,
15};
16use crate::{EpisodeSummary, RLStrategies};
17use burn_core::record::FileRecorder;
18use burn_core::tensor::backend::AutodiffBackend;
19use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner, SliceAccess};
20use std::collections::BTreeSet;
21use std::path::{Path, PathBuf};
22use std::sync::Arc;
23
24pub struct RLTraining<RLC: RLComponentsTypes> {
26 #[allow(clippy::type_complexity)]
28 checkpointers: Option<(
29 AsyncCheckpointer<RLPolicyRecord<RLC>, RLC::Backend>,
30 AsyncCheckpointer<RLAgentRecord<RLC>, RLC::Backend>,
31 )>,
32 num_steps: usize,
33 checkpoint: Option<usize>,
34 directory: PathBuf,
35 grad_accumulation: Option<usize>,
36 renderer: Option<Box<dyn MetricsRenderer + 'static>>,
37 metrics: RLMetrics<RLC::TrainingOutput, RLC::ActionContext>,
38 event_store: LogEventStore,
39 interrupter: Interrupter,
40 tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
41 checkpointer_strategy: Box<dyn CheckpointingStrategy>,
42 learning_strategy: RLStrategies<RLC>,
43 summary_metrics: BTreeSet<String>,
45 summary: bool,
46 env_initializer: RLC::EnvInit,
47}
48
49impl<B, E, EI, A> RLTraining<RLComponentsMarker<B, E, EI, A>>
50where
51 B: AutodiffBackend,
52 E: Environment + 'static,
53 EI: EnvironmentInit<E> + Send + 'static,
54 A: PolicyLearner<B> + Send + 'static,
55 A::TrainContext: ItemLazy + Clone + Send,
56 A::InnerPolicy: Policy<B> + Send,
57 <A::InnerPolicy as Policy<B>>::Observation: Batchable + Clone + Send,
58 <A::InnerPolicy as Policy<B>>::ActionDistribution: Batchable + Clone + Send,
59 <A::InnerPolicy as Policy<B>>::Action: Batchable + Clone + Send,
60 <A::InnerPolicy as Policy<B>>::ActionContext: ItemLazy + Clone + Send + 'static,
61 <A::InnerPolicy as Policy<B>>::PolicyState: Clone + Send,
62 E::State: Into<<A::InnerPolicy as Policy<B>>::Observation> + Clone + Send + 'static,
63 E::Action: From<<A::InnerPolicy as Policy<B>>::Action>
64 + Into<<A::InnerPolicy as Policy<B>>::Action>
65 + Clone
66 + Send
67 + 'static,
68{
69 pub fn new(directory: impl AsRef<Path>, env_initializer: EI) -> Self {
76 let directory = directory.as_ref().to_path_buf();
77 let experiment_log_file = directory.join("experiment.log");
78 Self {
79 num_steps: 1,
80 checkpoint: None,
81 checkpointers: None,
82 directory,
83 grad_accumulation: None,
84 metrics: RLMetrics::default(),
85 event_store: LogEventStore::default(),
86 renderer: None,
87 interrupter: Interrupter::new(),
88 tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
89 experiment_log_file,
90 ))),
91 checkpointer_strategy: Box::new(
92 ComposedCheckpointingStrategy::builder()
93 .add(KeepLastNCheckpoints::new(2))
94 .add(MetricCheckpointingStrategy::new(
95 &EpisodeLengthMetric::new(), Aggregate::Mean,
97 Direction::Lowest,
98 Split::Valid,
99 ))
100 .build(),
101 ),
102 learning_strategy: RLStrategies::OffPolicyStrategy(OffPolicyConfig::new()),
103 summary_metrics: BTreeSet::new(),
104 summary: false,
105 env_initializer,
106 }
107 }
108}
109
110impl<RLC: RLComponentsTypes + 'static> RLTraining<RLC> {
111 pub fn with_learning_strategy(mut self, learning_strategy: RLStrategies<RLC>) -> Self {
117 self.learning_strategy = learning_strategy;
118 self
119 }
120
121 pub fn with_metric_logger<ML>(mut self, logger: ML) -> Self
127 where
128 ML: MetricLogger + 'static,
129 {
130 self.event_store.register_logger(logger);
131 self
132 }
133
134 pub fn with_checkpointing_strategy<CS: CheckpointingStrategy + 'static>(
136 mut self,
137 strategy: CS,
138 ) -> Self {
139 self.checkpointer_strategy = Box::new(strategy);
140 self
141 }
142
143 pub fn renderer<MR>(mut self, renderer: MR) -> Self
149 where
150 MR: MetricsRenderer + 'static,
151 {
152 self.renderer = Some(Box::new(renderer));
153 self
154 }
155
156 pub fn metrics_train<Me: TrainMetricRegistration<RLC>>(self, metrics: Me) -> Self {
158 metrics.register(self)
159 }
160
161 pub fn text_metrics_train<Me: TrainTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
163 metrics.register(self)
164 }
165
166 pub fn metrics_agent<Me: AgentMetricRegistration<RLC>>(self, metrics: Me) -> Self {
168 metrics.register(self)
169 }
170
171 pub fn text_metrics_agent<Me: AgentTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
173 metrics.register(self)
174 }
175
176 pub fn metrics_episode<Me: EpisodeMetricRegistration<RLC>>(self, metrics: Me) -> Self {
178 metrics.register(self)
179 }
180
181 pub fn text_metrics_episode<Me: EpisodeTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
183 metrics.register(self)
184 }
185
186 pub fn text_metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
188 where
189 <RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<Me::Input>,
190 {
191 self.metrics.register_text_metric_train(metric);
192 self
193 }
194
195 pub fn metric_train<Me>(mut self, metric: Me) -> Self
197 where
198 Me: Metric + Numeric + 'static,
199 <RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<Me::Input>,
200 {
201 self.summary_metrics.insert(metric.name().to_string());
202 self.metrics.register_metric_train(metric);
203 self
204 }
205
206 pub fn text_metric_agent<Me: Metric + 'static>(mut self, metric: Me) -> Self
208 where
209 <RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<Me::Input>,
210 {
211 self.metrics.register_text_metric_agent(metric.clone());
212 self.metrics.register_text_metric_agent_valid(metric);
213 self
214 }
215
216 pub fn metric_agent<Me>(mut self, metric: Me) -> Self
218 where
219 Me: Metric + Numeric + 'static,
220 <RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<Me::Input>,
221 {
222 self.summary_metrics.insert(metric.name().to_string());
223 self.metrics.register_agent_metric(metric.clone());
224 self.metrics.register_agent_metric_valid(metric);
225 self
226 }
227
228 pub fn text_metric_episode<Me: Metric + 'static>(mut self, metric: Me) -> Self
230 where
231 EpisodeSummary: Adaptor<Me::Input> + 'static,
232 {
233 self.metrics.register_text_metric_episode(metric.clone());
234 self.metrics.register_text_metric_episode_valid(metric);
235 self
236 }
237
238 pub fn metric_episode<Me>(mut self, metric: Me) -> Self
240 where
241 Me: Metric + Numeric + 'static,
242 EpisodeSummary: Adaptor<Me::Input> + 'static,
243 {
244 self.summary_metrics.insert(metric.name().to_string());
245 self.metrics.register_episode_metric(metric.clone());
246 self.metrics.register_episode_metric_valid(metric);
247 self
248 }
249
250 pub fn num_steps(mut self, num_steps: usize) -> Self {
252 self.num_steps = num_steps;
253 self
254 }
255
256 pub fn checkpoint(mut self, checkpoint: usize) -> Self {
258 self.checkpoint = Some(checkpoint);
259 self
260 }
261
262 pub fn interrupter(&self) -> Interrupter {
264 self.interrupter.clone()
265 }
266
267 pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
269 self.interrupter = interrupter;
270 self
271 }
272
273 pub fn with_application_logger(
277 mut self,
278 logger: Option<Box<dyn ApplicationLoggerInstaller>>,
279 ) -> Self {
280 self.tracing_logger = logger;
281 self
282 }
283
284 pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
287 where
288 FR: FileRecorder<RLC::Backend> + 'static,
289 FR: FileRecorder<<RLC::Backend as AutodiffBackend>::InnerBackend> + 'static,
290 {
291 let checkpoint_dir = self.directory.join("checkpoint");
292 let checkpointer_policy =
293 FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "policy");
294 let checkpointer_learning =
295 FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "learning-agent");
296
297 self.checkpointers = Some((
298 AsyncCheckpointer::new(checkpointer_policy),
299 AsyncCheckpointer::new(checkpointer_learning),
300 ));
301
302 self
303 }
304
305 pub fn summary(mut self) -> Self {
309 self.summary = true;
310 self
311 }
312
313 pub fn launch(mut self, learner_agent: RLC::LearningAgent) -> RLResult<RLC::Policy>
315 where
316 RLC::PolicyObs: SliceAccess<RLC::Backend>,
317 RLC::PolicyAction: SliceAccess<RLC::Backend>,
318 {
319 if self.tracing_logger.is_some()
320 && let Err(e) = self.tracing_logger.as_ref().unwrap().install()
321 {
322 log::warn!("Failed to install the experiment logger: {e}");
323 }
324 let renderer = self
325 .renderer
326 .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
327
328 if !self.event_store.has_loggers() {
329 self.event_store
330 .register_logger(FileMetricLogger::new(self.directory.clone()));
331 }
332
333 let event_store = Arc::new(EventStoreClient::new(self.event_store));
334 let event_processor = AsyncProcessorTraining::new(RLEventProcessor::new(
335 self.metrics,
336 renderer,
337 event_store.clone(),
338 ));
339
340 let checkpointer = self.checkpointers.map(|(policy, learning_agent)| {
341 RLCheckpointer::new(policy, learning_agent, self.checkpointer_strategy)
342 });
343
344 let summary = if self.summary {
345 Some(LearnerSummaryConfig {
346 directory: self.directory,
347 metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
348 })
349 } else {
350 None
351 };
352
353 let components = RLComponents::<RLC> {
354 checkpoint: self.checkpoint,
355 checkpointer,
356 interrupter: self.interrupter,
357 event_processor,
358 event_store,
359 num_steps: self.num_steps,
360 grad_accumulation: self.grad_accumulation,
361 summary,
362 };
363
364 match self.learning_strategy {
365 RLStrategies::OffPolicyStrategy(config) => {
366 let strategy = OffPolicyStrategy::new(config);
367 strategy.train(learner_agent, components, self.env_initializer)
368 }
369 RLStrategies::Custom(strategy) => {
370 strategy.train(learner_agent, components, self.env_initializer)
371 }
372 }
373 }
374}
375
376pub struct RLResult<P> {
378 pub policy: P,
380 pub renderer: Box<dyn MetricsRenderer>,
382}
383
384pub trait AgentMetricRegistration<RLC: RLComponentsTypes>: Sized {
386 fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
388}
389
390pub trait AgentTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
392 fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
394}
395
396pub trait TrainMetricRegistration<RLC: RLComponentsTypes>: Sized {
398 fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
400}
401
402pub trait TrainTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
404 fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
406}
407
408pub trait EpisodeMetricRegistration<RLC: RLComponentsTypes>: Sized {
410 fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
412}
413
414pub trait EpisodeTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
416 fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
418}
419
420macro_rules! gen_tuple {
421 ($($M:ident),*) => {
422 impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainTextMetricRegistration<RLC> for ($($M,)*)
423 where
424 $(<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
425 $($M: Metric + 'static,)*
426 {
427 #[allow(non_snake_case)]
428 fn register(
429 self,
430 builder: RLTraining<RLC>,
431 ) -> RLTraining<RLC> {
432 let ($($M,)*) = self;
433 $(let builder = builder.text_metric_train($M.clone());)*
434 builder
435 }
436 }
437
438 impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainMetricRegistration<RLC> for ($($M,)*)
439 where
440 $(<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
441 $($M: Metric + Numeric + 'static,)*
442 {
443 #[allow(non_snake_case)]
444 fn register(
445 self,
446 builder: RLTraining<RLC>,
447 ) -> RLTraining<RLC> {
448 let ($($M,)*) = self;
449 $(let builder = builder.metric_train($M.clone());)*
450 builder
451 }
452 }
453
454 impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentTextMetricRegistration<RLC> for ($($M,)*)
455 where
456 $(<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
457 $($M: Metric + 'static,)*
458 {
459 #[allow(non_snake_case)]
460 fn register(
461 self,
462 builder: RLTraining<RLC>,
463 ) -> RLTraining<RLC> {
464 let ($($M,)*) = self;
465 $(let builder = builder.text_metric_agent($M.clone());)*
466 builder
467 }
468 }
469
470 impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentMetricRegistration<RLC> for ($($M,)*)
471 where
472 $(<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
473 $($M: Metric + Numeric + 'static,)*
474 {
475 #[allow(non_snake_case)]
476 fn register(
477 self,
478 builder: RLTraining<RLC>,
479 ) -> RLTraining<RLC> {
480 let ($($M,)*) = self;
481 $(let builder = builder.metric_agent($M.clone());)*
482 builder
483 }
484 }
485
486 impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeTextMetricRegistration<RLC> for ($($M,)*)
487 where
488 $(EpisodeSummary: Adaptor<$M::Input> + 'static,)*
489 $($M: Metric + 'static,)*
490 {
491 #[allow(non_snake_case)]
492 fn register(
493 self,
494 builder: RLTraining<RLC>,
495 ) -> RLTraining<RLC> {
496 let ($($M,)*) = self;
497 $(let builder = builder.text_metric_episode($M.clone());)*
498 builder
499 }
500 }
501
502 impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeMetricRegistration<RLC> for ($($M,)*)
503 where
504 $(EpisodeSummary: Adaptor<$M::Input> + 'static,)*
505 $($M: Metric + Numeric + 'static,)*
506 {
507 #[allow(non_snake_case)]
508 fn register(
509 self,
510 builder: RLTraining<RLC>,
511 ) -> RLTraining<RLC> {
512 let ($($M,)*) = self;
513 $(let builder = builder.metric_episode($M.clone());)*
514 builder
515 }
516 }
517 };
518}
519
520gen_tuple!(M1);
521gen_tuple!(M1, M2);
522gen_tuple!(M1, M2, M3);
523gen_tuple!(M1, M2, M3, M4);
524gen_tuple!(M1, M2, M3, M4, M5);
525gen_tuple!(M1, M2, M3, M4, M5, M6);