use crate::backend::{get_backend, Backend};
use crate::tensors::{WithGrad, Ten64};
pub type FnToDoubleTen64 = dyn Fn(&Ten64) -> (Ten64, Ten64);
pub type FnF64Ten64<'a> = dyn Fn(f64) -> Ten64 + 'a;
pub type FnTen64To<'a> = dyn Fn(&Ten64) -> Ten64 + 'a;
pub fn matmul(
a: &WithGrad<Ten64>,
b: &WithGrad<Ten64>,
) -> (Ten64, Box<FnToDoubleTen64>) {
match get_backend() {
Backend::Cuda => {
#[cfg(feature = "cuda")]
{
if let Some(result) = super::cuda::cuda_matmul(a, b) {
return result;
}
}
}
Backend::Wgpu => {
#[cfg(feature = "wgpu")]
{
if let Some(result) = super::wgpu::wgpu_matmul(a, b) {
return result;
}
}
}
_ => {}
}
super::cpu::matmul(a, b)
}
pub fn mse_loss<'a>(
prediction: &'a WithGrad<Ten64>,
target: &'a Ten64,
) -> (f64, Box<FnF64Ten64<'a>>) {
match get_backend() {
Backend::Cuda => {
#[cfg(feature = "cuda")]
{
if let Some(result) = super::cuda::cuda_mse_loss(prediction, target) {
return result;
}
}
}
Backend::Wgpu => {
#[cfg(feature = "wgpu")]
{
if let Some(result) = super::wgpu::wgpu_mse_loss(prediction, target) {
return result;
}
}
}
_ => {}
}
super::cpu::mse_loss(prediction, target)
}
pub fn sgd(w: &mut WithGrad<Ten64>, lr: f64) {
match get_backend() {
Backend::Cuda => {
#[cfg(feature = "cuda")]
{
if super::cuda::cuda_sgd(w, lr) {
return;
}
}
}
Backend::Wgpu => {
#[cfg(feature = "wgpu")]
{
if super::wgpu::wgpu_sgd(w, lr) {
return;
}
}
}
_ => {}
}
super::cpu::sgd(w, lr)
}
pub fn relu(
input: &WithGrad<Ten64>,
) -> (Ten64, Box<FnTen64To>) {
match get_backend() {
Backend::Cuda => {
#[cfg(feature = "cuda")]
{
if let Some(result) = super::cuda::cuda_relu(input) {
return result;
}
}
}
Backend::Wgpu => {
#[cfg(feature = "wgpu")]
{
if let Some(result) = super::wgpu::wgpu_relu(input) {
return result;
}
}
}
_ => {}
}
super::cpu::relu(input)
}