nasbench/
model.rs

1use crate::protobuf::EvaluationData;
2use crate::Result;
3use byteorder::{BigEndian, ByteOrder, ReadBytesExt, WriteBytesExt};
4use md5;
5use std::collections::BTreeMap;
6use std::fmt;
7use std::hash::{Hash, Hasher};
8use std::io::{Read, Write};
9use std::str::FromStr;
10use trackable::error::{Failed, Failure};
11
12/// Model specification given adjacency matrix and operations (a.k.a. "module").
13///
14/// Note that two instances of `ModuleSpec` are regarded as the same
15/// if their structures are semantically equivalent (see below).
16///
17/// ```rust
18/// use nasbench::{ModelSpec, Op};
19/// # use trackable::result::TopLevelResult;
20///
21/// # fn main() -> TopLevelResult {
22/// let model0 = ModelSpec::new(vec![Op::Input, Op::Output], "0100".parse()?)?;
23/// let model1 = ModelSpec::new(
24///     vec![Op::Input, Op::Conv1x1, Op::Output],
25///     "001000000".parse()?,
26/// )?;
27///
28/// assert_eq!(model0, model1);
29/// # Ok(())
30/// # }
31/// ```
32#[derive(Clone)]
33pub struct ModelSpec {
34    ops: Vec<Op>,
35    adjacency: AdjacencyMatrix,
36    module_hash: u128,
37}
38impl ModelSpec {
39    /// Makes a new `ModelSpec` instance.
40    pub fn new(mut ops: Vec<Op>, mut adjacency: AdjacencyMatrix) -> Result<Self> {
41        track_assert_eq!(ops.len(), adjacency.dimension(), Failed);
42        track_assert!(adjacency.dimension() >= 2, Failed);
43
44        Self::prune(&mut ops, &mut adjacency);
45        let module_hash = Self::module_hash(&ops, &adjacency);
46        Ok(Self {
47            ops,
48            adjacency,
49            module_hash,
50        })
51    }
52
53    pub(crate) fn with_module_hash(
54        mut ops: Vec<Op>,
55        mut adjacency: AdjacencyMatrix,
56        module_hash: u128,
57    ) -> Self {
58        Self::prune(&mut ops, &mut adjacency);
59        Self {
60            ops,
61            adjacency,
62            module_hash,
63        }
64    }
65
66    pub(crate) fn validate_module_hash(&self) -> Result<()> {
67        let expected_module_hash = Self::module_hash(&self.ops, &self.adjacency);
68        track_assert_eq!(self.module_hash, expected_module_hash, Failed);
69        Ok(())
70    }
71
72    /// Returns a reference to the operations of this model.
73    pub fn ops(&self) -> &[Op] {
74        &self.ops
75    }
76
77    /// Returns a reference to the adjacency matrix of this model.
78    pub fn adjacency(&self) -> &AdjacencyMatrix {
79        &self.adjacency
80    }
81
82    fn prune(ops: &mut Vec<Op>, adjacency: &mut AdjacencyMatrix) {
83        let mut deleted = true;
84        while deleted {
85            deleted = false;
86
87            for row in 1..adjacency.dimension() - 1 {
88                let in_edges = adjacency.in_edges(row);
89                if in_edges == 0 {
90                    deleted = true;
91                    ops.remove(row);
92                    adjacency.remove(row);
93                    break;
94                }
95
96                let out_edges = adjacency.out_edges(row);
97                if out_edges == 0 {
98                    deleted = true;
99                    ops.remove(row);
100                    adjacency.remove(row);
101                    break;
102                }
103            }
104        }
105    }
106
107    fn module_hash(ops: &[Op], adjacency: &AdjacencyMatrix) -> u128 {
108        let dim = ops.len();
109
110        let mut hashes = Vec::with_capacity(dim);
111        for (row, op) in ops.iter().enumerate() {
112            let in_edges = adjacency.in_edges(row);
113            let out_edges = adjacency.out_edges(row);
114            let s = format!("({}, {}, {})", out_edges, in_edges, op.to_hash_index());
115            hashes.push(format!("{:032x}", md5::compute(s.as_bytes())));
116        }
117
118        for _ in 0..dim {
119            let mut new_hashes = Vec::with_capacity(dim);
120            for (v, h) in hashes.iter().enumerate() {
121                let mut in_neighbors = (0..dim)
122                    .filter(|&w| adjacency.has_edge(w, v))
123                    .map(|w| hashes[w].as_str())
124                    .collect::<Vec<_>>();
125                let mut out_neighbors = (0..dim)
126                    .filter(|&w| adjacency.has_edge(v, w))
127                    .map(|w| hashes[w].as_str())
128                    .collect::<Vec<_>>();
129                in_neighbors.sort();
130                out_neighbors.sort();
131
132                let s = format!("{}|{}|{}", in_neighbors.join(""), out_neighbors.join(""), h);
133                new_hashes.push(format!("{:032x}", md5::compute(s.as_bytes())));
134            }
135            hashes = new_hashes;
136        }
137
138        hashes.sort();
139        let hashes = hashes
140            .iter()
141            .map(|h| format!("'{}'", h))
142            .collect::<Vec<_>>();
143        let fingerprint = format!("[{}]", hashes.join(", "));
144        BigEndian::read_u128(&md5::compute(fingerprint.as_bytes()).0)
145    }
146
147    pub(crate) fn from_reader<R: Read>(mut reader: R) -> Result<Self> {
148        let len = track_any_err!(reader.read_u8())? as usize;
149        let mut ops = Vec::with_capacity(len);
150        for _ in 0..len {
151            let op = match track_any_err!(reader.read_u8())? {
152                0 => Op::Input,
153                1 => Op::Conv1x1,
154                2 => Op::Conv3x3,
155                3 => Op::MaxPool3x3,
156                4 => Op::Output,
157                n => track_panic!(Failed, "Unknown operation number: {}", n),
158            };
159            ops.push(op);
160        }
161
162        let adjacency = track!(AdjacencyMatrix::from_reader(&mut reader))?;
163        let module_hash = track_any_err!(reader.read_u128::<BigEndian>())?;
164
165        Ok(Self {
166            ops,
167            adjacency,
168            module_hash,
169        })
170    }
171
172    pub(crate) fn to_writer<W: Write>(&self, mut writer: W) -> Result<()> {
173        track_any_err!(writer.write_u8(self.ops.len() as u8))?;
174        for op in &self.ops {
175            track_any_err!(writer.write_u8(*op as u8))?;
176        }
177
178        track!(self.adjacency.to_writer(&mut writer))?;
179        track_any_err!(writer.write_u128::<BigEndian>(self.module_hash))?;
180
181        Ok(())
182    }
183}
184impl fmt::Debug for ModelSpec {
185    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
186        write!(
187            f,
188            "ModelSpec {{ ops: {:?}, adjacency: {:?}, .. }}",
189            self.ops, self.adjacency
190        )
191    }
192}
193impl PartialEq for ModelSpec {
194    fn eq(&self, other: &Self) -> bool {
195        self.module_hash == other.module_hash
196    }
197}
198impl Eq for ModelSpec {}
199impl Hash for ModelSpec {
200    fn hash<H: Hasher>(&self, h: &mut H) {
201        self.module_hash.hash(h);
202    }
203}
204
205/// Operation.
206///
207/// # Examples
208///
209/// ```
210/// use nasbench::Op;
211///
212/// assert_eq!("input".parse().ok(), Some(Op::Input));
213/// assert_eq!("conv1x1-bn-relu".parse().ok(), Some(Op::Conv1x1));
214/// assert_eq!("conv3x3-bn-relu".parse().ok(), Some(Op::Conv3x3));
215/// assert_eq!("maxpool3x3".parse().ok(), Some(Op::MaxPool3x3));
216/// assert_eq!("output".parse().ok(), Some(Op::Output));
217/// ```
218#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
219pub enum Op {
220    /// Input tensor.
221    Input,
222
223    /// 1x1 convolution -> batch-norm -> ReLU.
224    Conv1x1,
225
226    /// 3x3 convolution -> batch-norm -> ReLU.
227    Conv3x3,
228
229    /// 3x3 max-pool.
230    MaxPool3x3,
231
232    /// Output tensor.
233    Output,
234}
235impl Op {
236    fn to_hash_index(self) -> isize {
237        match self {
238            Op::Input => -1,
239            Op::Conv3x3 => 0,
240            Op::Conv1x1 => 1,
241            Op::MaxPool3x3 => 2,
242            Op::Output => -2,
243        }
244    }
245}
246impl FromStr for Op {
247    type Err = Failure;
248
249    fn from_str(op: &str) -> Result<Self> {
250        Ok(match op {
251            "input" => Op::Input,
252            "conv1x1-bn-relu" => Op::Conv1x1,
253            "conv3x3-bn-relu" => Op::Conv3x3,
254            "maxpool3x3" => Op::MaxPool3x3,
255            "output" => Op::Output,
256            _ => track_panic!(Failed, "Unknown operator: {:?}", op),
257        })
258    }
259}
260
261/// Adjacency (upper-triangular) matrix of a module.
262///
263/// # Examples
264///
265/// ```
266/// use nasbench::AdjacencyMatrix;
267/// # use trackable::result::TopLevelResult;
268///
269/// # fn main() -> TopLevelResult {
270/// let matrix0 = AdjacencyMatrix::new(vec![
271///     vec![false, true, false, false, true, true, false],
272///     vec![false, false, true, false, false, false, false],
273///     vec![false, false, false, true, false, false, true],
274///     vec![false, false, false, false, false, true, false],
275///     vec![false, false, false, false, false, true, false],
276///     vec![false, false, false, false, false, false, true],
277///     vec![false, false, false, false, false, false, false]
278/// ])?;
279/// assert_eq!(matrix0.dimension(), 7);
280///
281/// let matrix1 = "0100110001000000010010000010000001000000010000000".parse()?;
282/// assert_eq!(matrix0, matrix1);
283/// # Ok(())
284/// # }
285/// ```
286#[derive(Clone, PartialEq, Eq, Hash)]
287pub struct AdjacencyMatrix {
288    dim: u8,
289    triangle: u32,
290}
291impl AdjacencyMatrix {
292    /// Makes a new `AdjacencyMatrix` instance.
293    pub fn new(matrix: Vec<Vec<bool>>) -> Result<Self> {
294        let dim = matrix.len();
295        track_assert_ne!(dim, 0, Failed);
296        track_assert!(dim <= 7, Failed; dim);
297
298        let mut triangle = 0;
299        let mut offset = 0;
300        for (i, row) in matrix.into_iter().enumerate() {
301            track_assert_eq!(row.len(), dim, Failed);
302
303            for (j, adjacent) in row.into_iter().enumerate() {
304                if j <= i {
305                    track_assert!(!adjacent, Failed; i, j);
306                    continue;
307                }
308
309                offset += 1;
310                if !adjacent {
311                    continue;
312                }
313
314                triangle |= 1 << (offset - 1);
315            }
316        }
317
318        let dim = dim as u8;
319        Ok(Self { dim, triangle })
320    }
321
322    /// Returns the dimension of this matrix.
323    pub fn dimension(&self) -> usize {
324        usize::from(self.dim)
325    }
326
327    fn from_reader<R: Read>(mut reader: R) -> Result<Self> {
328        let dim = track_any_err!(reader.read_u8())?;
329        let triangle = track_any_err!(reader.read_u32::<BigEndian>())?;
330        Ok(Self { dim, triangle })
331    }
332
333    fn to_writer<W: Write>(&self, mut writer: W) -> Result<()> {
334        track_any_err!(writer.write_u8(self.dim))?;
335        track_any_err!(writer.write_u32::<BigEndian>(self.triangle))?;
336        Ok(())
337    }
338
339    fn remove(&mut self, row: usize) {
340        let mut triangle = 0;
341        let mut offset = 0;
342        for i in (0..self.dimension()).filter(|&i| i != row) {
343            for j in (i + 1..self.dimension()).filter(|&j| j != row) {
344                offset += 1;
345                if !self.has_edge(i, j) {
346                    continue;
347                }
348
349                triangle |= 1 << (offset - 1);
350            }
351        }
352
353        self.dim -= 1;
354        self.triangle = triangle;
355    }
356
357    fn has_edge(&self, row: usize, column: usize) -> bool {
358        if column <= row {
359            return false;
360        }
361
362        let offset = match self.dim {
363            7 => &[0, 6, 11, 15, 18, 20, 21][..],
364            6 => &[0, 5, 9, 12, 14, 15][..],
365            5 => &[0, 4, 7, 9, 10][..],
366            4 => &[0, 3, 5, 6][..],
367            3 => &[0, 2, 1][..],
368            2 => &[0, 1][..],
369            1 => &[0][..],
370            _ => {
371                unreachable!("dim={}", self.dim);
372            }
373        };
374        let i = offset[row] + column - row - 1;
375        (self.triangle & (1 << i)) != 0
376    }
377
378    fn in_edges(&self, row: usize) -> usize {
379        (0..row)
380            .filter(|&column| self.has_edge(column, row))
381            .count()
382    }
383
384    fn out_edges(&self, row: usize) -> usize {
385        (row + 1..self.dimension())
386            .filter(|&column| self.has_edge(row, column))
387            .count()
388    }
389}
390impl FromStr for AdjacencyMatrix {
391    type Err = Failure;
392
393    fn from_str(s: &str) -> Result<Self> {
394        let dim = (s.len() as f64).sqrt() as usize;
395        track_assert_eq!(dim * dim, s.len(), Failed, "Not a matrix: {:?}", s);
396
397        let mut matrix = vec![vec![false; dim]; dim];
398        for (i, row) in matrix.iter_mut().enumerate() {
399            for (j, v) in row.iter_mut().enumerate() {
400                *v = s.as_bytes()[i * dim + j] == b'1';
401            }
402        }
403
404        track!(Self::new(matrix), "Not an upper triangular matrix; {:?}", s)
405    }
406}
407impl fmt::Debug for AdjacencyMatrix {
408    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
409        write!(f, "AdjacencyMatrix(0b")?;
410        for row in 0..self.dimension() {
411            for column in 0..self.dimension() {
412                write!(f, "{}", self.has_edge(row, column) as u8)?;
413            }
414            if row != self.dimension() - 1 {
415                write!(f, "_")?;
416            }
417        }
418        write!(f, ")")?;
419        Ok(())
420    }
421}
422
423/// Model statistics.
424#[derive(Debug, Default, PartialEq)]
425pub struct ModelStats {
426    /// Number of trainable parameters in the model.
427    pub trainable_parameters: u32,
428
429    /// Map from epoch number to evaluation metrics at that epoch.
430    ///
431    /// Because each model has evaluated multiple times, each epoch is associated with multiple metrics.
432    pub epochs: BTreeMap<u8, Vec<EpochStats>>,
433}
434impl ModelStats {
435    pub(crate) fn from_reader<R: Read>(mut reader: R) -> Result<Self> {
436        let trainable_parameters = track_any_err!(reader.read_u32::<BigEndian>())?;
437
438        let len = track_any_err!(reader.read_u8())? as usize;
439        let mut epochs = BTreeMap::new();
440        for _ in 0..len {
441            let epoch_num = track_any_err!(reader.read_u8())?;
442
443            let len = track_any_err!(reader.read_u8())? as usize;
444            let mut stats_list = Vec::with_capacity(len);
445            for _ in 0..len {
446                stats_list.push(track!(EpochStats::from_reader(&mut reader))?);
447            }
448
449            epochs.insert(epoch_num, stats_list);
450        }
451
452        Ok(Self {
453            trainable_parameters,
454            epochs,
455        })
456    }
457
458    pub(crate) fn to_writer<W: Write>(&self, mut writer: W) -> Result<()> {
459        track_any_err!(writer.write_u32::<BigEndian>(self.trainable_parameters))?;
460
461        track_any_err!(writer.write_u8(self.epochs.len() as u8))?;
462        for (epoch_num, stats_list) in &self.epochs {
463            track_any_err!(writer.write_u8(*epoch_num))?;
464
465            track_any_err!(writer.write_u8(stats_list.len() as u8))?;
466            for s in stats_list {
467                track!(s.to_writer(&mut writer))?;
468            }
469        }
470
471        Ok(())
472    }
473}
474
475/// Epoch statistics.
476#[derive(Debug, PartialEq)]
477pub struct EpochStats {
478    /// Evaluation metrics at the half-way point of training.
479    pub halfway: EvaluationMetrics,
480
481    /// Evaluation metrics at the end of training.
482    pub complete: EvaluationMetrics,
483}
484impl EpochStats {
485    fn from_reader<R: Read>(mut reader: R) -> Result<Self> {
486        let halfway = track!(EvaluationMetrics::from_reader(&mut reader))?;
487        let complete = track!(EvaluationMetrics::from_reader(&mut reader))?;
488        Ok(Self { halfway, complete })
489    }
490
491    fn to_writer<W: Write>(&self, mut writer: W) -> Result<()> {
492        track!(self.halfway.to_writer(&mut writer))?;
493        track!(self.complete.to_writer(&mut writer))?;
494        Ok(())
495    }
496}
497
498/// Evaluation metrics.
499#[derive(Debug, Clone, PartialEq)]
500pub struct EvaluationMetrics {
501    /// The total training time in seconds up to this point.
502    pub training_time: f64,
503
504    /// Training accuracy.
505    pub training_accuracy: f64,
506
507    /// Validation accuracy.
508    pub validation_accuracy: f64,
509
510    /// Test accuracy.
511    pub test_accuracy: f64,
512}
513impl EvaluationMetrics {
514    fn from_reader<R: Read>(mut reader: R) -> Result<Self> {
515        let training_time = track_any_err!(reader.read_f64::<BigEndian>())?;
516        let training_accuracy = track_any_err!(reader.read_f64::<BigEndian>())?;
517        let validation_accuracy = track_any_err!(reader.read_f64::<BigEndian>())?;
518        let test_accuracy = track_any_err!(reader.read_f64::<BigEndian>())?;
519        Ok(Self {
520            training_time,
521            training_accuracy,
522            validation_accuracy,
523            test_accuracy,
524        })
525    }
526
527    fn to_writer<W: Write>(&self, mut writer: W) -> Result<()> {
528        track_any_err!(writer.write_f64::<BigEndian>(self.training_time))?;
529        track_any_err!(writer.write_f64::<BigEndian>(self.training_accuracy))?;
530        track_any_err!(writer.write_f64::<BigEndian>(self.validation_accuracy))?;
531        track_any_err!(writer.write_f64::<BigEndian>(self.test_accuracy))?;
532        Ok(())
533    }
534}
535impl From<EvaluationData> for EvaluationMetrics {
536    fn from(f: EvaluationData) -> Self {
537        Self {
538            training_time: f.training_time,
539            training_accuracy: f.train_accuracy,
540            validation_accuracy: f.validation_accuracy,
541            test_accuracy: f.test_accuracy,
542        }
543    }
544}
545
546#[cfg(test)]
547mod tests {
548    use super::*;
549    use trackable::result::TopLevelResult;
550
551    #[test]
552    fn model_spec_works() -> TopLevelResult {
553        let model0 = ModelSpec::new(vec![Op::Input, Op::Output], "0100".parse()?)?;
554        let model1 = ModelSpec::new(
555            vec![Op::Input, Op::Conv1x1, Op::Output],
556            "001000000".parse()?,
557        )?;
558        assert_eq!(model0, model1);
559
560        let model2 = ModelSpec::new(
561            vec![Op::Input, Op::Conv3x3, Op::MaxPool3x3, Op::Output],
562            "0101001000010000".parse()?,
563        )?;
564        let model3 = ModelSpec::new(
565            vec![
566                Op::Input,
567                Op::Conv1x1,
568                Op::Conv3x3,
569                Op::MaxPool3x3,
570                Op::Conv3x3,
571                Op::Output,
572            ],
573            "001001000000000100000001000000000000".parse()?,
574        )?;
575        assert_eq!(model2, model3);
576
577        let model4 = ModelSpec::new(vec![Op::Input, Op::Output], "0000".parse()?)?;
578        let model5 = ModelSpec::new(
579            vec![
580                Op::Input,
581                Op::Conv1x1,
582                Op::MaxPool3x3,
583                Op::Conv3x3,
584                Op::Output,
585            ],
586            "0000000000000000000000000".parse()?,
587        )?;
588        assert_eq!(model4, model5);
589
590        let model6 = ModelSpec::new(vec![Op::Input, Op::Output], "0100".parse()?)?;
591        let model7 = ModelSpec::new(
592            vec![Op::Input, Op::Conv3x3, Op::Conv1x1, Op::Output],
593            "0111000000000000".parse()?,
594        )?;
595        assert_eq!(model6, model7);
596
597        Ok(())
598    }
599
600    #[test]
601    fn module_hash_works() -> TopLevelResult {
602        let matrix = track!(AdjacencyMatrix::new(vec![
603            vec![false, true, true, true, false, true, false],
604            vec![false, false, false, false, false, false, true],
605            vec![false, false, false, false, false, false, true],
606            vec![false, false, false, false, true, false, false],
607            vec![false, false, false, false, false, false, true],
608            vec![false, false, false, false, false, false, true],
609            vec![false, false, false, false, false, false, false],
610        ]))?;
611        let ops = vec![
612            Op::Input,
613            Op::Conv1x1,
614            Op::Conv3x3,
615            Op::Conv3x3,
616            Op::Conv3x3,
617            Op::MaxPool3x3,
618            Op::Output,
619        ];
620
621        let spec = track!(ModelSpec::new(ops, matrix))?;
622        assert_eq!(spec.module_hash, 0x28cfc7874f6d200472e1a9dcd8650aa0);
623
624        Ok(())
625    }
626
627    #[test]
628    fn op_works() {
629        assert_eq!("input".parse().ok(), Some(Op::Input));
630        assert_eq!("conv1x1-bn-relu".parse().ok(), Some(Op::Conv1x1));
631        assert_eq!("conv3x3-bn-relu".parse().ok(), Some(Op::Conv3x3));
632        assert_eq!("maxpool3x3".parse().ok(), Some(Op::MaxPool3x3));
633        assert_eq!("output".parse().ok(), Some(Op::Output));
634    }
635
636    #[test]
637    fn adjacency_matrix_works() -> TopLevelResult {
638        let original_matrix = vec![
639            vec![false, true, false, false, true, true, false],
640            vec![false, false, true, false, false, false, false],
641            vec![false, false, false, true, false, false, true],
642            vec![false, false, false, false, false, true, false],
643            vec![false, false, false, false, false, true, false],
644            vec![false, false, false, false, false, false, true],
645            vec![false, false, false, false, false, false, false],
646        ];
647
648        let matrix0 = track!(AdjacencyMatrix::new(original_matrix.clone()))?;
649        assert_eq!(matrix0.dimension(), 7);
650
651        let matrix1 = track!("0100110001000000010010000010000001000000010000000".parse())?;
652        assert_eq!(matrix0, matrix1);
653
654        for row in 0..original_matrix.len() {
655            for column in 0..original_matrix.len() {
656                assert_eq!(
657                    matrix0.has_edge(row, column),
658                    original_matrix[row][column],
659                    "row={}, column={}",
660                    row,
661                    column
662                );
663            }
664        }
665        Ok(())
666    }
667}