use super::rng;
use crate::dtype::Element;
#[inline]
pub unsafe fn fill_kernel<T: Element>(out: *mut T, value: T, len: usize) {
let out_slice = std::slice::from_raw_parts_mut(out, len);
out_slice.fill(value);
}
#[inline]
pub unsafe fn copy_kernel<T: Element>(src: *const T, dst: *mut T, len: usize) {
std::ptr::copy_nonoverlapping(src, dst, len);
}
#[inline]
pub unsafe fn cast_kernel(
src: *const u8,
dst: *mut u8,
len: usize,
src_dtype: crate::dtype::DType,
dst_dtype: crate::dtype::DType,
) -> crate::error::Result<()> {
use crate::dtype::DType;
use crate::error::Error;
macro_rules! cast_from {
($src_ty:ty, $src_ptr:expr, $dst_ptr:expr, $len:expr, $dst_dtype:expr) => {{
let src_slice = std::slice::from_raw_parts($src_ptr as *const $src_ty, $len);
match $dst_dtype {
DType::F64 => {
let dst_slice = std::slice::from_raw_parts_mut($dst_ptr as *mut f64, $len);
for i in 0..$len {
dst_slice[i] = src_slice[i].to_f64();
}
}
DType::F32 => {
let dst_slice = std::slice::from_raw_parts_mut($dst_ptr as *mut f32, $len);
for i in 0..$len {
dst_slice[i] = src_slice[i].to_f64() as f32;
}
}
DType::F16 => {
#[cfg(feature = "f16")]
{
let dst_slice =
std::slice::from_raw_parts_mut($dst_ptr as *mut half::f16, $len);
for i in 0..$len {
dst_slice[i] = half::f16::from_f64(src_slice[i].to_f64());
}
}
#[cfg(not(feature = "f16"))]
{
return Err(Error::UnsupportedDType {
dtype: DType::F16,
op: "cast",
});
}
}
DType::BF16 => {
#[cfg(feature = "f16")]
{
let dst_slice =
std::slice::from_raw_parts_mut($dst_ptr as *mut half::bf16, $len);
for i in 0..$len {
dst_slice[i] = half::bf16::from_f64(src_slice[i].to_f64());
}
}
#[cfg(not(feature = "f16"))]
{
return Err(Error::UnsupportedDType {
dtype: DType::BF16,
op: "cast",
});
}
}
DType::FP8E4M3 => {
let dst_slice = std::slice::from_raw_parts_mut(
$dst_ptr as *mut crate::dtype::FP8E4M3,
$len,
);
for i in 0..$len {
dst_slice[i] =
crate::dtype::FP8E4M3::from_f32(src_slice[i].to_f64() as f32);
}
}
DType::FP8E5M2 => {
let dst_slice = std::slice::from_raw_parts_mut(
$dst_ptr as *mut crate::dtype::FP8E5M2,
$len,
);
for i in 0..$len {
dst_slice[i] =
crate::dtype::FP8E5M2::from_f32(src_slice[i].to_f64() as f32);
}
}
DType::I64 => {
let dst_slice = std::slice::from_raw_parts_mut($dst_ptr as *mut i64, $len);
for i in 0..$len {
dst_slice[i] = src_slice[i].to_f64() as i64;
}
}
DType::I32 => {
let dst_slice = std::slice::from_raw_parts_mut($dst_ptr as *mut i32, $len);
for i in 0..$len {
dst_slice[i] = src_slice[i].to_f64() as i32;
}
}
DType::I16 => {
let dst_slice = std::slice::from_raw_parts_mut($dst_ptr as *mut i16, $len);
for i in 0..$len {
dst_slice[i] = src_slice[i].to_f64() as i16;
}
}
DType::I8 => {
let dst_slice = std::slice::from_raw_parts_mut($dst_ptr as *mut i8, $len);
for i in 0..$len {
dst_slice[i] = src_slice[i].to_f64() as i8;
}
}
DType::U64 => {
let dst_slice = std::slice::from_raw_parts_mut($dst_ptr as *mut u64, $len);
for i in 0..$len {
dst_slice[i] = src_slice[i].to_f64() as u64;
}
}
DType::U32 => {
let dst_slice = std::slice::from_raw_parts_mut($dst_ptr as *mut u32, $len);
for i in 0..$len {
dst_slice[i] = src_slice[i].to_f64() as u32;
}
}
DType::U16 => {
let dst_slice = std::slice::from_raw_parts_mut($dst_ptr as *mut u16, $len);
for i in 0..$len {
dst_slice[i] = src_slice[i].to_f64() as u16;
}
}
DType::U8 => {
let dst_slice = std::slice::from_raw_parts_mut($dst_ptr as *mut u8, $len);
for i in 0..$len {
dst_slice[i] = src_slice[i].to_f64() as u8;
}
}
DType::Bool => {
let dst_slice = std::slice::from_raw_parts_mut($dst_ptr as *mut u8, $len);
for i in 0..$len {
dst_slice[i] = if src_slice[i].to_f64() != 0.0 { 1 } else { 0 };
}
}
DType::Complex64 => {
let dst_slice = std::slice::from_raw_parts_mut(
$dst_ptr as *mut crate::dtype::Complex64,
$len,
);
for i in 0..$len {
dst_slice[i] =
crate::dtype::Complex64::new(src_slice[i].to_f64() as f32, 0.0);
}
}
DType::Complex128 => {
let dst_slice = std::slice::from_raw_parts_mut(
$dst_ptr as *mut crate::dtype::Complex128,
$len,
);
for i in 0..$len {
dst_slice[i] = crate::dtype::Complex128::new(src_slice[i].to_f64(), 0.0);
}
}
}
}};
}
match src_dtype {
DType::F64 => cast_from!(f64, src, dst, len, dst_dtype),
DType::F32 => cast_from!(f32, src, dst, len, dst_dtype),
DType::F16 => {
#[cfg(feature = "f16")]
{
cast_from!(half::f16, src, dst, len, dst_dtype)
}
#[cfg(not(feature = "f16"))]
{
return Err(Error::UnsupportedDType {
dtype: DType::F16,
op: "cast",
});
}
}
DType::BF16 => {
#[cfg(feature = "f16")]
{
cast_from!(half::bf16, src, dst, len, dst_dtype)
}
#[cfg(not(feature = "f16"))]
{
return Err(Error::UnsupportedDType {
dtype: DType::BF16,
op: "cast",
});
}
}
DType::FP8E4M3 => {
cast_from!(crate::dtype::FP8E4M3, src, dst, len, dst_dtype)
}
DType::FP8E5M2 => {
cast_from!(crate::dtype::FP8E5M2, src, dst, len, dst_dtype)
}
DType::I64 => cast_from!(i64, src, dst, len, dst_dtype),
DType::I32 => cast_from!(i32, src, dst, len, dst_dtype),
DType::I16 => cast_from!(i16, src, dst, len, dst_dtype),
DType::I8 => cast_from!(i8, src, dst, len, dst_dtype),
DType::U64 => cast_from!(u64, src, dst, len, dst_dtype),
DType::U32 => cast_from!(u32, src, dst, len, dst_dtype),
DType::U16 => cast_from!(u16, src, dst, len, dst_dtype),
DType::U8 => cast_from!(u8, src, dst, len, dst_dtype),
DType::Bool => {
cast_from!(u8, src, dst, len, dst_dtype)
}
DType::Complex64 => {
let src_slice = std::slice::from_raw_parts(src as *const crate::dtype::Complex64, len);
match dst_dtype {
DType::Complex64 => {
let dst_slice =
std::slice::from_raw_parts_mut(dst as *mut crate::dtype::Complex64, len);
dst_slice.copy_from_slice(src_slice);
}
DType::Complex128 => {
let dst_slice =
std::slice::from_raw_parts_mut(dst as *mut crate::dtype::Complex128, len);
for i in 0..len {
dst_slice[i] = crate::dtype::Complex128::new(
src_slice[i].re as f64,
src_slice[i].im as f64,
);
}
}
_ => {
let dst_slice = std::slice::from_raw_parts_mut(dst as *mut f32, len);
for i in 0..len {
dst_slice[i] = src_slice[i].re;
}
return Err(Error::UnsupportedDType {
dtype: dst_dtype,
op: "cast from Complex64",
});
}
}
}
DType::Complex128 => {
let src_slice = std::slice::from_raw_parts(src as *const crate::dtype::Complex128, len);
match dst_dtype {
DType::Complex128 => {
let dst_slice =
std::slice::from_raw_parts_mut(dst as *mut crate::dtype::Complex128, len);
dst_slice.copy_from_slice(src_slice);
}
DType::Complex64 => {
let dst_slice =
std::slice::from_raw_parts_mut(dst as *mut crate::dtype::Complex64, len);
for i in 0..len {
dst_slice[i] = crate::dtype::Complex64::new(
src_slice[i].re as f32,
src_slice[i].im as f32,
);
}
}
_ => {
return Err(Error::UnsupportedDType {
dtype: dst_dtype,
op: "cast from Complex128",
});
}
}
}
}
Ok(())
}
#[inline]
pub unsafe fn rand_uniform_kernel<T: Element>(out: *mut T, len: usize) {
let mut prng = rng::thread_rng();
let out_slice = std::slice::from_raw_parts_mut(out, len);
let needs_clamp = T::from_f64(0.9999).to_f64() >= 1.0;
for elem in out_slice.iter_mut() {
let val = rng::sample_uniform(&mut prng);
*elem = T::from_f64(val);
if needs_clamp && elem.to_f64() >= 1.0 {
*elem = T::from_f64(0.0);
}
}
}
#[inline]
pub unsafe fn rand_normal_kernel<T: Element>(out: *mut T, len: usize) {
let mut prng = rng::thread_rng();
let out_slice = std::slice::from_raw_parts_mut(out, len);
for elem in out_slice.iter_mut() {
let val = rng::sample_normal(&mut prng);
*elem = T::from_f64(val);
}
}
#[inline]
pub unsafe fn randint_kernel<T: Element>(out: *mut T, low: i64, high: i64, len: usize) {
let mut prng = rng::thread_rng();
let out_slice = std::slice::from_raw_parts_mut(out, len);
for elem in out_slice.iter_mut() {
let val = rng::sample_uniform_int(&mut prng, low, high);
*elem = T::from_f64(val as f64);
}
}
#[inline]
pub unsafe fn arange_kernel<T: Element>(out: *mut T, start: f64, step: f64, len: usize) {
let out_slice = std::slice::from_raw_parts_mut(out, len);
for (i, elem) in out_slice.iter_mut().enumerate() {
let val = start + step * (i as f64);
*elem = T::from_f64(val);
}
}
#[inline]
pub unsafe fn linspace_kernel<T: Element>(out: *mut T, start: f64, stop: f64, steps: usize) {
let out_slice = std::slice::from_raw_parts_mut(out, steps);
if steps == 1 {
out_slice[0] = T::from_f64(start);
return;
}
let divisor = (steps - 1) as f64;
let delta = stop - start;
for (i, elem) in out_slice.iter_mut().enumerate() {
let val = start + delta * (i as f64) / divisor;
*elem = T::from_f64(val);
}
}
#[inline]
pub unsafe fn eye_kernel<T: Element>(out: *mut T, n: usize, m: usize) {
let out_slice = std::slice::from_raw_parts_mut(out, n * m);
out_slice.fill(T::from_f64(0.0));
let diag_len = n.min(m);
for i in 0..diag_len {
out_slice[i * m + i] = T::from_f64(1.0);
}
}
#[inline]
pub unsafe fn multinomial_kernel_with_replacement<T: Element>(
probs: *const T,
out: *mut i64,
num_distributions: usize,
num_categories: usize,
num_samples: usize,
) {
let mut prng = rng::thread_rng();
for dist in 0..num_distributions {
let prob_row = std::slice::from_raw_parts(probs.add(dist * num_categories), num_categories);
let mut sum = 0.0f64;
for &p in prob_row {
sum += p.to_f64();
}
let mut cdf = Vec::with_capacity(num_categories);
let mut cumsum = 0.0f64;
for &p in prob_row {
cumsum += p.to_f64() / sum;
cdf.push(cumsum);
}
if !cdf.is_empty() {
*cdf.last_mut().unwrap() = 1.0;
}
let out_row = std::slice::from_raw_parts_mut(out.add(dist * num_samples), num_samples);
for sample in out_row {
let u = rng::sample_uniform(&mut prng);
let idx = cdf.partition_point(|&c| c < u);
*sample = idx.min(num_categories - 1) as i64;
}
}
}
#[inline]
pub unsafe fn multinomial_kernel_without_replacement<T: Element>(
probs: *const T,
out: *mut i64,
num_distributions: usize,
num_categories: usize,
num_samples: usize,
) {
let mut prng = rng::thread_rng();
for dist in 0..num_distributions {
let prob_row = std::slice::from_raw_parts(probs.add(dist * num_categories), num_categories);
let mut remaining_probs: Vec<f64> = prob_row.iter().map(|p| p.to_f64()).collect();
let out_row = std::slice::from_raw_parts_mut(out.add(dist * num_samples), num_samples);
for sample in out_row {
let sum: f64 = remaining_probs.iter().sum();
let mut cdf = Vec::with_capacity(num_categories);
let mut cumsum = 0.0f64;
for &p in &remaining_probs {
cumsum += p / sum;
cdf.push(cumsum);
}
if !cdf.is_empty() {
*cdf.last_mut().unwrap() = 1.0;
}
let u = rng::sample_uniform(&mut prng);
let idx = cdf.partition_point(|&c| c < u).min(num_categories - 1);
*sample = idx as i64;
remaining_probs[idx] = 0.0;
}
}
}
pub unsafe fn randperm_kernel(out: *mut i64, n: usize) {
let mut prng = rng::thread_rng();
let out_slice = std::slice::from_raw_parts_mut(out, n);
for i in 0..n {
out_slice[i] = i as i64;
}
for i in (1..n).rev() {
let j = (prng.next() % (i as u64 + 1)) as usize;
out_slice.swap(i, j);
}
}
pub unsafe fn one_hot_kernel<T: Element>(
indices: *const T,
out: *mut f32,
numel: usize,
num_classes: usize,
) {
let indices_slice = std::slice::from_raw_parts(indices, numel);
let out_slice = std::slice::from_raw_parts_mut(out, numel * num_classes);
for i in 0..numel {
let idx = indices_slice[i].to_f64() as usize;
if idx < num_classes {
out_slice[i * num_classes + idx] = 1.0;
}
}
}