linfa_trees/decision_trees/
tikz.rs

1use super::{DecisionTree, TreeNode};
2use linfa::{Float, Label};
3use std::collections::HashSet;
4use std::fmt::Debug;
5
6/// Struct to print a fitted decision tree in Tex using tikz and forest.
7///
8/// There are two settable parameters:
9///
10/// * `legend`: if true, a box with the names of the split features will appear in the top right of the tree
11/// * `complete`: if true, a complete and standalone Tex document will be generated; otherwise the result will an embeddable Tex tree.
12///
13/// ### Usage
14///
15/// ```rust
16/// use linfa::prelude::*;
17/// use linfa_datasets;
18/// use linfa_trees::DecisionTree;
19///
20/// // Load dataset
21/// let dataset = linfa_datasets::iris();
22/// // Fit the tree
23/// let tree = DecisionTree::params().fit(&dataset).unwrap();
24/// // Export to tikz
25/// let tikz = tree.export_to_tikz().with_legend();
26/// let latex_tree = tikz.to_string();
27/// // Now you can write latex_tree to the preferred destination
28///
29/// ```
30#[derive(Debug, Clone, PartialEq)]
31pub struct Tikz<'a, F: Float, L: Label + Debug> {
32    legend: bool,
33    complete: bool,
34    tree: &'a DecisionTree<F, L>,
35}
36
37impl<'a, F: Float, L: Debug + Label> Tikz<'a, F, L> {
38    /// Creates a new Tikz structure for the decision tree
39    /// with the following default parameters:
40    ///
41    /// * `legend=false`
42    /// * `complete=true`
43    pub fn new(tree: &'a DecisionTree<F, L>) -> Self {
44        Tikz {
45            legend: false,
46            complete: true,
47            tree,
48        }
49    }
50
51    fn format_node(node: &'a TreeNode<F, L>) -> String {
52        let depth = vec![""; node.depth() + 1].join("\t");
53        if let Some(prediction) = node.prediction() {
54            format!("{}[Label: {:?}]", depth, prediction)
55        } else {
56            let (idx, value, impurity_decrease) = node.split();
57            let mut out = format!(
58                "{}[Val(${}$) $ \\leq {:.2}$ \\\\ Imp. ${:.2}$",
59                depth, idx, value, impurity_decrease
60            );
61            for child in node.children().into_iter().filter_map(|x| x.as_ref()) {
62                out.push('\n');
63                out.push_str(&Self::format_node(child));
64            }
65            out.push(']');
66
67            out
68        }
69    }
70
71    /// Whether a complete Tex document should be generated
72    pub fn complete(mut self, complete: bool) -> Self {
73        self.complete = complete;
74
75        self
76    }
77
78    /// Add a legend to the generated tree
79    pub fn with_legend(mut self) -> Self {
80        self.legend = true;
81
82        self
83    }
84
85    fn legend(&self) -> String {
86        if self.legend {
87            let mut map = HashSet::new();
88            let mut out = "\n".to_string()
89                + r#"\node [anchor=north west] at (current bounding box.north east) {%
90                \begin{tabular}{c c c}
91                  \multicolumn{3}{@{}l@{}}{Legend:}\\
92                  Imp.&:&Impurity decrease\\"#;
93            for node in self.tree.iter_nodes() {
94                if !node.is_leaf() && !map.contains(&node.split().0) {
95                    let var = format!(
96                        "Var({})&:&{}\\\\",
97                        node.split().0,
98                        // TODO:: why use lengend if there are no feature names? Should it be allowed?
99                        node.feature_name().unwrap_or(&"".to_string())
100                    );
101                    out.push_str(&var);
102                    map.insert(node.split().0);
103                }
104            }
105            out.push_str("\\end{tabular}};");
106            out
107        } else {
108            "".to_string()
109        }
110    }
111}
112
113use std::fmt;
114
115impl<F: Float, L: Debug + Label> fmt::Display for Tikz<'_, F, L> {
116    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
117        let mut out = if self.complete {
118            String::from(
119                r#"
120\documentclass[margin=10pt]{standalone}
121\usepackage{tikz,forest}
122\usetikzlibrary{arrows.meta}"#,
123            )
124        } else {
125            String::from("")
126        };
127        out.push_str(
128            r#"
129\forestset{
130default preamble={
131before typesetting nodes={
132  !r.replace by={[, coordinate, append]}
133},  
134where n children=0{
135  tier=word,
136}{  
137  %diamond, aspect=2,
138},  
139where level=0{}{
140  if n=1{
141    edge label={node[pos=.2, above] {Y}},
142  }{  
143    edge label={node[pos=.2, above] {N}},
144  }   
145},  
146for tree={
147  edge+={thick, -Latex},
148  s sep'+=2cm,
149  draw,
150  thick,
151  edge path'={ (!u) -| (.parent)},
152  align=center,
153}   
154}
155}"#,
156        );
157
158        if self.complete {
159            out.push_str(r#"\begin{document}"#);
160        }
161        out.push_str(r#"\begin{forest}"#);
162
163        out.push_str(&Self::format_node(self.tree.root_node()));
164        out.push_str(&self.legend());
165        out.push_str("\n\t\\end{forest}\n");
166        if self.complete {
167            out.push_str("\\end{document}");
168        }
169
170        write!(f, "{}", out)
171    }
172}