use crate::Result;
use scirs2_core::ndarray::{Array1, ArrayView1, Zip};
#[cfg(target_arch = "x86_64")]
use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
#[cfg(target_arch = "x86_64")]
pub struct Avx512Ops;
#[cfg(target_arch = "x86_64")]
impl Avx512Ops {
pub fn is_available() -> bool {
let caps = PlatformCapabilities::detect();
caps.avx512_available
}
pub fn masked_add(
a: &ArrayView1<f64>,
b: &ArrayView1<f64>,
mask: &ArrayView1<bool>,
) -> Result<Array1<f64>> {
if a.len() != b.len() || a.len() != mask.len() {
return Err(crate::NumRs2Error::DimensionMismatch(
"All arrays must have the same length".to_string(),
));
}
if !Self::is_available() {
return Ok(Self::masked_add_fallback(a, b, mask));
}
let mut result = Array1::zeros(a.len());
let chunk_size = 8;
let full_chunks = a.len() / chunk_size;
for i in 0..full_chunks {
let start = i * chunk_size;
let end = start + chunk_size;
let mut mask_bits = 0u8;
for j in 0..chunk_size {
if mask[start + j] {
mask_bits |= 1 << j;
}
}
if mask_bits != 0 {
let chunk_a = a.slice(scirs2_core::ndarray::s![start..end]);
let chunk_b = b.slice(scirs2_core::ndarray::s![start..end]);
let chunk_result = f64::simd_add(&chunk_a, &chunk_b);
for j in 0..chunk_size {
if (mask_bits & (1 << j)) != 0 {
result[start + j] = chunk_result[j];
} else {
result[start + j] = a[start + j];
}
}
} else {
result
.slice_mut(scirs2_core::ndarray::s![start..end])
.assign(&a.slice(scirs2_core::ndarray::s![start..end]));
}
}
let remainder_start = full_chunks * chunk_size;
for i in remainder_start..a.len() {
result[i] = if mask[i] { a[i] + b[i] } else { a[i] };
}
Ok(result)
}
fn masked_add_fallback(
a: &ArrayView1<f64>,
b: &ArrayView1<f64>,
mask: &ArrayView1<bool>,
) -> Array1<f64> {
let mut result = Array1::zeros(a.len());
Zip::from(&mut result).and(a).and(b).and(mask).for_each(
|out, &a_val, &b_val, &mask_val| {
*out = if mask_val { a_val + b_val } else { a_val };
},
);
result
}
pub fn gather(data: &ArrayView1<f64>, indices: &ArrayView1<usize>) -> Result<Array1<f64>> {
for &idx in indices {
if idx >= data.len() {
return Err(crate::NumRs2Error::IndexOutOfBounds(format!(
"Index {} out of bounds for array of size {}",
idx,
data.len()
)));
}
}
if !Self::is_available() {
return Ok(Array1::from_vec(indices.iter().map(|&i| data[i]).collect()));
}
let mut result = Array1::zeros(indices.len());
let chunk_size = 8;
let full_chunks = indices.len() / chunk_size;
for i in 0..full_chunks {
let start = i * chunk_size;
let end = start + chunk_size;
for j in start..end {
result[j] = data[indices[j]];
}
}
for i in (full_chunks * chunk_size)..indices.len() {
result[i] = data[indices[i]];
}
Ok(result)
}
pub fn scatter(
values: &ArrayView1<f64>,
indices: &ArrayView1<usize>,
output_size: usize,
) -> Result<Array1<f64>> {
if values.len() != indices.len() {
return Err(crate::NumRs2Error::DimensionMismatch(
"Values and indices must have the same length".to_string(),
));
}
for &idx in indices {
if idx >= output_size {
return Err(crate::NumRs2Error::IndexOutOfBounds(format!(
"Index {} out of bounds for output size {}",
idx, output_size
)));
}
}
let mut result = Array1::zeros(output_size);
if !Self::is_available() {
for (val, &idx) in values.iter().zip(indices.iter()) {
result[idx] = *val;
}
return Ok(result);
}
for (val, &idx) in values.iter().zip(indices.iter()) {
result[idx] = *val;
}
Ok(result)
}
pub fn masked_sum(data: &ArrayView1<f64>, mask: &ArrayView1<bool>) -> Result<f64> {
if data.len() != mask.len() {
return Err(crate::NumRs2Error::DimensionMismatch(
"Data and mask must have the same length".to_string(),
));
}
if !Self::is_available() {
return Ok(data
.iter()
.zip(mask.iter())
.filter(|(_, &m)| m)
.map(|(v, _)| v)
.sum());
}
let mut sum = 0.0;
let chunk_size = 8;
let full_chunks = data.len() / chunk_size;
for i in 0..full_chunks {
let start = i * chunk_size;
let end = start + chunk_size;
let mut mask_bits = 0u8;
for j in 0..chunk_size {
if mask[start + j] {
mask_bits |= 1 << j;
}
}
if mask_bits != 0 {
let chunk = data.slice(scirs2_core::ndarray::s![start..end]);
for j in 0..chunk_size {
if (mask_bits & (1 << j)) != 0 {
sum += chunk[j];
}
}
}
}
for i in (full_chunks * chunk_size)..data.len() {
if mask[i] {
sum += data[i];
}
}
Ok(sum)
}
pub fn convert_f64_to_f32(data: &ArrayView1<f64>) -> Array1<f32> {
if !Self::is_available() {
return data.map(|&x| x as f32);
}
let mut result = Array1::zeros(data.len());
let chunk_size = 8;
for (i, chunk) in data.exact_chunks(chunk_size).into_iter().enumerate() {
let start = i * chunk_size;
for (j, &val) in chunk.iter().enumerate() {
result[start + j] = val as f32;
}
}
result
}
pub fn histogram(
data: &ArrayView1<f64>,
bins: usize,
min_val: f64,
max_val: f64,
) -> Result<Array1<usize>> {
if bins == 0 {
return Err(crate::NumRs2Error::InvalidOperation(
"Number of bins must be greater than 0".to_string(),
));
}
if min_val >= max_val {
return Err(crate::NumRs2Error::InvalidOperation(
"min_val must be less than max_val".to_string(),
));
}
let mut hist = Array1::zeros(bins);
let bin_width = (max_val - min_val) / bins as f64;
if !Self::is_available() {
for &val in data {
if val >= min_val && val < max_val {
let bin_idx = ((val - min_val) / bin_width) as usize;
let bin_idx = bin_idx.min(bins - 1);
hist[bin_idx] += 1;
}
}
return Ok(hist);
}
for &val in data {
if val >= min_val && val < max_val {
let bin_idx = ((val - min_val) / bin_width) as usize;
let bin_idx = bin_idx.min(bins - 1);
hist[bin_idx] += 1;
}
}
Ok(hist)
}
}
#[cfg(target_arch = "x86_64")]
pub struct Avx512MatrixOps;
#[cfg(target_arch = "x86_64")]
impl Avx512MatrixOps {
pub fn transpose_8x8_block(input: &[f64; 64], output: &mut [f64; 64]) {
if !Avx512Ops::is_available() {
for i in 0..8 {
for j in 0..8 {
output[j * 8 + i] = input[i * 8 + j];
}
}
return;
}
for i in 0..8 {
for j in 0..8 {
output[j * 8 + i] = input[i * 8 + j];
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(target_arch = "x86_64")]
fn test_masked_add() {
let a = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let b = Array1::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
let mask = Array1::from_vec(vec![true, false, true, false, true]);
let result = Avx512Ops::masked_add(&a.view(), &b.view(), &mask.view())
.expect("masked_add should succeed for equal length arrays");
assert_eq!(result[0], 11.0); assert_eq!(result[1], 2.0); assert_eq!(result[2], 33.0); assert_eq!(result[3], 4.0); assert_eq!(result[4], 55.0); }
#[test]
#[cfg(target_arch = "x86_64")]
fn test_gather() {
let data = Array1::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
let indices = Array1::from_vec(vec![4, 2, 0, 3, 1]);
let result = Avx512Ops::gather(&data.view(), &indices.view())
.expect("gather should succeed for valid indices");
assert_eq!(result[0], 50.0);
assert_eq!(result[1], 30.0);
assert_eq!(result[2], 10.0);
assert_eq!(result[3], 40.0);
assert_eq!(result[4], 20.0);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_scatter() {
let values = Array1::from_vec(vec![100.0, 200.0, 300.0]);
let indices = Array1::from_vec(vec![2, 0, 4]);
let result = Avx512Ops::scatter(&values.view(), &indices.view(), 5)
.expect("scatter should succeed for valid indices");
assert_eq!(result[0], 200.0);
assert_eq!(result[1], 0.0);
assert_eq!(result[2], 100.0);
assert_eq!(result[3], 0.0);
assert_eq!(result[4], 300.0);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_histogram() {
let data = Array1::from_vec(vec![0.5, 1.5, 2.5, 3.5, 4.5, 0.2, 1.8, 2.2, 3.8, 4.2]);
let hist = Avx512Ops::histogram(&data.view(), 5, 0.0, 5.0)
.expect("histogram should succeed for valid parameters");
assert_eq!(hist[0], 2); assert_eq!(hist[1], 2); assert_eq!(hist[2], 2); assert_eq!(hist[3], 2); assert_eq!(hist[4], 2); }
}