quantrs2_ml/computer_vision/
generationhead_traits.rs

1//! # GenerationHead - Trait Implementations
2//!
3//! This module contains trait implementations for `GenerationHead`.
4//!
5//! ## Implemented Traits
6//!
7//! - `TaskHead`
8//!
9//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
10
11use super::*;
12use crate::error::{MLError, Result};
13use scirs2_core::ndarray::*;
14use scirs2_core::random::prelude::*;
15use scirs2_core::{Complex32, Complex64};
16use std::f64::consts::PI;
17
18use super::types::GenerationHead;
19
20impl TaskHead for GenerationHead {
21    fn forward(&self, features: &Array4<f64>) -> Result<TaskOutput> {
22        let (batch_size, _, height, width) = features.dim();
23        let images = Array4::zeros((batch_size, self.output_channels, height, width));
24        let latent_codes = Array2::zeros((batch_size, self.latent_dim));
25        Ok(TaskOutput::Generation {
26            images,
27            latent_codes,
28        })
29    }
30    fn parameters(&self) -> &Array1<f64> {
31        &self.parameters
32    }
33    fn update_parameters(&mut self, _params: &Array1<f64>) -> Result<()> {
34        Ok(())
35    }
36    fn clone_box(&self) -> Box<dyn TaskHead> {
37        Box::new(self.clone())
38    }
39}