use crate::CausalTensor;
use crate::CausalTensorError;
use deep_causality_num::RealField;
pub trait CausalTensorMathExt<T> {
fn log_nat(&self) -> Result<CausalTensor<T>, CausalTensorError>;
fn log2(&self) -> Result<CausalTensor<T>, CausalTensorError>;
fn log10(&self) -> Result<CausalTensor<T>, CausalTensorError>;
fn surd_log2(&self) -> Result<CausalTensor<T>, CausalTensorError>;
fn safe_div(&self, rhs: &CausalTensor<T>) -> Result<CausalTensor<T>, CausalTensorError>;
}
impl<T> CausalTensorMathExt<T> for CausalTensor<T>
where
T: RealField,
{
fn log_nat(&self) -> Result<CausalTensor<T>, CausalTensorError> {
if self.is_empty() {
return Ok(CausalTensor::from_slice(&[], self.shape()));
}
let new_data: Vec<T> = self.as_slice().iter().map(|&val| val.ln()).collect();
Ok(CausalTensor::from_slice(&new_data, self.shape()))
}
fn log2(&self) -> Result<CausalTensor<T>, CausalTensorError> {
if self.is_empty() {
return Ok(CausalTensor::from_slice(&[], self.shape()));
}
let new_data: Vec<T> = self.as_slice().iter().map(|&val| val.log2()).collect();
Ok(CausalTensor::from_slice(&new_data, self.shape()))
}
fn log10(&self) -> Result<CausalTensor<T>, CausalTensorError> {
if self.is_empty() {
return Ok(CausalTensor::from_slice(&[], self.shape()));
}
let new_data: Vec<T> = self.as_slice().iter().map(|&val| val.log10()).collect();
Ok(CausalTensor::from_slice(&new_data, self.shape()))
}
fn surd_log2(&self) -> Result<CausalTensor<T>, CausalTensorError> {
if self.is_empty() {
return Ok(CausalTensor::from_slice(&[], self.shape()));
}
let zero = T::zero();
let new_data: Vec<T> = self
.as_slice()
.iter()
.map(|&val| if val == zero { zero } else { val.log2() })
.collect();
Ok(CausalTensor::from_slice(&new_data, self.shape()))
}
fn safe_div(&self, rhs: &CausalTensor<T>) -> Result<CausalTensor<T>, CausalTensorError> {
let eps = T::epsilon();
let zero = T::zero();
self.broadcast_op(rhs, move |numerator: T, denominator: T| {
if denominator.abs() < eps {
if numerator.abs() < eps {
Ok(zero)
} else {
Err(CausalTensorError::DivisionByZero)
}
} else {
Ok(numerator / denominator)
}
})
}
}