affinitree/distill/
builder.rs

1//   Copyright 2025 affinitree developers
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15//! A collection of high-level methods to distill AffTree instances out of neural networks
16
17use std::borrow::Borrow;
18use std::cmp::min;
19use std::fs::File;
20use std::path::Path;
21use std::time::{Duration, Instant};
22
23use console::style;
24use indicatif::{HumanDuration, ProgressBar, ProgressStyle};
25use itertools::Itertools;
26use ndarray_npy::{NpzReader, ReadNpyError, ReadNpzError};
27use regex::Regex;
28
29use crate::distill::schema::{
30    argmax, class_characterization, partial_ReLU, partial_hard_sigmoid, partial_hard_tanh,
31    partial_leaky_ReLU,
32};
33use crate::linalg::affine::AffFunc;
34use crate::pwl::afftree::AffTree;
35
36/// An estimator for the number of nodes that will be created during the distillation.
37///
38/// Can be used to reserve space in advance or to display progress.
39/// Returns an estimate for the number of layers and the number of nodes.
40pub trait NodeEstimator {
41    fn estimate_nodes<Item: Borrow<Layer>>(
42        &self,
43        dim: usize,
44        current_depth: usize,
45        layers: &[Item],
46    ) -> (usize, usize);
47}
48
49/// A simple [``NodeEstimator``] based on experiments.
50/// Estimation is rather conservative.
51#[derive(Clone, Debug, PartialEq)]
52pub struct SimpleNodeEstimator;
53
54impl NodeEstimator for SimpleNodeEstimator {
55    fn estimate_nodes<Item>(
56        &self,
57        dim: usize,
58        _current_depth: usize,
59        layers: &[Item],
60    ) -> (usize, usize)
61    where
62        Item: Borrow<Layer>,
63    {
64        let mut n_splits = 0;
65        let mut first_dim = 0;
66        let mut last_dim = dim;
67
68        for layer in layers {
69            match layer.borrow() {
70                Layer::Linear(func) => {
71                    last_dim = func.outdim();
72                    if first_dim == 0 {
73                        first_dim = func.outdim();
74                    }
75                }
76                Layer::ReLU(_) => n_splits += 1,
77                Layer::LeakyReLU(_, _) => n_splits += 1,
78                Layer::HardTanh(_) => n_splits += 1,
79                Layer::HardSigmoid(_) => n_splits += 1,
80                Layer::ClassChar(_) => n_splits += 1,
81                Layer::Argmax => n_splits += last_dim,
82            }
83        }
84
85        // Rough underapproximation of the number of nodes
86        // It is assumed that after the first layer each node has on average q children.
87        let n = first_dim as i32;
88        let m = n_splits.saturating_sub(first_dim) as i32;
89        let q = 1.3f64;
90
91        let estimated_capacity =
92            (2.0f64.powi(n) * (2.0 + (q.powi(m + 1) - q) / (q - 1.0))) as usize;
93
94        (n_splits, estimated_capacity)
95    }
96}
97
98/// A visitor pattern for the distillation process.
99///
100/// Use cases include logging or to display progress.
101pub trait DistillVisitor {
102    /// Called once at the start of the distillation process.
103    fn start_distill(&mut self, dim: usize, n_layers: usize, n_nodes: usize);
104    /// Called just before each layer is processed.
105    fn start_layer(&mut self, layer: &Layer);
106    /// Called after each layer is finished.
107    fn finish_layer(
108        &mut self,
109        layer: &Layer,
110        new_nodes: usize,
111        decision_nodes: usize,
112        terminal_nodes: usize,
113    );
114    /// Called once at the end of the distillation process.
115    fn finish_distill(&mut self, total_decisions: usize, total_terminals: usize);
116}
117
118/// A [``DistillVisitor``] which displays a progress bar of the current state
119/// of the distillation at the console.
120/// Additionally, every layer is logged on a single line with the duration that layer took,
121/// and the number of nodes in the tree.
122#[derive(Clone, Debug)]
123pub struct DistillConsole {
124    pb: ProgressBar,
125    timer: Instant,
126    len: usize,
127}
128
129impl DistillConsole {
130    pub fn new() -> DistillConsole {
131        DistillConsole {
132            pb: ProgressBar::hidden(),
133            timer: Instant::now(),
134            len: 0,
135        }
136    }
137}
138
139impl Default for DistillConsole {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145impl DistillVisitor for DistillConsole {
146    fn start_distill(&mut self, dim: usize, n_layers: usize, n_nodes: usize) {
147        self.pb = ProgressBar::new(n_layers as u64);
148        let sty = ProgressStyle::default_bar()
149            .template(&format!(
150                "{: >12} {}",
151                style("Building").cyan().bold(),
152                "[{bar:25}] {pos:>2}/{len:2} ({elapsed})"
153            ))
154            .unwrap()
155            .progress_chars("=> ");
156        self.pb.set_style(sty.clone());
157        self.pb.enable_steady_tick(Duration::from_secs(5));
158
159        println!("Input dim: {}", dim);
160        println!("Estimated number of layers: {}", n_layers);
161        println!("Estimated number of nodes: {}", n_nodes);
162
163        self.timer = Instant::now();
164        self.len = n_layers;
165    }
166
167    fn start_layer(&mut self, _layer: &Layer) {
168        self.timer = Instant::now();
169    }
170
171    fn finish_layer(
172        &mut self,
173        layer: &Layer,
174        _new_nodes: usize,
175        decision_nodes: usize,
176        terminal_nodes: usize,
177    ) {
178        let duration = self.timer.elapsed();
179        match layer {
180            Layer::Linear(_) => {}
181            Layer::ReLU(_) => {
182                self.pb.println(format!(
183                    "{: >12} partial ReLU in {:#} ({} nodes, {} terminals)",
184                    style("Finished").green().bold(),
185                    HumanDuration(duration),
186                    decision_nodes + terminal_nodes,
187                    terminal_nodes
188                ));
189                self.pb.inc(1);
190            }
191            Layer::LeakyReLU(_, _) => {
192                self.pb.println(format!(
193                    "{: >12} partial leaky ReLU in {:#} ({} nodes, {} terminals)",
194                    style("Finished").green().bold(),
195                    HumanDuration(duration),
196                    decision_nodes + terminal_nodes,
197                    terminal_nodes
198                ));
199                self.pb.inc(1);
200            }
201            Layer::HardTanh(_) => {
202                self.pb.println(format!(
203                    "{: >12} partial hard tanh in {:#} ({} nodes, {} terminals)",
204                    style("Finished").green().bold(),
205                    HumanDuration(duration),
206                    decision_nodes + terminal_nodes,
207                    terminal_nodes
208                ));
209                self.pb.inc(1);
210            }
211            Layer::HardSigmoid(_) => {
212                self.pb.println(format!(
213                    "{: >12} partial hard sigmoid in {:#} ({} nodes, {} terminals)",
214                    style("Finished").green().bold(),
215                    HumanDuration(duration),
216                    decision_nodes + terminal_nodes,
217                    terminal_nodes
218                ));
219                self.pb.inc(1);
220            }
221            Layer::ClassChar(_) => {}
222            Layer::Argmax => {}
223        }
224    }
225
226    fn finish_distill(&mut self, total_decisions: usize, total_terminals: usize) {
227        self.pb.finish_and_clear();
228        println!(
229            "\n{: >12} constructing decision tree ({} decisions, {} terminals)",
230            style("Completed").green().bold(),
231            total_decisions,
232            total_terminals
233        );
234    }
235}
236
237#[derive(serde::Serialize)]
238struct CsvRow {
239    depth: usize,
240    inner_nodes: usize,
241    terminal_nodes: usize,
242    time_ms: u128,
243    in_dim: usize,
244}
245
246/// A [``DistillVisitor``] which logs key metrics of the distillation process
247/// to the given csv file.
248/// Also includes a simple progress bar.
249///
250/// The following metrics are logged for each layer: current depth, number of decision predicates,
251/// number of terminal nodes, duration the distillation took (in ms), and input dimension.
252#[derive(Debug)]
253pub struct DistillCsv {
254    writer: csv::Writer<File>,
255    timer: Instant,
256    depth: usize,
257    in_dim: usize,
258    pb: ProgressBar,
259}
260
261impl DistillCsv {
262    pub fn new<P: AsRef<Path>>(path: P) -> DistillCsv {
263        DistillCsv {
264            writer: csv::Writer::from_path(path).unwrap(),
265            timer: Instant::now(),
266            depth: 0,
267            in_dim: 0,
268            pb: ProgressBar::hidden(),
269        }
270    }
271}
272
273impl DistillVisitor for DistillCsv {
274    fn start_distill(&mut self, dim: usize, n_layers: usize, _n_nodes: usize) {
275        self.timer = Instant::now();
276        self.depth = 0;
277        self.in_dim = dim;
278        self.pb = ProgressBar::new(n_layers as u64);
279        let sty = ProgressStyle::default_bar()
280            .template(&format!(
281                "{: >12} {}",
282                style("Building").cyan().bold(),
283                "[{bar:20}] {pos:>2}/{len:2} ({elapsed})"
284            ))
285            .unwrap()
286            .progress_chars("=> ");
287        self.pb.set_style(sty.clone());
288    }
289
290    fn start_layer(&mut self, _layer: &Layer) {
291        self.timer = Instant::now();
292    }
293
294    fn finish_layer(
295        &mut self,
296        layer: &Layer,
297        _new_nodes: usize,
298        decision_nodes: usize,
299        terminal_nodes: usize,
300    ) {
301        let duration = self.timer.elapsed();
302        match layer {
303            Layer::Linear(_) => {}
304            Layer::ReLU(_)
305            | Layer::LeakyReLU(_, _)
306            | Layer::HardTanh(_)
307            | Layer::HardSigmoid(_) => {
308                self.depth += 1;
309                self.writer
310                    .serialize(CsvRow {
311                        depth: self.depth,
312                        inner_nodes: decision_nodes,
313                        terminal_nodes,
314                        time_ms: duration.as_millis(),
315                        in_dim: self.in_dim,
316                    })
317                    .unwrap();
318                self.writer.flush().unwrap();
319                self.pb.inc(1);
320            }
321            Layer::ClassChar(_) => {}
322            Layer::Argmax => {}
323        }
324    }
325
326    fn finish_distill(&mut self, _total_decisions: usize, _total_terminals: usize) {
327        self.writer.flush().unwrap();
328        self.pb.finish_and_clear();
329    }
330}
331
332/// A [``DistillVisitor``] which performs no operation.
333#[derive(Clone, Debug)]
334pub struct NoOpVis {}
335
336impl DistillVisitor for NoOpVis {
337    fn start_distill(&mut self, _: usize, _: usize, _: usize) {}
338
339    fn start_layer(&mut self, _: &Layer) {}
340
341    fn finish_layer(&mut self, _: &Layer, _: usize, _: usize, _: usize) {}
342
343    fn finish_distill(&mut self, _: usize, _: usize) {}
344}
345
346/// A simple enum type to conveniently specify the layer structure of a neural network.
347/// Each ``Layer`` corresponds to one piece-wise linear function.
348#[derive(Debug, Clone)]
349pub enum Layer {
350    /// A fully connected linear layer
351    Linear(AffFunc),
352    /// The ReLU applied to the i-th component of the input
353    ReLU(usize),
354    /// The Leaky ReLU applied to the i-th component of the input
355    LeakyReLU(usize, f64),
356    /// The hard hyperbolic tangent applied to the i-th component of the input
357    HardTanh(usize),
358    /// The hard sigmoid applied to the i-th component of the input
359    HardSigmoid(usize),
360    /// The argmax function
361    Argmax,
362    /// A binary version of the argmax called class characterization.
363    /// The result is a boolean indicating whether the input belongs to the specified class or not.
364    ClassChar(usize),
365}
366
367/// Wrapper of [`afftree_from_layers_generic`] that prints nothing to the terminal.
368pub fn afftree_from_layers<I>(dim: usize, layers: I, precondition: Option<AffTree<2>>) -> AffTree<2>
369where
370    I: IntoIterator,
371    I::Item: Borrow<Layer>,
372{
373    afftree_from_layers_generic(
374        dim,
375        layers,
376        precondition,
377        &mut SimpleNodeEstimator {},
378        &mut NoOpVis {},
379    )
380}
381
382/// Wrapper of [`afftree_from_layers_generic`] that logs the progress
383/// after each layer to the console.
384///
385/// For details on the console output, see also [``DistillConsole``].
386pub fn afftree_from_layers_verbose<I>(
387    dim: usize,
388    layers: I,
389    precondition: Option<AffTree<2>>,
390) -> AffTree<2>
391where
392    I: IntoIterator,
393    I::Item: Borrow<Layer>,
394{
395    afftree_from_layers_generic(
396        dim,
397        layers,
398        precondition,
399        &mut SimpleNodeEstimator {},
400        &mut DistillConsole::new(),
401    )
402}
403
404/// Wrapper of [`afftree_from_layers_generic`] that logs characteristics of the tree
405/// after each layer to a csv file located at ``path``.
406///
407/// For details on the logging, see also [``DistillCsv``].
408pub fn afftree_from_layers_csv<I, P: AsRef<Path>>(
409    dim: usize,
410    layers: I,
411    path: P,
412    precondition: Option<AffTree<2>>,
413) -> AffTree<2>
414where
415    I: IntoIterator,
416    I::Item: Borrow<Layer>,
417{
418    afftree_from_layers_generic(
419        dim,
420        layers,
421        precondition,
422        &mut SimpleNodeEstimator {},
423        &mut DistillCsv::new(path),
424    )
425}
426
427/// Generic implementation of the distillation process.
428///
429/// The provided sequence of ``layers`` is mapped to equivalent
430/// [``AffTree``] instances based on [``crate::distill::schema``]. Then, this sequence is composed into a single
431/// [``AffTree``] using [`AffTree::compose`]. In between each composition
432/// step the tree is pruned using [`AffTree::infeasible_elimination`].
433///
434/// Behavior can be customized by providing an appropriate ``visitor``.
435pub fn afftree_from_layers_generic<I, Estimator, Visitor>(
436    dim: usize,
437    layers: I,
438    precondition: Option<AffTree<2>>,
439    node_estimator: &mut Estimator,
440    visitor: &mut Visitor,
441) -> AffTree<2>
442where
443    I: IntoIterator,
444    I::Item: Borrow<Layer>,
445    Estimator: NodeEstimator,
446    Visitor: DistillVisitor,
447{
448    let container = layers.into_iter().collect_vec();
449
450    let (n_layers, n_nodes) = node_estimator.estimate_nodes(dim, 0, &container);
451    let mut dim = dim;
452
453    // bound maximum number of memory reserved in advance
454    let n_nodes = min(n_nodes, 524288);
455
456    let mut dd = if let Some(dd) = precondition {
457        assert_eq!(
458            dd.in_dim(),
459            dim,
460            "Specified dim does not match dim of precondition: {} vs {}",
461            dim,
462            dd.in_dim()
463        );
464        dim = dd.terminals().map(|x| x.aff.outdim()).next().unwrap();
465        dd
466    } else {
467        AffTree::<2>::with_capacity(dim, n_nodes)
468    };
469
470    dd.reserve(n_nodes);
471
472    visitor.start_distill(dim, n_layers, n_nodes);
473
474    for layer in container.into_iter() {
475        let layer = layer.borrow();
476        let old_len = dd.len();
477        visitor.start_layer(layer);
478        match layer {
479            Layer::Linear(aff) => {
480                assert!(
481                    aff.indim() == dim,
482                    "Input dimension of current layer does not match output dimension of previous layer: prev={} vs current={}",
483                    dim,
484                    aff.indim()
485                );
486                dd.apply_func(aff);
487                dim = aff.outdim();
488            }
489            Layer::ReLU(row) => {
490                dd.compose::<false, false>(&partial_ReLU(dim, *row));
491                dd.infeasible_elimination();
492            }
493            Layer::LeakyReLU(row, alpha) => {
494                dd.compose::<false, false>(&partial_leaky_ReLU(dim, *row, *alpha));
495                dd.infeasible_elimination();
496            }
497            Layer::HardTanh(row) => {
498                dd.compose::<false, false>(&partial_hard_tanh(dim, *row, -1., 1.));
499                dd.infeasible_elimination();
500            }
501            Layer::HardSigmoid(row) => {
502                dd.compose::<false, false>(&partial_hard_sigmoid(dim, *row));
503                dd.infeasible_elimination();
504            }
505            Layer::Argmax => {
506                dd.compose::<true, true>(&argmax(dim));
507            }
508            Layer::ClassChar(clazz) => {
509                dd.compose::<true, false>(&class_characterization(dim, *clazz));
510            }
511        }
512        visitor.finish_layer(
513            layer,
514            dd.len() - old_len,
515            dd.len() - dd.tree.num_terminals(),
516            dd.tree.num_terminals(),
517        );
518    }
519    visitor.finish_distill(dd.len() - dd.tree.num_terminals(), dd.tree.num_terminals());
520
521    dd
522}
523
524/// Parses numpy's "npz" file format and returns the contained layers.
525///
526/// Affinitree supports a dialect based on numpy's "npz" format to encode the layers of
527/// a piece-wise linear neural network. The structure is significantly simpler than
528/// other formats such as "onnx".
529///
530/// The underlying format from numpy is easy to use. It is described at
531/// <https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html>
532pub fn read_layers<P: AsRef<Path>>(path: &P) -> Result<Vec<Layer>, ReadNpzError> {
533    let file = File::open(path).map_err(ReadNpyError::from)?;
534    let mut npz = NpzReader::new(file)?;
535
536    let mut names = npz.names()?;
537    names.sort_unstable();
538
539    let pattern = Regex::new(r"^(\d+)\.([A-Za-z._]*?)(\.npy)?$").unwrap();
540
541    let mut layers = Vec::with_capacity(names.len());
542    let mut dim: usize = 0;
543
544    for layer_descr in names.iter() {
545        let layer_name = pattern.captures(layer_descr);
546        //    .expect(&format!("Layer description unknown: {}", layer_descr));
547        if layer_name.is_none() {
548            continue;
549        }
550        let layer_name = layer_name.unwrap();
551        match layer_name.get(2).unwrap().as_str() {
552            "relu" => {
553                for idx in 0..dim {
554                    layers.push(Layer::ReLU(idx));
555                }
556            }
557            "hard_tanh" => {
558                for idx in 0..dim {
559                    layers.push(Layer::HardTanh(idx));
560                }
561            }
562            "hard_sigmoid" => {
563                for idx in 0..dim {
564                    layers.push(Layer::HardSigmoid(idx));
565                }
566            }
567            "linear.weights" => {
568                let aff = AffFunc::from_mats(
569                    npz.by_name(&format!(
570                        "{}.linear.weights.npy",
571                        layer_name.get(1).unwrap().as_str()
572                    ))?,
573                    npz.by_name(&format!(
574                        "{}.linear.bias.npy",
575                        layer_name.get(1).unwrap().as_str()
576                    ))?,
577                );
578                dim = aff.outdim();
579                layers.push(Layer::Linear(aff));
580            }
581            "linear.bias" => {}
582            "layers" => {}
583            other => {
584                panic!("Unknown layer type:\"{}\"", other);
585            }
586        }
587    }
588
589    Ok(layers)
590}
591
592#[cfg(test)]
593mod tests {
594
595    use std::path::Path;
596
597    use approx::assert_relative_eq;
598    use ndarray::arr1;
599
600    use super::*;
601    use crate::aff;
602    use crate::pwl::afftree::AffTree;
603
604    #[test]
605    pub fn test_read_npy() {
606        let layers = read_layers(&Path::new("res/nn/ecoli.npz")).unwrap();
607
608        assert!(matches!(layers[0], Layer::Linear(_)));
609        assert!(matches!(layers[1], Layer::ReLU(0)));
610        assert!(matches!(layers[2], Layer::ReLU(1)));
611        assert!(matches!(layers[3], Layer::ReLU(2)));
612        assert!(matches!(layers[4], Layer::ReLU(3)));
613        assert!(matches!(layers[5], Layer::ReLU(4)));
614        assert!(matches!(layers[6], Layer::Linear(_)));
615        assert!(matches!(layers[7], Layer::ReLU(0)));
616        assert!(matches!(layers[8], Layer::ReLU(1)));
617        assert!(matches!(layers[9], Layer::ReLU(2)));
618        assert!(matches!(layers[10], Layer::ReLU(3)));
619        assert!(matches!(layers[11], Layer::ReLU(4)));
620        assert!(matches!(layers[12], Layer::Linear(_)));
621    }
622
623    #[test]
624    pub fn test_precondition() {
625        let pre = AffTree::from_aff(aff!([[1, 0], [0, 1], [1, 0], [0, 1]] + [0, 0, 0, 0]));
626
627        let layers = vec![
628            Layer::Linear(aff!([[0, 1, 2, 3], [4, 5, 6, 7]] + [-1, 1])),
629            Layer::ReLU(0),
630            Layer::ReLU(1),
631            Layer::Linear(aff!([2, -2] + 0)),
632        ];
633
634        let dd = afftree_from_layers(2, &layers, Some(pre));
635
636        assert_eq!(dd.in_dim(), 2);
637    }
638
639    #[test]
640    pub fn test_afftree_from_layers() {
641        let layers = read_layers(&Path::new("res/nn/ecoli.npz")).unwrap();
642
643        let dd = afftree_from_layers(7, &layers, None);
644
645        assert_relative_eq!(
646            dd.evaluate(&arr1(&[0.0, 1.0, 6.0, 2.0, -100.0, 7.0, -1.0]))
647                .unwrap(),
648            arr1(&[-19.85087719, 79.9919784, 20.50838996, -114.81462218]),
649            epsilon = 1e-08,
650            max_relative = 1e-05
651        );
652
653        assert_relative_eq!(
654            dd.evaluate(&arr1(&[0.0, 1.0, 6.0, 2.0, 0.0, 7.0, -1.0]))
655                .unwrap(),
656            arr1(&[-6.84282267, 10.55842791, -1.14444066, -23.78016759]),
657            epsilon = 1e-08,
658            max_relative = 1e-05
659        );
660
661        assert_relative_eq!(
662            dd.evaluate(&arr1(&[0.0, 1.0, 6.0, 2.0, 0.0, -7.0, -1.0]))
663                .unwrap(),
664            arr1(&[4.25559725, -9.07232097, -16.83579659, -5.00034567]),
665            epsilon = 1e-08,
666            max_relative = 1e-05
667        );
668
669        assert_relative_eq!(
670            dd.evaluate(&arr1(&[5.0, -2.0, 3.0, 50.0, -5.0, -2.0, 8.0]))
671                .unwrap(),
672            arr1(&[4.59256626, 21.26106201, 26.68923948, -43.60172981]),
673            epsilon = 1e-08,
674            max_relative = 1e-05
675        );
676    }
677}