quantrs2_ml/computer_vision/
classificationhead_traits.rs

1//! # ClassificationHead - Trait Implementations
2//!
3//! This module contains trait implementations for `ClassificationHead`.
4//!
5//! ## Implemented Traits
6//!
7//! - `TaskHead`
8//!
9//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
10
11use super::*;
12use crate::error::{MLError, Result};
13use scirs2_core::ndarray::*;
14use scirs2_core::random::prelude::*;
15use scirs2_core::{Complex32, Complex64};
16use std::f64::consts::PI;
17
18use super::types::ClassificationHead;
19
20impl TaskHead for ClassificationHead {
21    fn forward(&self, features: &Array4<f64>) -> Result<TaskOutput> {
22        let (batch_size, _, _, _) = features.dim();
23        let pooled = features
24            .mean_axis(Axis(2))
25            .unwrap()
26            .mean_axis(Axis(2))
27            .unwrap();
28        let mut logits = Array2::zeros((batch_size, self.num_classes));
29        let mut probabilities = Array2::zeros((batch_size, self.num_classes));
30        for i in 0..batch_size {
31            let feature_vec = pooled.slice(s![i, ..]).to_owned();
32            let class_logits = self.classifier.forward(&feature_vec)?;
33            let max_logit = class_logits
34                .iter()
35                .cloned()
36                .fold(f64::NEG_INFINITY, f64::max);
37            let exp_logits = class_logits.mapv(|x| (x - max_logit).exp());
38            let sum_exp = exp_logits.sum();
39            let probs = exp_logits / sum_exp;
40            logits.slice_mut(s![i, ..]).assign(&class_logits);
41            probabilities.slice_mut(s![i, ..]).assign(&probs);
42        }
43        Ok(TaskOutput::Classification {
44            logits,
45            probabilities,
46        })
47    }
48    fn parameters(&self) -> &Array1<f64> {
49        &self.classifier.parameters
50    }
51    fn update_parameters(&mut self, params: &Array1<f64>) -> Result<()> {
52        self.classifier.parameters = params.clone();
53        Ok(())
54    }
55    fn clone_box(&self) -> Box<dyn TaskHead> {
56        Box::new(self.clone())
57    }
58}