use scirs2_core::ndarray::Array1;
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use super::config::GpuConfig;
use crate::error::Result;
pub trait GpuAccelerated<F: Float + Debug> {
fn to_gpu(&self, config: &GpuConfig) -> Result<Self>
where
Self: Sized;
fn to_cpu(&self) -> Result<Self>
where
Self: Sized;
fn is_on_gpu(&self) -> bool;
fn gpu_memory_usage(&self) -> usize;
}
pub type DecompositionResult<F> = (Array1<F>, Array1<F>, Array1<F>);
pub trait GpuForecasting<F: Float + Debug> {
fn forecast_gpu(&self, steps: usize, config: &GpuConfig) -> Result<Array1<F>>;
fn batch_forecast_gpu(
&self,
data: &[Array1<F>],
steps: usize,
config: &GpuConfig,
) -> Result<Vec<Array1<F>>>;
}
pub trait GpuDecomposition<F: Float + Debug> {
fn decompose_gpu(&self, config: &GpuConfig) -> Result<DecompositionResult<F>>;
fn batch_decompose_gpu(
&self,
data: &[Array1<F>],
config: &GpuConfig,
) -> Result<Vec<DecompositionResult<F>>>;
}
pub trait GpuFeatureExtraction<F: Float + Debug> {
fn extract_features_gpu(&self, config: &GpuConfig) -> Result<Array1<F>>;
fn batch_extract_features_gpu(
&self,
data: &[Array1<F>],
config: &GpuConfig,
) -> Result<Vec<Array1<F>>>;
}