use crate::array::Array;
use crate::array_ops::manipulation::ravel;
use crate::error::{NumRs2Error, Result};
use std::fmt::Display;
pub fn interp<T>(
x: &Array<T>,
xp: &Array<T>,
fp: &Array<T>,
left: Option<T>,
right: Option<T>,
period: Option<T>,
) -> Result<Array<T>>
where
T: Clone
+ PartialOrd
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Add<Output = T>
+ std::ops::Div<Output = T>
+ num_traits::Float,
{
if xp.ndim() != 1 {
return Err(NumRs2Error::DimensionMismatch(
"xp must be 1-dimensional".into(),
));
}
if fp.ndim() != 1 {
return Err(NumRs2Error::DimensionMismatch(
"fp must be 1-dimensional".into(),
));
}
if xp.len() != fp.len() {
return Err(NumRs2Error::DimensionMismatch(
"xp and fp must have the same length".into(),
));
}
if xp.len() < 2 {
return Err(NumRs2Error::ValueError(
"xp and fp must have at least 2 elements".into(),
));
}
for i in 1..xp.len() {
if xp.get(&[i])? <= xp.get(&[i - 1])? {
return Err(NumRs2Error::ValueError(
"xp must be strictly increasing".into(),
));
}
}
let x_shape = x.shape().clone();
let x_flat = ravel(x, None)?;
let mut result = Array::zeros(&x_flat.shape());
let left_val = left.unwrap_or_else(|| fp.get(&[0]).unwrap_or_else(|_| T::zero()));
let right_val = right.unwrap_or_else(|| fp.get(&[fp.len() - 1]).unwrap_or_else(|_| T::zero()));
for i in 0..x_flat.len() {
let mut x_val = x_flat.get(&[i])?;
if let Some(ref p) = period {
let p_val = *p;
let xp_min = xp.get(&[0])?;
let xp_max = xp.get(&[xp.len() - 1])?;
let period_width = xp_max - xp_min;
let mut x_norm = x_val;
if x_norm >= xp_min + period_width || x_norm < xp_min {
x_norm = xp_min + ((x_norm - xp_min) % p_val + p_val) % p_val;
}
x_val = x_norm;
}
if x_val < xp.get(&[0])? {
result.set(&[i], left_val)?;
continue;
}
if x_val > xp.get(&[xp.len() - 1])? {
result.set(&[i], right_val)?;
continue;
}
let mut low: usize = 0;
let mut high: usize = xp.len() - 1;
while low < high - 1 {
let mid = (low + high) / 2;
if x_val < xp.get(&[mid])? {
high = mid;
} else {
low = mid;
}
}
let x0 = xp.get(&[low])?;
let x1 = xp.get(&[high])?;
let y0 = fp.get(&[low])?;
let y1 = fp.get(&[high])?;
let t = (x_val - x0) / (x1 - x0);
let interpolated = y0 * (T::one() - t) + y1 * t;
result.set(&[i], interpolated)?;
}
Ok(result.reshape(&x_shape))
}
pub fn where_cond<T: Clone + Display>(
condition: &Array<bool>,
x: &Array<T>,
y: &Array<T>,
) -> Result<Array<T>> {
let cond_shape = condition.shape();
let x_shape = x.shape();
let y_shape = y.shape();
let broadcast_shape_xy = Array::<T>::broadcast_shape(&x_shape, &y_shape)?;
let broadcast_shape = Array::<bool>::broadcast_shape(&cond_shape, &broadcast_shape_xy)?;
let cond_broadcast = condition.broadcast_to(&broadcast_shape)?;
let x_broadcast = x.broadcast_to(&broadcast_shape)?;
let y_broadcast = y.broadcast_to(&broadcast_shape)?;
let cond_data = cond_broadcast.to_vec();
let x_data = x_broadcast.to_vec();
let y_data = y_broadcast.to_vec();
let result_data: Vec<T> = cond_data
.iter()
.zip(x_data.iter())
.zip(y_data.iter())
.map(
|((&cond, x_val), y_val)| {
if cond {
x_val.clone()
} else {
y_val.clone()
}
},
)
.collect();
Ok(Array::from_vec(result_data).reshape(&broadcast_shape))
}
pub fn select<T: Clone + num_traits::Zero>(
condlist: &[&Array<bool>],
choicelist: &[&Array<T>],
default: Option<T>,
) -> Result<Array<T>> {
if condlist.len() != choicelist.len() {
return Err(NumRs2Error::InvalidOperation(
"condlist and choicelist must have the same length".to_string(),
));
}
if condlist.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"condlist and choicelist cannot be empty".to_string(),
));
}
let mut broadcast_shape = condlist[0].shape();
for cond in condlist.iter() {
broadcast_shape = Array::<bool>::broadcast_shape(&broadcast_shape, &cond.shape())?;
}
for choice in choicelist.iter() {
broadcast_shape = Array::<T>::broadcast_shape(&broadcast_shape, &choice.shape())?;
}
let mut cond_broadcasts = Vec::with_capacity(condlist.len());
let mut choice_broadcasts = Vec::with_capacity(choicelist.len());
for cond in condlist.iter() {
cond_broadcasts.push(cond.broadcast_to(&broadcast_shape)?);
}
for choice in choicelist.iter() {
choice_broadcasts.push(choice.broadcast_to(&broadcast_shape)?);
}
let default_val = default.unwrap_or_else(T::zero);
let mut result = Array::full(&broadcast_shape, default_val);
let total_size = broadcast_shape.iter().product::<usize>();
for i in 0..total_size {
let mut indices = Vec::with_capacity(broadcast_shape.len());
let mut temp = i;
for &dim in broadcast_shape.iter().rev() {
indices.insert(0, temp % dim);
temp /= dim;
}
for (cond_broadcast, choice_broadcast) in
cond_broadcasts.iter().zip(choice_broadcasts.iter())
{
if let Some(cond_val) = cond_broadcast
.array()
.get(scirs2_core::ndarray::IxDyn(&indices))
{
if *cond_val {
if let Some(choice_val) = choice_broadcast
.array()
.get(scirs2_core::ndarray::IxDyn(&indices))
{
result.set(&indices, choice_val.clone())?;
break; }
}
}
}
}
Ok(result)
}