1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
//! Types for recording various values obtained during training and evaluation.
//!
//! [Record] is a [HashMap], where its key and values represents various values obtained during training and
//! evaluation. A record may contains multiple types of values.
//!
//! ```no_run
//! use border_core::record::{Record, RecordValue};
//!
//! // following values are obtained with some process in reality
//! let step = 1;
//! let obs = vec![1f32, 2.0, 3.0, 4.0, 5.0];
//! let reward = -1f32;
//!
//! let mut record = Record::empty();
//! record.insert("Step", RecordValue::Scalar(step as f32));
//! record.insert("Reward", RecordValue::Scalar(reward));
//! record.insert("Obs", RecordValue::Array1(obs));
//! ```
//!
//! A typical usecase is to record internal values obtained in training processes.
//! [Trainer::train](crate::Trainer::train), which executes a training loop, writes a record
//! in a [Recorder] given as an input argument.
//!
use chrono::prelude::{DateTime, Local};
use std::{
collections::{
hash_map::{IntoIter, Iter, Keys},
HashMap,
},
convert::Into,
iter::IntoIterator,
};
use crate::error::LrrError;
#[derive(Debug, Clone)]
/// Represents possible types of values in a [`Record`].
pub enum RecordValue {
/// Represents a scalar, e.g., optimization steps and loss value.
Scalar(f32),
/// Represents a datetime.
DateTime(DateTime<Local>),
/// A 1-dimensional array
Array1(Vec<f32>),
/// A 2-dimensional array
Array2(Vec<f32>, [usize; 2]),
/// A 3-dimensional array
Array3(Vec<f32>, [usize; 3]),
/// String
String(String),
}
#[derive(Debug)]
/// Represents a record.
pub struct Record(HashMap<String, RecordValue>);
impl Record {
/// Construct empty record.
pub fn empty() -> Self {
Self { 0: HashMap::new() }
}
/// Create `Record` from slice of `(Into<String>, RecordValue)`.
pub fn from_slice<K: Into<String> + Clone>(s: &[(K, RecordValue)]) -> Self {
Self(
s.iter()
.map(|(k, v)| (k.clone().into(), v.clone()))
.collect(),
)
}
/// Get keys.
pub fn keys(&self) -> Keys<String, RecordValue> {
self.0.keys()
}
/// Insert a key-value pair into the record.
pub fn insert(&mut self, k: impl Into<String>, v: RecordValue) {
self.0.insert(k.into(), v);
}
/// Return an iterator over key-value pairs in the record.
pub fn iter(&self) -> Iter<'_, String, RecordValue> {
self.0.iter()
}
/// Return an iterator over key-value pairs in the record.
pub fn into_iter_in_record(self) -> IntoIter<String, RecordValue> {
self.0.into_iter()
}
/// Get the value of the given key.
pub fn get(&self, k: &str) -> Option<&RecordValue> {
self.0.get(k)
}
/// Merge records.
pub fn merge(self, record: Record) -> Self {
Record(self.0.into_iter().chain(record.0).collect())
}
/// Get scalar value.
///
/// * `key` - The key of an entry in the record.
pub fn get_scalar(&self, k: &str) -> Result<f32, LrrError> {
if let Some(v) = self.0.get(k) {
match v {
RecordValue::Scalar(v) => Ok(*v as _),
_ => Err(LrrError::RecordValueTypeError("Scalar".to_string())),
}
} else {
Err(LrrError::RecordKeyError(k.to_string()))
}
}
/// Get Array1 value.
pub fn get_array1(&self, k: &str) -> Result<Vec<f32>, LrrError> {
if let Some(v) = self.0.get(k) {
match v {
RecordValue::Array1(v) => Ok(v.clone()),
_ => Err(LrrError::RecordValueTypeError("Array1".to_string())),
}
} else {
Err(LrrError::RecordKeyError(k.to_string()))
}
}
/// Get Array2 value.
pub fn get_array2(&self, k: &str) -> Result<(Vec<f32>, [usize; 2]), LrrError> {
if let Some(v) = self.0.get(k) {
match v {
RecordValue::Array2(v, s) => Ok((v.clone(), s.clone())),
_ => Err(LrrError::RecordValueTypeError("Array2".to_string())),
}
} else {
Err(LrrError::RecordKeyError(k.to_string()))
}
}
/// Get Array3 value.
pub fn get_array3(&self, k: &str) -> Result<(Vec<f32>, [usize; 3]), LrrError> {
if let Some(v) = self.0.get(k) {
match v {
RecordValue::Array3(v, s) => Ok((v.clone(), s.clone())),
_ => Err(LrrError::RecordValueTypeError("Array3".to_string())),
}
} else {
Err(LrrError::RecordKeyError(k.to_string()))
}
}
/// Get String value.
pub fn get_string(&self, k: &str) -> Result<String, LrrError> {
if let Some(v) = self.0.get(k) {
match v {
RecordValue::String(s) => Ok(s.clone()),
_ => Err(LrrError::RecordValueTypeError("String".to_string())),
}
} else {
Err(LrrError::RecordKeyError(k.to_string()))
}
}
}
/// Process records provided with [`Recorder::write`]
pub trait Recorder {
/// Write a record to the [`Recorder`].
fn write(&mut self, record: Record);
}
/// A recorder that ignores any record. This struct is used just for debugging.
pub struct NullRecorder {}
impl NullRecorder {}
impl Recorder for NullRecorder {
/// Discard the given record.
fn write(&mut self, _record: Record) {}
}
/// Buffered recorder.
///
/// This is used for recording sequences of observation and action
/// during evaluation runs in [`crate::util::eval_with_recorder`].
#[derive(Default)]
pub struct BufferedRecorder(Vec<Record>);
impl BufferedRecorder {
/// Construct the recorder.
pub fn new() -> Self {
Self(Vec::default())
}
/// Returns an iterator over the records.
pub fn iter(&self) -> std::slice::Iter<Record> {
self.0.iter()
}
}
impl Recorder for BufferedRecorder {
/// Write a [`Record`] to the buffer.
///
/// TODO: Consider if it is worth making the method public.
fn write(&mut self, record: Record) {
self.0.push(record);
}
}