use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::PrimInt;
use std::cmp::max;
use crate::error::{SparseError, SparseResult};
#[allow(dead_code)]
pub fn get_index_dtype(shape: (usize, usize), idx_arrays: &[ArrayView1<usize>]) -> &'static str {
let (rows, cols) = shape;
let theoretical_max = rows.saturating_mul(cols);
let observed_max = if idx_arrays.is_empty() {
0
} else {
idx_arrays
.iter()
.flat_map(|arr| arr.iter())
.fold(0, |acc, &x| max(acc, x))
};
let max_value = max(theoretical_max, observed_max);
if max_value <= i32::MAX as usize {
"i32"
} else if max_value <= i64::MAX as usize {
"i64"
} else {
"usize"
}
}
#[allow(dead_code)]
pub fn safely_cast_index_arrays<T>(arrays: &[ArrayView1<usize>]) -> SparseResult<Vec<Array1<T>>>
where
T: PrimInt + 'static + TryFrom<usize>,
<T as TryFrom<usize>>::Error: std::fmt::Debug,
{
let mut result = Vec::with_capacity(arrays.len());
for array in arrays {
let mut converted = Array1::uninit(array.len());
for (i, &val) in array.iter().enumerate() {
match T::try_from(val) {
Ok(converted_val) => {
unsafe {
converted.uget_mut(i).write(converted_val);
}
}
Err(_) => {
return Err(SparseError::IndexCastOverflow {
value: val,
target_type: std::any::type_name::<T>(),
});
}
}
}
let safe_array = unsafe { converted.assume_init() };
result.push(safe_array);
}
Ok(result)
}
#[allow(dead_code)]
pub fn can_cast_safely<T>(array: ArrayView1<usize>) -> bool
where
T: PrimInt + 'static + TryFrom<usize>,
<T as TryFrom<usize>>::Error: std::fmt::Debug,
{
for &val in array.iter() {
if T::try_from(val).is_err() {
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_get_index_dtype_small() {
let shape = (100, 100);
let dtype = get_index_dtype(shape, &[]);
assert_eq!(dtype, "i32");
}
#[test]
fn test_get_index_dtype_medium() {
let shape = (50_000, 50_000);
let dtype = get_index_dtype(shape, &[]);
assert_eq!(dtype, "i64");
}
#[test]
fn test_get_index_dtype_large() {
let shape = (usize::MAX / 2, 3);
let dtype = get_index_dtype(shape, &[]);
assert_eq!(dtype, "usize");
}
#[test]
fn test_get_index_dtype_with_arrays() {
let indices1 = Array1::from_vec(vec![0, 10, 20, 30]);
let indices2 = Array1::from_vec(vec![5, 15, 25, 1000]);
let dtype = get_index_dtype((100, 100), &[indices1.view(), indices2.view()]);
assert_eq!(dtype, "i32");
}
#[test]
fn test_get_index_dtype_with_large_values() {
let indices = Array1::from_vec(vec![0, i32::MAX as usize + 1]);
let dtype = get_index_dtype((100, 100), &[indices.view()]);
assert_eq!(dtype, "i64");
}
#[test]
fn test_safely_cast_valid() {
let indices = Array1::from_vec(vec![0, 5, 10, 100]);
let result = safely_cast_index_arrays::<i32>(&[indices.view()]);
assert!(result.is_ok());
let arrays = result.expect("Operation failed");
assert_eq!(arrays.len(), 1);
assert_eq!(arrays[0].len(), 4);
assert_eq!(arrays[0][2], 10);
}
#[test]
fn test_safely_cast_multiple() {
let indices1 = Array1::from_vec(vec![0, 5, 10]);
let indices2 = Array1::from_vec(vec![1, 2, 3, 4]);
let result = safely_cast_index_arrays::<i32>(&[indices1.view(), indices2.view()]);
assert!(result.is_ok());
let arrays = result.expect("Operation failed");
assert_eq!(arrays.len(), 2);
assert_eq!(arrays[0].len(), 3);
assert_eq!(arrays[1].len(), 4);
}
#[test]
fn test_safely_cast_invalid() {
let indices = Array1::from_vec(vec![0, 5, 10, 200]);
let result = safely_cast_index_arrays::<i8>(&[indices.view()]);
assert!(result.is_err());
match result {
Err(SparseError::IndexCastOverflow { value, target_type }) => {
assert_eq!(value, 200);
assert_eq!(target_type, "i8");
}
_ => panic!("Expected IndexCastOverflow error"),
}
}
#[test]
fn test_can_cast_safely() {
let small_indices = Array1::from_vec(vec![0, 5, 10, 20]);
assert!(can_cast_safely::<i8>(small_indices.view()));
let large_indices = Array1::from_vec(vec![0, 5, 10, 200]);
assert!(!can_cast_safely::<i8>(large_indices.view()));
}
}