border_tensorboard/
lib.rs1use anyhow::Result;
6use border_core::{
7 record::{Record, RecordValue, Recorder},
8 Env, ReplayBufferBase,
9};
10use std::{
11 marker::PhantomData,
12 path::{Path, PathBuf},
13};
14use tensorboard_rs::summary_writer::SummaryWriter;
15
16pub struct TensorboardRecorder<E, R>
18where
19 E: Env,
20 R: ReplayBufferBase,
21{
22 model_dir: PathBuf,
23 writer: SummaryWriter,
24 step_key: String,
25 latest_record: Option<Record>,
26 ignore_unsupported_value: bool,
27 phantom: PhantomData<(E, R)>,
28}
29
30impl<E, R> TensorboardRecorder<E, R>
31where
32 E: Env,
33 R: ReplayBufferBase,
34{
35 pub fn new(
41 log_dir: impl AsRef<Path>,
42 model_dir: impl AsRef<Path>,
43 check_unsupported_value: bool,
44 ) -> Self {
45 Self {
46 model_dir: model_dir.as_ref().to_path_buf(),
47 writer: SummaryWriter::new(log_dir),
48 step_key: "opt_steps".to_string(),
49 ignore_unsupported_value: !check_unsupported_value,
50 latest_record: None,
51 phantom: PhantomData,
52 }
53 }
54}
55
56impl<E, R> Recorder<E, R> for TensorboardRecorder<E, R>
57where
58 E: Env,
59 R: ReplayBufferBase,
60{
61 fn write(&mut self, record: Record) {
66 let step = match record.get(&self.step_key).unwrap() {
68 RecordValue::Scalar(v) => *v as usize,
69 _ => {
70 panic!()
71 }
72 };
73
74 for (k, v) in record.iter() {
75 if *k != self.step_key {
76 match v {
77 RecordValue::Scalar(v) => self.writer.add_scalar(k, *v as f32, step),
78 RecordValue::DateTime(_) => {} RecordValue::Array2(data, shape) => {
80 let shape = [3, shape[0], shape[1]];
81 let min = data.iter().fold(f32::MAX, |m, v| v.min(m));
82 let scale = data.iter().fold(-f32::MAX, |m, v| v.max(m)) - min;
83 let mut data = data
84 .iter()
85 .map(|&e| ((e - min) / scale * 255f32) as u8)
86 .collect::<Vec<_>>();
87 let data_ = data.clone();
88 data.extend(data_.iter());
89 data.extend(data_.iter());
90 self.writer.add_image(k, data.as_slice(), &shape, step)
91 }
92 _ => {
93 if !self.ignore_unsupported_value {
94 panic!("Unsupported value: {:?}", (k, v));
95 }
96 }
97 };
98 }
99 }
100 }
101
102 fn store(&mut self, record: Record) {
103 self.latest_record = Some(record);
104 }
105
106 fn flush(&mut self, step: i64) {
107 if self.latest_record.is_some() {
108 let mut record = self.latest_record.take().unwrap();
109 record.insert("opt_steps", RecordValue::Scalar(step as _));
110 self.write(record);
111 }
112 }
113
114 fn save_model(&self, base: &Path, agent: &Box<dyn border_core::Agent<E, R>>) -> Result<()> {
116 let path = self.model_dir.join(base);
117 let _ = agent.save_params(&path)?;
118 Ok(())
119 }
120
121 fn load_model(&self, base: &Path, agent: &mut Box<dyn border_core::Agent<E, R>>) -> Result<()> {
123 let path = self.model_dir.join(base);
124 agent.load_params(&path)
125 }
126}