use scivex_gpu::GpuTensor;
use super::variable::GpuVariable;
pub fn gpu_add(a: &GpuVariable, b: &GpuVariable) -> GpuVariable {
let result = a.with_data(|a_gpu| {
b.with_data(|b_gpu| scivex_gpu::ops::add(a_gpu, b_gpu).expect("gpu add forward"))
});
let device = a.device();
GpuVariable::from_op(
result,
vec![a.clone(), b.clone()],
Box::new(move |grad: &GpuTensor| {
let g_cpu = grad.to_tensor().expect("grad download");
let ga = GpuTensor::from_tensor(&device, &g_cpu);
let gb = GpuTensor::from_tensor(&device, &g_cpu);
vec![ga, gb]
}),
)
}
pub fn gpu_sub(a: &GpuVariable, b: &GpuVariable) -> GpuVariable {
let result = a.with_data(|a_gpu| {
b.with_data(|b_gpu| scivex_gpu::ops::sub(a_gpu, b_gpu).expect("gpu sub forward"))
});
let device = a.device();
GpuVariable::from_op(
result,
vec![a.clone(), b.clone()],
Box::new(move |grad: &GpuTensor| {
let g_cpu = grad.to_tensor().expect("grad download");
let ga = GpuTensor::from_tensor(&device, &g_cpu);
let neg_g = scivex_gpu::ops::negate(&ga).expect("negate grad");
vec![ga, neg_g]
}),
)
}
pub fn gpu_mul(a: &GpuVariable, b: &GpuVariable) -> GpuVariable {
let a_data_cpu = a.data_cpu().expect("download a for backward");
let b_data_cpu = b.data_cpu().expect("download b for backward");
let device = a.device();
let result = a.with_data(|a_gpu| {
b.with_data(|b_gpu| scivex_gpu::ops::mul(a_gpu, b_gpu).expect("gpu mul forward"))
});
GpuVariable::from_op(
result,
vec![a.clone(), b.clone()],
Box::new(move |grad: &GpuTensor| {
let a_re = GpuTensor::from_tensor(&device, &a_data_cpu);
let b_re = GpuTensor::from_tensor(&device, &b_data_cpu);
let ga = scivex_gpu::ops::mul(grad, &b_re).expect("grad_a = grad * b");
let gb = scivex_gpu::ops::mul(grad, &a_re).expect("grad_b = grad * a");
vec![ga, gb]
}),
)
}
pub fn gpu_neg(a: &GpuVariable) -> GpuVariable {
let result = a.with_data(|a_gpu| scivex_gpu::ops::negate(a_gpu).expect("gpu negate forward"));
GpuVariable::from_op(
result,
vec![a.clone()],
Box::new(|grad: &GpuTensor| vec![scivex_gpu::ops::negate(grad).expect("negate grad")]),
)
}
pub fn gpu_matmul(a: &GpuVariable, b: &GpuVariable) -> GpuVariable {
let a_data_cpu = a.data_cpu().expect("download a for backward");
let b_data_cpu = b.data_cpu().expect("download b for backward");
let device = a.device();
let result = a.with_data(|a_gpu| {
b.with_data(|b_gpu| scivex_gpu::ops::matmul(a_gpu, b_gpu).expect("gpu matmul forward"))
});
GpuVariable::from_op(
result,
vec![a.clone(), b.clone()],
Box::new(move |grad: &GpuTensor| {
let a_re = GpuTensor::from_tensor(&device, &a_data_cpu);
let b_re = GpuTensor::from_tensor(&device, &b_data_cpu);
let bt = scivex_gpu::ops::transpose(&b_re).expect("transpose b");
let ga = scivex_gpu::ops::matmul(grad, &bt).expect("grad @ b^T");
let at = scivex_gpu::ops::transpose(&a_re).expect("transpose a");
let gb = scivex_gpu::ops::matmul(&at, grad).expect("a^T @ grad");
vec![ga, gb]
}),
)
}
pub fn gpu_sum(a: &GpuVariable) -> GpuVariable {
let shape = a.shape();
let device = a.device();
let sum_val = a.with_data(|a_gpu| scivex_gpu::ops::sum(a_gpu).expect("gpu sum forward"));
let result =
GpuTensor::from_slice(&device, &[sum_val], vec![1]).expect("scalar tensor from sum");
let device2 = device.clone();
GpuVariable::from_op(
result,
vec![a.clone()],
Box::new(move |grad: &GpuTensor| {
let g_cpu = grad.to_tensor().expect("grad download");
let g_val = g_cpu.as_slice()[0];
let full = scivex_gpu::ops::fill(&device2, shape.clone(), g_val)
.expect("fill for sum backward");
vec![full]
}),
)
}
pub fn gpu_mean(a: &GpuVariable) -> GpuVariable {
let n = a.numel();
let shape = a.shape();
let device = a.device();
let mean_val = a.with_data(|a_gpu| scivex_gpu::ops::mean(a_gpu).expect("gpu mean forward"));
let result =
GpuTensor::from_slice(&device, &[mean_val], vec![1]).expect("scalar tensor from mean");
let device2 = device.clone();
GpuVariable::from_op(
result,
vec![a.clone()],
Box::new(move |grad: &GpuTensor| {
let g_cpu = grad.to_tensor().expect("grad download");
let g_val = g_cpu.as_slice()[0];
let scale = g_val / n as f32;
let full = scivex_gpu::ops::fill(&device2, shape.clone(), scale)
.expect("fill for mean backward");
vec![full]
}),
)
}
pub fn gpu_scalar_mul(a: &GpuVariable, scalar: f32) -> GpuVariable {
let result = a.with_data(|a_gpu| {
scivex_gpu::ops::mul_scalar(a_gpu, scalar).expect("gpu mul_scalar forward")
});
GpuVariable::from_op(
result,
vec![a.clone()],
Box::new(move |grad: &GpuTensor| {
vec![scivex_gpu::ops::mul_scalar(grad, scalar).expect("grad * scalar")]
}),
)
}