use crate::array::Array;
use crate::array_ops::manipulation::ravel;
use crate::error::{NumRs2Error, Result};
use std::fmt::Display;
pub fn interp<T>(
x: &Array<T>,
xp: &Array<T>,
fp: &Array<T>,
left: Option<T>,
right: Option<T>,
period: Option<T>,
) -> Result<Array<T>>
where
T: Clone
+ PartialOrd
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Add<Output = T>
+ std::ops::Div<Output = T>
+ num_traits::Float,
{
if xp.ndim() != 1 {
return Err(NumRs2Error::DimensionMismatch(
"xp must be 1-dimensional".into(),
));
}
if fp.ndim() != 1 {
return Err(NumRs2Error::DimensionMismatch(
"fp must be 1-dimensional".into(),
));
}
if xp.len() != fp.len() {
return Err(NumRs2Error::DimensionMismatch(
"xp and fp must have the same length".into(),
));
}
if xp.len() < 2 {
return Err(NumRs2Error::ValueError(
"xp and fp must have at least 2 elements".into(),
));
}
for i in 1..xp.len() {
if xp.get(&[i])? <= xp.get(&[i - 1])? {
return Err(NumRs2Error::ValueError(
"xp must be strictly increasing".into(),
));
}
}
let x_shape = x.shape().clone();
let x_flat = ravel(x, None)?;
let mut result = Array::zeros(&x_flat.shape());
let left_val = left.unwrap_or_else(|| {
fp.get(&[0])
.expect("fp array should have at least 2 elements as validated above")
});
let right_val = right.unwrap_or_else(|| {
fp.get(&[fp.len() - 1])
.expect("fp array should have at least 2 elements as validated above")
});
for i in 0..x_flat.len() {
let mut x_val = x_flat.get(&[i])?;
if let Some(ref p) = period {
let p_val = *p;
let xp_min = xp.get(&[0])?;
let xp_max = xp.get(&[xp.len() - 1])?;
let period_width = xp_max - xp_min;
let mut x_norm = x_val;
if x_norm >= xp_min + period_width || x_norm < xp_min {
x_norm = xp_min + ((x_norm - xp_min) % p_val + p_val) % p_val;
}
x_val = x_norm;
}
if x_val < xp.get(&[0])? {
result.set(&[i], left_val)?;
continue;
}
if x_val > xp.get(&[xp.len() - 1])? {
result.set(&[i], right_val)?;
continue;
}
let mut low: usize = 0;
let mut high: usize = xp.len() - 1;
while low < high - 1 {
let mid = (low + high) / 2;
if x_val < xp.get(&[mid])? {
high = mid;
} else {
low = mid;
}
}
let x0 = xp.get(&[low])?;
let x1 = xp.get(&[high])?;
let y0 = fp.get(&[low])?;
let y1 = fp.get(&[high])?;
let t = (x_val - x0) / (x1 - x0);
let interpolated = y0 * (T::one() - t) + y1 * t;
result.set(&[i], interpolated)?;
}
Ok(result.reshape(&x_shape))
}
pub fn where_cond<T: Clone + Display + Send + Sync>(
condition: &Array<bool>,
x: &Array<T>,
y: &Array<T>,
) -> Result<Array<T>> {
let cond_shape = condition.shape();
let x_shape = x.shape();
let y_shape = y.shape();
let broadcast_shape_xy = Array::<T>::broadcast_shape(&x_shape, &y_shape)?;
let broadcast_shape = Array::<bool>::broadcast_shape(&cond_shape, &broadcast_shape_xy)?;
let cond_broadcast = condition.broadcast_to(&broadcast_shape)?;
let x_broadcast = x.broadcast_to(&broadcast_shape)?;
let y_broadcast = y.broadcast_to(&broadcast_shape)?;
let cond_data = cond_broadcast.to_vec();
let x_data = x_broadcast.to_vec();
let y_data = y_broadcast.to_vec();
const PARALLEL_THRESHOLD: usize = 1000;
let result_data: Vec<T> = if cond_data.len() >= PARALLEL_THRESHOLD {
use scirs2_core::parallel_ops::*;
(0..cond_data.len())
.into_par_iter()
.map(|i| {
if cond_data[i] {
x_data[i].clone()
} else {
y_data[i].clone()
}
})
.collect()
} else {
cond_data
.iter()
.zip(x_data.iter())
.zip(y_data.iter())
.map(
|((&cond, x_val), y_val)| {
if cond {
x_val.clone()
} else {
y_val.clone()
}
},
)
.collect()
};
Ok(Array::from_vec(result_data).reshape(&broadcast_shape))
}
pub fn select<T: Clone + num_traits::Zero>(
condlist: &[&Array<bool>],
choicelist: &[&Array<T>],
default: Option<T>,
) -> Result<Array<T>> {
if condlist.len() != choicelist.len() {
return Err(NumRs2Error::InvalidOperation(
"condlist and choicelist must have the same length".to_string(),
));
}
if condlist.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"condlist and choicelist cannot be empty".to_string(),
));
}
let mut broadcast_shape = condlist[0].shape();
for cond in condlist.iter() {
broadcast_shape = Array::<bool>::broadcast_shape(&broadcast_shape, &cond.shape())?;
}
for choice in choicelist.iter() {
broadcast_shape = Array::<T>::broadcast_shape(&broadcast_shape, &choice.shape())?;
}
let mut cond_broadcasts = Vec::with_capacity(condlist.len());
let mut choice_broadcasts = Vec::with_capacity(choicelist.len());
for cond in condlist.iter() {
cond_broadcasts.push(cond.broadcast_to(&broadcast_shape)?);
}
for choice in choicelist.iter() {
choice_broadcasts.push(choice.broadcast_to(&broadcast_shape)?);
}
let default_val = default.unwrap_or_else(T::zero);
let mut result = Array::full(&broadcast_shape, default_val);
let total_size = broadcast_shape.iter().product::<usize>();
for i in 0..total_size {
let mut indices = Vec::with_capacity(broadcast_shape.len());
let mut temp = i;
for &dim in broadcast_shape.iter().rev() {
indices.insert(0, temp % dim);
temp /= dim;
}
for (cond_broadcast, choice_broadcast) in
cond_broadcasts.iter().zip(choice_broadcasts.iter())
{
let cond_val = cond_broadcast
.array()
.get(scirs2_core::ndarray::IxDyn(&indices))
.expect("indices should be valid within broadcast shape");
if *cond_val {
let choice_val = choice_broadcast
.array()
.get(scirs2_core::ndarray::IxDyn(&indices))
.expect("indices should be valid within broadcast shape");
result.set(&indices, choice_val.clone())?;
break; }
}
}
Ok(result)
}
pub fn choose<T: Clone + num_traits::Zero>(
a: &Array<usize>,
choices: &[&Array<T>],
mode: &str,
) -> Result<Array<T>> {
if choices.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"choices cannot be empty".to_string(),
));
}
let n_choices = choices.len();
let mut broadcast_shape = a.shape();
for choice in choices.iter() {
broadcast_shape = Array::<T>::broadcast_shape(&broadcast_shape, &choice.shape())?;
}
let mut choice_broadcasts = Vec::with_capacity(n_choices);
for choice in choices.iter() {
choice_broadcasts.push(choice.broadcast_to(&broadcast_shape)?);
}
let mut result_data = Vec::with_capacity(a.len());
for (i, &idx) in a.to_vec().iter().enumerate() {
let actual_idx = match mode {
"raise" => {
if idx >= n_choices {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"index {} is out of bounds for choices of size {}",
idx, n_choices
)));
}
idx
}
"clip" => {
if idx >= n_choices {
n_choices - 1
} else {
idx
}
}
"wrap" => idx % n_choices,
_ => {
return Err(NumRs2Error::InvalidOperation(format!(
"Invalid mode '{}'. Use 'raise', 'clip', or 'wrap'",
mode
)))
}
};
let chosen_array = &choice_broadcasts[actual_idx];
let mut indices = Vec::with_capacity(a.ndim());
let mut temp = i;
for &dim in a.shape().iter().rev() {
indices.insert(0, temp % dim);
temp /= dim;
}
let value = chosen_array.get(&indices)?;
result_data.push(value);
}
Ok(Array::from_vec(result_data).reshape(&a.shape()))
}
pub fn piecewise<T, F>(
x: &Array<T>,
condlist: &[&Array<bool>],
funclist: &[&F],
fill_value: Option<T>,
) -> Result<Array<T>>
where
T: Clone + num_traits::Zero,
F: Fn(&Array<T>) -> Array<T>,
{
if condlist.len() != funclist.len() {
return Err(NumRs2Error::InvalidOperation(
"condlist and funclist must have the same length".to_string(),
));
}
if condlist.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"condlist and funclist cannot be empty".to_string(),
));
}
for cond in condlist {
if cond.shape() != x.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: x.shape(),
actual: cond.shape(),
});
}
}
let fill_val = fill_value.unwrap_or_else(T::zero);
let mut result = Array::full(&x.shape(), fill_val);
let mut mask_used = Array::full(&x.shape(), false);
for (cond, func) in condlist.iter().zip(funclist.iter()) {
let func_result = func(x);
if func_result.shape() != x.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: x.shape(),
actual: func_result.shape(),
});
}
let cond_data = cond.to_vec();
let mask_data = mask_used.to_vec();
let func_data = func_result.to_vec();
let mut result_data = result.to_vec();
for i in 0..cond_data.len() {
if cond_data[i] && !mask_data[i] {
result_data[i] = func_data[i].clone();
}
}
result = Array::from_vec(result_data).reshape(&x.shape());
let mut new_mask_data = mask_used.to_vec();
for i in 0..cond_data.len() {
if cond_data[i] {
new_mask_data[i] = true;
}
}
mask_used = Array::from_vec(new_mask_data).reshape(&x.shape());
}
Ok(result)
}