use protobuf_codec::field::num::{F1, F2, F3, F4, F5, F6};
use protobuf_codec::field::{FieldDecoder, Fields, MaybeDefault, MessageFieldDecoder, Repeated};
use protobuf_codec::message::MessageDecoder;
use protobuf_codec::scalar::{DoubleDecoder, Int32Decoder, StringDecoder};
#[derive(Debug, Clone, PartialEq)]
pub struct ModelMetrics {
pub evaluation_data_list: Vec<EvaluationData>,
pub trainable_parameters: i32,
pub total_time: f64,
}
#[derive(Debug, Clone, PartialEq)]
pub struct EvaluationData {
pub current_epoch: f64,
pub training_time: f64,
pub train_accuracy: f64,
pub validation_accuracy: f64,
pub test_accuracy: f64,
pub checkpoint_path: String,
}
macro_rules! impl_message_decode {
($decoder:ty, $item:ty, $map:expr) => {
impl ::bytecodec::Decode for $decoder {
type Item = $item;
fn decode(&mut self, buf: &[u8], eos: ::bytecodec::Eos) -> ::bytecodec::Result<usize> {
track!(self.inner.decode(buf, eos))
}
fn finish_decoding(&mut self) -> ::bytecodec::Result<Self::Item> {
let item = track!(self.inner.finish_decoding())?;
$map(item)
}
fn is_idle(&self) -> bool {
self.inner.is_idle()
}
fn requiring_bytes(&self) -> ::bytecodec::ByteCount {
self.inner.requiring_bytes()
}
}
impl ::protobuf_codec::message::MessageDecode for $decoder {}
};
}
#[derive(Debug, Default)]
#[allow(clippy::type_complexity)]
pub struct ModelMetricsDecoder {
inner: MessageDecoder<
Fields<(
Repeated<MessageFieldDecoder<F1, EvaluationDataDecoder>, Vec<EvaluationData>>,
MaybeDefault<FieldDecoder<F2, Int32Decoder>>,
MaybeDefault<FieldDecoder<F3, DoubleDecoder>>,
)>,
>,
}
impl_message_decode!(ModelMetricsDecoder, ModelMetrics, |t: (_, _, _)| {
Ok(ModelMetrics {
evaluation_data_list: t.0,
trainable_parameters: t.1,
total_time: t.2,
})
});
#[derive(Debug, Default)]
#[allow(clippy::type_complexity)]
pub struct EvaluationDataDecoder {
inner: MessageDecoder<
Fields<(
MaybeDefault<FieldDecoder<F1, DoubleDecoder>>,
MaybeDefault<FieldDecoder<F2, DoubleDecoder>>,
MaybeDefault<FieldDecoder<F3, DoubleDecoder>>,
MaybeDefault<FieldDecoder<F4, DoubleDecoder>>,
MaybeDefault<FieldDecoder<F5, DoubleDecoder>>,
MaybeDefault<FieldDecoder<F6, StringDecoder>>,
)>,
>,
}
impl_message_decode!(EvaluationDataDecoder, EvaluationData, |t: (
_,
_,
_,
_,
_,
_
)| Ok(
EvaluationData {
current_epoch: t.0,
training_time: t.1,
train_accuracy: t.2,
validation_accuracy: t.3,
test_accuracy: t.4,
checkpoint_path: t.5
}
));
#[cfg(test)]
mod tests {
use super::*;
use base64;
use bytecodec::DecodeExt;
use trackable::result::TopLevelResult;
#[test]
fn decoder_works() -> TopLevelResult {
let base64_bytes = "Ci0JAAAAAAAAAAARAAAAAAAAAAAZAAAAoKqquT8hAAAAYJAGuT8pAAAAAGmQuT8KLQkAAAAAAAAAQBEAAADAdJNGQBkAAABgkAa6PyEAAAAg3/K6PykAAACAy7e6PwotCQAAAAAAABBAEQAAACDbeVZAGQAAAEAaZNY/IQAAAGBVVdY/KQAAAODybdY/EIqYigQZAABAhznbeUA=";
let bytes = track_any_err!(base64::decode(&base64_bytes))?;
let metrics = track!(ModelMetricsDecoder::default().decode_from_bytes(&bytes))?;
let expected = ModelMetrics {
evaluation_data_list: vec![
EvaluationData {
current_epoch: 0.0,
training_time: 0.0,
train_accuracy: 0.1002604141831398,
validation_accuracy: 0.09775640815496445,
test_accuracy: 0.09985977411270142,
checkpoint_path: "".to_owned(),
},
EvaluationData {
current_epoch: 2.0,
training_time: 45.152000427246094,
train_accuracy: 0.10166265815496445,
validation_accuracy: 0.10526842623949051,
test_accuracy: 0.10436698794364929,
checkpoint_path: "".to_owned(),
},
EvaluationData {
current_epoch: 4.0,
training_time: 89.90399932861328,
train_accuracy: 0.3498597741127014,
validation_accuracy: 0.3489583432674408,
test_accuracy: 0.3504607379436493,
checkpoint_path: "".to_owned(),
},
],
trainable_parameters: 8555530,
total_time: 413.7015450000763,
};
assert_eq!(metrics, expected);
Ok(())
}
}