use bytemuck::{try_cast_slice, try_cast_slice_mut};
pub fn f32_f32_cpu(dst_dims: &[i64], dim: usize, src_bytes: &[u8], dst_ptr: &mut [u8]) {
assert!(!dst_dims.is_empty(), "dst dims empty");
assert_eq!(
dim,
dst_dims.len() - 1,
"Softmax helper currently only supports the last dimension"
);
let num_elements: usize = dst_dims.iter().map(|d| *d as usize).product();
let src_f32: &[f32] = try_cast_slice(src_bytes)
.expect("src byte slice cannot be cast to f32 slice (alignment/length mismatch)");
let dst_f32: &mut [f32] = try_cast_slice_mut(dst_ptr)
.expect("dst byte slice cannot be cast to f32 slice (alignment/length mismatch)");
assert_eq!(dst_f32.len(), num_elements, "dst buffer size mismatch");
let feature_size = dst_dims[dim] as usize;
let batch_size = num_elements / feature_size;
for b in 0..batch_size {
let offset = b * feature_size;
let mut max_val = f32::NEG_INFINITY;
for i in 0..feature_size {
let v = src_f32[offset + i];
if v > max_val {
max_val = v;
}
}
let mut sum = 0.0f32;
for i in 0..feature_size {
let exp_val = (src_f32[offset + i] - max_val).exp();
dst_f32[offset + i] = exp_val;
sum += exp_val;
}
if sum == 0.0 {
let inv = 1.0 / feature_size as f32;
for i in 0..feature_size {
dst_f32[offset + i] = inv;
}
} else {
for i in 0..feature_size {
dst_f32[offset + i] /= sum;
}
}
}
}