use crate::autograd::DualTensor;
use crate::error::Result;
use crate::ops::{BinaryOps, ScalarOps};
use crate::runtime::{Runtime, RuntimeClient};
pub fn dual_add_scalar<R, C>(a: &DualTensor<R>, scalar: f64, client: &C) -> Result<DualTensor<R>>
where
R: Runtime,
C: RuntimeClient<R> + ScalarOps<R>,
{
let primal = client.add_scalar(a.primal(), scalar)?;
Ok(DualTensor::new(primal, a.tangent().cloned()))
}
pub fn dual_sub_scalar<R, C>(a: &DualTensor<R>, scalar: f64, client: &C) -> Result<DualTensor<R>>
where
R: Runtime,
C: RuntimeClient<R> + ScalarOps<R>,
{
let primal = client.sub_scalar(a.primal(), scalar)?;
Ok(DualTensor::new(primal, a.tangent().cloned()))
}
pub fn dual_mul_scalar<R, C>(a: &DualTensor<R>, scalar: f64, client: &C) -> Result<DualTensor<R>>
where
R: Runtime,
C: RuntimeClient<R> + ScalarOps<R>,
{
let primal = client.mul_scalar(a.primal(), scalar)?;
let tangent = match a.tangent() {
Some(t) => Some(client.mul_scalar(t, scalar)?),
None => None,
};
Ok(DualTensor::new(primal, tangent))
}
pub fn dual_div_scalar<R, C>(a: &DualTensor<R>, scalar: f64, client: &C) -> Result<DualTensor<R>>
where
R: Runtime,
C: RuntimeClient<R> + ScalarOps<R>,
{
let primal = client.div_scalar(a.primal(), scalar)?;
let tangent = match a.tangent() {
Some(t) => Some(client.div_scalar(t, scalar)?),
None => None,
};
Ok(DualTensor::new(primal, tangent))
}
pub fn dual_pow_scalar<R, C>(a: &DualTensor<R>, n: f64, client: &C) -> Result<DualTensor<R>>
where
R: Runtime,
C: RuntimeClient<R> + ScalarOps<R> + BinaryOps<R>,
{
let primal = client.pow_scalar(a.primal(), n)?;
let tangent = match a.tangent() {
Some(at) => {
let a_pow_nm1 = client.pow_scalar(a.primal(), n - 1.0)?;
let coeff = client.mul_scalar(&a_pow_nm1, n)?;
Some(client.mul(&coeff, at)?)
}
None => None,
};
Ok(DualTensor::new(primal, tangent))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
use crate::tensor::Tensor;
fn setup() -> (CpuDevice, CpuClient) {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
(device, client)
}
#[test]
fn test_dual_pow_scalar() {
let (device, client) = setup();
let x_primal = Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device);
let x_tangent = Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device);
let x = DualTensor::with_tangent(x_primal, x_tangent);
let y = dual_pow_scalar(&x, 3.0, &client).unwrap();
assert_eq!(y.primal().to_vec::<f32>(), [8.0]);
assert!((y.tangent().unwrap().to_vec::<f32>()[0] - 12.0).abs() < 1e-5);
}
}