quantrs2_ml/computer_vision/
detectionhead_traits.rs1use super::*;
12use crate::error::{MLError, Result};
13use scirs2_core::ndarray::*;
14use scirs2_core::random::prelude::*;
15use scirs2_core::{Complex32, Complex64};
16use std::f64::consts::PI;
17
18use super::types::DetectionHead;
19
20impl TaskHead for DetectionHead {
21 fn forward(&self, features: &Array4<f64>) -> Result<TaskOutput> {
22 let (batch_size, _, _, _) = features.dim();
23 let boxes = Array3::zeros((batch_size, 100, 4));
24 let scores = Array2::zeros((batch_size, 100));
25 let classes = Array2::<f64>::zeros((batch_size, 100));
26 Ok(TaskOutput::Detection {
27 boxes,
28 scores,
29 classes: classes.mapv(|x| x as usize),
30 })
31 }
32 fn parameters(&self) -> &Array1<f64> {
33 &self.parameters
34 }
35 fn update_parameters(&mut self, _params: &Array1<f64>) -> Result<()> {
36 Ok(())
37 }
38 fn clone_box(&self) -> Box<dyn TaskHead> {
39 Box::new(self.clone())
40 }
41}