quantrs2_ml/computer_vision/
classificationhead_traits.rs1use super::*;
12use crate::error::{MLError, Result};
13use scirs2_core::ndarray::*;
14use scirs2_core::random::prelude::*;
15use scirs2_core::{Complex32, Complex64};
16use std::f64::consts::PI;
17
18use super::types::ClassificationHead;
19
20impl TaskHead for ClassificationHead {
21 fn forward(&self, features: &Array4<f64>) -> Result<TaskOutput> {
22 let (batch_size, _, _, _) = features.dim();
23 let pooled = features
24 .mean_axis(Axis(2))
25 .unwrap()
26 .mean_axis(Axis(2))
27 .unwrap();
28 let mut logits = Array2::zeros((batch_size, self.num_classes));
29 let mut probabilities = Array2::zeros((batch_size, self.num_classes));
30 for i in 0..batch_size {
31 let feature_vec = pooled.slice(s![i, ..]).to_owned();
32 let class_logits = self.classifier.forward(&feature_vec)?;
33 let max_logit = class_logits
34 .iter()
35 .cloned()
36 .fold(f64::NEG_INFINITY, f64::max);
37 let exp_logits = class_logits.mapv(|x| (x - max_logit).exp());
38 let sum_exp = exp_logits.sum();
39 let probs = exp_logits / sum_exp;
40 logits.slice_mut(s![i, ..]).assign(&class_logits);
41 probabilities.slice_mut(s![i, ..]).assign(&probs);
42 }
43 Ok(TaskOutput::Classification {
44 logits,
45 probabilities,
46 })
47 }
48 fn parameters(&self) -> &Array1<f64> {
49 &self.classifier.parameters
50 }
51 fn update_parameters(&mut self, params: &Array1<f64>) -> Result<()> {
52 self.classifier.parameters = params.clone();
53 Ok(())
54 }
55 fn clone_box(&self) -> Box<dyn TaskHead> {
56 Box::new(self.clone())
57 }
58}