quantrs2_ml/computer_vision/
featureextractionhead_traits.rs1use super::*;
12use crate::error::{MLError, Result};
13use scirs2_core::ndarray::*;
14use scirs2_core::random::prelude::*;
15use scirs2_core::{Complex32, Complex64};
16use std::f64::consts::PI;
17
18use super::types::FeatureExtractionHead;
19
20impl TaskHead for FeatureExtractionHead {
21 fn forward(&self, features: &Array4<f64>) -> Result<TaskOutput> {
22 let (batch_size, channels, _, _) = features.dim();
23 let pooled = features
24 .mean_axis(Axis(2))
25 .ok_or_else(|| {
26 MLError::ComputationError("Failed to compute mean over axis 2".to_string())
27 })?
28 .mean_axis(Axis(2))
29 .ok_or_else(|| {
30 MLError::ComputationError(
31 "Failed to compute mean over axis 2 (second pass)".to_string(),
32 )
33 })?;
34 let mut extracted_features = Array2::zeros((batch_size, self.feature_dim));
35 for i in 0..batch_size {
36 let feature_vec = pooled.slice(s![i, ..]).to_owned();
37 for j in 0..self.feature_dim {
38 extracted_features[[i, j]] = feature_vec[j % channels];
39 }
40 if self.normalize {
41 let norm = extracted_features
42 .slice(s![i, ..])
43 .mapv(|x| x * x)
44 .sum()
45 .sqrt();
46 if norm > 1e-10 {
47 extracted_features
48 .slice_mut(s![i, ..])
49 .mapv_inplace(|x| x / norm);
50 }
51 }
52 }
53 Ok(TaskOutput::Features {
54 features: extracted_features,
55 attention_maps: None,
56 })
57 }
58 fn parameters(&self) -> &Array1<f64> {
59 &self.parameters
60 }
61 fn update_parameters(&mut self, _params: &Array1<f64>) -> Result<()> {
62 Ok(())
63 }
64 fn clone_box(&self) -> Box<dyn TaskHead> {
65 Box::new(self.clone())
66 }
67}