use crate::error::{NeuralError, Result};
use crate::utils::colors::{
colored_metric_cell, colorize, gradient_color, stylize, ColorOptions, Style,
};
use scirs2_core::ndarray::Array1;
use scirs2_core::numeric::Float;
use std::fmt::{Debug, Display};
pub struct FeatureImportance<F: Float + Debug + Display> {
pub feature_names: Vec<String>,
pub importance: Array1<F>,
}
impl<F: Float + Debug + Display> FeatureImportance<F> {
pub fn new(_featurenames: Vec<String>, importance: Array1<F>) -> Result<Self> {
if _featurenames.len() != importance.len() {
return Err(NeuralError::ValidationError(
"Number of feature _names must match number of importance scores".to_string(),
));
}
Ok(FeatureImportance {
feature_names: _featurenames,
importance,
})
}
pub fn top_k(&self, k: usize) -> (Vec<String>, Array1<F>) {
let k = std::cmp::min(k, self.feature_names.len());
let mut indices: Vec<usize> = (0..self.feature_names.len()).collect();
indices.sort_by(|&a, &b| {
self.importance[b]
.partial_cmp(&self.importance[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let top_indices = indices[..k].to_vec();
let top_names = top_indices
.iter()
.map(|&i| self.feature_names[i].clone())
.collect();
let top_importance = Array1::from_iter(top_indices.iter().map(|&i| self.importance[i]));
(top_names, top_importance)
}
pub fn to_ascii(&self, title: Option<&str>, width: usize, k: Option<usize>) -> String {
self.to_ascii_with_options(title, width, k, &ColorOptions::default())
}
pub fn to_ascii_with_options(
&self,
title: Option<&str>,
width: usize,
k: Option<usize>,
color_options: &ColorOptions,
) -> String {
let (features, importance) = if let Some(top_k) = k {
self.top_k(top_k)
} else {
(self.feature_names.clone(), self.importance.clone())
};
let mut result = String::with_capacity(features.len() * 80);
if let Some(titletext) = title {
if color_options.enabled {
result.push_str(&stylize(titletext, Style::Bold));
} else {
result.push_str(titletext);
}
result.push_str("\n\n");
}
let max_importance =
importance
.iter()
.copied()
.fold(F::zero(), |acc, v| if v > acc { v } else { acc });
let max_name_len = features
.iter()
.map(|name| name.len())
.max()
.unwrap_or(10)
.max(10);
let bar_area_width = width.saturating_sub(max_name_len + 10);
const ANSI_PADDING: usize = 9;
let mut indices: Vec<usize> = (0..features.len()).collect();
indices.sort_by(|&a, &b| {
importance[b]
.partial_cmp(&importance[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
for &idx in &indices {
let feature_name = &features[idx];
let imp = importance[idx];
let bar_length = if max_importance > F::zero() {
let ratio = (imp / max_importance).to_f64().unwrap_or(0.0);
(ratio * bar_area_width as f64).round() as usize
} else {
0
};
let formatted_name = if color_options.enabled {
stylize(feature_name, Style::Bold).to_string()
} else {
feature_name.clone()
};
let normalized_imp = if max_importance > F::zero() {
(imp / max_importance).to_f64().unwrap_or(0.0)
} else {
0.0
};
let formatted_imp = if color_options.enabled {
colored_metric_cell(format!("{imp:.3}"), normalized_imp, color_options)
} else {
format!("{imp:.3}")
};
let bar = if color_options.enabled {
if let Some(color) = gradient_color(normalized_imp, color_options) {
colorize("█".repeat(bar_length), color)
} else {
"█".repeat(bar_length)
}
} else {
"█".repeat(bar_length)
};
let name_padding = if color_options.enabled {
ANSI_PADDING
} else {
0
};
result.push_str(&format!(
"{:<width$} | {} | {}\n",
formatted_name,
formatted_imp,
bar,
width = max_name_len + name_padding
));
}
result
}
}