quantrs2_ml/computer_vision/
instancesegmentationhead_traits.rs1use super::*;
12use crate::error::{MLError, Result};
13use scirs2_core::ndarray::*;
14use scirs2_core::random::prelude::*;
15use scirs2_core::{Complex32, Complex64};
16use std::f64::consts::PI;
17
18use super::types::InstanceSegmentationHead;
19
20impl TaskHead for InstanceSegmentationHead {
21 fn forward(&self, features: &Array4<f64>) -> Result<TaskOutput> {
22 let (batch_size, _, _, _) = features.dim();
23 let masks = Array4::zeros((
24 batch_size,
25 self.num_classes,
26 self.mask_resolution.0,
27 self.mask_resolution.1,
28 ));
29 let class_scores = Array4::zeros((
30 batch_size,
31 self.num_classes,
32 self.mask_resolution.0,
33 self.mask_resolution.1,
34 ));
35 Ok(TaskOutput::Segmentation {
36 masks,
37 class_scores,
38 })
39 }
40 fn parameters(&self) -> &Array1<f64> {
41 &self.parameters
42 }
43 fn update_parameters(&mut self, _params: &Array1<f64>) -> Result<()> {
44 Ok(())
45 }
46 fn clone_box(&self) -> Box<dyn TaskHead> {
47 Box::new(self.clone())
48 }
49}