use crate::dtype::DType;
use crate::runtime::Runtime;
use crate::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct DualTensor<R: Runtime> {
primal: Tensor<R>,
tangent: Option<Tensor<R>>,
}
impl<R: Runtime> DualTensor<R> {
pub fn new(primal: Tensor<R>, tangent: Option<Tensor<R>>) -> Self {
if let Some(ref t) = tangent {
assert_eq!(
primal.shape(),
t.shape(),
"Tangent shape {:?} must match primal shape {:?}",
t.shape(),
primal.shape()
);
}
Self { primal, tangent }
}
pub fn constant(primal: Tensor<R>) -> Self {
Self {
primal,
tangent: None,
}
}
pub fn with_unit_tangent(primal: Tensor<R>, device: &R::Device) -> Self
where
R: Runtime<DType = DType>,
{
let tangent = Tensor::ones(primal.shape(), primal.dtype(), device);
Self {
primal,
tangent: Some(tangent),
}
}
pub fn with_tangent(primal: Tensor<R>, tangent: Tensor<R>) -> Self {
Self::new(primal, Some(tangent))
}
#[inline]
pub fn primal(&self) -> &Tensor<R> {
&self.primal
}
#[inline]
pub fn tangent(&self) -> Option<&Tensor<R>> {
self.tangent.as_ref()
}
pub fn into_primal(self) -> Tensor<R> {
self.primal
}
pub fn into_parts(self) -> (Tensor<R>, Option<Tensor<R>>) {
(self.primal, self.tangent)
}
#[inline]
pub fn has_tangent(&self) -> bool {
self.tangent.is_some()
}
#[inline]
pub fn shape(&self) -> &[usize] {
self.primal.shape()
}
#[inline]
pub fn dtype(&self) -> DType
where
R: Runtime<DType = DType>,
{
self.primal.dtype()
}
#[inline]
pub fn device(&self) -> &R::Device {
self.primal.device()
}
#[inline]
pub fn numel(&self) -> usize {
self.primal.numel()
}
#[inline]
pub fn ndim(&self) -> usize {
self.primal.ndim()
}
pub fn detach(&self) -> Self {
Self {
primal: self.primal.clone(),
tangent: None,
}
}
pub fn zero_tangent(&self, device: &R::Device) -> Tensor<R>
where
R: Runtime<DType = DType>,
{
Tensor::zeros(self.primal.shape(), self.primal.dtype(), device)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::cpu::{CpuDevice, CpuRuntime};
#[test]
fn test_dual_tensor_new() {
let device = CpuDevice::new();
let primal = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let tangent = Tensor::<CpuRuntime>::from_slice(&[0.1f32, 0.2, 0.3], &[3], &device);
let dual = DualTensor::new(primal.clone(), Some(tangent.clone()));
assert_eq!(dual.shape(), &[3]);
assert!(dual.has_tangent());
assert_eq!(dual.primal().to_vec::<f32>(), [1.0, 2.0, 3.0]);
assert_eq!(dual.tangent().unwrap().to_vec::<f32>(), [0.1, 0.2, 0.3]);
}
#[test]
fn test_dual_tensor_constant() {
let device = CpuDevice::new();
let primal = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let dual = DualTensor::constant(primal);
assert!(!dual.has_tangent());
assert!(dual.tangent().is_none());
}
#[test]
fn test_dual_tensor_with_unit_tangent() {
let device = CpuDevice::new();
let primal = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let dual = DualTensor::with_unit_tangent(primal, &device);
assert!(dual.has_tangent());
assert_eq!(dual.tangent().unwrap().to_vec::<f32>(), [1.0, 1.0, 1.0]);
}
#[test]
fn test_dual_tensor_into_parts() {
let device = CpuDevice::new();
let primal = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let tangent = Tensor::<CpuRuntime>::from_slice(&[0.5f32, 0.5], &[2], &device);
let dual = DualTensor::new(primal, Some(tangent));
let (p, t) = dual.into_parts();
assert_eq!(p.to_vec::<f32>(), [1.0, 2.0]);
assert_eq!(t.unwrap().to_vec::<f32>(), [0.5, 0.5]);
}
#[test]
#[should_panic(expected = "Tangent shape")]
fn test_dual_tensor_shape_mismatch() {
let device = CpuDevice::new();
let primal = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let tangent = Tensor::<CpuRuntime>::from_slice(&[0.1f32, 0.2], &[2], &device);
DualTensor::new(primal, Some(tangent));
}
}