use crate::autograd::Var;
use crate::autograd::var_ops::ops::CastBackward;
use crate::dtype::DType;
use crate::error::Result;
use crate::ops::TypeConversionOps;
use crate::runtime::{Runtime, RuntimeClient};
use std::sync::Arc;
pub fn var_cast<R, C>(a: &Var<R>, dtype: DType, client: &C) -> Result<Var<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TypeConversionOps<R>,
R::Client: TypeConversionOps<R>,
{
let input_dtype = a.tensor().dtype();
if input_dtype == dtype {
return Ok(Var::with_id(a.tensor().clone(), a.id(), a.requires_grad()));
}
let output = client.cast(a.tensor(), dtype)?;
if a.requires_grad() {
let grad_fn = CastBackward::<R>::new(a.id(), input_dtype);
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::backward;
use crate::runtime::cpu::{CpuDevice, CpuRuntime};
use crate::tensor::Tensor;
#[test]
fn test_var_cast_noop_same_dtype() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let t = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let v = Var::new(t, true);
let result = var_cast(&v, DType::F32, &client).unwrap();
assert_eq!(result.tensor().dtype(), DType::F32);
let data = result.tensor().to_vec::<f32>();
assert_eq!(data, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_var_cast_f32_to_f64_gradient() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let t = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let x = Var::new(t, true);
let y = var_cast(&x, DType::F64, &client).unwrap();
assert_eq!(y.tensor().dtype(), DType::F64);
let sum = crate::autograd::var_sum(&y, &[], false, &client).unwrap();
let grads = backward(&sum, &client).unwrap();
let grad = grads.get(x.id()).unwrap();
assert_eq!(grad.dtype(), DType::F32);
let data = grad.to_vec::<f32>();
assert_eq!(data, vec![1.0, 1.0, 1.0]);
}
#[test]
fn test_var_cast_no_grad() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let t = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let v = Var::new(t, false);
let result = var_cast(&v, DType::F64, &client).unwrap();
assert!(!result.requires_grad());
assert_eq!(result.tensor().dtype(), DType::F64);
}
}