border_tensorboard/
lib.rs

1//! A logger for border-core crate.
2//!
3//! [`TensorboardRecorder`] saves TFRecord files and model parameters to a directory
4//! in the local file system during training.
5use anyhow::Result;
6use border_core::{
7    record::{Record, RecordValue, Recorder},
8    Env, ReplayBufferBase,
9};
10use std::{
11    marker::PhantomData,
12    path::{Path, PathBuf},
13};
14use tensorboard_rs::summary_writer::SummaryWriter;
15
16/// Write records to TFRecord.
17pub struct TensorboardRecorder<E, R>
18where
19    E: Env,
20    R: ReplayBufferBase,
21{
22    model_dir: PathBuf,
23    writer: SummaryWriter,
24    step_key: String,
25    latest_record: Option<Record>,
26    ignore_unsupported_value: bool,
27    phantom: PhantomData<(E, R)>,
28}
29
30impl<E, R> TensorboardRecorder<E, R>
31where
32    E: Env,
33    R: ReplayBufferBase,
34{
35    /// Construct a [`TensorboardRecorder`].
36    ///
37    /// * `log_dir` - Directory in which TFRecords will be stored.
38    /// * `model_dir` - Directory in which the trained model will be saved.
39    /// * `check_unsupported_value` - If true, check unsupported record value in the write() method.
40    pub fn new(
41        log_dir: impl AsRef<Path>,
42        model_dir: impl AsRef<Path>,
43        check_unsupported_value: bool,
44    ) -> Self {
45        Self {
46            model_dir: model_dir.as_ref().to_path_buf(),
47            writer: SummaryWriter::new(log_dir),
48            step_key: "opt_steps".to_string(),
49            ignore_unsupported_value: !check_unsupported_value,
50            latest_record: None,
51            phantom: PhantomData,
52        }
53    }
54}
55
56impl<E, R> Recorder<E, R> for TensorboardRecorder<E, R>
57where
58    E: Env,
59    R: ReplayBufferBase,
60{
61    /// Writes a given [`Record`] into a TFRecord.
62    ///
63    /// This method handles [RecordValue::Scalar] and [RecordValue::DateTime] in the [`Record`].
64    /// Other variants will be ignored.
65    fn write(&mut self, record: Record) {
66        // TODO: handle error
67        let step = match record.get(&self.step_key).unwrap() {
68            RecordValue::Scalar(v) => *v as usize,
69            _ => {
70                panic!()
71            }
72        };
73
74        for (k, v) in record.iter() {
75            if *k != self.step_key {
76                match v {
77                    RecordValue::Scalar(v) => self.writer.add_scalar(k, *v as f32, step),
78                    RecordValue::DateTime(_) => {} // discard value
79                    RecordValue::Array2(data, shape) => {
80                        let shape = [3, shape[0], shape[1]];
81                        let min = data.iter().fold(f32::MAX, |m, v| v.min(m));
82                        let scale = data.iter().fold(-f32::MAX, |m, v| v.max(m)) - min;
83                        let mut data = data
84                            .iter()
85                            .map(|&e| ((e - min) / scale * 255f32) as u8)
86                            .collect::<Vec<_>>();
87                        let data_ = data.clone();
88                        data.extend(data_.iter());
89                        data.extend(data_.iter());
90                        self.writer.add_image(k, data.as_slice(), &shape, step)
91                    }
92                    _ => {
93                        if !self.ignore_unsupported_value {
94                            panic!("Unsupported value: {:?}", (k, v));
95                        }
96                    }
97                };
98            }
99        }
100    }
101
102    fn store(&mut self, record: Record) {
103        self.latest_record = Some(record);
104    }
105
106    fn flush(&mut self, step: i64) {
107        if self.latest_record.is_some() {
108            let mut record = self.latest_record.take().unwrap();
109            record.insert("opt_steps", RecordValue::Scalar(step as _));
110            self.write(record);
111        }
112    }
113
114    /// Saves the model parameters in the local file system.
115    fn save_model(&self, base: &Path, agent: &Box<dyn border_core::Agent<E, R>>) -> Result<()> {
116        let path = self.model_dir.join(base);
117        let _ = agent.save_params(&path)?;
118        Ok(())
119    }
120
121    /// Loads the model parameters from the local file system.
122    fn load_model(&self, base: &Path, agent: &mut Box<dyn border_core::Agent<E, R>>) -> Result<()> {
123        let path = self.model_dir.join(base);
124        agent.load_params(&path)
125    }
126}