quantrs2_ml/computer_vision/
classificationhead_traits.rs

1//! # ClassificationHead - Trait Implementations
2//!
3//! This module contains trait implementations for `ClassificationHead`.
4//!
5//! ## Implemented Traits
6//!
7//! - `TaskHead`
8//!
9//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
10
11use super::*;
12use crate::error::{MLError, Result};
13use scirs2_core::ndarray::*;
14use scirs2_core::random::prelude::*;
15use scirs2_core::{Complex32, Complex64};
16use std::f64::consts::PI;
17
18use super::types::ClassificationHead;
19
20impl TaskHead for ClassificationHead {
21    fn forward(&self, features: &Array4<f64>) -> Result<TaskOutput> {
22        let (batch_size, _, _, _) = features.dim();
23        let pooled = features
24            .mean_axis(Axis(2))
25            .ok_or_else(|| {
26                MLError::ComputationError("Failed to compute mean over axis 2".to_string())
27            })?
28            .mean_axis(Axis(2))
29            .ok_or_else(|| {
30                MLError::ComputationError(
31                    "Failed to compute mean over axis 2 (second pass)".to_string(),
32                )
33            })?;
34        let mut logits = Array2::zeros((batch_size, self.num_classes));
35        let mut probabilities = Array2::zeros((batch_size, self.num_classes));
36        for i in 0..batch_size {
37            let feature_vec = pooled.slice(s![i, ..]).to_owned();
38            let class_logits = self.classifier.forward(&feature_vec)?;
39            let max_logit = class_logits
40                .iter()
41                .cloned()
42                .fold(f64::NEG_INFINITY, f64::max);
43            let exp_logits = class_logits.mapv(|x| (x - max_logit).exp());
44            let sum_exp = exp_logits.sum();
45            let probs = exp_logits / sum_exp;
46            logits.slice_mut(s![i, ..]).assign(&class_logits);
47            probabilities.slice_mut(s![i, ..]).assign(&probs);
48        }
49        Ok(TaskOutput::Classification {
50            logits,
51            probabilities,
52        })
53    }
54    fn parameters(&self) -> &Array1<f64> {
55        &self.classifier.parameters
56    }
57    fn update_parameters(&mut self, params: &Array1<f64>) -> Result<()> {
58        self.classifier.parameters = params.clone();
59        Ok(())
60    }
61    fn clone_box(&self) -> Box<dyn TaskHead> {
62        Box::new(self.clone())
63    }
64}