use crate::tensors::{WithGrad, Ten64};
use crate::ops::dispatch::{FnToDoubleTen64, FnF64Ten64, FnTen64To};
pub fn cuda_matmul(
a: &WithGrad<Ten64>,
b: &WithGrad<Ten64>,
) -> Option<(Ten64, Box<FnToDoubleTen64>)> {
super::wgpu::wgpu_matmul(a, b) }
pub fn cuda_mse_loss<'a>(
prediction: &'a WithGrad<Ten64>,
target: &'a Ten64,
) -> Option<(f64, Box<FnF64Ten64<'a>>)> {
super::wgpu::wgpu_mse_loss(prediction, target) }
pub fn cuda_relu(
input: &WithGrad<Ten64>,
) -> Option<(Ten64, Box<FnTen64To>)> {
super::wgpu::wgpu_relu(input) }
pub fn cuda_sgd(
w: &mut WithGrad<Ten64>,
lr: f64,
) -> bool {
super::wgpu::wgpu_sgd(w, lr) }