use crate::dtype::Element;
use std::cmp::Ordering;
#[inline]
pub unsafe fn sort_kernel<T: Element + PartialOrd>(
a: *const T,
out_values: *mut T,
out_indices: *mut i64,
outer_size: usize,
sort_size: usize,
inner_size: usize,
descending: bool,
) {
if sort_size == 0 {
return;
}
let mut pairs: Vec<(T, i64)> = Vec::with_capacity(sort_size);
for outer in 0..outer_size {
for inner in 0..inner_size {
pairs.clear();
for i in 0..sort_size {
let idx = outer * sort_size * inner_size + i * inner_size + inner;
pairs.push((*a.add(idx), i as i64));
}
if descending {
pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));
} else {
pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
}
for (i, (val, orig_idx)) in pairs.iter().enumerate() {
let out_idx = outer * sort_size * inner_size + i * inner_size + inner;
*out_values.add(out_idx) = *val;
*out_indices.add(out_idx) = *orig_idx;
}
}
}
}
#[inline]
pub unsafe fn sort_values_kernel<T: Element + PartialOrd>(
a: *const T,
out: *mut T,
outer_size: usize,
sort_size: usize,
inner_size: usize,
descending: bool,
) {
if sort_size == 0 {
return;
}
let mut values: Vec<T> = Vec::with_capacity(sort_size);
for outer in 0..outer_size {
for inner in 0..inner_size {
values.clear();
for i in 0..sort_size {
let idx = outer * sort_size * inner_size + i * inner_size + inner;
values.push(*a.add(idx));
}
if descending {
values.sort_by(|a, b| b.partial_cmp(a).unwrap_or(Ordering::Equal));
} else {
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
}
for (i, val) in values.iter().enumerate() {
let out_idx = outer * sort_size * inner_size + i * inner_size + inner;
*out.add(out_idx) = *val;
}
}
}
}
#[inline]
pub unsafe fn argsort_kernel<T: Element + PartialOrd>(
a: *const T,
out: *mut i64,
outer_size: usize,
sort_size: usize,
inner_size: usize,
descending: bool,
) {
if sort_size == 0 {
return;
}
let mut pairs: Vec<(T, i64)> = Vec::with_capacity(sort_size);
for outer in 0..outer_size {
for inner in 0..inner_size {
pairs.clear();
for i in 0..sort_size {
let idx = outer * sort_size * inner_size + i * inner_size + inner;
pairs.push((*a.add(idx), i as i64));
}
if descending {
pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));
} else {
pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
}
for (i, (_, orig_idx)) in pairs.iter().enumerate() {
let out_idx = outer * sort_size * inner_size + i * inner_size + inner;
*out.add(out_idx) = *orig_idx;
}
}
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn topk_kernel<T: Element + PartialOrd>(
a: *const T,
out_values: *mut T,
out_indices: *mut i64,
outer_size: usize,
sort_size: usize,
inner_size: usize,
k: usize,
largest: bool,
sorted: bool,
) {
if sort_size == 0 || k == 0 {
return;
}
let k = k.min(sort_size);
let mut pairs: Vec<(T, i64)> = Vec::with_capacity(sort_size);
for outer in 0..outer_size {
for inner in 0..inner_size {
pairs.clear();
for i in 0..sort_size {
let idx = outer * sort_size * inner_size + i * inner_size + inner;
pairs.push((*a.add(idx), i as i64));
}
if largest {
pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));
} else {
pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
}
let topk = &pairs[..k];
let output: Vec<(T, i64)> = if sorted {
topk.to_vec()
} else {
let mut temp = topk.to_vec();
temp.sort_by_key(|(_, idx)| *idx);
temp
};
for (i, (val, orig_idx)) in output.iter().enumerate() {
let out_idx = outer * k * inner_size + i * inner_size + inner;
*out_values.add(out_idx) = *val;
*out_indices.add(out_idx) = *orig_idx;
}
}
}
}
#[inline]
pub unsafe fn count_unique_kernel<T: Element + PartialEq>(a: *const T, len: usize) -> usize {
if len == 0 {
return 0;
}
let mut count = 1;
let mut prev = *a;
for i in 1..len {
let curr = *a.add(i);
if curr != prev {
count += 1;
prev = curr;
}
}
count
}
#[inline]
pub unsafe fn extract_unique_kernel<T: Element + PartialEq>(
a: *const T,
out: *mut T,
len: usize,
_unique_count: usize,
) {
if len == 0 {
return;
}
let mut out_idx = 0;
let mut prev = *a;
*out.add(out_idx) = prev;
out_idx += 1;
for i in 1..len {
let curr = *a.add(i);
if curr != prev {
*out.add(out_idx) = curr;
out_idx += 1;
prev = curr;
}
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn unique_with_counts_kernel<T: Element + PartialEq>(
_a: *const T,
sorted: *const T,
sort_indices: *const i64,
out_unique: *mut T,
out_inverse: *mut i64,
out_counts: *mut i64,
numel: usize,
_unique_count: usize,
) {
if numel == 0 {
return;
}
let mut unique_idx: i64 = 0;
let mut prev = *sorted;
*out_unique.add(0) = prev;
*out_counts.add(0) = 1;
let orig_idx = *sort_indices.add(0) as usize;
*out_inverse.add(orig_idx) = 0;
for i in 1..numel {
let curr = *sorted.add(i);
let orig_idx = *sort_indices.add(i) as usize;
if curr != prev {
unique_idx += 1;
*out_unique.add(unique_idx as usize) = curr;
*out_counts.add(unique_idx as usize) = 1;
prev = curr;
} else {
*out_counts.add(unique_idx as usize) += 1;
}
*out_inverse.add(orig_idx) = unique_idx;
}
}
#[inline]
pub unsafe fn count_nonzero_kernel<T: Element>(a: *const T, numel: usize) -> usize {
let a_slice = std::slice::from_raw_parts(a, numel);
a_slice.iter().filter(|&&x| x.to_f64() != 0.0).count()
}
#[inline]
pub unsafe fn nonzero_flat_kernel<T: Element>(a: *const T, out: *mut i64, numel: usize) {
let a_slice = std::slice::from_raw_parts(a, numel);
let mut out_idx = 0;
for (i, &val) in a_slice.iter().enumerate() {
if val.to_f64() != 0.0 {
*out.add(out_idx) = i as i64;
out_idx += 1;
}
}
}
#[inline]
pub unsafe fn flat_to_multi_index_kernel(
flat_indices: *const i64,
out: *mut i64,
nnz: usize,
shape: &[usize],
) {
let ndim = shape.len();
if ndim == 0 || nnz == 0 {
return;
}
let mut strides = vec![1usize; ndim];
for i in (0..ndim - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
for i in 0..nnz {
let mut remaining = *flat_indices.add(i) as usize;
for d in 0..ndim {
let coord = remaining / strides[d];
remaining %= strides[d];
*out.add(i * ndim + d) = coord as i64;
}
}
}
#[inline]
pub unsafe fn searchsorted_kernel<T: Element + PartialOrd>(
sorted_seq: *const T,
values: *const T,
out: *mut i64,
seq_len: usize,
num_values: usize,
right: bool,
) {
if seq_len == 0 {
for i in 0..num_values {
*out.add(i) = 0;
}
return;
}
let seq_slice = std::slice::from_raw_parts(sorted_seq, seq_len);
for i in 0..num_values {
let val = *values.add(i);
let idx = if right {
seq_slice.partition_point(|x| {
x.partial_cmp(&val).unwrap_or(Ordering::Less) != Ordering::Greater
})
} else {
seq_slice.partition_point(|x| {
x.partial_cmp(&val).unwrap_or(Ordering::Greater) == Ordering::Less
})
};
*out.add(i) = idx as i64;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sort_ascending() {
let a = [3.0f32, 1.0, 4.0, 1.0, 5.0];
let mut values = [0.0f32; 5];
let mut indices = [0i64; 5];
unsafe {
sort_kernel(
a.as_ptr(),
values.as_mut_ptr(),
indices.as_mut_ptr(),
1,
5,
1,
false,
);
}
assert_eq!(values, [1.0, 1.0, 3.0, 4.0, 5.0]);
assert_eq!(indices, [1, 3, 0, 2, 4]);
}
#[test]
fn test_sort_descending() {
let a = [3.0f32, 1.0, 4.0, 1.0, 5.0];
let mut values = [0.0f32; 5];
let mut indices = [0i64; 5];
unsafe {
sort_kernel(
a.as_ptr(),
values.as_mut_ptr(),
indices.as_mut_ptr(),
1,
5,
1,
true,
);
}
assert_eq!(values, [5.0, 4.0, 3.0, 1.0, 1.0]);
assert_eq!(indices, [4, 2, 0, 1, 3]);
}
#[test]
fn test_argsort() {
let a = [3.0f32, 1.0, 4.0];
let mut indices = [0i64; 3];
unsafe {
argsort_kernel(a.as_ptr(), indices.as_mut_ptr(), 1, 3, 1, false);
}
assert_eq!(indices, [1, 0, 2]); }
#[test]
fn test_topk() {
let a = [3.0f32, 1.0, 4.0, 1.0, 5.0];
let mut values = [0.0f32; 2];
let mut indices = [0i64; 2];
unsafe {
topk_kernel(
a.as_ptr(),
values.as_mut_ptr(),
indices.as_mut_ptr(),
1,
5,
1,
2,
true, true, );
}
assert_eq!(values, [5.0, 4.0]);
assert_eq!(indices, [4, 2]);
}
#[test]
fn test_count_unique() {
let sorted = [1.0f32, 1.0, 2.0, 3.0, 3.0, 3.0];
let count = unsafe { count_unique_kernel(sorted.as_ptr(), 6) };
assert_eq!(count, 3);
}
#[test]
fn test_searchsorted() {
let sorted_seq = [1.0f32, 3.0, 5.0, 7.0];
let values = [2.0f32, 4.0, 6.0, 0.0, 8.0];
let mut out = [0i64; 5];
unsafe {
searchsorted_kernel(
sorted_seq.as_ptr(),
values.as_ptr(),
out.as_mut_ptr(),
4,
5,
false, );
}
assert_eq!(out, [1, 2, 3, 0, 4]);
}
#[test]
fn test_count_nonzero() {
let a = [0.0f32, 1.0, 0.0, 2.0, 3.0, 0.0];
let count = unsafe { count_nonzero_kernel(a.as_ptr(), 6) };
assert_eq!(count, 3);
}
}