use crate::error::{NeuralError, Result};
use scirs2_core::ndarray::{Array, ArrayD, Dimension, IxDyn};
use scirs2_core::numeric::Float;
use scirs2_core::ndarray::ArrayStatCompat;
use std::collections::HashMap;
use std::fmt::Debug;
use std::iter::Sum;
use statrs::statistics::Statistics;
#[derive(Debug, Clone, PartialEq)]
pub enum DistanceMetric {
L1,
L2,
LInf,
Weighted(Vec<f64>),
}
pub enum PerturbationStrategy {
BinaryMask,
GaussianNoise {
std: f64,
},
UniformNoise {
range: f64,
FeatureDropping {
drop_prob: f64,
SuperpixelMask {
num_superpixels: usize,
pub enum AttentionAggregation {
Mean,
Max,
Min,
Std,
SelectHead(usize),
pub enum LRPRule {
Epsilon,
Gamma {
gamma: f64,
AlphaBeta {
alpha: f64,
beta: f64,
ZPlus,
ZB {
low: f64,
high: f64,
pub enum AttributionMethod {
Saliency,
IntegratedGradients {
baseline: BaselineMethod,
num_steps: usize,
GradCAM {
target_layer: String,
GuidedBackprop,
DeepLIFT {
SHAP {
background_samples: usize,
num_samples: usize,
LayerWiseRelevancePropagation {
rule: LRPRule,
epsilon: f64,
SmoothGrad {
base_method: Box<AttributionMethod>,
noise_std: f64,
InputXGradient,
ExpectedGradients {
num_references: usize,
pub enum BaselineMethod {
Zero,
Random {
seed: u64,
GaussianBlur {
sigma: f64,
TrainingMean,
Custom(ArrayD<f32>),
pub enum VisualizationMethod {
ActivationMaximization {
target_unit: Option<usize>,
num_iterations: usize,
learning_rate: f64,
DeepDream {
amplify_factor: f64,
FeatureInversion {
regularization_weight: f64,
ClassActivationMapping {
target_class: usize,
NetworkDissection {
concept_data: Vec<ArrayD<f32>>,
concept_labels: Vec<String>,
#[allow(dead_code)]
pub struct CounterfactualGenerator<F: Float + Debug> {
max_features: usize,
learning_rate: f64,
max_iterations: usize,
distance_metric: DistanceMetric,
original_predictions: HashMap<String, F>,
#[derive(Debug, Clone)]
pub struct ConceptActivationVector<F: Float + Debug> {
pub name: String,
pub layer_name: String,
pub direction_vector: ArrayD<F>,
pub sensitivity: F,
pub positive_examples: Vec<ArrayD<F>>,
pub negative_examples: Vec<ArrayD<F>>,
pub struct LIMEExplainer<F: Float + Debug> {
num_samples: usize,
num_features: usize,
kernel_width: f64,
perturbation_strategy: PerturbationStrategy,
random_seed: Option<u64>,
importance_threshold: F,
pub struct AttentionVisualizer<F: Float + Debug> {
attention_heads: Vec<String>,
aggregation_method: AttentionAggregation,
attention_cache: HashMap<String, ArrayD<F>>,
pub struct AdversarialExplanation<F: Float + Debug> {
pub original_input: ArrayD<F>,
pub adversarial_input: ArrayD<F>,
pub perturbation: ArrayD<F>,
pub original_prediction: usize,
pub adversarial_prediction: usize,
pub original_confidence: F,
pub adversarial_confidence: F,
pub attack_method: String,
pub attack_parameters: HashMap<String, f64>,
pub struct NetworkDissectionResult<F: Float + Debug> {
pub neuron_index: usize,
pub concept_scores: HashMap<String, F>,
pub top_concepts: Vec<(String, F)>,
pub selectivity_threshold: F,
pub num_test_images: usize,
pub struct LayerAnalysisStats<F: Float + Debug> {
pub mean_activation: F,
pub std_activation: F,
pub max_activation: F,
pub min_activation: F,
pub dead_neuron_percentage: f64,
pub sparsity: f64,
pub histogram: Vec<u32>,
pub bin_edges: Vec<F>,
pub struct ModelInterpreter<
F: Float
+ Debug
+ 'static
+ scirs2_core::ndarray::ScalarOperand
+ scirs2_core::numeric::FromPrimitive
+ Sum
+ Clone
+ Copy,
> {
attribution_methods: Vec<AttributionMethod>,
gradient_cache: HashMap<String, ArrayD<F>>,
activation_cache: HashMap<String, ArrayD<F>>,
#[allow(dead_code)]
layer_statistics: HashMap<String, LayerAnalysisStats<F>>,
counterfactual_generator: Option<CounterfactualGenerator<F>>,
concept_vectors: HashMap<String, ConceptActivationVector<F>>,
lime_explainer: Option<LIMEExplainer<F>>,
attention_visualizer: Option<AttentionVisualizer<F>>,
impl<
F: Float
+ Debug
+ 'static
+ scirs2_core::ndarray::ScalarOperand
+ scirs2_core::numeric::FromPrimitive
+ Sum
+ Clone
+ Copy,
> ModelInterpreter<F>
{
pub fn new() -> Self {
Self {
attribution_methods: Vec::new(),
gradient_cache: HashMap::new(),
activation_cache: HashMap::new(),
layer_statistics: HashMap::new(),
counterfactual_generator: None,
concept_vectors: HashMap::new(),
lime_explainer: None,
attention_visualizer: None,
}
}
pub fn add_attribution_method(&mut self, method: AttributionMethod) {
self.attribution_methods.push(method);
pub fn cache_activations(&mut self, layername: String, activations: ArrayD<F>) {
self.activation_cache.insert(layer_name, activations);
pub fn cache_gradients(&mut self, layername: String, gradients: ArrayD<F>) {
self.gradient_cache.insert(layer_name, gradients);
pub fn compute_attribution(
&self,
method: &AttributionMethod,
input: &ArrayD<F>,
target_class: Option<usize>,
) -> Result<ArrayD<F>> {
match method {
AttributionMethod::Saliency => self.compute_saliency_attribution(input, target_class),
AttributionMethod::IntegratedGradients {
baseline,
num_steps,
} => self.compute_integrated_gradients(input, baseline, *num_steps, target_class),
AttributionMethod::GradCAM { target_layer } => {
self.compute_gradcam_attribution(input, target_layer, target_class)
}
AttributionMethod::GuidedBackprop => {
self.compute_guided_backprop_attribution(input, target_class)
AttributionMethod::DeepLIFT { baseline } => {
self.compute_deeplift_attribution(input, baseline, target_class)
AttributionMethod::SHAP {
background_samples,
num_samples,
} => self.compute_shap_attribution(
input,
*background_samples,
*num_samples,
target_class,
),
AttributionMethod::LayerWiseRelevancePropagation { rule, epsilon } => {
self.compute_lrp_attribution(input, rule, *epsilon, target_class)
AttributionMethod::SmoothGrad {
base_method,
noise_std,
} => self.compute_smoothgrad_attribution(
*noise_std,
AttributionMethod::InputXGradient => {
self.compute_input_x_gradient_attribution(input, target_class)
AttributionMethod::ExpectedGradients {
num_references,
} => self.compute_expected_gradients_attribution(
*num_references,
*num_steps,
pub fn enable_counterfactual_explanations(
&mut self,
max_features: usize,
max_iterations: usize,
distance_metric: DistanceMetric,
) {
self.counterfactual_generator = Some(CounterfactualGenerator {
max_features,
learning_rate,
max_iterations,
distance_metric,
original_predictions: HashMap::new(),
});
fn compute_saliency_attribution(
_target_class: Option<usize>,
let grad_key = "input_gradient";
if let Some(gradient) = self.gradient_cache.get(grad_key) {
Ok(gradient.mapv(|x| x.abs()))
} else {
let attribution = input.mapv(|_| F::from(0.5).expect("Failed to convert constant to float"));
Ok(attribution)
fn compute_integrated_gradients(
baseline: &BaselineMethod,
let baseline_input = self.create_baseline(input, baseline)?;
let mut accumulated_gradients = Array::zeros(input.raw_dim());
for i in 0..num_steps {
let alpha = F::from(i as f64 / (num_steps - 1) as f64).expect("Operation failed");
let interpolated_input = &baseline_input + (&(input.clone() - &baseline_input) * alpha);
let step_gradient = interpolated_input.mapv(|x| x * F::from(0.1).expect("Failed to convert constant to float"));
accumulated_gradients = accumulated_gradients + step_gradient;
let integrated_gradients =
(input - &baseline_input) * accumulated_gradients / F::from(num_steps).expect("Failed to convert to float");
Ok(integrated_gradients)
fn compute_gradcam_attribution(
target_layer: &str,
let activations = self.activation_cache.get(target_layer).ok_or_else(|| {
NeuralError::ComputationError(format!(
"Activations not found for _layer: {}",
target_layer
))
})?;
let gradients = self.gradient_cache.get(target_layer).ok_or_else(|| {
"Gradients not found for layer: {}",
if activations.ndim() < 3 {
return Err(NeuralError::InvalidArchitecture(
"GradCAM requires at least 3D activations (batch, channels, spatial)".to_string(),
));
let mut weights = Vec::new();
let num_channels = activations.shape()[1];
for c in 0..num_channels {
let channel_grad = gradients.index_axis(scirs2_core::ndarray::Axis(1), c);
let weight = channel_grad.mean_or(F::zero());
weights.push(weight);
let first_channel = activations
.index_axis(scirs2_core::ndarray::Axis(1), 0)
.to_owned()
.into_dyn();
let mut gradcam = Array::zeros(first_channel.raw_dim());
for (c, &weight) in weights.iter().enumerate().take(num_channels) {
let channel_activation = activations
.index_axis(scirs2_core::ndarray::Axis(1), c)
.to_owned()
.into_dyn();
let weighted_activation = channel_activation * weight;
gradcam = gradcam + weighted_activation;
let gradcam_relu = gradcam.mapv(|x: F| x.max(F::zero()));
if gradcam_relu.raw_dim() != input.raw_dim() {
self.resize_attribution(&gradcam_relu, input.raw_dim())
Ok(gradcam_relu)
fn compute_guided_backprop_attribution(
_input: &ArrayD<F>,
if let Some(gradient) = self.gradient_cache.get("input_gradient") {
Ok(gradient.mapv(|x| x.max(F::zero())))
Ok(Array::zeros(_input.raw_dim()))
fn compute_deeplift_attribution(
let diff = input - &baseline_input;
Ok(&diff * gradient)
Ok(diff)
fn compute_shap_attribution(
_background_samples: usize,
let mut total_attribution = Array::zeros(input.raw_dim());
for _ in 0..num_samples {
let coalition_mask = input.mapv(|_| {
if scirs2_core::random::random::<f64>() > 0.5 {
F::one()
} else {
F::zero()
}
});
let marginal_contribution = input * &coalition_mask * F::from(0.1).expect("Failed to convert constant to float");
total_attribution = total_attribution + marginal_contribution;
Ok(total_attribution / F::from(num_samples).expect("Failed to convert to float"))
fn compute_lrp_attribution(
rule: &LRPRule,
match rule {
LRPRule::Epsilon => {
if let Some(gradient) = self.gradient_cache.get("input_gradient") {
let eps = F::from(epsilon).expect("Failed to convert to float");
let denominator = gradient.mapv(|x| x + eps.copysign(x));
Ok(input * gradient / denominator)
Ok(input.clone())
LRPRule::Gamma { gamma } => {
let gamma_val = F::from(*gamma).expect("Failed to convert to float");
let positive_part = gradient.mapv(|x| x.max(F::zero()));
let negative_part = gradient.mapv(|x| x.min(F::zero()));
Ok(input * (positive_part * (F::one() + gamma_val) + negative_part)), LRPRule::AlphaBeta { alpha, beta } => {
let alpha_val = F::from(*alpha).expect("Failed to convert to float");
let beta_val = F::from(*beta).expect("Failed to convert to float");
Ok(input * (positive_part * alpha_val - negative_part * beta_val))
LRPRule::ZPlus => {
let positive_input = input.mapv(|x| x.max(F::zero()));
Ok(positive_input * gradient)
Ok(input.mapv(|x| x.max(F::zero()))), LRPRule::ZB { low, high } => {
let low_val = F::from(*low).expect("Failed to convert to float");
let high_val = F::from(*high).expect("Failed to convert to float");
let clamped_input = input.mapv(|x| x.max(low_val).min(high_val));
Ok(clamped_input * gradient)
fn compute_smoothgrad_attribution(
base_method: &AttributionMethod,
let mut accumulated_attribution = Array::zeros(input.raw_dim());
let noise_scale = F::from(noise_std).expect("Failed to convert to float");
let noisy_input = input.mapv(|x| {
let noise = F::from(scirs2_core::random::random::<f64>() - 0.5).expect("Operation failed") * noise_scale;
x + noise
let attribution = self.compute_attribution(base_method, &noisy_input, target_class)?;
accumulated_attribution = accumulated_attribution + attribution;
Ok(accumulated_attribution / F::from(num_samples).expect("Failed to convert to float"))
fn compute_input_x_gradient_attribution(
Ok(input * gradient)
Ok(input.clone())
fn compute_expected_gradients_attribution(
for _ in 0..num_references {
let reference = input.mapv(|_| F::from(scirs2_core::random::random::<f64>()).expect("Operation failed"));
let baseline =
BaselineMethod::Custom(reference.mapv(|x| x.to_f64().unwrap_or(0.0) as f32));
let attribution =
self.compute_integrated_gradients(input, &baseline, num_steps, target_class)?;
total_attribution = total_attribution + attribution;
Ok(total_attribution / F::from(num_references).expect("Failed to convert to float"))
fn create_baseline(&self, input: &ArrayD<F>, method: &BaselineMethod) -> Result<ArrayD<F>> {
BaselineMethod::Zero => Ok(Array::zeros(input.raw_dim())),
BaselineMethod::Random { seed: _ } => {
Ok(input.mapv(|_| F::from(scirs2_core::random::random::<f64>()).expect("Operation failed"))), BaselineMethod::GaussianBlur { sigma: _ } => {
let blurred = input.mapv(|x| x * F::from(0.5).expect("Failed to convert constant to float"));
Ok(blurred)
BaselineMethod::TrainingMean => Ok(Array::zeros(input.raw_dim())),
BaselineMethod::Custom(baseline) => {
if baseline.shape() == input.shape() {
let converted = baseline.mapv(|x| F::from(x as f64).unwrap_or(F::zero()));
Ok(converted)
Err(NeuralError::DimensionMismatch(
"Custom baseline shape doesn't match input".to_string(),
))
fn resize_attribution(
attribution: &ArrayD<F>,
targetshape: IxDyn,
if attribution.len() == targetshape.size() {
Ok(attribution.clone().into_shape_with_order(targetshape)?)
let mean_val = attribution.mean_or(F::zero());
Ok(Array::from_elem(targetshape, mean_val))
/// Analyze layer activations and store statistics
pub fn analyze_layer_activations(
layer_name: String,
activations: &ArrayD<F>,
) -> Result<()> {
// Cache the activations
self.cache_activations(layer_name.clone(), activations.clone());
// Compute statistics
let flattened = activations
.view()
.into_shape_with_order(activations.len())?;
let mean_activation = flattened.mean_or(F::zero());
let variance = flattened
.mapv(|x| (x - mean_activation) * (x - mean_activation))
.mean()
.unwrap_or(F::zero());
let std_activation = variance.sqrt();
let max_activation =
flattened.fold(F::neg_infinity(), |acc, &x| if x > acc { x } else { acc });
let min_activation = flattened.fold(F::infinity(), |acc, &x| if x < acc { x } else { acc });
let dead_count = flattened.iter().filter(|&&x| x == F::zero()).count();
let dead_neuron_percentage = (dead_count as f64 / flattened.len() as f64) * 100.0;
let threshold = F::from(1e-6).unwrap_or(F::zero());
let sparse_count = flattened.iter().filter(|&&x| x.abs() < threshold).count();
let sparsity = (sparse_count as f64 / flattened.len() as f64) * 100.0;
let num_bins = 10;
let range = max_activation - min_activation;
let bin_width = if range > F::zero() {
range / F::from(num_bins).expect("Failed to convert to float"), F::one()
};
let mut histogram = vec![0u32; num_bins];
let mut bin_edges = Vec::new();
for i in 0..=num_bins {
bin_edges.push(min_activation + F::from(i).expect("Failed to convert to float") * bin_width);
// Fill histogram
for &activation in flattened.iter() {
if range > F::zero() {
let bin_index = ((activation - min_activation) / bin_width)
.to_usize()
.unwrap_or(0);
let bin_index = bin_index.min(num_bins - 1);
histogram[bin_index] += 1;
} else {
histogram[0] += 1; // All values in first bin if no range
let stats = LayerAnalysisStats {
mean_activation,
std_activation,
max_activation,
min_activation,
dead_neuron_percentage,
sparsity,
histogram,
bin_edges,
self.layer_statistics.insert(layer_name, stats);
Ok(())
/// Get statistics for a specific layer
pub fn get_layer_statistics(&self, layername: &str) -> Option<&LayerAnalysisStats<F>> {
self.layer_statistics.get(layer_name)
/// Generate a comprehensive interpretation report
pub fn generate_interpretation_report(
) -> Result<InterpretationReport<F>> {
let inputshape = input.shape().to_vec();
// Generate attributions for all available methods
let mut attributions = HashMap::new();
let mut attribution_statistics = HashMap::new();
for method in &self.attribution_methods {
if let Ok(attribution) = self.compute_attribution(method, input, target_class) {
let method_name = format!("{:?}", method);
// Compute attribution statistics
let flattened = attribution
.view()
.into_shape_with_order(attribution.len())?;
let mean = flattened.mean_or(F::zero());
let mean_absolute = flattened.mapv(|x| x.abs()).mean_or(F::zero());
let max_absolute =
flattened.fold(
F::zero(),
|acc, &x| if x.abs() > acc { x.abs() } else { acc },
);
let positive_count = flattened.iter().filter(|&&x| x > F::zero()).count();
let positive_attribution_ratio = positive_count as f64 / flattened.len() as f64;
let total_positive_attribution = flattened
.iter()
.filter(|&&x| x > F::zero())
.fold(F::zero(), |acc, &x| acc + x);
let total_negative_attribution = flattened
.filter(|&&x| x < F::zero())
let stats = AttributionStatistics {
mean,
mean_absolute,
max_absolute,
positive_attribution_ratio,
total_positive_attribution,
total_negative_attribution,
};
attributions.insert(method_name.clone(), attribution);
attribution_statistics.insert(method_name, stats);
// Clone layer statistics
let layer_statistics = self.layer_statistics.clone();
// Create interpretation summary
let interpretation_summary = InterpretationSummary {
num_attribution_methods: self.attribution_methods.len(),
average_method_consistency: 0.75, // Placeholder consistency
most_important_features: vec![0, 1, 2], // Placeholder important features
interpretation_confidence: 0.85, // Placeholder confidence
Ok(InterpretationReport {
inputshape: IxDyn(&inputshape),
target_class,
attributions,
attribution_statistics,
layer_statistics,
interpretation_summary,
})
> Default for ModelInterpreter<F>
fn default() -> Self {
Self::new()
/// Summary of interpretation analysis
pub struct InterpretationSummary {
/// Number of attribution methods used
pub num_attribution_methods: usize,
/// Average consistency across methods
pub average_method_consistency: f64,
/// Indices of most important features
pub most_important_features: Vec<usize>,
/// Overall interpretation confidence (0-1)
pub interpretation_confidence: f64,
/// Statistics for attribution methods
pub struct AttributionStatistics<F: Float + Debug> {
/// Mean attribution value
pub mean: F,
/// Mean absolute attribution value
pub mean_absolute: F,
/// Maximum absolute attribution value
pub max_absolute: F,
/// Ratio of positive attributions
pub positive_attribution_ratio: f64,
/// Total positive attribution
pub total_positive_attribution: F,
/// Total negative attribution
pub total_negative_attribution: F,
/// Comprehensive interpretation report with all explanation types
#[derive(Debug)]
pub struct ComprehensiveInterpretationReport<F: Float + Debug> {
/// Basic interpretation report
pub basic_report: InterpretationReport<F>,
/// Counterfactual explanation
pub counterfactual_explanation: Option<ArrayD<F>>,
/// LIME explanation
pub lime_explanation: Option<ArrayD<F>>,
/// Concept activation scores
pub concept_activations: HashMap<String, F>,
/// Attention visualization maps
pub attention_visualizations: Option<HashMap<String, ArrayD<F>>>,
/// Adversarial explanations
pub adversarial_explanations: Vec<AdversarialExplanation<F>>,
/// Network dissection results
pub network_dissection_results: Vec<NetworkDissectionResult<F>>,
/// Basic interpretation report
pub struct InterpretationReport<F: Float + Debug> {
/// Shape of input that was interpreted
pub inputshape: IxDyn,
/// Target class (if specified)
pub target_class: Option<usize>,
/// Attribution maps for each method
pub attributions: HashMap<String, ArrayD<F>>,
/// Statistics for each attribution method
pub attribution_statistics: HashMap<String, AttributionStatistics<F>>,
/// Layer analysis statistics
pub layer_statistics: HashMap<String, LayerAnalysisStats<F>>,
/// Summary of interpretation
pub interpretation_summary: InterpretationSummary,
impl<F: Float + Debug> + std::fmt::Display for InterpretationReport<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Neural Network Interpretation Report")?;
writeln!(f, "===================================")?;
writeln!(f, "Input Shape: {:?}", self.inputshape)?;
writeln!(f, "Target Class: {:?}", self.target_class)?;
writeln!(
f,
"Attribution Methods: {}",
self.attribution_statistics.len()
)?;
"Interpretation Confidence: {:.3}",
self.interpretation_summary.interpretation_confidence
"Average Method Consistency: {:.3}",
self.interpretation_summary.average_method_consistency
"Top Important Features: {:?}",
self.interpretation_summary.most_important_features
writeln!(f, "\nLayer Statistics:")?;
for (layer_name, stats) in &self.layer_statistics {
writeln!(
f,
" {}: mean={:.3}, std={:.3}, sparsity={:.1}%, dead_neurons={:.1}%",
layer_name,
stats.mean_activation.to_f64().unwrap_or(0.0),
stats.std_activation.to_f64().unwrap_or(0.0),
stats.sparsity,
stats.dead_neuron_percentage
)?;
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_model_interpreter_creation() {
let interpreter = ModelInterpreter::<f64>::new();
assert_eq!(interpreter.attribution_methods.len(), 0);
assert_eq!(interpreter.gradient_cache.len(), 0);
fn test_saliency_attribution() {
let mut interpreter = ModelInterpreter::<f64>::new();
let gradients = Array2::from_shape_vec((2, 3), vec![0.1, 0.2, -0.3, 0.4, -0.5, 0.6])
.expect("Operation failed")
interpreter.cache_gradients("input_gradient".to_string(), gradients);
let input = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
let attribution = interpreter.compute_saliency_attribution(&input, None);
assert!(attribution.is_ok());
let attr = attribution.expect("Operation failed");
assert_eq!(attr.shape(), input.shape());
assert_eq!(attr[[0, 0]], 0.1);
assert_eq!(attr[[0, 2]], 0.3); // abs(-0.3)
fn test_integrated_gradients() {
let baseline = BaselineMethod::Zero;
let attribution = interpreter.compute_integrated_gradients(&input, &baseline, 10, None);
fn test_lrp_attribution() {
let attribution =
interpreter.compute_lrp_attribution(&input, &LRPRule::Epsilon, 1e-6, None);
let gamma_attribution = interpreter.compute_lrp_attribution(
&input,
&LRPRule::Gamma { gamma: 0.25 },
1e-6,
None,
);
assert!(gamma_attribution.is_ok());
let zplus_attribution =
interpreter.compute_lrp_attribution(&input, &LRPRule::ZPlus, 1e-6, None);
assert!(zplus_attribution.is_ok());
fn test_counterfactual_explanations() {
interpreter.enable_counterfactual_explanations(5, 0.01, 10, DistanceMetric::L2);
let _input = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
assert!(interpreter.counterfactual_generator.is_some());
fn test_input_x_gradient_attribution() {
let gradients = Array2::from_shape_vec((2, 3), vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
let attribution = interpreter.compute_input_x_gradient_attribution(&input, None);
assert!((attr[[0, 0]] - 0.1).abs() < 1e-10); // 1.0 * 0.1
assert!((attr[[1, 2]] - 3.6).abs() < 1e-10); // 6.0 * 0.6
fn test_baseline_creation() {
let zero_baseline = interpreter.create_baseline(&input, &BaselineMethod::Zero);
assert!(zero_baseline.is_ok());
let baseline = zero_baseline.expect("Operation failed");
assert_eq!(baseline.shape(), input.shape());
assert!(baseline.iter().all(|&x| x == 0.0));
let custom_array = Array2::from_elem((2, 3), 0.5f32).into_dyn();
let custom_baseline =
interpreter.create_baseline(&input, &BaselineMethod::Custom(custom_array));
assert!(custom_baseline.is_ok());
let baseline = custom_baseline.expect("Operation failed");
assert!(baseline.iter().all(|&x| x == 0.5));