use pyo3::exceptions::{PyNotImplementedError, PyTypeError, PyValueError};
use pyo3::prelude::*;
#[pyfunction]
pub fn gpu_device_info() -> String {
"cpu (cuda_bridge feature not enabled)".to_string()
}
#[pyfunction]
pub fn gpu_matmul(
a_data: Vec<f64>,
a_rows: usize,
a_cols: usize,
b_data: Vec<f64>,
b_cols: usize,
) -> PyResult<Vec<f64>> {
if a_data.len() != a_rows * a_cols {
return Err(PyValueError::new_err(format!(
"a_data length {} does not match a_rows * a_cols = {} * {} = {}",
a_data.len(),
a_rows,
a_cols,
a_rows * a_cols,
)));
}
if b_data.len() != a_cols * b_cols {
return Err(PyValueError::new_err(format!(
"b_data length {} does not match a_cols * b_cols = {} * {} = {}",
b_data.len(),
a_cols,
b_cols,
a_cols * b_cols,
)));
}
let mut c = vec![0.0f64; a_rows * b_cols];
for i in 0..a_rows {
for k in 0..a_cols {
let a_ik = a_data[i * a_cols + k];
for j in 0..b_cols {
c[i * b_cols + j] += a_ik * b_data[k * b_cols + j];
}
}
}
Ok(c)
}
#[pyfunction]
pub fn gpu_elementwise(data: Vec<f64>, op: &str) -> PyResult<Vec<f64>> {
let result: Vec<f64> = match op {
"exp" => data.iter().map(|&x| x.exp()).collect(),
"log" => data
.iter()
.map(|&x| if x > 0.0 { x.ln() } else { f64::NEG_INFINITY })
.collect(),
"sqrt" => data
.iter()
.map(|&x| if x >= 0.0 { x.sqrt() } else { f64::NAN })
.collect(),
"relu" => data.iter().map(|&x| x.max(0.0)).collect(),
"sigmoid" => data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
"tanh" => data.iter().map(|&x| x.tanh()).collect(),
"abs" => data.iter().map(|&x| x.abs()).collect(),
"square" => data.iter().map(|&x| x * x).collect(),
_ => {
return Err(PyValueError::new_err(format!(
"Unknown op '{op}'. Supported: exp, log, sqrt, relu, sigmoid, tanh, abs, square"
)))
}
};
Ok(result)
}
#[pyfunction]
pub fn gpu_matrix_add(a_data: Vec<f64>, b_data: Vec<f64>) -> PyResult<Vec<f64>> {
if a_data.len() != b_data.len() {
return Err(PyValueError::new_err(format!(
"Length mismatch: a has {} elements, b has {}",
a_data.len(),
b_data.len(),
)));
}
Ok(a_data
.iter()
.zip(b_data.iter())
.map(|(&a, &b)| a + b)
.collect())
}
#[pyfunction]
pub fn gpu_matrix_scale(data: Vec<f64>, scalar: f64) -> Vec<f64> {
data.iter().map(|&x| x * scalar).collect()
}
#[pyfunction]
pub fn gpu_frobenius_norm(data: Vec<f64>) -> f64 {
data.iter().map(|&x| x * x).sum::<f64>().sqrt()
}
#[pyfunction]
pub fn cuda_tensor_matmul<'py>(
_py: Python<'py>,
tensor_a: &Bound<'py, PyAny>,
_tensor_b: &Bound<'py, PyAny>,
) -> PyResult<Py<PyAny>> {
let has_dlpack = tensor_a.hasattr("__dlpack__").unwrap_or(false);
if !has_dlpack {
return Err(PyTypeError::new_err(
"Tensors must implement the __dlpack__ protocol (e.g. PyTorch or JAX tensors)",
));
}
Err(PyNotImplementedError::new_err(
"CUDA tensor bridge is not yet compiled in. \
Enable the `cuda_bridge` Cargo feature and install `cudarc`. \
For a CPU fallback that accepts Python lists, use gpu_matmul().",
))
}
pub fn register_gpu_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(gpu_device_info, m)?)?;
m.add_function(wrap_pyfunction!(gpu_matmul, m)?)?;
m.add_function(wrap_pyfunction!(gpu_elementwise, m)?)?;
m.add_function(wrap_pyfunction!(gpu_matrix_add, m)?)?;
m.add_function(wrap_pyfunction!(gpu_matrix_scale, m)?)?;
m.add_function(wrap_pyfunction!(gpu_frobenius_norm, m)?)?;
m.add_function(wrap_pyfunction!(cuda_tensor_matmul, m)?)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_device_info_non_empty() {
let info = gpu_device_info();
assert!(!info.is_empty());
assert!(info.contains("cpu"));
}
#[test]
fn test_matmul_2x2_identity() {
let id = vec![1.0, 0.0, 0.0, 1.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let c = gpu_matmul(id, 2, 2, b.clone(), 2).expect("matmul should not fail");
assert!((c[0] - 5.0).abs() < 1e-12);
assert!((c[1] - 6.0).abs() < 1e-12);
assert!((c[2] - 7.0).abs() < 1e-12);
assert!((c[3] - 8.0).abs() < 1e-12);
}
#[test]
fn test_matmul_2x2_general() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let c = gpu_matmul(a, 2, 2, b, 2).expect("matmul should not fail");
assert!((c[0] - 19.0).abs() < 1e-12);
assert!((c[1] - 22.0).abs() < 1e-12);
assert!((c[2] - 43.0).abs() < 1e-12);
assert!((c[3] - 50.0).abs() < 1e-12);
}
#[test]
fn test_matmul_non_square() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let c = gpu_matmul(a, 2, 3, b, 2).expect("non-square matmul should succeed");
assert!((c[0] - 58.0).abs() < 1e-12);
assert!((c[1] - 64.0).abs() < 1e-12);
assert!((c[2] - 139.0).abs() < 1e-12);
assert!((c[3] - 154.0).abs() < 1e-12);
}
#[test]
fn test_matmul_a_length_mismatch_returns_error() {
let a = vec![1.0, 2.0]; let b = vec![1.0, 2.0, 3.0, 4.0];
assert!(gpu_matmul(a, 2, 2, b, 2).is_err());
}
#[test]
fn test_matmul_b_length_mismatch_returns_error() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![1.0, 2.0]; assert!(gpu_matmul(a, 2, 2, b, 2).is_err());
}
#[test]
fn test_elementwise_relu() {
let data = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
let out = gpu_elementwise(data, "relu").expect("relu should succeed");
assert_eq!(out, vec![0.0, 0.0, 0.0, 1.0, 2.0]);
}
#[test]
fn test_elementwise_sigmoid_bounds() {
let data = vec![-100.0, 0.0, 100.0];
let out = gpu_elementwise(data, "sigmoid").expect("sigmoid should succeed");
assert!(out[0] < 1e-3, "sigmoid(-100) should be near 0");
assert!((out[1] - 0.5).abs() < 1e-12, "sigmoid(0) should be 0.5");
assert!(out[2] > 1.0 - 1e-3, "sigmoid(100) should be near 1");
}
#[test]
fn test_elementwise_tanh() {
let data = vec![-1.0, 0.0, 1.0];
let out = gpu_elementwise(data, "tanh").expect("tanh should succeed");
assert!((out[1] - 0.0).abs() < 1e-12);
assert!((out[2] - 1.0_f64.tanh()).abs() < 1e-12);
}
#[test]
fn test_elementwise_exp_log_roundtrip() {
let data = vec![1.0, 2.0, 3.0];
let exped = gpu_elementwise(data.clone(), "exp").expect("exp should succeed");
let logged = gpu_elementwise(exped, "log").expect("log should succeed");
for (orig, rt) in data.iter().zip(logged.iter()) {
assert!((orig - rt).abs() < 1e-10, "exp-log roundtrip failed");
}
}
#[test]
fn test_elementwise_sqrt_non_negative() {
let data = vec![0.0, 1.0, 4.0, 9.0, 16.0];
let out = gpu_elementwise(data, "sqrt").expect("sqrt should succeed");
assert!((out[0] - 0.0).abs() < 1e-12);
assert!((out[1] - 1.0).abs() < 1e-12);
assert!((out[2] - 2.0).abs() < 1e-12);
assert!((out[4] - 4.0).abs() < 1e-12);
}
#[test]
fn test_elementwise_abs() {
let data = vec![-3.0, -1.5, 0.0, 2.5];
let out = gpu_elementwise(data, "abs").expect("abs should succeed");
assert_eq!(out, vec![3.0, 1.5, 0.0, 2.5]);
}
#[test]
fn test_elementwise_square() {
let data = vec![-2.0, 3.0];
let out = gpu_elementwise(data, "square").expect("square should succeed");
assert!((out[0] - 4.0).abs() < 1e-12);
assert!((out[1] - 9.0).abs() < 1e-12);
}
#[test]
fn test_elementwise_unknown_op_returns_error() {
let data = vec![1.0, 2.0];
assert!(gpu_elementwise(data, "unknown_activation").is_err());
}
#[test]
fn test_matrix_add_correct() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let out = gpu_matrix_add(a, b).expect("matrix_add should succeed");
assert_eq!(out, vec![5.0, 7.0, 9.0]);
}
#[test]
fn test_matrix_add_length_mismatch_returns_error() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0];
assert!(gpu_matrix_add(a, b).is_err());
}
#[test]
fn test_matrix_scale() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let out = gpu_matrix_scale(data, 2.5);
assert_eq!(out, vec![2.5, 5.0, 7.5, 10.0]);
}
#[test]
fn test_frobenius_norm_identity() {
let id = vec![1.0, 0.0, 0.0, 1.0];
let norm = gpu_frobenius_norm(id);
assert!((norm - 2.0_f64.sqrt()).abs() < 1e-12);
}
}