use super::{DecisionTree, TreeNode};
use linfa::{Float, Label};
use std::collections::HashSet;
use std::fmt::Debug;
#[derive(Debug, Clone, PartialEq)]
pub struct Tikz<'a, F: Float, L: Label + Debug> {
legend: bool,
complete: bool,
tree: &'a DecisionTree<F, L>,
}
impl<'a, F: Float, L: Debug + Label> Tikz<'a, F, L> {
pub fn new(tree: &'a DecisionTree<F, L>) -> Self {
Tikz {
legend: false,
complete: true,
tree,
}
}
fn format_node(node: &'a TreeNode<F, L>) -> String {
let depth = vec![""; node.depth() + 1].join("\t");
if let Some(prediction) = node.prediction() {
format!("{depth}[Label: {prediction:?}]")
} else {
let (idx, value, impurity_decrease) = node.split();
let mut out = format!(
"{depth}[Val(${idx}$) $ \\leq {value:.2}$ \\\\ Imp. ${impurity_decrease:.2}$"
);
for child in node.children().into_iter().filter_map(|x| x.as_ref()) {
out.push('\n');
out.push_str(&Self::format_node(child));
}
out.push(']');
out
}
}
pub fn complete(mut self, complete: bool) -> Self {
self.complete = complete;
self
}
pub fn with_legend(mut self) -> Self {
self.legend = true;
self
}
fn legend(&self) -> String {
if self.legend {
let mut map = HashSet::new();
let mut out = "\n".to_string()
+ r#"\node [anchor=north west] at (current bounding box.north east) {%
\begin{tabular}{c c c}
\multicolumn{3}{@{}l@{}}{Legend:}\\
Imp.&:&Impurity decrease\\"#;
for node in self.tree.iter_nodes() {
if !node.is_leaf() && !map.contains(&node.split().0) {
let var = format!(
"Var({})&:&{}\\\\",
node.split().0,
node.feature_name().unwrap_or(&"".to_string())
);
out.push_str(&var);
map.insert(node.split().0);
}
}
out.push_str("\\end{tabular}};");
out
} else {
"".to_string()
}
}
}
use std::fmt;
impl<F: Float, L: Debug + Label> fmt::Display for Tikz<'_, F, L> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut out = if self.complete {
String::from(
r#"
\documentclass[margin=10pt]{standalone}
\usepackage{tikz,forest}
\usetikzlibrary{arrows.meta}"#,
)
} else {
String::from("")
};
out.push_str(
r#"
\forestset{
default preamble={
before typesetting nodes={
!r.replace by={[, coordinate, append]}
},
where n children=0{
tier=word,
}{
%diamond, aspect=2,
},
where level=0{}{
if n=1{
edge label={node[pos=.2, above] {Y}},
}{
edge label={node[pos=.2, above] {N}},
}
},
for tree={
edge+={thick, -Latex},
s sep'+=2cm,
draw,
thick,
edge path'={ (!u) -| (.parent)},
align=center,
}
}
}"#,
);
if self.complete {
out.push_str(r#"\begin{document}"#);
}
out.push_str(r#"\begin{forest}"#);
out.push_str(&Self::format_node(self.tree.root_node()));
out.push_str(&self.legend());
out.push_str("\n\t\\end{forest}\n");
if self.complete {
out.push_str("\\end{document}");
}
write!(f, "{out}")
}
}