1use pyo3::exceptions::{PyNotImplementedError, PyTypeError, PyValueError};
18use pyo3::prelude::*;
19
20#[pyfunction]
27pub fn gpu_device_info() -> String {
28 "cpu (cuda_bridge feature not enabled)".to_string()
29}
30
31#[pyfunction]
48pub fn gpu_matmul(
49 a_data: Vec<f64>,
50 a_rows: usize,
51 a_cols: usize,
52 b_data: Vec<f64>,
53 b_cols: usize,
54) -> PyResult<Vec<f64>> {
55 if a_data.len() != a_rows * a_cols {
56 return Err(PyValueError::new_err(format!(
57 "a_data length {} does not match a_rows * a_cols = {} * {} = {}",
58 a_data.len(),
59 a_rows,
60 a_cols,
61 a_rows * a_cols,
62 )));
63 }
64 if b_data.len() != a_cols * b_cols {
65 return Err(PyValueError::new_err(format!(
66 "b_data length {} does not match a_cols * b_cols = {} * {} = {}",
67 b_data.len(),
68 a_cols,
69 b_cols,
70 a_cols * b_cols,
71 )));
72 }
73
74 let mut c = vec![0.0f64; a_rows * b_cols];
76 for i in 0..a_rows {
77 for k in 0..a_cols {
78 let a_ik = a_data[i * a_cols + k];
79 for j in 0..b_cols {
80 c[i * b_cols + j] += a_ik * b_data[k * b_cols + j];
81 }
82 }
83 }
84 Ok(c)
85}
86
87#[pyfunction]
100pub fn gpu_elementwise(data: Vec<f64>, op: &str) -> PyResult<Vec<f64>> {
101 let result: Vec<f64> = match op {
102 "exp" => data.iter().map(|&x| x.exp()).collect(),
103 "log" => data
104 .iter()
105 .map(|&x| if x > 0.0 { x.ln() } else { f64::NEG_INFINITY })
106 .collect(),
107 "sqrt" => data
108 .iter()
109 .map(|&x| if x >= 0.0 { x.sqrt() } else { f64::NAN })
110 .collect(),
111 "relu" => data.iter().map(|&x| x.max(0.0)).collect(),
112 "sigmoid" => data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
113 "tanh" => data.iter().map(|&x| x.tanh()).collect(),
114 "abs" => data.iter().map(|&x| x.abs()).collect(),
115 "square" => data.iter().map(|&x| x * x).collect(),
116 _ => {
117 return Err(PyValueError::new_err(format!(
118 "Unknown op '{op}'. Supported: exp, log, sqrt, relu, sigmoid, tanh, abs, square"
119 )))
120 }
121 };
122 Ok(result)
123}
124
125#[pyfunction]
134pub fn gpu_matrix_add(a_data: Vec<f64>, b_data: Vec<f64>) -> PyResult<Vec<f64>> {
135 if a_data.len() != b_data.len() {
136 return Err(PyValueError::new_err(format!(
137 "Length mismatch: a has {} elements, b has {}",
138 a_data.len(),
139 b_data.len(),
140 )));
141 }
142 Ok(a_data
143 .iter()
144 .zip(b_data.iter())
145 .map(|(&a, &b)| a + b)
146 .collect())
147}
148
149#[pyfunction]
151pub fn gpu_matrix_scale(data: Vec<f64>, scalar: f64) -> Vec<f64> {
152 data.iter().map(|&x| x * scalar).collect()
153}
154
155#[pyfunction]
157pub fn gpu_frobenius_norm(data: Vec<f64>) -> f64 {
158 data.iter().map(|&x| x * x).sum::<f64>().sqrt()
159}
160
161#[pyfunction]
184pub fn cuda_tensor_matmul<'py>(
185 _py: Python<'py>,
186 tensor_a: &Bound<'py, PyAny>,
187 _tensor_b: &Bound<'py, PyAny>,
188) -> PyResult<Py<PyAny>> {
189 let has_dlpack = tensor_a.hasattr("__dlpack__").unwrap_or(false);
192 if !has_dlpack {
193 return Err(PyTypeError::new_err(
194 "Tensors must implement the __dlpack__ protocol (e.g. PyTorch or JAX tensors)",
195 ));
196 }
197
198 Err(PyNotImplementedError::new_err(
200 "CUDA tensor bridge is not yet compiled in. \
201 Enable the `cuda_bridge` Cargo feature and install `cudarc`. \
202 For a CPU fallback that accepts Python lists, use gpu_matmul().",
203 ))
204}
205
206pub fn register_gpu_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
210 m.add_function(wrap_pyfunction!(gpu_device_info, m)?)?;
211 m.add_function(wrap_pyfunction!(gpu_matmul, m)?)?;
212 m.add_function(wrap_pyfunction!(gpu_elementwise, m)?)?;
213 m.add_function(wrap_pyfunction!(gpu_matrix_add, m)?)?;
214 m.add_function(wrap_pyfunction!(gpu_matrix_scale, m)?)?;
215 m.add_function(wrap_pyfunction!(gpu_frobenius_norm, m)?)?;
216 m.add_function(wrap_pyfunction!(cuda_tensor_matmul, m)?)?;
217 Ok(())
218}
219
220#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn test_gpu_device_info_non_empty() {
228 let info = gpu_device_info();
229 assert!(!info.is_empty());
230 assert!(info.contains("cpu"));
231 }
232
233 #[test]
234 fn test_matmul_2x2_identity() {
235 let id = vec![1.0, 0.0, 0.0, 1.0];
237 let b = vec![5.0, 6.0, 7.0, 8.0];
238 let c = gpu_matmul(id, 2, 2, b.clone(), 2).expect("matmul should not fail");
239 assert!((c[0] - 5.0).abs() < 1e-12);
240 assert!((c[1] - 6.0).abs() < 1e-12);
241 assert!((c[2] - 7.0).abs() < 1e-12);
242 assert!((c[3] - 8.0).abs() < 1e-12);
243 }
244
245 #[test]
246 fn test_matmul_2x2_general() {
247 let a = vec![1.0, 2.0, 3.0, 4.0];
249 let b = vec![5.0, 6.0, 7.0, 8.0];
250 let c = gpu_matmul(a, 2, 2, b, 2).expect("matmul should not fail");
251 assert!((c[0] - 19.0).abs() < 1e-12);
252 assert!((c[1] - 22.0).abs() < 1e-12);
253 assert!((c[2] - 43.0).abs() < 1e-12);
254 assert!((c[3] - 50.0).abs() < 1e-12);
255 }
256
257 #[test]
258 fn test_matmul_non_square() {
259 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
261 let b = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
262 let c = gpu_matmul(a, 2, 3, b, 2).expect("non-square matmul should succeed");
263 assert!((c[0] - 58.0).abs() < 1e-12);
264 assert!((c[1] - 64.0).abs() < 1e-12);
265 assert!((c[2] - 139.0).abs() < 1e-12);
266 assert!((c[3] - 154.0).abs() < 1e-12);
267 }
268
269 #[test]
270 fn test_matmul_a_length_mismatch_returns_error() {
271 let a = vec![1.0, 2.0]; let b = vec![1.0, 2.0, 3.0, 4.0];
273 assert!(gpu_matmul(a, 2, 2, b, 2).is_err());
274 }
275
276 #[test]
277 fn test_matmul_b_length_mismatch_returns_error() {
278 let a = vec![1.0, 2.0, 3.0, 4.0];
279 let b = vec![1.0, 2.0]; assert!(gpu_matmul(a, 2, 2, b, 2).is_err());
281 }
282
283 #[test]
284 fn test_elementwise_relu() {
285 let data = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
286 let out = gpu_elementwise(data, "relu").expect("relu should succeed");
287 assert_eq!(out, vec![0.0, 0.0, 0.0, 1.0, 2.0]);
288 }
289
290 #[test]
291 fn test_elementwise_sigmoid_bounds() {
292 let data = vec![-100.0, 0.0, 100.0];
293 let out = gpu_elementwise(data, "sigmoid").expect("sigmoid should succeed");
294 assert!(out[0] < 1e-3, "sigmoid(-100) should be near 0");
295 assert!((out[1] - 0.5).abs() < 1e-12, "sigmoid(0) should be 0.5");
296 assert!(out[2] > 1.0 - 1e-3, "sigmoid(100) should be near 1");
297 }
298
299 #[test]
300 fn test_elementwise_tanh() {
301 let data = vec![-1.0, 0.0, 1.0];
302 let out = gpu_elementwise(data, "tanh").expect("tanh should succeed");
303 assert!((out[1] - 0.0).abs() < 1e-12);
304 assert!((out[2] - 1.0_f64.tanh()).abs() < 1e-12);
305 }
306
307 #[test]
308 fn test_elementwise_exp_log_roundtrip() {
309 let data = vec![1.0, 2.0, 3.0];
310 let exped = gpu_elementwise(data.clone(), "exp").expect("exp should succeed");
311 let logged = gpu_elementwise(exped, "log").expect("log should succeed");
312 for (orig, rt) in data.iter().zip(logged.iter()) {
313 assert!((orig - rt).abs() < 1e-10, "exp-log roundtrip failed");
314 }
315 }
316
317 #[test]
318 fn test_elementwise_sqrt_non_negative() {
319 let data = vec![0.0, 1.0, 4.0, 9.0, 16.0];
320 let out = gpu_elementwise(data, "sqrt").expect("sqrt should succeed");
321 assert!((out[0] - 0.0).abs() < 1e-12);
322 assert!((out[1] - 1.0).abs() < 1e-12);
323 assert!((out[2] - 2.0).abs() < 1e-12);
324 assert!((out[4] - 4.0).abs() < 1e-12);
325 }
326
327 #[test]
328 fn test_elementwise_abs() {
329 let data = vec![-3.0, -1.5, 0.0, 2.5];
330 let out = gpu_elementwise(data, "abs").expect("abs should succeed");
331 assert_eq!(out, vec![3.0, 1.5, 0.0, 2.5]);
332 }
333
334 #[test]
335 fn test_elementwise_square() {
336 let data = vec![-2.0, 3.0];
337 let out = gpu_elementwise(data, "square").expect("square should succeed");
338 assert!((out[0] - 4.0).abs() < 1e-12);
339 assert!((out[1] - 9.0).abs() < 1e-12);
340 }
341
342 #[test]
343 fn test_elementwise_unknown_op_returns_error() {
344 let data = vec![1.0, 2.0];
345 assert!(gpu_elementwise(data, "unknown_activation").is_err());
346 }
347
348 #[test]
349 fn test_matrix_add_correct() {
350 let a = vec![1.0, 2.0, 3.0];
351 let b = vec![4.0, 5.0, 6.0];
352 let out = gpu_matrix_add(a, b).expect("matrix_add should succeed");
353 assert_eq!(out, vec![5.0, 7.0, 9.0]);
354 }
355
356 #[test]
357 fn test_matrix_add_length_mismatch_returns_error() {
358 let a = vec![1.0, 2.0, 3.0];
359 let b = vec![4.0, 5.0];
360 assert!(gpu_matrix_add(a, b).is_err());
361 }
362
363 #[test]
364 fn test_matrix_scale() {
365 let data = vec![1.0, 2.0, 3.0, 4.0];
366 let out = gpu_matrix_scale(data, 2.5);
367 assert_eq!(out, vec![2.5, 5.0, 7.5, 10.0]);
368 }
369
370 #[test]
371 fn test_frobenius_norm_identity() {
372 let id = vec![1.0, 0.0, 0.0, 1.0];
374 let norm = gpu_frobenius_norm(id);
375 assert!((norm - 2.0_f64.sqrt()).abs() < 1e-12);
376 }
377}