linfa_trees/decision_trees/
tikz.rs1use super::{DecisionTree, TreeNode};
2use linfa::{Float, Label};
3use std::collections::HashSet;
4use std::fmt::Debug;
5
6#[derive(Debug, Clone, PartialEq)]
31pub struct Tikz<'a, F: Float, L: Label + Debug> {
32 legend: bool,
33 complete: bool,
34 tree: &'a DecisionTree<F, L>,
35}
36
37impl<'a, F: Float, L: Debug + Label> Tikz<'a, F, L> {
38 pub fn new(tree: &'a DecisionTree<F, L>) -> Self {
44 Tikz {
45 legend: false,
46 complete: true,
47 tree,
48 }
49 }
50
51 fn format_node(node: &'a TreeNode<F, L>) -> String {
52 let depth = vec![""; node.depth() + 1].join("\t");
53 if let Some(prediction) = node.prediction() {
54 format!("{}[Label: {:?}]", depth, prediction)
55 } else {
56 let (idx, value, impurity_decrease) = node.split();
57 let mut out = format!(
58 "{}[Val(${}$) $ \\leq {:.2}$ \\\\ Imp. ${:.2}$",
59 depth, idx, value, impurity_decrease
60 );
61 for child in node.children().into_iter().filter_map(|x| x.as_ref()) {
62 out.push('\n');
63 out.push_str(&Self::format_node(child));
64 }
65 out.push(']');
66
67 out
68 }
69 }
70
71 pub fn complete(mut self, complete: bool) -> Self {
73 self.complete = complete;
74
75 self
76 }
77
78 pub fn with_legend(mut self) -> Self {
80 self.legend = true;
81
82 self
83 }
84
85 fn legend(&self) -> String {
86 if self.legend {
87 let mut map = HashSet::new();
88 let mut out = "\n".to_string()
89 + r#"\node [anchor=north west] at (current bounding box.north east) {%
90 \begin{tabular}{c c c}
91 \multicolumn{3}{@{}l@{}}{Legend:}\\
92 Imp.&:&Impurity decrease\\"#;
93 for node in self.tree.iter_nodes() {
94 if !node.is_leaf() && !map.contains(&node.split().0) {
95 let var = format!(
96 "Var({})&:&{}\\\\",
97 node.split().0,
98 node.feature_name().unwrap_or(&"".to_string())
100 );
101 out.push_str(&var);
102 map.insert(node.split().0);
103 }
104 }
105 out.push_str("\\end{tabular}};");
106 out
107 } else {
108 "".to_string()
109 }
110 }
111}
112
113use std::fmt;
114
115impl<F: Float, L: Debug + Label> fmt::Display for Tikz<'_, F, L> {
116 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
117 let mut out = if self.complete {
118 String::from(
119 r#"
120\documentclass[margin=10pt]{standalone}
121\usepackage{tikz,forest}
122\usetikzlibrary{arrows.meta}"#,
123 )
124 } else {
125 String::from("")
126 };
127 out.push_str(
128 r#"
129\forestset{
130default preamble={
131before typesetting nodes={
132 !r.replace by={[, coordinate, append]}
133},
134where n children=0{
135 tier=word,
136}{
137 %diamond, aspect=2,
138},
139where level=0{}{
140 if n=1{
141 edge label={node[pos=.2, above] {Y}},
142 }{
143 edge label={node[pos=.2, above] {N}},
144 }
145},
146for tree={
147 edge+={thick, -Latex},
148 s sep'+=2cm,
149 draw,
150 thick,
151 edge path'={ (!u) -| (.parent)},
152 align=center,
153}
154}
155}"#,
156 );
157
158 if self.complete {
159 out.push_str(r#"\begin{document}"#);
160 }
161 out.push_str(r#"\begin{forest}"#);
162
163 out.push_str(&Self::format_node(self.tree.root_node()));
164 out.push_str(&self.legend());
165 out.push_str("\n\t\\end{forest}\n");
166 if self.complete {
167 out.push_str("\\end{document}");
168 }
169
170 write!(f, "{}", out)
171 }
172}