quantrs2_ml/computer_vision/
featureextractionhead_traits.rs

1//! # FeatureExtractionHead - Trait Implementations
2//!
3//! This module contains trait implementations for `FeatureExtractionHead`.
4//!
5//! ## Implemented Traits
6//!
7//! - `TaskHead`
8//!
9//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
10
11use super::*;
12use crate::error::{MLError, Result};
13use scirs2_core::ndarray::*;
14use scirs2_core::random::prelude::*;
15use scirs2_core::{Complex32, Complex64};
16use std::f64::consts::PI;
17
18use super::types::FeatureExtractionHead;
19
20impl TaskHead for FeatureExtractionHead {
21    fn forward(&self, features: &Array4<f64>) -> Result<TaskOutput> {
22        let (batch_size, channels, _, _) = features.dim();
23        let pooled = features
24            .mean_axis(Axis(2))
25            .ok_or_else(|| {
26                MLError::ComputationError("Failed to compute mean over axis 2".to_string())
27            })?
28            .mean_axis(Axis(2))
29            .ok_or_else(|| {
30                MLError::ComputationError(
31                    "Failed to compute mean over axis 2 (second pass)".to_string(),
32                )
33            })?;
34        let mut extracted_features = Array2::zeros((batch_size, self.feature_dim));
35        for i in 0..batch_size {
36            let feature_vec = pooled.slice(s![i, ..]).to_owned();
37            for j in 0..self.feature_dim {
38                extracted_features[[i, j]] = feature_vec[j % channels];
39            }
40            if self.normalize {
41                let norm = extracted_features
42                    .slice(s![i, ..])
43                    .mapv(|x| x * x)
44                    .sum()
45                    .sqrt();
46                if norm > 1e-10 {
47                    extracted_features
48                        .slice_mut(s![i, ..])
49                        .mapv_inplace(|x| x / norm);
50                }
51            }
52        }
53        Ok(TaskOutput::Features {
54            features: extracted_features,
55            attention_maps: None,
56        })
57    }
58    fn parameters(&self) -> &Array1<f64> {
59        &self.parameters
60    }
61    fn update_parameters(&mut self, _params: &Array1<f64>) -> Result<()> {
62        Ok(())
63    }
64    fn clone_box(&self) -> Box<dyn TaskHead> {
65        Box::new(self.clone())
66    }
67}