use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::Float;
use scirs2_core::parallel_ops::*;
pub fn parallel_map<T, U, F>(array: &Array<T>, f: F) -> Array<U>
where
T: Send + Sync + Clone,
U: Send + Clone,
F: Fn(T) -> U + Send + Sync,
{
let vec_data = array.to_vec();
let result: Vec<U> = vec_data.par_iter().map(|x| f(x.clone())).collect();
let shape = array.shape();
Array::from_vec(result).reshape(&shape)
}
pub enum MemoryLayout {
RowMajor,
ColumnMajor,
}
pub fn optimize_layout<T: Clone>(array: &Array<T>, layout: MemoryLayout) -> Array<T> {
match layout {
MemoryLayout::RowMajor => array.clone(),
MemoryLayout::ColumnMajor => {
array.clone()
}
}
}
pub fn can_operate_inplace<T>(_array: &Array<T>) -> bool {
true
}
pub fn broadcast_arrays<T: Clone>(arrays: &[&Array<T>]) -> Result<Vec<Array<T>>> {
if arrays.is_empty() {
return Ok(Vec::new());
}
let mut broadcast_shape = Vec::new();
for array in arrays {
let shape = array.shape();
if broadcast_shape.is_empty() {
broadcast_shape = shape.clone();
} else {
let mut new_shape = Vec::new();
let max_dims = broadcast_shape.len().max(shape.len());
let padded_a = pad_shape(&broadcast_shape, max_dims);
let padded_b = pad_shape(&shape, max_dims);
for i in 0..max_dims {
let dim_a = padded_a[i];
let dim_b = padded_b[i];
if dim_a == 1 {
new_shape.push(dim_b);
} else if dim_b == 1 || dim_a == dim_b {
new_shape.push(dim_a);
} else {
return Err(NumRs2Error::ShapeMismatch {
expected: broadcast_shape,
actual: shape.clone(),
});
}
}
broadcast_shape = new_shape;
}
}
let mut result = Vec::new();
for array in arrays {
result.push(broadcast_to(array, &broadcast_shape)?);
}
Ok(result)
}
fn pad_shape(shape: &[usize], target_len: usize) -> Vec<usize> {
let mut padded = vec![1; target_len];
let offset = target_len - shape.len();
for (i, &dim) in shape.iter().enumerate() {
padded[i + offset] = dim;
}
padded
}
fn broadcast_to<T: Clone>(array: &Array<T>, shape: &[usize]) -> Result<Array<T>> {
let orig_shape = array.shape();
if orig_shape == shape {
return Ok(array.clone());
}
let padded_orig = pad_shape(&orig_shape, shape.len());
for i in 0..shape.len() {
if padded_orig[i] != 1 && padded_orig[i] != shape[i] {
return Err(NumRs2Error::ShapeMismatch {
expected: shape.to_vec(),
actual: orig_shape,
});
}
}
let orig_data = array.to_vec();
let mut result_data = Vec::new();
let size: usize = shape.iter().product();
result_data.reserve(size);
let mut strides = vec![1; shape.len()];
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let mut orig_strides = vec![1; padded_orig.len()];
for i in (0..padded_orig.len() - 1).rev() {
orig_strides[i] = orig_strides[i + 1] * padded_orig[i + 1];
}
for i in 0..size {
let mut orig_idx = 0;
let mut idx = i;
for j in 0..shape.len() {
let dim_idx = idx / strides[j];
idx %= strides[j];
if padded_orig[j] > 1 {
orig_idx += dim_idx * orig_strides[j];
}
}
result_data.push(orig_data[orig_idx].clone());
}
Ok(Array::from_vec(result_data).reshape(shape))
}
pub fn astype<T: Clone, U: Clone + From<T>>(array: &Array<T>) -> Array<U> {
let data = array.to_vec();
let converted: Vec<U> = data.into_iter().map(U::from).collect();
Array::from_vec(converted).reshape(&array.shape())
}
pub fn fast_sum<T: Float + Send + Sync>(array: &Array<T>) -> T {
let data = array.to_vec();
data.par_iter().cloned().reduce(|| T::zero(), |a, b| a + b)
}
pub fn isscalar<T>(_value: &T) -> bool {
std::mem::size_of::<T>() <= 16 && !std::any::type_name::<T>().starts_with("numrs2::")
}
pub fn isscalar_array<T: Clone>(array: &Array<T>) -> bool {
array.ndim() == 0
}
pub fn can_cast(from_type: &str, to_type: &str, casting: &str) -> bool {
match casting {
"no" => from_type == to_type,
"equiv" => from_type == to_type,
"safe" => {
matches!(
(from_type, to_type),
("i8", "i16" | "i32" | "i64" | "f32" | "f64")
| ("i16", "i32" | "i64" | "f32" | "f64")
| ("i32", "i64" | "f64")
| ("i64", "f64")
| (
"u8",
"u16" | "u32" | "u64" | "i16" | "i32" | "i64" | "f32" | "f64"
)
| ("u16", "u32" | "u64" | "i32" | "i64" | "f32" | "f64")
| ("u32", "u64" | "i64" | "f64")
| ("u64", "f64")
| ("f32", "f64")
) || from_type == to_type
}
"same_kind" => {
let from_kind = get_type_kind(from_type);
let to_kind = get_type_kind(to_type);
from_kind == to_kind || can_cast(from_type, to_type, "safe")
}
"unsafe" => true, _ => false,
}
}
fn get_type_kind(type_name: &str) -> &str {
match type_name {
"i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" => "integer",
"f32" | "f64" => "float",
"bool" => "bool",
_ => "other",
}
}
pub fn common_type(types: &[&str]) -> &'static str {
if types.is_empty() {
return "f64"; }
let mut result = find_common_type_static(types[0]);
for &type_name in &types[1..] {
result = find_common_type(result, type_name);
}
result
}
fn find_common_type(type1: &str, type2: &str) -> &'static str {
if type1 == type2 {
return find_common_type_static(type1);
}
let promotion_order = [
"bool", "i8", "u8", "i16", "u16", "i32", "u32", "i64", "u64", "f32", "f64",
];
let pos1 = promotion_order.iter().position(|&x| x == type1);
let pos2 = promotion_order.iter().position(|&x| x == type2);
match (pos1, pos2) {
(Some(p1), Some(p2)) => promotion_order[p1.max(p2)],
_ => "f64", }
}
fn find_common_type_static(type_name: &str) -> &'static str {
match type_name {
"bool" => "bool",
"i8" => "i8",
"u8" => "u8",
"i16" => "i16",
"u16" => "u16",
"i32" => "i32",
"u32" => "u32",
"i64" => "i64",
"u64" => "u64",
"f32" => "f32",
"f64" => "f64",
_ => "f64",
}
}
pub fn result_type(types: &[&str]) -> &'static str {
common_type(types)
}