Skip to main content

entrenar/train/tui/charts/
feature_importance.rs

1//! Feature importance bar chart for terminal display.
2
3/// Feature importance bar chart for terminal display.
4#[derive(Debug, Clone)]
5pub struct FeatureImportanceChart {
6    /// Feature names
7    pub(crate) names: Vec<String>,
8    /// Importance scores
9    pub(crate) scores: Vec<f32>,
10    /// Bar width
11    pub(crate) bar_width: usize,
12    /// Number of features to show
13    pub(crate) top_k: usize,
14}
15
16impl FeatureImportanceChart {
17    /// Create a new feature importance chart.
18    pub fn new(top_k: usize, bar_width: usize) -> Self {
19        Self { names: Vec::new(), scores: Vec::new(), bar_width, top_k }
20    }
21
22    /// Update with new importance scores.
23    pub fn update(&mut self, importances: &[(usize, f32)], feature_names: Option<&[String]>) {
24        let mut sorted: Vec<_> = importances.to_vec();
25        sorted
26            .sort_by(|a, b| b.1.abs().partial_cmp(&a.1.abs()).unwrap_or(std::cmp::Ordering::Equal));
27        sorted.truncate(self.top_k);
28
29        self.names.clear();
30        self.scores.clear();
31
32        for (idx, score) in sorted {
33            let name = feature_names
34                .and_then(|n| n.get(idx))
35                .cloned()
36                .unwrap_or_else(|| format!("feature_{idx}"));
37            self.names.push(name);
38            self.scores.push(score);
39        }
40    }
41
42    /// Render to string.
43    pub fn render(&self) -> String {
44        if self.names.is_empty() {
45            return String::from("No feature importance data");
46        }
47
48        let max_name_len = self.names.iter().map(String::len).max().unwrap_or(10);
49        let max_score = self.scores.iter().copied().fold(0.0f32, f32::max);
50
51        let mut output = String::new();
52        output.push_str("┌─ Feature Importance ─────────────────────────────┐\n");
53
54        for (name, score) in self.names.iter().zip(self.scores.iter()) {
55            let bar_len = if max_score > 0.0 {
56                ((score / max_score) * self.bar_width as f32).round() as usize
57            } else {
58                0
59            };
60            let bar: String = "█".repeat(bar_len);
61            output.push_str(&format!(
62                "│  {:width$}  {:bar_width$}  {:.3}  │\n",
63                name,
64                bar,
65                score,
66                width = max_name_len,
67                bar_width = self.bar_width
68            ));
69        }
70
71        output.push_str("└──────────────────────────────────────────────────┘\n");
72        output
73    }
74}