1use crate::protobuf::EvaluationData;
2use crate::Result;
3use byteorder::{BigEndian, ByteOrder, ReadBytesExt, WriteBytesExt};
4use md5;
5use std::collections::BTreeMap;
6use std::fmt;
7use std::hash::{Hash, Hasher};
8use std::io::{Read, Write};
9use std::str::FromStr;
10use trackable::error::{Failed, Failure};
11
12#[derive(Clone)]
33pub struct ModelSpec {
34 ops: Vec<Op>,
35 adjacency: AdjacencyMatrix,
36 module_hash: u128,
37}
38impl ModelSpec {
39 pub fn new(mut ops: Vec<Op>, mut adjacency: AdjacencyMatrix) -> Result<Self> {
41 track_assert_eq!(ops.len(), adjacency.dimension(), Failed);
42 track_assert!(adjacency.dimension() >= 2, Failed);
43
44 Self::prune(&mut ops, &mut adjacency);
45 let module_hash = Self::module_hash(&ops, &adjacency);
46 Ok(Self {
47 ops,
48 adjacency,
49 module_hash,
50 })
51 }
52
53 pub(crate) fn with_module_hash(
54 mut ops: Vec<Op>,
55 mut adjacency: AdjacencyMatrix,
56 module_hash: u128,
57 ) -> Self {
58 Self::prune(&mut ops, &mut adjacency);
59 Self {
60 ops,
61 adjacency,
62 module_hash,
63 }
64 }
65
66 pub(crate) fn validate_module_hash(&self) -> Result<()> {
67 let expected_module_hash = Self::module_hash(&self.ops, &self.adjacency);
68 track_assert_eq!(self.module_hash, expected_module_hash, Failed);
69 Ok(())
70 }
71
72 pub fn ops(&self) -> &[Op] {
74 &self.ops
75 }
76
77 pub fn adjacency(&self) -> &AdjacencyMatrix {
79 &self.adjacency
80 }
81
82 fn prune(ops: &mut Vec<Op>, adjacency: &mut AdjacencyMatrix) {
83 let mut deleted = true;
84 while deleted {
85 deleted = false;
86
87 for row in 1..adjacency.dimension() - 1 {
88 let in_edges = adjacency.in_edges(row);
89 if in_edges == 0 {
90 deleted = true;
91 ops.remove(row);
92 adjacency.remove(row);
93 break;
94 }
95
96 let out_edges = adjacency.out_edges(row);
97 if out_edges == 0 {
98 deleted = true;
99 ops.remove(row);
100 adjacency.remove(row);
101 break;
102 }
103 }
104 }
105 }
106
107 fn module_hash(ops: &[Op], adjacency: &AdjacencyMatrix) -> u128 {
108 let dim = ops.len();
109
110 let mut hashes = Vec::with_capacity(dim);
111 for (row, op) in ops.iter().enumerate() {
112 let in_edges = adjacency.in_edges(row);
113 let out_edges = adjacency.out_edges(row);
114 let s = format!("({}, {}, {})", out_edges, in_edges, op.to_hash_index());
115 hashes.push(format!("{:032x}", md5::compute(s.as_bytes())));
116 }
117
118 for _ in 0..dim {
119 let mut new_hashes = Vec::with_capacity(dim);
120 for (v, h) in hashes.iter().enumerate() {
121 let mut in_neighbors = (0..dim)
122 .filter(|&w| adjacency.has_edge(w, v))
123 .map(|w| hashes[w].as_str())
124 .collect::<Vec<_>>();
125 let mut out_neighbors = (0..dim)
126 .filter(|&w| adjacency.has_edge(v, w))
127 .map(|w| hashes[w].as_str())
128 .collect::<Vec<_>>();
129 in_neighbors.sort();
130 out_neighbors.sort();
131
132 let s = format!("{}|{}|{}", in_neighbors.join(""), out_neighbors.join(""), h);
133 new_hashes.push(format!("{:032x}", md5::compute(s.as_bytes())));
134 }
135 hashes = new_hashes;
136 }
137
138 hashes.sort();
139 let hashes = hashes
140 .iter()
141 .map(|h| format!("'{}'", h))
142 .collect::<Vec<_>>();
143 let fingerprint = format!("[{}]", hashes.join(", "));
144 BigEndian::read_u128(&md5::compute(fingerprint.as_bytes()).0)
145 }
146
147 pub(crate) fn from_reader<R: Read>(mut reader: R) -> Result<Self> {
148 let len = track_any_err!(reader.read_u8())? as usize;
149 let mut ops = Vec::with_capacity(len);
150 for _ in 0..len {
151 let op = match track_any_err!(reader.read_u8())? {
152 0 => Op::Input,
153 1 => Op::Conv1x1,
154 2 => Op::Conv3x3,
155 3 => Op::MaxPool3x3,
156 4 => Op::Output,
157 n => track_panic!(Failed, "Unknown operation number: {}", n),
158 };
159 ops.push(op);
160 }
161
162 let adjacency = track!(AdjacencyMatrix::from_reader(&mut reader))?;
163 let module_hash = track_any_err!(reader.read_u128::<BigEndian>())?;
164
165 Ok(Self {
166 ops,
167 adjacency,
168 module_hash,
169 })
170 }
171
172 pub(crate) fn to_writer<W: Write>(&self, mut writer: W) -> Result<()> {
173 track_any_err!(writer.write_u8(self.ops.len() as u8))?;
174 for op in &self.ops {
175 track_any_err!(writer.write_u8(*op as u8))?;
176 }
177
178 track!(self.adjacency.to_writer(&mut writer))?;
179 track_any_err!(writer.write_u128::<BigEndian>(self.module_hash))?;
180
181 Ok(())
182 }
183}
184impl fmt::Debug for ModelSpec {
185 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
186 write!(
187 f,
188 "ModelSpec {{ ops: {:?}, adjacency: {:?}, .. }}",
189 self.ops, self.adjacency
190 )
191 }
192}
193impl PartialEq for ModelSpec {
194 fn eq(&self, other: &Self) -> bool {
195 self.module_hash == other.module_hash
196 }
197}
198impl Eq for ModelSpec {}
199impl Hash for ModelSpec {
200 fn hash<H: Hasher>(&self, h: &mut H) {
201 self.module_hash.hash(h);
202 }
203}
204
205#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
219pub enum Op {
220 Input,
222
223 Conv1x1,
225
226 Conv3x3,
228
229 MaxPool3x3,
231
232 Output,
234}
235impl Op {
236 fn to_hash_index(self) -> isize {
237 match self {
238 Op::Input => -1,
239 Op::Conv3x3 => 0,
240 Op::Conv1x1 => 1,
241 Op::MaxPool3x3 => 2,
242 Op::Output => -2,
243 }
244 }
245}
246impl FromStr for Op {
247 type Err = Failure;
248
249 fn from_str(op: &str) -> Result<Self> {
250 Ok(match op {
251 "input" => Op::Input,
252 "conv1x1-bn-relu" => Op::Conv1x1,
253 "conv3x3-bn-relu" => Op::Conv3x3,
254 "maxpool3x3" => Op::MaxPool3x3,
255 "output" => Op::Output,
256 _ => track_panic!(Failed, "Unknown operator: {:?}", op),
257 })
258 }
259}
260
261#[derive(Clone, PartialEq, Eq, Hash)]
287pub struct AdjacencyMatrix {
288 dim: u8,
289 triangle: u32,
290}
291impl AdjacencyMatrix {
292 pub fn new(matrix: Vec<Vec<bool>>) -> Result<Self> {
294 let dim = matrix.len();
295 track_assert_ne!(dim, 0, Failed);
296 track_assert!(dim <= 7, Failed; dim);
297
298 let mut triangle = 0;
299 let mut offset = 0;
300 for (i, row) in matrix.into_iter().enumerate() {
301 track_assert_eq!(row.len(), dim, Failed);
302
303 for (j, adjacent) in row.into_iter().enumerate() {
304 if j <= i {
305 track_assert!(!adjacent, Failed; i, j);
306 continue;
307 }
308
309 offset += 1;
310 if !adjacent {
311 continue;
312 }
313
314 triangle |= 1 << (offset - 1);
315 }
316 }
317
318 let dim = dim as u8;
319 Ok(Self { dim, triangle })
320 }
321
322 pub fn dimension(&self) -> usize {
324 usize::from(self.dim)
325 }
326
327 fn from_reader<R: Read>(mut reader: R) -> Result<Self> {
328 let dim = track_any_err!(reader.read_u8())?;
329 let triangle = track_any_err!(reader.read_u32::<BigEndian>())?;
330 Ok(Self { dim, triangle })
331 }
332
333 fn to_writer<W: Write>(&self, mut writer: W) -> Result<()> {
334 track_any_err!(writer.write_u8(self.dim))?;
335 track_any_err!(writer.write_u32::<BigEndian>(self.triangle))?;
336 Ok(())
337 }
338
339 fn remove(&mut self, row: usize) {
340 let mut triangle = 0;
341 let mut offset = 0;
342 for i in (0..self.dimension()).filter(|&i| i != row) {
343 for j in (i + 1..self.dimension()).filter(|&j| j != row) {
344 offset += 1;
345 if !self.has_edge(i, j) {
346 continue;
347 }
348
349 triangle |= 1 << (offset - 1);
350 }
351 }
352
353 self.dim -= 1;
354 self.triangle = triangle;
355 }
356
357 fn has_edge(&self, row: usize, column: usize) -> bool {
358 if column <= row {
359 return false;
360 }
361
362 let offset = match self.dim {
363 7 => &[0, 6, 11, 15, 18, 20, 21][..],
364 6 => &[0, 5, 9, 12, 14, 15][..],
365 5 => &[0, 4, 7, 9, 10][..],
366 4 => &[0, 3, 5, 6][..],
367 3 => &[0, 2, 1][..],
368 2 => &[0, 1][..],
369 1 => &[0][..],
370 _ => {
371 unreachable!("dim={}", self.dim);
372 }
373 };
374 let i = offset[row] + column - row - 1;
375 (self.triangle & (1 << i)) != 0
376 }
377
378 fn in_edges(&self, row: usize) -> usize {
379 (0..row)
380 .filter(|&column| self.has_edge(column, row))
381 .count()
382 }
383
384 fn out_edges(&self, row: usize) -> usize {
385 (row + 1..self.dimension())
386 .filter(|&column| self.has_edge(row, column))
387 .count()
388 }
389}
390impl FromStr for AdjacencyMatrix {
391 type Err = Failure;
392
393 fn from_str(s: &str) -> Result<Self> {
394 let dim = (s.len() as f64).sqrt() as usize;
395 track_assert_eq!(dim * dim, s.len(), Failed, "Not a matrix: {:?}", s);
396
397 let mut matrix = vec![vec![false; dim]; dim];
398 for (i, row) in matrix.iter_mut().enumerate() {
399 for (j, v) in row.iter_mut().enumerate() {
400 *v = s.as_bytes()[i * dim + j] == b'1';
401 }
402 }
403
404 track!(Self::new(matrix), "Not an upper triangular matrix; {:?}", s)
405 }
406}
407impl fmt::Debug for AdjacencyMatrix {
408 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
409 write!(f, "AdjacencyMatrix(0b")?;
410 for row in 0..self.dimension() {
411 for column in 0..self.dimension() {
412 write!(f, "{}", self.has_edge(row, column) as u8)?;
413 }
414 if row != self.dimension() - 1 {
415 write!(f, "_")?;
416 }
417 }
418 write!(f, ")")?;
419 Ok(())
420 }
421}
422
423#[derive(Debug, Default, PartialEq)]
425pub struct ModelStats {
426 pub trainable_parameters: u32,
428
429 pub epochs: BTreeMap<u8, Vec<EpochStats>>,
433}
434impl ModelStats {
435 pub(crate) fn from_reader<R: Read>(mut reader: R) -> Result<Self> {
436 let trainable_parameters = track_any_err!(reader.read_u32::<BigEndian>())?;
437
438 let len = track_any_err!(reader.read_u8())? as usize;
439 let mut epochs = BTreeMap::new();
440 for _ in 0..len {
441 let epoch_num = track_any_err!(reader.read_u8())?;
442
443 let len = track_any_err!(reader.read_u8())? as usize;
444 let mut stats_list = Vec::with_capacity(len);
445 for _ in 0..len {
446 stats_list.push(track!(EpochStats::from_reader(&mut reader))?);
447 }
448
449 epochs.insert(epoch_num, stats_list);
450 }
451
452 Ok(Self {
453 trainable_parameters,
454 epochs,
455 })
456 }
457
458 pub(crate) fn to_writer<W: Write>(&self, mut writer: W) -> Result<()> {
459 track_any_err!(writer.write_u32::<BigEndian>(self.trainable_parameters))?;
460
461 track_any_err!(writer.write_u8(self.epochs.len() as u8))?;
462 for (epoch_num, stats_list) in &self.epochs {
463 track_any_err!(writer.write_u8(*epoch_num))?;
464
465 track_any_err!(writer.write_u8(stats_list.len() as u8))?;
466 for s in stats_list {
467 track!(s.to_writer(&mut writer))?;
468 }
469 }
470
471 Ok(())
472 }
473}
474
475#[derive(Debug, PartialEq)]
477pub struct EpochStats {
478 pub halfway: EvaluationMetrics,
480
481 pub complete: EvaluationMetrics,
483}
484impl EpochStats {
485 fn from_reader<R: Read>(mut reader: R) -> Result<Self> {
486 let halfway = track!(EvaluationMetrics::from_reader(&mut reader))?;
487 let complete = track!(EvaluationMetrics::from_reader(&mut reader))?;
488 Ok(Self { halfway, complete })
489 }
490
491 fn to_writer<W: Write>(&self, mut writer: W) -> Result<()> {
492 track!(self.halfway.to_writer(&mut writer))?;
493 track!(self.complete.to_writer(&mut writer))?;
494 Ok(())
495 }
496}
497
498#[derive(Debug, Clone, PartialEq)]
500pub struct EvaluationMetrics {
501 pub training_time: f64,
503
504 pub training_accuracy: f64,
506
507 pub validation_accuracy: f64,
509
510 pub test_accuracy: f64,
512}
513impl EvaluationMetrics {
514 fn from_reader<R: Read>(mut reader: R) -> Result<Self> {
515 let training_time = track_any_err!(reader.read_f64::<BigEndian>())?;
516 let training_accuracy = track_any_err!(reader.read_f64::<BigEndian>())?;
517 let validation_accuracy = track_any_err!(reader.read_f64::<BigEndian>())?;
518 let test_accuracy = track_any_err!(reader.read_f64::<BigEndian>())?;
519 Ok(Self {
520 training_time,
521 training_accuracy,
522 validation_accuracy,
523 test_accuracy,
524 })
525 }
526
527 fn to_writer<W: Write>(&self, mut writer: W) -> Result<()> {
528 track_any_err!(writer.write_f64::<BigEndian>(self.training_time))?;
529 track_any_err!(writer.write_f64::<BigEndian>(self.training_accuracy))?;
530 track_any_err!(writer.write_f64::<BigEndian>(self.validation_accuracy))?;
531 track_any_err!(writer.write_f64::<BigEndian>(self.test_accuracy))?;
532 Ok(())
533 }
534}
535impl From<EvaluationData> for EvaluationMetrics {
536 fn from(f: EvaluationData) -> Self {
537 Self {
538 training_time: f.training_time,
539 training_accuracy: f.train_accuracy,
540 validation_accuracy: f.validation_accuracy,
541 test_accuracy: f.test_accuracy,
542 }
543 }
544}
545
546#[cfg(test)]
547mod tests {
548 use super::*;
549 use trackable::result::TopLevelResult;
550
551 #[test]
552 fn model_spec_works() -> TopLevelResult {
553 let model0 = ModelSpec::new(vec![Op::Input, Op::Output], "0100".parse()?)?;
554 let model1 = ModelSpec::new(
555 vec![Op::Input, Op::Conv1x1, Op::Output],
556 "001000000".parse()?,
557 )?;
558 assert_eq!(model0, model1);
559
560 let model2 = ModelSpec::new(
561 vec![Op::Input, Op::Conv3x3, Op::MaxPool3x3, Op::Output],
562 "0101001000010000".parse()?,
563 )?;
564 let model3 = ModelSpec::new(
565 vec![
566 Op::Input,
567 Op::Conv1x1,
568 Op::Conv3x3,
569 Op::MaxPool3x3,
570 Op::Conv3x3,
571 Op::Output,
572 ],
573 "001001000000000100000001000000000000".parse()?,
574 )?;
575 assert_eq!(model2, model3);
576
577 let model4 = ModelSpec::new(vec![Op::Input, Op::Output], "0000".parse()?)?;
578 let model5 = ModelSpec::new(
579 vec![
580 Op::Input,
581 Op::Conv1x1,
582 Op::MaxPool3x3,
583 Op::Conv3x3,
584 Op::Output,
585 ],
586 "0000000000000000000000000".parse()?,
587 )?;
588 assert_eq!(model4, model5);
589
590 let model6 = ModelSpec::new(vec![Op::Input, Op::Output], "0100".parse()?)?;
591 let model7 = ModelSpec::new(
592 vec![Op::Input, Op::Conv3x3, Op::Conv1x1, Op::Output],
593 "0111000000000000".parse()?,
594 )?;
595 assert_eq!(model6, model7);
596
597 Ok(())
598 }
599
600 #[test]
601 fn module_hash_works() -> TopLevelResult {
602 let matrix = track!(AdjacencyMatrix::new(vec![
603 vec![false, true, true, true, false, true, false],
604 vec![false, false, false, false, false, false, true],
605 vec![false, false, false, false, false, false, true],
606 vec![false, false, false, false, true, false, false],
607 vec![false, false, false, false, false, false, true],
608 vec![false, false, false, false, false, false, true],
609 vec![false, false, false, false, false, false, false],
610 ]))?;
611 let ops = vec![
612 Op::Input,
613 Op::Conv1x1,
614 Op::Conv3x3,
615 Op::Conv3x3,
616 Op::Conv3x3,
617 Op::MaxPool3x3,
618 Op::Output,
619 ];
620
621 let spec = track!(ModelSpec::new(ops, matrix))?;
622 assert_eq!(spec.module_hash, 0x28cfc7874f6d200472e1a9dcd8650aa0);
623
624 Ok(())
625 }
626
627 #[test]
628 fn op_works() {
629 assert_eq!("input".parse().ok(), Some(Op::Input));
630 assert_eq!("conv1x1-bn-relu".parse().ok(), Some(Op::Conv1x1));
631 assert_eq!("conv3x3-bn-relu".parse().ok(), Some(Op::Conv3x3));
632 assert_eq!("maxpool3x3".parse().ok(), Some(Op::MaxPool3x3));
633 assert_eq!("output".parse().ok(), Some(Op::Output));
634 }
635
636 #[test]
637 fn adjacency_matrix_works() -> TopLevelResult {
638 let original_matrix = vec![
639 vec![false, true, false, false, true, true, false],
640 vec![false, false, true, false, false, false, false],
641 vec![false, false, false, true, false, false, true],
642 vec![false, false, false, false, false, true, false],
643 vec![false, false, false, false, false, true, false],
644 vec![false, false, false, false, false, false, true],
645 vec![false, false, false, false, false, false, false],
646 ];
647
648 let matrix0 = track!(AdjacencyMatrix::new(original_matrix.clone()))?;
649 assert_eq!(matrix0.dimension(), 7);
650
651 let matrix1 = track!("0100110001000000010010000010000001000000010000000".parse())?;
652 assert_eq!(matrix0, matrix1);
653
654 for row in 0..original_matrix.len() {
655 for column in 0..original_matrix.len() {
656 assert_eq!(
657 matrix0.has_edge(row, column),
658 original_matrix[row][column],
659 "row={}, column={}",
660 row,
661 column
662 );
663 }
664 }
665 Ok(())
666 }
667}