quantrs2_ml/computer_vision/
instancesegmentationhead_traits.rs

1//! # InstanceSegmentationHead - Trait Implementations
2//!
3//! This module contains trait implementations for `InstanceSegmentationHead`.
4//!
5//! ## Implemented Traits
6//!
7//! - `TaskHead`
8//!
9//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
10
11use super::*;
12use crate::error::{MLError, Result};
13use scirs2_core::ndarray::*;
14use scirs2_core::random::prelude::*;
15use scirs2_core::{Complex32, Complex64};
16use std::f64::consts::PI;
17
18use super::types::InstanceSegmentationHead;
19
20impl TaskHead for InstanceSegmentationHead {
21    fn forward(&self, features: &Array4<f64>) -> Result<TaskOutput> {
22        let (batch_size, _, _, _) = features.dim();
23        let masks = Array4::zeros((
24            batch_size,
25            self.num_classes,
26            self.mask_resolution.0,
27            self.mask_resolution.1,
28        ));
29        let class_scores = Array4::zeros((
30            batch_size,
31            self.num_classes,
32            self.mask_resolution.0,
33            self.mask_resolution.1,
34        ));
35        Ok(TaskOutput::Segmentation {
36            masks,
37            class_scores,
38        })
39    }
40    fn parameters(&self) -> &Array1<f64> {
41        &self.parameters
42    }
43    fn update_parameters(&mut self, _params: &Array1<f64>) -> Result<()> {
44        Ok(())
45    }
46    fn clone_box(&self) -> Box<dyn TaskHead> {
47        Box::new(self.clone())
48    }
49}