use crate::dispatch_registry_examples::initialize_dispatch_registrations;
use crate::dispatch_registry_extended::initialize_extended_registrations;
use lazy_static::lazy_static;
lazy_static! {
pub static ref DISPATCH_INIT: () = {
initialize_dispatch_registrations();
initialize_extended_registrations();
#[cfg(feature = "gpu")]
initialize_gpu_operations();
};
}
#[cfg(feature = "gpu")]
fn initialize_gpu_operations() {
register_gpu_reductions();
}
#[cfg(feature = "gpu")]
fn register_gpu_reductions() {
use crate::dispatch_registry::{
BackendType, KernelImplementation, OperationDescriptor, F32_REGISTRY,
};
use crate::DType;
{
let desc = OperationDescriptor::new("sum", "reduction").with_dtypes(vec![DType::Float32]);
if F32_REGISTRY.get_operation("sum").is_none() {
F32_REGISTRY.register_operation(desc).ok();
}
F32_REGISTRY
.register_kernel(
"sum",
KernelImplementation::unary(BackendType::Gpu, sum_f32_gpu),
)
.ok();
}
{
let desc = OperationDescriptor::new("mean", "reduction").with_dtypes(vec![DType::Float32]);
if F32_REGISTRY.get_operation("mean").is_none() {
F32_REGISTRY.register_operation(desc).ok();
}
F32_REGISTRY
.register_kernel(
"mean",
KernelImplementation::unary(BackendType::Gpu, mean_f32_gpu),
)
.ok();
}
}
#[cfg(feature = "gpu")]
fn sum_f32_gpu(x: &crate::Tensor<f32>) -> crate::Result<crate::Tensor<f32>> {
use crate::gpu::buffer::GpuBuffer;
use crate::gpu::ops::operation_types::ReductionOp;
use crate::gpu::ops::reduction_ops::execute_reduction_op;
use crate::Device;
use scirs2_core::ndarray::Array;
let slice = x.data();
let gpu_buffer = GpuBuffer::from_slice(slice, &Device::Gpu(0))?;
let result_buffer = execute_reduction_op(&gpu_buffer, ReductionOp::Sum, None)?;
let result_data = result_buffer.to_cpu()?;
let result = Array::from_elem(vec![], result_data[0]).into_dyn();
Ok(crate::Tensor::from_array(result))
}
#[cfg(feature = "gpu")]
fn mean_f32_gpu(x: &crate::Tensor<f32>) -> crate::Result<crate::Tensor<f32>> {
use crate::gpu::buffer::GpuBuffer;
use crate::gpu::ops::operation_types::ReductionOp;
use crate::gpu::ops::reduction_ops::execute_reduction_op;
use crate::Device;
use scirs2_core::ndarray::Array;
let slice = x.data();
let gpu_buffer = GpuBuffer::from_slice(slice, &Device::Gpu(0))?;
let result_buffer = execute_reduction_op(&gpu_buffer, ReductionOp::Mean, None)?;
let result_data = result_buffer.to_cpu()?;
let result = Array::from_elem(vec![], result_data[0]).into_dyn();
Ok(crate::Tensor::from_array(result))
}
fn sum_f32_cpu(x: &crate::Tensor<f32>) -> crate::Result<crate::Tensor<f32>> {
use scirs2_core::ndarray::Array;
let data = x.data();
let sum: f32 = data.iter().sum();
let result = Array::from_elem(vec![], sum).into_dyn();
Ok(crate::Tensor::from_array(result))
}
fn mean_f32_cpu(x: &crate::Tensor<f32>) -> crate::Result<crate::Tensor<f32>> {
use scirs2_core::ndarray::Array;
let data = x.data();
let sum: f32 = data.iter().sum();
let mean = sum / data.len() as f32;
let result = Array::from_elem(vec![], mean).into_dyn();
Ok(crate::Tensor::from_array(result))
}
pub fn ensure_initialized() {
*DISPATCH_INIT;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Tensor;
use scirs2_core::ndarray::array;
#[test]
fn test_initialization() {
ensure_initialized();
use crate::dispatch_registry::F32_REGISTRY;
assert!(F32_REGISTRY.get_operation("add").is_some());
assert!(F32_REGISTRY.get_operation("mul").is_some());
assert!(F32_REGISTRY.get_operation("div").is_some());
}
#[test]
fn test_sum_cpu_fallback() {
let input = Tensor::from_array(array![1.0f32, 2.0, 3.0, 4.0, 5.0].into_dyn());
let result = sum_f32_cpu(&input).expect("test: sum_f32_cpu should succeed");
assert_eq!(result.data()[0], 15.0);
}
#[test]
fn test_mean_cpu_fallback() {
let input = Tensor::from_array(array![1.0f32, 2.0, 3.0, 4.0, 5.0].into_dyn());
let result = mean_f32_cpu(&input).expect("test: mean_f32_cpu should succeed");
assert_eq!(result.data()[0], 3.0);
}
}