use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use ndarray::ScalarOperand;
use num_traits::{Float, FromPrimitive};
#[cfg(feature = "cuda")]
use cudarc::cublas::{CudaBlas, Gemm};
#[cfg(feature = "cuda")]
use cudarc::driver::{CudaDevice, CudaSlice, DeviceRepr, ValidAsZeroBits};
#[cfg(feature = "metal")]
use metal::{Buffer, CommandBuffer, CommandQueue, Device as MetalDevice, MTLSize};
#[cfg(feature = "opencl")]
use opencl3::memory::ClMem;
pub struct GpuMatrixExecutor<T: Float + FromPrimitive + ScalarOperand + 'static> {
device_type: super::DeviceType,
_phantom: std::marker::PhantomData<T>,
}
impl<T: Float + FromPrimitive + ScalarOperand + 'static> GpuMatrixExecutor<T> {
pub fn new(device_type: super::DeviceType) -> RusTorchResult<Self> {
Ok(Self {
device_type,
_phantom: std::marker::PhantomData,
})
}
#[cfg(feature = "metal")]
pub fn metal_matmul(&self, a: &Tensor<T>, b: &Tensor<T>) -> RusTorchResult<Tensor<T>> {
match self.device_type {
super::DeviceType::Metal(_) => {
self.execute_metal_matmul(a, b)
}
_ => Err(RusTorchError::gpu(
"Device type not supported for Metal operations",
)),
}
}
#[cfg(any(
feature = "coreml",
feature = "coreml-hybrid",
feature = "coreml-fallback"
))]
pub fn coreml_matmul(&self, a: &Tensor<T>, b: &Tensor<T>) -> RusTorchResult<Tensor<T>> {
match self.device_type {
#[cfg(any(
feature = "coreml",
feature = "coreml-hybrid",
feature = "coreml-fallback"
))]
super::DeviceType::CoreML(_) => {
self.execute_coreml_matmul(a, b)
}
_ => Err(RusTorchError::gpu(
"Device type not supported for CoreML operations",
)),
}
}
#[cfg(feature = "metal")]
fn execute_metal_matmul(&self, a: &Tensor<T>, b: &Tensor<T>) -> RusTorchResult<Tensor<T>> {
use crate::gpu::metal_kernels::metal_matmul_f32;
let a_data = a
.data
.iter()
.map(|&x| x.to_f32().unwrap())
.collect::<Vec<f32>>();
let b_data = b
.data
.iter()
.map(|&x| x.to_f32().unwrap())
.collect::<Vec<f32>>();
let a_shape = a.data.shape();
let b_shape = b.data.shape();
if a_shape.len() != 2 || b_shape.len() != 2 {
return Err(RusTorchError::gpu(
"Only 2D matrix multiplication supported",
));
}
let (m, k) = (a_shape[0], a_shape[1]);
let (k2, n) = (b_shape[0], b_shape[1]);
if k != k2 {
return Err(RusTorchError::gpu(
"Matrix dimensions don't match for multiplication",
));
}
let mut c_data = vec![0.0f32; m * n];
metal_matmul_f32(&a_data, &b_data, &mut c_data, m, n, k)?;
let result_data: Vec<T> = c_data
.into_iter()
.map(|x| T::from_f32(x).unwrap())
.collect();
let result_array = ndarray::Array::from_shape_vec((m, n), result_data)
.map_err(|e| RusTorchError::gpu(&format!("Failed to create result array: {}", e)))?;
Ok(Tensor {
data: result_array.into_dyn(),
device: a.device.clone(),
requires_grad: a.requires_grad || b.requires_grad,
})
}
#[cfg(any(
feature = "coreml",
feature = "coreml-hybrid",
feature = "coreml-fallback"
))]
fn execute_coreml_matmul(&self, a: &Tensor<T>, b: &Tensor<T>) -> RusTorchResult<Tensor<T>> {
use crate::gpu::coreml::operations::linear_algebra::CoreMLLinearAlgebra;
a.coreml_matmul(b)
.map_err(|e| RusTorchError::gpu(&format!("CoreML matmul failed: {}", e)))
}
}
pub struct GpuBatchMatrixExecutor<T: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static>
{
device_type: super::DeviceType,
context: Option<super::GpuContext>,
_phantom: std::marker::PhantomData<T>,
}
impl<T: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static> GpuBatchMatrixExecutor<T> {
pub fn new(device_type: super::DeviceType) -> RusTorchResult<Self> {
let context = super::GpuContext::new(device_type).ok();
Ok(Self {
device_type,
context,
_phantom: std::marker::PhantomData,
})
}
pub fn device_type(&self) -> &super::DeviceType {
&self.device_type
}
pub fn is_gpu_available(&self) -> bool {
self.context.is_some()
}
pub fn batch_matmul(&self, a: &Tensor<T>, b: &Tensor<T>) -> RusTorchResult<Tensor<T>> {
if let Some(_ctx) = &self.context {
match &self.device_type {
super::DeviceType::Cuda(_) => {
self.cuda_batch_matmul(a, b)
}
super::DeviceType::Metal(_) => {
self.metal_batch_matmul(a, b)
}
super::DeviceType::OpenCL(_) => {
self.opencl_batch_matmul(a, b)
}
super::DeviceType::Cpu => {
a.matmul(b).map_err(|e| RusTorchError::gpu(e.to_string()))
}
#[cfg(any(
feature = "coreml",
feature = "coreml-hybrid",
feature = "coreml-fallback"
))]
super::DeviceType::CoreML(_) => {
a.matmul(b).map_err(|e| RusTorchError::gpu(e.to_string()))
}
super::DeviceType::Auto => {
a.matmul(b).map_err(|e| RusTorchError::gpu(e.to_string()))
}
#[cfg(feature = "mac-hybrid")]
super::DeviceType::MacHybrid => {
a.matmul(b).map_err(|e| RusTorchError::gpu(e.to_string()))
}
}
} else {
#[cfg(feature = "blas-optimized")]
{
crate::linalg::optimized_matmul(a, b)
}
#[cfg(not(feature = "blas-optimized"))]
{
a.matmul(b).map_err(|e| RusTorchError::gpu(e.to_string()))
}
}
}
fn cuda_batch_matmul(&self, a: &Tensor<T>, b: &Tensor<T>) -> RusTorchResult<Tensor<T>> {
#[cfg(feature = "cuda")]
{
use crate::gpu::cuda_enhanced::CudaMatrixExecutor;
if let Ok(executor) = CudaMatrixExecutor::new(0) {
let a_shape = a.shape();
let b_shape = b.shape();
if a_shape.len() == 2 && b_shape.len() == 2 && a_shape[1] == b_shape[0] {
let result_data = vec![T::from_f32(0.0).unwrap(); a_shape[0] * b_shape[1]];
if let (Some(a_slice), Some(b_slice)) = (a.as_slice(), b.as_slice()) {
let a_f32: Vec<f32> =
a_slice.iter().map(|&x| x.to_f32().unwrap_or(0.0)).collect();
let b_f32: Vec<f32> =
b_slice.iter().map(|&x| x.to_f32().unwrap_or(0.0)).collect();
let mut result_f32 = vec![0.0f32; a_shape[0] * b_shape[1]];
match executor.matmul_f32(
&a_f32,
&b_f32,
&mut result_f32,
a_shape[0],
b_shape[1],
a_shape[1],
false,
) {
Ok(_) => {
let result_t: Vec<T> = result_f32
.iter()
.map(|&x| {
T::from_f32(x).unwrap_or_else(|| T::from_f32(0.0).unwrap())
})
.collect();
let tensor = match ndarray::ArrayD::from_shape_vec(
vec![a_shape[0], b_shape[1]],
result_t,
) {
Ok(array) => Tensor::new(array),
Err(e) => {
return Err(RusTorchError::gpu(&format!(
"CUDA result tensor creation failed: {}",
e
)))
}
};
return Ok(tensor);
}
Err(e) => {
return Err(RusTorchError::gpu(format!(
"CUDA matrix multiplication failed: {}",
e
)));
}
}
}
}
}
}
Err(RusTorchError::DeviceNotAvailable(
"CUDA not available or failed to execute matrix multiplication".to_string(),
))
}
fn metal_batch_matmul(&self, a: &Tensor<T>, b: &Tensor<T>) -> RusTorchResult<Tensor<T>> {
#[cfg(feature = "metal")]
{
use crate::gpu::metal_kernels::MetalKernelExecutor;
if let Ok(executor) = MetalKernelExecutor::new() {
if let (Some(a_slice), Some(b_slice)) = (a.as_slice(), b.as_slice()) {
let a_f32: Vec<f32> =
a_slice.iter().map(|&x| x.to_f32().unwrap_or(0.0)).collect();
let b_f32: Vec<f32> =
b_slice.iter().map(|&x| x.to_f32().unwrap_or(0.0)).collect();
match executor.matrix_multiply_f32(
&a_f32,
&b_f32,
a.shape()[0],
b.shape()[1],
a.shape()[1],
) {
Ok(result_data) => {
let result_t: Vec<T> = result_data
.iter()
.map(|&x| {
T::from_f32(x).unwrap_or_else(|| T::from_f32(0.0).unwrap())
})
.collect();
let tensor = match ndarray::ArrayD::from_shape_vec(
vec![a.shape()[0], b.shape()[1]],
result_t,
) {
Ok(array) => Tensor::new(array),
Err(e) => {
return Err(RusTorchError::gpu(&format!(
"Metal batch matmul failed: {}",
e
)))
}
};
return Ok(tensor);
}
Err(e) => {
return Err(RusTorchError::gpu(format!(
"Metal matrix multiplication failed: {}",
e
)));
}
}
}
}
}
Err(RusTorchError::DeviceNotAvailable(
"Metal not available or failed to execute batch matrix multiplication".to_string(),
))
}
fn opencl_batch_matmul(&self, a: &Tensor<T>, b: &Tensor<T>) -> RusTorchResult<Tensor<T>> {
#[cfg(feature = "opencl")]
{
use crate::gpu::opencl_kernels::OpenClKernelExecutor;
if let Ok(executor) = OpenClKernelExecutor::new(0) {
let a_shape = a.shape();
let b_shape = b.shape();
if a_shape.len() == 2 && b_shape.len() == 2 && a_shape[1] == b_shape[0] {
if let (Some(a_slice), Some(b_slice)) = (a.as_slice(), b.as_slice()) {
let a_f32: Vec<f32> =
a_slice.iter().map(|&x| x.to_f32().unwrap_or(0.0)).collect();
let b_f32: Vec<f32> =
b_slice.iter().map(|&x| x.to_f32().unwrap_or(0.0)).collect();
match executor
.matrix_multiply_f32(&a_f32, &b_f32, a_shape[0], b_shape[1], a_shape[1])
{
Ok(result_data) => {
let result_t: Vec<T> = result_data
.iter()
.map(|&x| {
T::from_f32(x).unwrap_or_else(|| T::from_f32(0.0).unwrap())
})
.collect();
let tensor = match ndarray::ArrayD::from_shape_vec(
vec![a_shape[0], b_shape[1]],
result_t,
) {
Ok(array) => Tensor::new(array),
Err(e) => {
return Err(RusTorchError::gpu(&format!(
"OpenCL result tensor creation failed: {}",
e
)))
}
};
return Ok(tensor);
}
Err(e) => {
return Err(RusTorchError::gpu(format!(
"OpenCL matrix multiplication failed: {}",
e
)));
}
}
}
}
}
}
Err(RusTorchError::DeviceNotAvailable(
"OpenCL not available or failed to execute batch matrix multiplication".to_string(),
))
}
}
pub trait GpuLinearAlgebra<T: Float + FromPrimitive + ScalarOperand + 'static> {
fn gpu_matmul(&self, other: &Self) -> RusTorchResult<Tensor<T>>;
fn gpu_batch_matmul(&self, other: &Self) -> RusTorchResult<Tensor<T>>;
fn gpu_matvec(&self, vector: &Self) -> RusTorchResult<Tensor<T>>;
}
impl<T: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static> GpuLinearAlgebra<T>
for Tensor<T>
{
fn gpu_matmul(&self, other: &Self) -> RusTorchResult<Tensor<T>> {
let device_type = if super::DeviceManager::is_cuda_available() {
super::DeviceType::Cuda(0)
} else if super::DeviceManager::is_metal_available() {
super::DeviceType::Metal(0)
} else {
super::DeviceType::Cpu
};
let executor = GpuBatchMatrixExecutor::<T>::new(device_type)?;
executor.batch_matmul(self, other)
}
fn gpu_batch_matmul(&self, other: &Self) -> RusTorchResult<Tensor<T>> {
let device_type = if super::DeviceManager::is_cuda_available() {
super::DeviceType::Cuda(0)
} else if super::DeviceManager::is_metal_available() {
super::DeviceType::Metal(0)
} else {
super::DeviceType::Cpu
};
let executor = GpuBatchMatrixExecutor::<T>::new(device_type)?;
executor.batch_matmul(self, other)
}
fn gpu_matvec(&self, vector: &Self) -> RusTorchResult<Tensor<T>> {
self.gpu_matmul(vector)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Tensor;
#[test]
fn test_gpu_matmul_cpu_fallback() {
let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]);
let result = a.gpu_matmul(&b).unwrap();
assert_eq!(result.shape(), &[2, 2]);
}
#[test]
fn test_gpu_matrix_executor_creation() {
println!("Matrix executor test skipped - see simple_metal_test for GPU testing");
}
#[test]
fn test_batch_matrix_executor() {
let executor = GpuBatchMatrixExecutor::<f32>::new(super::super::DeviceType::Cpu).unwrap();
let a = Tensor::<f32>::from_vec(vec![1.0, 2.0], vec![1, 2]);
let b = Tensor::<f32>::from_vec(vec![3.0, 4.0], vec![2, 1]);
let result = executor.batch_matmul(&a, &b).unwrap();
assert_eq!(result.shape(), &[1, 1]);
}
}