mod cpu_kernel;
#[cfg(feature = "cuda")]
mod cuda_kernel;
use super::ops::{try_unary_op, UnaryKernel};
use crate::{shapes::*, tensor::*};
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct NansToKernelOp<E>(E);
pub fn nans_to<S: Shape, E: Dtype, D: UnaryKernel<NansToKernelOp<E>, E>, T: Tape<E, D>>(
t: Tensor<S, E, D, T>,
value: impl Into<f64>,
) -> Tensor<S, E, D, T> {
t.nans_to(value)
}
impl<S: Shape, E: Dtype, D: UnaryKernel<NansToKernelOp<E>, E>, T: Tape<E, D>> Tensor<S, E, D, T> {
pub fn nans_to(self, value: impl Into<f64>) -> Self {
self.try_nans_to(value).unwrap()
}
pub fn try_nans_to(self, value: impl Into<f64>) -> Result<Self, D::Err> {
let value = E::from_f64(value.into()).unwrap();
try_unary_op(NansToKernelOp(value), self)
}
}
#[cfg(test)]
mod tests {
use crate::{tensor::*, tensor_ops::*, tests::*};
#[test]
fn test_nans_1d() {
let dev: TestDevice = Default::default();
let t = dev
.tensor([1.0, f64::NAN, -f64::NAN, 4.0])
.to_dtype::<TestDtype>();
let r = t.leaky_trace().nans_to(0.0);
assert_close_to_literal!(r, [1.0, 0.0, 0.0, 4.0]);
let g = r.exp().mean().backward();
assert_close_to_literal!(g.get(&t), [0.67957044, 0.0, 0.0, 13.649537]);
}
}