use crate::{Result, VisionError};
use scirs2_core::ndarray::Array3; use std::sync::Arc;
use torsh_core::device::Device;
use torsh_nn::Module;
use torsh_tensor::Tensor;
pub struct GradCAM {
target_layer_name: String,
device: Arc<dyn Device>,
}
impl GradCAM {
pub fn new(target_layer_name: String, device: Arc<dyn Device>) -> Self {
Self {
target_layer_name,
device,
}
}
pub fn generate_heatmap(
&self,
model: &dyn Module,
input: &Tensor,
target_class: usize,
) -> Result<Tensor> {
let batched_input = if input.ndim() == 3 {
input.unsqueeze(0)?
} else {
input.clone()
};
let output = model.forward(&batched_input)?;
let target_score = output.narrow(1, target_class as i64, 1)?;
target_score.backward()?;
let input_shape = batched_input.shape();
let height = input_shape.dims()[2] as usize;
let width = input_shape.dims()[3] as usize;
let heatmap = self.create_placeholder_heatmap(height, width)?;
Ok(heatmap)
}
pub fn generate_gradcam_plus_plus(
&self,
model: &dyn Module,
input: &Tensor,
target_class: usize,
) -> Result<Tensor> {
self.generate_heatmap(model, input, target_class)
}
pub fn overlay_heatmap(&self, image: &Tensor, heatmap: &Tensor, alpha: f32) -> Result<Tensor> {
let hmax = heatmap.max(None, false)?;
let hmin = heatmap.min()?;
let normalized = heatmap.sub(&hmin)?.div(&hmax.sub(&hmin)?)?;
let colored_heatmap = self.apply_colormap(&normalized)?;
let blended = image
.mul_scalar(1.0 - alpha)?
.add(&colored_heatmap.mul_scalar(alpha)?)?;
Ok(blended)
}
fn apply_colormap(&self, heatmap: &Tensor) -> Result<Tensor> {
let r = heatmap.mul_scalar(1.5)?.clamp(0.0, 1.0)?.unsqueeze(0)?;
let g = heatmap
.mul_scalar(2.0)?
.sub_scalar(0.5)?
.clamp(0.0, 1.0)?
.unsqueeze(0)?;
let b = heatmap
.mul_scalar(1.5)?
.sub_scalar(1.0)?
.clamp(0.0, 1.0)?
.unsqueeze(0)?;
let colored = Tensor::cat(&[&r, &g, &b], 0)?;
Ok(colored)
}
fn create_placeholder_heatmap(&self, height: usize, width: usize) -> Result<Tensor> {
use torsh_tensor::creation;
let heatmap: Tensor<f32> = creation::zeros(&[height, width])?;
Ok(heatmap)
}
}
pub struct SaliencyMap {
device: Arc<dyn Device>,
}
impl SaliencyMap {
pub fn new(device: Arc<dyn Device>) -> Self {
Self { device }
}
pub fn generate(
&self,
model: &dyn Module,
input: &Tensor,
target_class: usize,
) -> Result<Tensor> {
let input_with_grad = input.clone().requires_grad_(true);
let batched = if input_with_grad.ndim() == 3 {
input_with_grad.unsqueeze(0)?
} else {
input_with_grad.clone()
};
let output = model.forward(&batched)?;
let score = output.narrow(1, target_class as i64, 1)?;
score.backward()?;
let grad = batched
.grad()
.ok_or_else(|| VisionError::Other(anyhow::anyhow!("No gradient computed")))?;
let abs_grad = grad.abs()?;
let saliency = abs_grad.max(Some(1), false)?;
Ok(saliency)
}
pub fn generate_smooth(
&self,
model: &dyn Module,
input: &Tensor,
target_class: usize,
num_samples: usize,
noise_stddev: f32,
) -> Result<Tensor> {
use torsh_tensor::creation;
let shape: Vec<usize> = input.shape().dims().iter().map(|&x| x as usize).collect();
let mut accumulated: Tensor<f32> = creation::zeros(&shape)?;
for _ in 0..num_samples {
let noise: Tensor<f32> = creation::randn(&shape)?;
let noise = noise.mul_scalar(noise_stddev)?;
let noisy_input = input.add(&noise)?;
let saliency = self.generate(model, &noisy_input, target_class)?;
accumulated = accumulated.add(&saliency)?;
}
let smooth_saliency = accumulated.div_scalar(num_samples as f32)?;
Ok(smooth_saliency)
}
}
pub struct IntegratedGradients {
baseline_type: BaselineType,
num_steps: usize,
device: Arc<dyn Device>,
}
#[derive(Debug, Clone, Copy)]
pub enum BaselineType {
Black,
Random,
Blurred,
}
impl IntegratedGradients {
pub fn new(baseline_type: BaselineType, num_steps: usize, device: Arc<dyn Device>) -> Self {
Self {
baseline_type,
num_steps,
device,
}
}
pub fn generate(
&self,
model: &dyn Module,
input: &Tensor,
target_class: usize,
) -> Result<Tensor> {
let baseline = self.create_baseline(input)?;
let mut accumulated_gradients = baseline.clone();
for step in 0..self.num_steps {
let alpha = (step as f32) / (self.num_steps as f32);
let interpolated = baseline
.mul_scalar(1.0 - alpha)?
.add(&input.mul_scalar(alpha)?)?;
let interp_with_grad = interpolated.clone().requires_grad_(true);
let output = model.forward(&interp_with_grad)?;
let score = output.narrow(1, target_class as i64, 1)?;
score.backward()?;
let grad = interp_with_grad
.grad()
.ok_or_else(|| VisionError::Other(anyhow::anyhow!("No gradient computed")))?;
accumulated_gradients = accumulated_gradients.add(&grad)?;
}
let avg_gradients = accumulated_gradients.div_scalar(self.num_steps as f32)?;
let attribution = input.sub(&baseline)?.mul(&avg_gradients)?;
Ok(attribution)
}
fn create_baseline(&self, input: &Tensor) -> Result<Tensor> {
use torsh_tensor::creation;
let shape: Vec<usize> = input.shape().dims().iter().map(|&x| x as usize).collect();
match self.baseline_type {
BaselineType::Black => {
let baseline: Tensor<f32> = creation::zeros(&shape)?;
Ok(baseline)
}
BaselineType::Random => {
let baseline: Tensor<f32> = creation::randn(&shape)?;
Ok(baseline.mul_scalar(0.1)?)
}
BaselineType::Blurred => {
Ok(input.clone())
}
}
}
}
pub struct AttentionVisualizer {
device: Arc<dyn Device>,
}
impl AttentionVisualizer {
pub fn new(device: Arc<dyn Device>) -> Self {
Self { device }
}
pub fn visualize_attention(
&self,
attention_weights: &Tensor,
patch_size: usize,
image_size: (usize, usize),
) -> Result<Tensor> {
let avg_attention = attention_weights.mean(Some(&[1]), false)?;
let cls_attention = avg_attention.narrow(1, 0, 1)?;
let num_patches_h = image_size.0 / patch_size;
let num_patches_w = image_size.1 / patch_size;
let reshaped =
cls_attention.reshape(&[1i32, num_patches_h as i32, num_patches_w as i32])?;
let upsampled = reshaped.clone();
Ok(upsampled)
}
pub fn attention_rollout(&self, attention_layers: Vec<Tensor>) -> Result<Tensor> {
if attention_layers.is_empty() {
return Err(VisionError::InvalidArgument(
"No attention layers provided".to_string(),
));
}
let mut rollout = attention_layers[0].clone();
for attention in attention_layers.iter().skip(1) {
rollout = rollout.matmul(attention)?;
}
Ok(rollout)
}
}
pub struct FeatureVisualizer {
learning_rate: f32,
num_iterations: usize,
device: Arc<dyn Device>,
}
impl FeatureVisualizer {
pub fn new(learning_rate: f32, num_iterations: usize, device: Arc<dyn Device>) -> Self {
Self {
learning_rate,
num_iterations,
device,
}
}
pub fn visualize_class(
&self,
model: &dyn Module,
target_class: usize,
image_size: (usize, usize),
) -> Result<Tensor> {
use torsh_tensor::creation;
let mut image: Tensor<f32> = creation::randn(&[1, 3, image_size.0, image_size.1])?;
image = image.mul_scalar(0.1)?.add_scalar(0.5)?.requires_grad_(true);
for iteration in 0..self.num_iterations {
let output = model.forward(&image)?;
let class_score = output.narrow(1, target_class as i64, 1)?;
let loss = class_score.neg()?;
loss.backward()?;
let grad = image
.grad()
.ok_or_else(|| VisionError::Other(anyhow::anyhow!("No gradient computed")))?;
image = image.sub(&grad.mul_scalar(self.learning_rate)?)?;
image = image.clamp(-2.0, 2.0)?;
if iteration % 10 == 0 {
println!("Iteration {}: loss = {:?}", iteration, loss.item());
}
}
Ok(image)
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::CpuDevice;
use torsh_tensor::creation;
#[test]
fn test_gradcam_creation() {
let device = Arc::new(CpuDevice::new());
let gradcam = GradCAM::new("layer4".to_string(), device);
assert_eq!(gradcam.target_layer_name, "layer4");
}
#[test]
fn test_saliency_map_creation() {
let device = Arc::new(CpuDevice::new());
let _saliency = SaliencyMap::new(device);
}
#[test]
fn test_integrated_gradients_creation() {
let device = Arc::new(CpuDevice::new());
let _ig = IntegratedGradients::new(BaselineType::Black, 50, device);
}
#[test]
fn test_attention_visualizer_creation() {
let device = Arc::new(CpuDevice::new());
let _visualizer = AttentionVisualizer::new(device);
}
#[test]
fn test_feature_visualizer_creation() {
let device = Arc::new(CpuDevice::new());
let _visualizer = FeatureVisualizer::new(0.1, 100, device);
}
#[test]
fn test_baseline_types() {
let device: Arc<dyn Device> = Arc::new(CpuDevice::new());
let ig_black = IntegratedGradients::new(BaselineType::Black, 50, Arc::clone(&device));
let ig_random = IntegratedGradients::new(BaselineType::Random, 50, Arc::clone(&device));
let ig_blurred = IntegratedGradients::new(BaselineType::Blurred, 50, Arc::clone(&device));
let input: Tensor<f32> = creation::ones(&[1, 3, 224, 224]).unwrap();
assert!(ig_black.create_baseline(&input).is_ok());
assert!(ig_random.create_baseline(&input).is_ok());
assert!(ig_blurred.create_baseline(&input).is_ok());
}
}