use bytemuck::{try_cast_slice, try_cast_slice_mut};
pub fn f32_f32_f32_cpu(
strides_a: Vec<usize>,
strides_b: Vec<usize>,
dst_dims: Vec<i64>,
src1_bytes: &[u8],
src2_bytes: &[u8],
dst_ptr: &mut [u8],
) {
let num_elements: usize = dst_dims.iter().map(|d| *d as usize).product();
let src1_f32: &[f32] = try_cast_slice(src1_bytes)
.expect("src1 byte slice cannot be cast to f32 slice (alignment/length mismatch)");
let src2_f32: &[f32] = try_cast_slice(src2_bytes)
.expect("src2 byte slice cannot be cast to f32 slice (alignment/length mismatch)");
let dst_f32: &mut [f32] = try_cast_slice_mut(dst_ptr)
.expect("dst byte slice cannot be cast to f32 slice (alignment/length mismatch)");
assert_eq!(dst_f32.len(), num_elements, "dst buffer size mismatch");
let rank = dst_dims.len();
let dims_usize: Vec<usize> = dst_dims.iter().map(|d| *d as usize).collect();
let mut idxs = vec![0usize; rank];
let mut off_a: usize = 0;
let mut off_b: usize = 0;
for dst_slot in dst_f32.iter_mut().take(num_elements) {
*dst_slot = src1_f32[off_a].min(src2_f32[off_b]);
for d in (0..rank).rev() {
idxs[d] += 1;
off_a = off_a.wrapping_add(strides_a[d]);
off_b = off_b.wrapping_add(strides_b[d]);
if idxs[d] < dims_usize[d] {
break;
} else {
idxs[d] = 0;
off_a = off_a.wrapping_sub(strides_a[d] * dims_usize[d]);
off_b = off_b.wrapping_sub(strides_b[d] * dims_usize[d]);
}
}
}
}