quantrs2_ml/computer_vision/
segmentationhead_traits.rs1use super::*;
12use crate::error::{MLError, Result};
13use scirs2_core::ndarray::*;
14use scirs2_core::random::prelude::*;
15use scirs2_core::{Complex32, Complex64};
16use std::f64::consts::PI;
17
18use super::types::SegmentationHead;
19
20impl TaskHead for SegmentationHead {
21 fn forward(&self, features: &Array4<f64>) -> Result<TaskOutput> {
22 let (batch_size, _, height, width) = features.dim();
23 let masks = Array4::zeros((batch_size, self.num_classes, height, width));
24 let class_scores = Array4::zeros((batch_size, self.num_classes, height, width));
25 Ok(TaskOutput::Segmentation {
26 masks,
27 class_scores,
28 })
29 }
30 fn parameters(&self) -> &Array1<f64> {
31 &self.parameters
32 }
33 fn update_parameters(&mut self, _params: &Array1<f64>) -> Result<()> {
34 Ok(())
35 }
36 fn clone_box(&self) -> Box<dyn TaskHead> {
37 Box::new(self.clone())
38 }
39}