use ferrotorch_core::{FerrotorchError, FerrotorchResult, Tensor};
use crate::runtime::CubeRuntime;
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
use ferrotorch_core::from_vec;
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
use crate::kernels;
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
use crate::runtime::CubeClient;
fn check_same_shape(a: &Tensor<f32>, b: &Tensor<f32>) -> FerrotorchResult<()> {
if a.shape() != b.shape() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"cubecl elementwise: lhs {:?} vs rhs {:?}",
a.shape(),
b.shape()
),
});
}
Ok(())
}
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
fn contiguous_data(t: &Tensor<f32>) -> FerrotorchResult<Vec<f32>> {
match t.data() {
Ok(slice) => Ok(slice.to_vec()),
Err(_) => t.data_vec(),
}
}
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
macro_rules! dispatch_binary {
($rt:expr, $launcher:path, $a:expr, $b:expr) => {{
match $rt.client() {
#[cfg(feature = "wgpu")]
CubeClient::Wgpu(c) => $launcher(c, $a, $b),
#[cfg(feature = "cuda")]
CubeClient::Cuda(c) => $launcher(c, $a, $b),
#[cfg(feature = "rocm")]
CubeClient::Rocm(c) => $launcher(c, $a, $b),
}
}};
}
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
macro_rules! dispatch_unary {
($rt:expr, $launcher:path, $x:expr) => {{
match $rt.client() {
#[cfg(feature = "wgpu")]
CubeClient::Wgpu(c) => $launcher(c, $x),
#[cfg(feature = "cuda")]
CubeClient::Cuda(c) => $launcher(c, $x),
#[cfg(feature = "rocm")]
CubeClient::Rocm(c) => $launcher(c, $x),
}
}};
}
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
macro_rules! dispatch_matmul {
($rt:expr, $a:expr, $b:expr, $m:expr, $k:expr, $n:expr) => {{
match $rt.client() {
#[cfg(feature = "wgpu")]
CubeClient::Wgpu(c) => kernels::run_matmul(c, $a, $b, $m, $k, $n),
#[cfg(feature = "cuda")]
CubeClient::Cuda(c) => kernels::run_matmul(c, $a, $b, $m, $k, $n),
#[cfg(feature = "rocm")]
CubeClient::Rocm(c) => kernels::run_matmul(c, $a, $b, $m, $k, $n),
}
}};
}
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
pub fn portable_add(
a: &Tensor<f32>,
b: &Tensor<f32>,
rt: &CubeRuntime,
) -> FerrotorchResult<Tensor<f32>> {
check_same_shape(a, b)?;
let a_data = contiguous_data(a)?;
let b_data = contiguous_data(b)?;
let out = dispatch_binary!(rt, kernels::run_add, &a_data, &b_data);
from_vec(out, a.shape())
}
#[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
pub fn portable_add(
a: &Tensor<f32>,
b: &Tensor<f32>,
_rt: &CubeRuntime,
) -> FerrotorchResult<Tensor<f32>> {
check_same_shape(a, b)?;
Err(FerrotorchError::DeviceUnavailable)
}
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
pub fn portable_sub(
a: &Tensor<f32>,
b: &Tensor<f32>,
rt: &CubeRuntime,
) -> FerrotorchResult<Tensor<f32>> {
check_same_shape(a, b)?;
let a_data = contiguous_data(a)?;
let b_data = contiguous_data(b)?;
let out = dispatch_binary!(rt, kernels::run_sub, &a_data, &b_data);
from_vec(out, a.shape())
}
#[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
pub fn portable_sub(
a: &Tensor<f32>,
b: &Tensor<f32>,
_rt: &CubeRuntime,
) -> FerrotorchResult<Tensor<f32>> {
check_same_shape(a, b)?;
Err(FerrotorchError::DeviceUnavailable)
}
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
pub fn portable_mul(
a: &Tensor<f32>,
b: &Tensor<f32>,
rt: &CubeRuntime,
) -> FerrotorchResult<Tensor<f32>> {
check_same_shape(a, b)?;
let a_data = contiguous_data(a)?;
let b_data = contiguous_data(b)?;
let out = dispatch_binary!(rt, kernels::run_mul, &a_data, &b_data);
from_vec(out, a.shape())
}
#[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
pub fn portable_mul(
a: &Tensor<f32>,
b: &Tensor<f32>,
_rt: &CubeRuntime,
) -> FerrotorchResult<Tensor<f32>> {
check_same_shape(a, b)?;
Err(FerrotorchError::DeviceUnavailable)
}
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
pub fn portable_relu(x: &Tensor<f32>, rt: &CubeRuntime) -> FerrotorchResult<Tensor<f32>> {
let x_data = contiguous_data(x)?;
let out = dispatch_unary!(rt, kernels::run_relu, &x_data);
from_vec(out, x.shape())
}
#[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
pub fn portable_relu(_x: &Tensor<f32>, _rt: &CubeRuntime) -> FerrotorchResult<Tensor<f32>> {
Err(FerrotorchError::DeviceUnavailable)
}
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
pub fn portable_matmul(
a: &Tensor<f32>,
b: &Tensor<f32>,
rt: &CubeRuntime,
) -> FerrotorchResult<Tensor<f32>> {
let (m, k, n) = check_matmul_shapes(a, b)?;
let a_data = contiguous_data(a)?;
let b_data = contiguous_data(b)?;
let out = dispatch_matmul!(rt, &a_data, &b_data, m, k, n);
from_vec(out, &[m, n])
}
#[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
pub fn portable_matmul(
a: &Tensor<f32>,
b: &Tensor<f32>,
_rt: &CubeRuntime,
) -> FerrotorchResult<Tensor<f32>> {
check_matmul_shapes(a, b)?;
Err(FerrotorchError::DeviceUnavailable)
}
fn check_matmul_shapes(
a: &Tensor<f32>,
b: &Tensor<f32>,
) -> FerrotorchResult<(usize, usize, usize)> {
let a_shape = a.shape();
let b_shape = b.shape();
if a_shape.len() != 2 || b_shape.len() != 2 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"cubecl matmul requires 2-D tensors, got lhs {:?} rhs {:?}",
a_shape, b_shape
),
});
}
if a_shape[1] != b_shape[0] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"cubecl matmul inner dims mismatch: lhs {:?} vs rhs {:?}",
a_shape, b_shape
),
});
}
Ok((a_shape[0], a_shape[1], b_shape[1]))
}
macro_rules! define_portable_unary {
($name:ident, $runner:ident) => {
#[doc = concat!("Elementwise `", stringify!($name), "(x)` on the GPU.")]
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
pub fn $name(x: &Tensor<f32>, rt: &CubeRuntime) -> FerrotorchResult<Tensor<f32>> {
let x_data = contiguous_data(x)?;
let out = dispatch_unary!(rt, kernels::$runner, &x_data);
from_vec(out, x.shape())
}
#[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
pub fn $name(_x: &Tensor<f32>, _rt: &CubeRuntime) -> FerrotorchResult<Tensor<f32>> {
Err(FerrotorchError::DeviceUnavailable)
}
};
}
macro_rules! define_portable_binary {
($name:ident, $runner:ident) => {
#[doc = concat!("Elementwise `a ", stringify!($name), " b` on the GPU.")]
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
pub fn $name(
a: &Tensor<f32>,
b: &Tensor<f32>,
rt: &CubeRuntime,
) -> FerrotorchResult<Tensor<f32>> {
check_same_shape(a, b)?;
let a_data = contiguous_data(a)?;
let b_data = contiguous_data(b)?;
let out = dispatch_binary!(rt, kernels::$runner, &a_data, &b_data);
from_vec(out, a.shape())
}
#[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
pub fn $name(
a: &Tensor<f32>,
b: &Tensor<f32>,
_rt: &CubeRuntime,
) -> FerrotorchResult<Tensor<f32>> {
check_same_shape(a, b)?;
Err(FerrotorchError::DeviceUnavailable)
}
};
}
define_portable_binary!(portable_div, run_div);
define_portable_unary!(portable_neg, run_neg);
define_portable_unary!(portable_abs, run_abs);
define_portable_unary!(portable_exp, run_exp);
define_portable_unary!(portable_ln, run_ln);
define_portable_unary!(portable_sqrt, run_sqrt);
define_portable_unary!(portable_sin, run_sin);
define_portable_unary!(portable_cos, run_cos);
define_portable_unary!(portable_tanh, run_tanh);
define_portable_unary!(portable_sigmoid, run_sigmoid);
#[cfg(all(test, feature = "wgpu"))]
mod tests {
use super::*;
use crate::runtime::CubeDevice;
fn runtime() -> CubeRuntime {
CubeRuntime::new(CubeDevice::Wgpu(0)).expect("wgpu runtime init")
}
#[test]
fn portable_add_runs_on_gpu() {
let rt = runtime();
let a = ferrotorch_core::tensor(&[1.0_f32, 2.0, 3.0, 4.0]).unwrap();
let b = ferrotorch_core::tensor(&[10.0_f32, 20.0, 30.0, 40.0]).unwrap();
let c = portable_add(&a, &b, &rt).unwrap();
let data: &[f32] = c.data().unwrap();
assert_eq!(data, &[11.0, 22.0, 33.0, 44.0]);
}
#[test]
fn portable_sub_runs_on_gpu() {
let rt = runtime();
let a = ferrotorch_core::tensor(&[10.0_f32, 20.0, 30.0]).unwrap();
let b = ferrotorch_core::tensor(&[1.0_f32, 2.0, 3.0]).unwrap();
let c = portable_sub(&a, &b, &rt).unwrap();
let data: &[f32] = c.data().unwrap();
assert_eq!(data, &[9.0, 18.0, 27.0]);
}
#[test]
fn portable_mul_runs_on_gpu() {
let rt = runtime();
let a = ferrotorch_core::tensor(&[2.0_f32, 3.0, 4.0]).unwrap();
let b = ferrotorch_core::tensor(&[10.0_f32, 10.0, 10.0]).unwrap();
let c = portable_mul(&a, &b, &rt).unwrap();
let data: &[f32] = c.data().unwrap();
assert_eq!(data, &[20.0, 30.0, 40.0]);
}
#[test]
fn portable_relu_runs_on_gpu() {
let rt = runtime();
let a = ferrotorch_core::tensor(&[-3.0_f32, -1.0, 0.0, 1.0, 3.0]).unwrap();
let c = portable_relu(&a, &rt).unwrap();
let data: &[f32] = c.data().unwrap();
assert_eq!(data, &[0.0, 0.0, 0.0, 1.0, 3.0]);
}
#[test]
fn portable_matmul_runs_on_gpu() {
let rt = runtime();
let a = ferrotorch_core::from_vec(vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
let b =
ferrotorch_core::from_vec(vec![7.0_f32, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2]).unwrap();
let c = portable_matmul(&a, &b, &rt).unwrap();
assert_eq!(c.shape(), &[2, 2]);
let data: &[f32] = c.data().unwrap();
assert!((data[0] - 58.0).abs() < 1e-4);
assert!((data[1] - 64.0).abs() < 1e-4);
assert!((data[2] - 139.0).abs() < 1e-4);
assert!((data[3] - 154.0).abs() < 1e-4);
}
#[test]
fn portable_matmul_rejects_rank_mismatch() {
let rt = runtime();
let a = ferrotorch_core::tensor(&[1.0_f32, 2.0, 3.0]).unwrap();
let b = ferrotorch_core::from_vec(vec![1.0_f32, 2.0, 3.0], &[3, 1]).unwrap();
let err = portable_matmul(&a, &b, &rt).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
#[test]
fn portable_matmul_rejects_inner_dim_mismatch() {
let rt = runtime();
let a = ferrotorch_core::from_vec(vec![1.0_f32; 6], &[2, 3]).unwrap();
let b = ferrotorch_core::from_vec(vec![1.0_f32; 8], &[4, 2]).unwrap();
let err = portable_matmul(&a, &b, &rt).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
#[test]
fn portable_add_rejects_shape_mismatch() {
let rt = runtime();
let a = ferrotorch_core::tensor(&[1.0_f32, 2.0, 3.0]).unwrap();
let b = ferrotorch_core::tensor(&[1.0_f32, 2.0]).unwrap();
let err = portable_add(&a, &b, &rt).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
#[test]
fn portable_add_large_shape() {
let rt = runtime();
let n: usize = 1024;
let a: Vec<f32> = (0..n).map(|i| i as f32).collect();
let b: Vec<f32> = (0..n).map(|i| (2 * i) as f32).collect();
let ta = ferrotorch_core::from_vec(a.clone(), &[n]).unwrap();
let tb = ferrotorch_core::from_vec(b.clone(), &[n]).unwrap();
let tc = portable_add(&ta, &tb, &rt).unwrap();
let data: &[f32] = tc.data().unwrap();
for i in 0..n {
assert_eq!(data[i], a[i] + b[i]);
}
}
#[test]
fn portable_matmul_square_8x8() {
let rt = runtime();
let n = 8usize;
let mut a = vec![0.0_f32; n * n];
for i in 0..n {
a[i * n + i] = 1.0;
}
let b: Vec<f32> = (0..n * n).map(|i| i as f32).collect();
let ta = ferrotorch_core::from_vec(a, &[n, n]).unwrap();
let tb = ferrotorch_core::from_vec(b.clone(), &[n, n]).unwrap();
let c = portable_matmul(&ta, &tb, &rt).unwrap();
let data: &[f32] = c.data().unwrap();
assert_eq!(data, b.as_slice(), "I * B should equal B");
}
fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
assert_eq!(actual.len(), expected.len());
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() < tol,
"at index {i}: expected {e}, got {a} (tol={tol})"
);
}
}
#[test]
fn portable_div_runs_on_gpu() {
let rt = runtime();
let a = ferrotorch_core::tensor(&[10.0_f32, 20.0, 30.0, 40.0]).unwrap();
let b = ferrotorch_core::tensor(&[2.0_f32, 4.0, 5.0, 8.0]).unwrap();
let c = portable_div(&a, &b, &rt).unwrap();
let data: &[f32] = c.data().unwrap();
assert_close(data, &[5.0, 5.0, 6.0, 5.0], 1e-4);
}
#[test]
fn portable_neg_runs_on_gpu() {
let rt = runtime();
let a = ferrotorch_core::tensor(&[-3.0_f32, -1.0, 0.0, 1.0, 3.0]).unwrap();
let c = portable_neg(&a, &rt).unwrap();
let data: &[f32] = c.data().unwrap();
assert_close(data, &[3.0, 1.0, 0.0, -1.0, -3.0], 1e-6);
}
#[test]
fn portable_abs_runs_on_gpu() {
let rt = runtime();
let a = ferrotorch_core::tensor(&[-3.0_f32, -1.0, 0.0, 1.0, 3.0]).unwrap();
let c = portable_abs(&a, &rt).unwrap();
let data: &[f32] = c.data().unwrap();
assert_close(data, &[3.0, 1.0, 0.0, 1.0, 3.0], 1e-6);
}
#[test]
fn portable_exp_runs_on_gpu() {
let rt = runtime();
let a = ferrotorch_core::tensor(&[0.0_f32, 1.0, 2.0, -1.0]).unwrap();
let c = portable_exp(&a, &rt).unwrap();
let data: &[f32] = c.data().unwrap();
let expected: Vec<f32> = [0.0, 1.0, 2.0, -1.0].iter().map(|x: &f32| x.exp()).collect();
assert_close(data, &expected, 1e-3);
}
#[test]
fn portable_ln_runs_on_gpu() {
let rt = runtime();
let a = ferrotorch_core::tensor(&[1.0_f32, std::f32::consts::E, 10.0, 100.0]).unwrap();
let c = portable_ln(&a, &rt).unwrap();
let data: &[f32] = c.data().unwrap();
let expected: Vec<f32> = [1.0_f32, std::f32::consts::E, 10.0, 100.0]
.iter()
.map(|x| x.ln())
.collect();
assert_close(data, &expected, 1e-3);
}
#[test]
fn portable_sqrt_runs_on_gpu() {
let rt = runtime();
let a = ferrotorch_core::tensor(&[0.0_f32, 1.0, 4.0, 9.0, 16.0]).unwrap();
let c = portable_sqrt(&a, &rt).unwrap();
let data: &[f32] = c.data().unwrap();
assert_close(data, &[0.0, 1.0, 2.0, 3.0, 4.0], 1e-4);
}
#[test]
fn portable_sin_runs_on_gpu() {
let rt = runtime();
let pi = std::f32::consts::PI;
let a = ferrotorch_core::tensor(&[0.0_f32, pi / 2.0, pi, 3.0 * pi / 2.0]).unwrap();
let c = portable_sin(&a, &rt).unwrap();
let data: &[f32] = c.data().unwrap();
assert_close(data, &[0.0, 1.0, 0.0, -1.0], 1e-3);
}
#[test]
fn portable_cos_runs_on_gpu() {
let rt = runtime();
let pi = std::f32::consts::PI;
let a = ferrotorch_core::tensor(&[0.0_f32, pi / 2.0, pi, 3.0 * pi / 2.0]).unwrap();
let c = portable_cos(&a, &rt).unwrap();
let data: &[f32] = c.data().unwrap();
assert_close(data, &[1.0, 0.0, -1.0, 0.0], 1e-3);
}
#[test]
fn portable_tanh_runs_on_gpu() {
let rt = runtime();
let a = ferrotorch_core::tensor(&[-2.0_f32, -0.5, 0.0, 0.5, 2.0]).unwrap();
let c = portable_tanh(&a, &rt).unwrap();
let data: &[f32] = c.data().unwrap();
let expected: Vec<f32> = [-2.0_f32, -0.5, 0.0, 0.5, 2.0]
.iter()
.map(|x| x.tanh())
.collect();
assert_close(data, &expected, 1e-4);
}
#[test]
fn portable_sigmoid_runs_on_gpu() {
let rt = runtime();
let a = ferrotorch_core::tensor(&[-4.0_f32, -1.0, 0.0, 1.0, 4.0]).unwrap();
let c = portable_sigmoid(&a, &rt).unwrap();
let data: &[f32] = c.data().unwrap();
let expected: Vec<f32> = [-4.0_f32, -1.0, 0.0, 1.0, 4.0]
.iter()
.map(|x| 1.0 / (1.0 + (-x).exp()))
.collect();
assert_close(data, &expected, 1e-4);
}
#[test]
fn portable_sigmoid_large_shape() {
let rt = runtime();
let n: usize = 1024;
let input: Vec<f32> = (0..n).map(|i| (i as f32 - 512.0) * 0.01).collect();
let t = ferrotorch_core::from_vec(input.clone(), &[n]).unwrap();
let out = portable_sigmoid(&t, &rt).unwrap();
let data: &[f32] = out.data().unwrap();
for (i, &x) in input.iter().enumerate() {
let expected = 1.0 / (1.0 + (-x).exp());
assert!(
(data[i] - expected).abs() < 1e-4,
"sigmoid mismatch at {i}: got {}, expected {}",
data[i],
expected
);
}
}
#[test]
fn portable_exp_then_ln_is_identity() {
let rt = runtime();
let a = ferrotorch_core::tensor(&[-2.0_f32, -0.5, 0.1, 1.0, 3.0]).unwrap();
let e = portable_exp(&a, &rt).unwrap();
let back = portable_ln(&e, &rt).unwrap();
let data: &[f32] = back.data().unwrap();
assert_close(data, &[-2.0, -0.5, 0.1, 1.0, 3.0], 1e-3);
}
#[test]
fn portable_div_rejects_shape_mismatch() {
let rt = runtime();
let a = ferrotorch_core::tensor(&[1.0_f32, 2.0, 3.0]).unwrap();
let b = ferrotorch_core::tensor(&[1.0_f32, 2.0]).unwrap();
let err = portable_div(&a, &b, &rt).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
}
#[cfg(all(test, not(any(feature = "wgpu", feature = "cuda", feature = "rocm"))))]
mod no_backend_tests {
use super::*;
use crate::runtime::CubeDevice;
#[test]
fn runtime_construction_errors_without_backend() {
let err = CubeRuntime::new(CubeDevice::Wgpu(0)).unwrap_err();
assert!(matches!(err, FerrotorchError::DeviceUnavailable));
}
}