1use std::borrow::Borrow;
18use std::cmp::min;
19use std::fs::File;
20use std::path::Path;
21use std::time::{Duration, Instant};
22
23use console::style;
24use indicatif::{HumanDuration, ProgressBar, ProgressStyle};
25use itertools::Itertools;
26use ndarray_npy::{NpzReader, ReadNpyError, ReadNpzError};
27use regex::Regex;
28
29use crate::distill::schema::{
30 argmax, class_characterization, partial_ReLU, partial_hard_sigmoid, partial_hard_tanh,
31 partial_leaky_ReLU,
32};
33use crate::linalg::affine::AffFunc;
34use crate::pwl::afftree::AffTree;
35
36pub trait NodeEstimator {
41 fn estimate_nodes<Item: Borrow<Layer>>(
42 &self,
43 dim: usize,
44 current_depth: usize,
45 layers: &[Item],
46 ) -> (usize, usize);
47}
48
49#[derive(Clone, Debug, PartialEq)]
52pub struct SimpleNodeEstimator;
53
54impl NodeEstimator for SimpleNodeEstimator {
55 fn estimate_nodes<Item>(
56 &self,
57 dim: usize,
58 _current_depth: usize,
59 layers: &[Item],
60 ) -> (usize, usize)
61 where
62 Item: Borrow<Layer>,
63 {
64 let mut n_splits = 0;
65 let mut first_dim = 0;
66 let mut last_dim = dim;
67
68 for layer in layers {
69 match layer.borrow() {
70 Layer::Linear(func) => {
71 last_dim = func.outdim();
72 if first_dim == 0 {
73 first_dim = func.outdim();
74 }
75 }
76 Layer::ReLU(_) => n_splits += 1,
77 Layer::LeakyReLU(_, _) => n_splits += 1,
78 Layer::HardTanh(_) => n_splits += 1,
79 Layer::HardSigmoid(_) => n_splits += 1,
80 Layer::ClassChar(_) => n_splits += 1,
81 Layer::Argmax => n_splits += last_dim,
82 }
83 }
84
85 let n = first_dim as i32;
88 let m = n_splits.saturating_sub(first_dim) as i32;
89 let q = 1.3f64;
90
91 let estimated_capacity =
92 (2.0f64.powi(n) * (2.0 + (q.powi(m + 1) - q) / (q - 1.0))) as usize;
93
94 (n_splits, estimated_capacity)
95 }
96}
97
98pub trait DistillVisitor {
102 fn start_distill(&mut self, dim: usize, n_layers: usize, n_nodes: usize);
104 fn start_layer(&mut self, layer: &Layer);
106 fn finish_layer(
108 &mut self,
109 layer: &Layer,
110 new_nodes: usize,
111 decision_nodes: usize,
112 terminal_nodes: usize,
113 );
114 fn finish_distill(&mut self, total_decisions: usize, total_terminals: usize);
116}
117
118#[derive(Clone, Debug)]
123pub struct DistillConsole {
124 pb: ProgressBar,
125 timer: Instant,
126 len: usize,
127}
128
129impl DistillConsole {
130 pub fn new() -> DistillConsole {
131 DistillConsole {
132 pb: ProgressBar::hidden(),
133 timer: Instant::now(),
134 len: 0,
135 }
136 }
137}
138
139impl Default for DistillConsole {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145impl DistillVisitor for DistillConsole {
146 fn start_distill(&mut self, dim: usize, n_layers: usize, n_nodes: usize) {
147 self.pb = ProgressBar::new(n_layers as u64);
148 let sty = ProgressStyle::default_bar()
149 .template(&format!(
150 "{: >12} {}",
151 style("Building").cyan().bold(),
152 "[{bar:25}] {pos:>2}/{len:2} ({elapsed})"
153 ))
154 .unwrap()
155 .progress_chars("=> ");
156 self.pb.set_style(sty.clone());
157 self.pb.enable_steady_tick(Duration::from_secs(5));
158
159 println!("Input dim: {}", dim);
160 println!("Estimated number of layers: {}", n_layers);
161 println!("Estimated number of nodes: {}", n_nodes);
162
163 self.timer = Instant::now();
164 self.len = n_layers;
165 }
166
167 fn start_layer(&mut self, _layer: &Layer) {
168 self.timer = Instant::now();
169 }
170
171 fn finish_layer(
172 &mut self,
173 layer: &Layer,
174 _new_nodes: usize,
175 decision_nodes: usize,
176 terminal_nodes: usize,
177 ) {
178 let duration = self.timer.elapsed();
179 match layer {
180 Layer::Linear(_) => {}
181 Layer::ReLU(_) => {
182 self.pb.println(format!(
183 "{: >12} partial ReLU in {:#} ({} nodes, {} terminals)",
184 style("Finished").green().bold(),
185 HumanDuration(duration),
186 decision_nodes + terminal_nodes,
187 terminal_nodes
188 ));
189 self.pb.inc(1);
190 }
191 Layer::LeakyReLU(_, _) => {
192 self.pb.println(format!(
193 "{: >12} partial leaky ReLU in {:#} ({} nodes, {} terminals)",
194 style("Finished").green().bold(),
195 HumanDuration(duration),
196 decision_nodes + terminal_nodes,
197 terminal_nodes
198 ));
199 self.pb.inc(1);
200 }
201 Layer::HardTanh(_) => {
202 self.pb.println(format!(
203 "{: >12} partial hard tanh in {:#} ({} nodes, {} terminals)",
204 style("Finished").green().bold(),
205 HumanDuration(duration),
206 decision_nodes + terminal_nodes,
207 terminal_nodes
208 ));
209 self.pb.inc(1);
210 }
211 Layer::HardSigmoid(_) => {
212 self.pb.println(format!(
213 "{: >12} partial hard sigmoid in {:#} ({} nodes, {} terminals)",
214 style("Finished").green().bold(),
215 HumanDuration(duration),
216 decision_nodes + terminal_nodes,
217 terminal_nodes
218 ));
219 self.pb.inc(1);
220 }
221 Layer::ClassChar(_) => {}
222 Layer::Argmax => {}
223 }
224 }
225
226 fn finish_distill(&mut self, total_decisions: usize, total_terminals: usize) {
227 self.pb.finish_and_clear();
228 println!(
229 "\n{: >12} constructing decision tree ({} decisions, {} terminals)",
230 style("Completed").green().bold(),
231 total_decisions,
232 total_terminals
233 );
234 }
235}
236
237#[derive(serde::Serialize)]
238struct CsvRow {
239 depth: usize,
240 inner_nodes: usize,
241 terminal_nodes: usize,
242 time_ms: u128,
243 in_dim: usize,
244}
245
246#[derive(Debug)]
253pub struct DistillCsv {
254 writer: csv::Writer<File>,
255 timer: Instant,
256 depth: usize,
257 in_dim: usize,
258 pb: ProgressBar,
259}
260
261impl DistillCsv {
262 pub fn new<P: AsRef<Path>>(path: P) -> DistillCsv {
263 DistillCsv {
264 writer: csv::Writer::from_path(path).unwrap(),
265 timer: Instant::now(),
266 depth: 0,
267 in_dim: 0,
268 pb: ProgressBar::hidden(),
269 }
270 }
271}
272
273impl DistillVisitor for DistillCsv {
274 fn start_distill(&mut self, dim: usize, n_layers: usize, _n_nodes: usize) {
275 self.timer = Instant::now();
276 self.depth = 0;
277 self.in_dim = dim;
278 self.pb = ProgressBar::new(n_layers as u64);
279 let sty = ProgressStyle::default_bar()
280 .template(&format!(
281 "{: >12} {}",
282 style("Building").cyan().bold(),
283 "[{bar:20}] {pos:>2}/{len:2} ({elapsed})"
284 ))
285 .unwrap()
286 .progress_chars("=> ");
287 self.pb.set_style(sty.clone());
288 }
289
290 fn start_layer(&mut self, _layer: &Layer) {
291 self.timer = Instant::now();
292 }
293
294 fn finish_layer(
295 &mut self,
296 layer: &Layer,
297 _new_nodes: usize,
298 decision_nodes: usize,
299 terminal_nodes: usize,
300 ) {
301 let duration = self.timer.elapsed();
302 match layer {
303 Layer::Linear(_) => {}
304 Layer::ReLU(_)
305 | Layer::LeakyReLU(_, _)
306 | Layer::HardTanh(_)
307 | Layer::HardSigmoid(_) => {
308 self.depth += 1;
309 self.writer
310 .serialize(CsvRow {
311 depth: self.depth,
312 inner_nodes: decision_nodes,
313 terminal_nodes,
314 time_ms: duration.as_millis(),
315 in_dim: self.in_dim,
316 })
317 .unwrap();
318 self.writer.flush().unwrap();
319 self.pb.inc(1);
320 }
321 Layer::ClassChar(_) => {}
322 Layer::Argmax => {}
323 }
324 }
325
326 fn finish_distill(&mut self, _total_decisions: usize, _total_terminals: usize) {
327 self.writer.flush().unwrap();
328 self.pb.finish_and_clear();
329 }
330}
331
332#[derive(Clone, Debug)]
334pub struct NoOpVis {}
335
336impl DistillVisitor for NoOpVis {
337 fn start_distill(&mut self, _: usize, _: usize, _: usize) {}
338
339 fn start_layer(&mut self, _: &Layer) {}
340
341 fn finish_layer(&mut self, _: &Layer, _: usize, _: usize, _: usize) {}
342
343 fn finish_distill(&mut self, _: usize, _: usize) {}
344}
345
346#[derive(Debug, Clone)]
349pub enum Layer {
350 Linear(AffFunc),
352 ReLU(usize),
354 LeakyReLU(usize, f64),
356 HardTanh(usize),
358 HardSigmoid(usize),
360 Argmax,
362 ClassChar(usize),
365}
366
367pub fn afftree_from_layers<I>(dim: usize, layers: I, precondition: Option<AffTree<2>>) -> AffTree<2>
369where
370 I: IntoIterator,
371 I::Item: Borrow<Layer>,
372{
373 afftree_from_layers_generic(
374 dim,
375 layers,
376 precondition,
377 &mut SimpleNodeEstimator {},
378 &mut NoOpVis {},
379 )
380}
381
382pub fn afftree_from_layers_verbose<I>(
387 dim: usize,
388 layers: I,
389 precondition: Option<AffTree<2>>,
390) -> AffTree<2>
391where
392 I: IntoIterator,
393 I::Item: Borrow<Layer>,
394{
395 afftree_from_layers_generic(
396 dim,
397 layers,
398 precondition,
399 &mut SimpleNodeEstimator {},
400 &mut DistillConsole::new(),
401 )
402}
403
404pub fn afftree_from_layers_csv<I, P: AsRef<Path>>(
409 dim: usize,
410 layers: I,
411 path: P,
412 precondition: Option<AffTree<2>>,
413) -> AffTree<2>
414where
415 I: IntoIterator,
416 I::Item: Borrow<Layer>,
417{
418 afftree_from_layers_generic(
419 dim,
420 layers,
421 precondition,
422 &mut SimpleNodeEstimator {},
423 &mut DistillCsv::new(path),
424 )
425}
426
427pub fn afftree_from_layers_generic<I, Estimator, Visitor>(
436 dim: usize,
437 layers: I,
438 precondition: Option<AffTree<2>>,
439 node_estimator: &mut Estimator,
440 visitor: &mut Visitor,
441) -> AffTree<2>
442where
443 I: IntoIterator,
444 I::Item: Borrow<Layer>,
445 Estimator: NodeEstimator,
446 Visitor: DistillVisitor,
447{
448 let container = layers.into_iter().collect_vec();
449
450 let (n_layers, n_nodes) = node_estimator.estimate_nodes(dim, 0, &container);
451 let mut dim = dim;
452
453 let n_nodes = min(n_nodes, 524288);
455
456 let mut dd = if let Some(dd) = precondition {
457 assert_eq!(
458 dd.in_dim(),
459 dim,
460 "Specified dim does not match dim of precondition: {} vs {}",
461 dim,
462 dd.in_dim()
463 );
464 dim = dd.terminals().map(|x| x.aff.outdim()).next().unwrap();
465 dd
466 } else {
467 AffTree::<2>::with_capacity(dim, n_nodes)
468 };
469
470 dd.reserve(n_nodes);
471
472 visitor.start_distill(dim, n_layers, n_nodes);
473
474 for layer in container.into_iter() {
475 let layer = layer.borrow();
476 let old_len = dd.len();
477 visitor.start_layer(layer);
478 match layer {
479 Layer::Linear(aff) => {
480 assert!(
481 aff.indim() == dim,
482 "Input dimension of current layer does not match output dimension of previous layer: prev={} vs current={}",
483 dim,
484 aff.indim()
485 );
486 dd.apply_func(aff);
487 dim = aff.outdim();
488 }
489 Layer::ReLU(row) => {
490 dd.compose::<false, false>(&partial_ReLU(dim, *row));
491 dd.infeasible_elimination();
492 }
493 Layer::LeakyReLU(row, alpha) => {
494 dd.compose::<false, false>(&partial_leaky_ReLU(dim, *row, *alpha));
495 dd.infeasible_elimination();
496 }
497 Layer::HardTanh(row) => {
498 dd.compose::<false, false>(&partial_hard_tanh(dim, *row, -1., 1.));
499 dd.infeasible_elimination();
500 }
501 Layer::HardSigmoid(row) => {
502 dd.compose::<false, false>(&partial_hard_sigmoid(dim, *row));
503 dd.infeasible_elimination();
504 }
505 Layer::Argmax => {
506 dd.compose::<true, true>(&argmax(dim));
507 }
508 Layer::ClassChar(clazz) => {
509 dd.compose::<true, false>(&class_characterization(dim, *clazz));
510 }
511 }
512 visitor.finish_layer(
513 layer,
514 dd.len() - old_len,
515 dd.len() - dd.tree.num_terminals(),
516 dd.tree.num_terminals(),
517 );
518 }
519 visitor.finish_distill(dd.len() - dd.tree.num_terminals(), dd.tree.num_terminals());
520
521 dd
522}
523
524pub fn read_layers<P: AsRef<Path>>(path: &P) -> Result<Vec<Layer>, ReadNpzError> {
533 let file = File::open(path).map_err(ReadNpyError::from)?;
534 let mut npz = NpzReader::new(file)?;
535
536 let mut names = npz.names()?;
537 names.sort_unstable();
538
539 let pattern = Regex::new(r"^(\d+)\.([A-Za-z._]*?)(\.npy)?$").unwrap();
540
541 let mut layers = Vec::with_capacity(names.len());
542 let mut dim: usize = 0;
543
544 for layer_descr in names.iter() {
545 let layer_name = pattern.captures(layer_descr);
546 if layer_name.is_none() {
548 continue;
549 }
550 let layer_name = layer_name.unwrap();
551 match layer_name.get(2).unwrap().as_str() {
552 "relu" => {
553 for idx in 0..dim {
554 layers.push(Layer::ReLU(idx));
555 }
556 }
557 "hard_tanh" => {
558 for idx in 0..dim {
559 layers.push(Layer::HardTanh(idx));
560 }
561 }
562 "hard_sigmoid" => {
563 for idx in 0..dim {
564 layers.push(Layer::HardSigmoid(idx));
565 }
566 }
567 "linear.weights" => {
568 let aff = AffFunc::from_mats(
569 npz.by_name(&format!(
570 "{}.linear.weights.npy",
571 layer_name.get(1).unwrap().as_str()
572 ))?,
573 npz.by_name(&format!(
574 "{}.linear.bias.npy",
575 layer_name.get(1).unwrap().as_str()
576 ))?,
577 );
578 dim = aff.outdim();
579 layers.push(Layer::Linear(aff));
580 }
581 "linear.bias" => {}
582 "layers" => {}
583 other => {
584 panic!("Unknown layer type:\"{}\"", other);
585 }
586 }
587 }
588
589 Ok(layers)
590}
591
592#[cfg(test)]
593mod tests {
594
595 use std::path::Path;
596
597 use approx::assert_relative_eq;
598 use ndarray::arr1;
599
600 use super::*;
601 use crate::aff;
602 use crate::pwl::afftree::AffTree;
603
604 #[test]
605 pub fn test_read_npy() {
606 let layers = read_layers(&Path::new("res/nn/ecoli.npz")).unwrap();
607
608 assert!(matches!(layers[0], Layer::Linear(_)));
609 assert!(matches!(layers[1], Layer::ReLU(0)));
610 assert!(matches!(layers[2], Layer::ReLU(1)));
611 assert!(matches!(layers[3], Layer::ReLU(2)));
612 assert!(matches!(layers[4], Layer::ReLU(3)));
613 assert!(matches!(layers[5], Layer::ReLU(4)));
614 assert!(matches!(layers[6], Layer::Linear(_)));
615 assert!(matches!(layers[7], Layer::ReLU(0)));
616 assert!(matches!(layers[8], Layer::ReLU(1)));
617 assert!(matches!(layers[9], Layer::ReLU(2)));
618 assert!(matches!(layers[10], Layer::ReLU(3)));
619 assert!(matches!(layers[11], Layer::ReLU(4)));
620 assert!(matches!(layers[12], Layer::Linear(_)));
621 }
622
623 #[test]
624 pub fn test_precondition() {
625 let pre = AffTree::from_aff(aff!([[1, 0], [0, 1], [1, 0], [0, 1]] + [0, 0, 0, 0]));
626
627 let layers = vec![
628 Layer::Linear(aff!([[0, 1, 2, 3], [4, 5, 6, 7]] + [-1, 1])),
629 Layer::ReLU(0),
630 Layer::ReLU(1),
631 Layer::Linear(aff!([2, -2] + 0)),
632 ];
633
634 let dd = afftree_from_layers(2, &layers, Some(pre));
635
636 assert_eq!(dd.in_dim(), 2);
637 }
638
639 #[test]
640 pub fn test_afftree_from_layers() {
641 let layers = read_layers(&Path::new("res/nn/ecoli.npz")).unwrap();
642
643 let dd = afftree_from_layers(7, &layers, None);
644
645 assert_relative_eq!(
646 dd.evaluate(&arr1(&[0.0, 1.0, 6.0, 2.0, -100.0, 7.0, -1.0]))
647 .unwrap(),
648 arr1(&[-19.85087719, 79.9919784, 20.50838996, -114.81462218]),
649 epsilon = 1e-08,
650 max_relative = 1e-05
651 );
652
653 assert_relative_eq!(
654 dd.evaluate(&arr1(&[0.0, 1.0, 6.0, 2.0, 0.0, 7.0, -1.0]))
655 .unwrap(),
656 arr1(&[-6.84282267, 10.55842791, -1.14444066, -23.78016759]),
657 epsilon = 1e-08,
658 max_relative = 1e-05
659 );
660
661 assert_relative_eq!(
662 dd.evaluate(&arr1(&[0.0, 1.0, 6.0, 2.0, 0.0, -7.0, -1.0]))
663 .unwrap(),
664 arr1(&[4.25559725, -9.07232097, -16.83579659, -5.00034567]),
665 epsilon = 1e-08,
666 max_relative = 1e-05
667 );
668
669 assert_relative_eq!(
670 dd.evaluate(&arr1(&[5.0, -2.0, 3.0, 50.0, -5.0, -2.0, 8.0]))
671 .unwrap(),
672 arr1(&[4.59256626, 21.26106201, 26.68923948, -43.60172981]),
673 epsilon = 1e-08,
674 max_relative = 1e-05
675 );
676 }
677}