quantrs2_ml/computer_vision/
classificationhead_traits.rs1use super::*;
12use crate::error::{MLError, Result};
13use scirs2_core::ndarray::*;
14use scirs2_core::random::prelude::*;
15use scirs2_core::{Complex32, Complex64};
16use std::f64::consts::PI;
17
18use super::types::ClassificationHead;
19
20impl TaskHead for ClassificationHead {
21 fn forward(&self, features: &Array4<f64>) -> Result<TaskOutput> {
22 let (batch_size, _, _, _) = features.dim();
23 let pooled = features
24 .mean_axis(Axis(2))
25 .ok_or_else(|| {
26 MLError::ComputationError("Failed to compute mean over axis 2".to_string())
27 })?
28 .mean_axis(Axis(2))
29 .ok_or_else(|| {
30 MLError::ComputationError(
31 "Failed to compute mean over axis 2 (second pass)".to_string(),
32 )
33 })?;
34 let mut logits = Array2::zeros((batch_size, self.num_classes));
35 let mut probabilities = Array2::zeros((batch_size, self.num_classes));
36 for i in 0..batch_size {
37 let feature_vec = pooled.slice(s![i, ..]).to_owned();
38 let class_logits = self.classifier.forward(&feature_vec)?;
39 let max_logit = class_logits
40 .iter()
41 .cloned()
42 .fold(f64::NEG_INFINITY, f64::max);
43 let exp_logits = class_logits.mapv(|x| (x - max_logit).exp());
44 let sum_exp = exp_logits.sum();
45 let probs = exp_logits / sum_exp;
46 logits.slice_mut(s![i, ..]).assign(&class_logits);
47 probabilities.slice_mut(s![i, ..]).assign(&probs);
48 }
49 Ok(TaskOutput::Classification {
50 logits,
51 probabilities,
52 })
53 }
54 fn parameters(&self) -> &Array1<f64> {
55 &self.classifier.parameters
56 }
57 fn update_parameters(&mut self, params: &Array1<f64>) -> Result<()> {
58 self.classifier.parameters = params.clone();
59 Ok(())
60 }
61 fn clone_box(&self) -> Box<dyn TaskHead> {
62 Box::new(self.clone())
63 }
64}