use baracuda_driver::{DeviceSlice, DeviceSliceMut};
use baracuda_types::DeviceRepr;
#[derive(Debug, Copy, Clone)]
pub struct TensorRef<'a, T: DeviceRepr + Copy + 'static, const N: usize> {
pub data: DeviceSlice<'a, T>,
pub shape: [i32; N],
pub stride: [i64; N],
}
#[derive(Debug)]
pub struct TensorMut<'a, T: DeviceRepr + Copy + 'static, const N: usize> {
pub data: DeviceSliceMut<'a, T>,
pub shape: [i32; N],
pub stride: [i64; N],
}
impl<'a, T: DeviceRepr + Copy + 'static, const N: usize> TensorRef<'a, T, N> {
#[inline]
pub fn numel(&self) -> i64 {
let mut n: i64 = 1;
let mut i = 0;
while i < N {
n = n.saturating_mul(self.shape[i] as i64);
i += 1;
}
n
}
#[inline]
pub fn is_contiguous(&self) -> bool {
if N == 0 {
return true;
}
let mut expected: i64 = 1;
let mut i = N;
while i > 0 {
i -= 1;
if self.stride[i] != expected {
return false;
}
expected = expected.saturating_mul(self.shape[i] as i64);
}
true
}
}
impl<'a, T: DeviceRepr + Copy + 'static, const N: usize> TensorMut<'a, T, N> {
#[inline]
pub fn numel(&self) -> i64 {
let mut n: i64 = 1;
let mut i = 0;
while i < N {
n = n.saturating_mul(self.shape[i] as i64);
i += 1;
}
n
}
#[inline]
pub fn is_contiguous(&self) -> bool {
if N == 0 {
return true;
}
let mut expected: i64 = 1;
let mut i = N;
while i > 0 {
i -= 1;
if self.stride[i] != expected {
return false;
}
expected = expected.saturating_mul(self.shape[i] as i64);
}
true
}
}
#[inline]
pub fn contiguous_stride<const N: usize>(shape: [i32; N]) -> [i64; N] {
let mut stride = [0i64; N];
if N == 0 {
return stride;
}
let mut acc: i64 = 1;
let mut i = N;
while i > 0 {
i -= 1;
stride[i] = acc;
acc = acc.saturating_mul(shape[i] as i64);
}
stride
}
#[inline]
pub fn strides_equal(a: &[i64], b: &[i64]) -> bool {
a == b
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn contiguous_stride_rank0_is_empty() {
let s: [i64; 0] = contiguous_stride([]);
assert_eq!(s, [] as [i64; 0]);
}
#[test]
fn contiguous_stride_rank1_is_one() {
assert_eq!(contiguous_stride([5]), [1]);
assert_eq!(contiguous_stride([100]), [1]);
}
#[test]
fn contiguous_stride_rank2_row_major() {
assert_eq!(contiguous_stride([4, 8]), [8, 1]);
assert_eq!(contiguous_stride([3, 5]), [5, 1]);
}
#[test]
fn contiguous_stride_rank3() {
assert_eq!(contiguous_stride([2, 4, 8]), [32, 8, 1]);
}
#[test]
fn contiguous_stride_rank4() {
assert_eq!(contiguous_stride([2, 3, 5, 7]), [105, 35, 7, 1]);
}
#[test]
fn strides_equal_empty_slices_are_equal() {
let a: [i64; 0] = [];
let b: [i64; 0] = [];
assert!(strides_equal(&a, &b));
}
#[test]
fn strides_equal_identical_arrays() {
assert!(strides_equal(&[1, 2, 3], &[1, 2, 3]));
assert!(strides_equal(&[8, 1], &[8, 1]));
assert!(strides_equal(&[1024, 32, 1], &[1024, 32, 1]));
}
#[test]
fn strides_equal_different_values() {
assert!(!strides_equal(&[1, 2, 3], &[1, 2, 4]));
assert!(!strides_equal(&[8, 1], &[1, 8]));
}
#[test]
fn strides_equal_different_lengths() {
assert!(!strides_equal(&[1, 2], &[1, 2, 3]));
assert!(!strides_equal(&[1], &[]));
}
#[test]
fn strides_equal_matches_contiguous_stride_output() {
for shape in &[[4, 8], [16, 32], [3, 5]] {
let s = contiguous_stride(*shape);
assert!(strides_equal(&s, &s));
}
}
#[test]
fn strides_equal_detects_transpose() {
let contig = contiguous_stride([4, 8]); let transposed = [1_i64, 4_i64]; assert!(!strides_equal(&contig, &transposed));
}
#[test]
fn strides_equal_detects_broadcast_zero_stride() {
let contig = contiguous_stride([4, 8]); let broadcast = [0_i64, 1_i64]; assert!(!strides_equal(&contig, &broadcast));
}
#[test]
fn strides_equal_detects_negative_stride() {
let contig = contiguous_stride([4]); let flipped = [-1_i64];
assert!(!strides_equal(&contig, &flipped));
}
#[test]
fn numel_matches_shape_product() {
let shape = [4, 8];
let stride = contiguous_stride(shape);
assert_eq!(stride[0] * shape[0] as i64, 32);
}
}