use std::alloc::{Layout, alloc_zeroed, dealloc};
use std::mem::size_of;
use std::ops::{Index, IndexMut};
use std::ptr::NonNull;
const ALIGNMENT: usize = 32;
#[inline]
fn dangling_aligned_f32() -> NonNull<f32> {
debug_assert_eq!(ALIGNMENT % std::mem::align_of::<f32>(), 0);
NonNull::new(ALIGNMENT as *mut u8)
.expect("aligned dangling pointer must be non-null")
.cast()
}
#[inline]
fn layout_for_f32_elems(len: usize) -> Layout {
let bytes = len
.checked_mul(size_of::<f32>())
.expect("tensor allocation overflow");
Layout::from_size_align(bytes, ALIGNMENT).expect("Invalid layout")
}
#[inline]
fn alloc_f32_buffer(len: usize) -> NonNull<f32> {
if len == 0 {
return dangling_aligned_f32();
}
let layout = layout_for_f32_elems(len);
let ptr = unsafe { alloc_zeroed(layout) };
NonNull::new(ptr).expect("Allocation failed").cast()
}
#[inline]
unsafe fn dealloc_f32_buffer(ptr: NonNull<f32>, len: usize) {
if len == 0 {
return;
}
let layout = layout_for_f32_elems(len);
unsafe {
dealloc(ptr.as_ptr() as *mut u8, layout);
}
}
#[inline]
fn padded_stride(cols: usize) -> usize {
cols.checked_add(7).expect("tensor stride overflow") & !7
}
#[repr(C)]
pub struct Tensor1D {
data: NonNull<f32>,
len: usize,
}
impl Tensor1D {
pub fn zeros(len: usize) -> Self {
Self {
data: alloc_f32_buffer(len),
len,
}
}
pub fn from_vec(v: Vec<f32>) -> Self {
let mut t = Self::zeros(v.len());
t.as_mut_slice().copy_from_slice(&v);
t
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn as_ptr(&self) -> *const f32 {
self.data.as_ptr()
}
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut f32 {
self.data.as_ptr()
}
#[inline]
pub fn as_slice(&self) -> &[f32] {
unsafe { std::slice::from_raw_parts(self.data.as_ptr(), self.len) }
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [f32] {
unsafe { std::slice::from_raw_parts_mut(self.data.as_ptr(), self.len) }
}
#[inline]
pub fn zero(&mut self) {
unsafe {
std::ptr::write_bytes(self.data.as_ptr(), 0, self.len);
}
}
#[inline]
pub fn copy_from(&mut self, other: &Tensor1D) {
debug_assert_eq!(self.len, other.len);
self.as_mut_slice().copy_from_slice(other.as_slice());
}
#[inline]
pub fn copy_from_slice(&mut self, slice: &[f32]) {
debug_assert_eq!(self.len, slice.len());
self.as_mut_slice().copy_from_slice(slice);
}
}
impl Clone for Tensor1D {
fn clone(&self) -> Self {
let mut new = Self::zeros(self.len);
new.as_mut_slice().copy_from_slice(self.as_slice());
new
}
}
impl Drop for Tensor1D {
fn drop(&mut self) {
unsafe {
dealloc_f32_buffer(self.data, self.len);
}
}
}
unsafe impl Send for Tensor1D {}
unsafe impl Sync for Tensor1D {}
impl Index<usize> for Tensor1D {
type Output = f32;
#[inline]
fn index(&self, i: usize) -> &f32 {
debug_assert!(i < self.len);
unsafe { &*self.data.as_ptr().add(i) }
}
}
impl IndexMut<usize> for Tensor1D {
#[inline]
fn index_mut(&mut self, i: usize) -> &mut f32 {
debug_assert!(i < self.len);
unsafe { &mut *self.data.as_ptr().add(i) }
}
}
#[repr(C)]
pub struct Tensor2D {
data: NonNull<f32>,
rows: usize,
cols: usize,
stride: usize, }
impl Tensor2D {
pub fn zeros(rows: usize, cols: usize) -> Self {
let stride = padded_stride(cols);
let total = rows
.checked_mul(stride)
.expect("tensor allocation overflow");
Self {
data: alloc_f32_buffer(total),
rows,
cols,
stride,
}
}
pub fn from_vec(v: Vec<f32>, rows: usize, cols: usize) -> Self {
assert_eq!(v.len(), rows * cols);
let mut t = Self::zeros(rows, cols);
for r in 0..rows {
let src_start = r * cols;
let src_end = src_start + cols;
t.row_mut(r).copy_from_slice(&v[src_start..src_end]);
}
t
}
#[inline]
pub fn rows(&self) -> usize {
self.rows
}
#[inline]
pub fn cols(&self) -> usize {
self.cols
}
#[inline]
pub fn stride(&self) -> usize {
self.stride
}
#[inline]
pub fn as_ptr(&self) -> *const f32 {
self.data.as_ptr()
}
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut f32 {
self.data.as_ptr()
}
#[inline]
pub fn row(&self, r: usize) -> &[f32] {
debug_assert!(r < self.rows);
unsafe {
let ptr = self.data.as_ptr().add(r * self.stride);
std::slice::from_raw_parts(ptr, self.cols)
}
}
#[inline]
pub fn row_mut(&mut self, r: usize) -> &mut [f32] {
debug_assert!(r < self.rows);
unsafe {
let ptr = self.data.as_ptr().add(r * self.stride);
std::slice::from_raw_parts_mut(ptr, self.cols)
}
}
#[inline]
pub fn row_ptr(&self, r: usize) -> *const f32 {
debug_assert!(r < self.rows);
unsafe { self.data.as_ptr().add(r * self.stride) }
}
#[inline]
pub fn row_ptr_mut(&mut self, r: usize) -> *mut f32 {
debug_assert!(r < self.rows);
unsafe { self.data.as_ptr().add(r * self.stride) }
}
pub fn zero(&mut self) {
let total = self
.rows
.checked_mul(self.stride)
.expect("tensor allocation overflow");
unsafe {
std::ptr::write_bytes(self.data.as_ptr(), 0, total);
}
}
}
impl Clone for Tensor2D {
fn clone(&self) -> Self {
let total = self
.rows
.checked_mul(self.stride)
.expect("tensor allocation overflow");
let data = alloc_f32_buffer(total);
unsafe {
std::ptr::copy_nonoverlapping(self.data.as_ptr(), data.as_ptr(), total);
}
Self {
data,
rows: self.rows,
cols: self.cols,
stride: self.stride,
}
}
}
impl Drop for Tensor2D {
fn drop(&mut self) {
let total = self
.rows
.checked_mul(self.stride)
.expect("tensor allocation overflow");
unsafe {
dealloc_f32_buffer(self.data, total);
}
}
}
unsafe impl Send for Tensor2D {}
unsafe impl Sync for Tensor2D {}
#[derive(Clone, Copy)]
pub struct TensorView1D<'a> {
data: &'a [f32],
}
impl<'a> TensorView1D<'a> {
#[inline]
pub fn new(data: &'a [f32]) -> Self {
Self { data }
}
#[inline]
pub fn len(&self) -> usize {
self.data.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[inline]
pub fn as_ptr(&self) -> *const f32 {
self.data.as_ptr()
}
#[inline]
pub fn as_slice(&self) -> &[f32] {
self.data
}
}
impl<'a> Index<usize> for TensorView1D<'a> {
type Output = f32;
#[inline]
fn index(&self, i: usize) -> &f32 {
&self.data[i]
}
}
#[derive(Clone, Copy)]
pub struct TensorView2D<'a> {
data: &'a [f32],
rows: usize,
cols: usize,
}
impl<'a> TensorView2D<'a> {
#[inline]
pub fn new(data: &'a [f32], rows: usize, cols: usize) -> Self {
debug_assert_eq!(data.len(), rows * cols);
Self { data, rows, cols }
}
#[inline]
pub fn rows(&self) -> usize {
self.rows
}
#[inline]
pub fn cols(&self) -> usize {
self.cols
}
#[inline]
pub fn as_ptr(&self) -> *const f32 {
self.data.as_ptr()
}
#[inline]
pub fn row(&self, r: usize) -> &[f32] {
debug_assert!(r < self.rows);
let start = r * self.cols;
&self.data[start..start + self.cols]
}
#[inline]
pub fn row_ptr(&self, r: usize) -> *const f32 {
debug_assert!(r < self.rows);
unsafe { self.data.as_ptr().add(r * self.cols) }
}
pub fn t(&self) -> TransposedView2D<'a> {
TransposedView2D {
data: self.data,
rows: self.cols, cols: self.rows, orig_cols: self.cols,
}
}
}
#[derive(Clone, Copy)]
pub struct TransposedView2D<'a> {
data: &'a [f32],
rows: usize,
cols: usize,
orig_cols: usize,
}
impl<'a> TransposedView2D<'a> {
#[inline]
pub fn rows(&self) -> usize {
self.rows
}
#[inline]
pub fn cols(&self) -> usize {
self.cols
}
#[inline]
pub fn get(&self, r: usize, c: usize) -> f32 {
self.data[c * self.orig_cols + r]
}
#[inline]
pub fn orig_row(&self, r: usize) -> &[f32] {
let start = r * self.orig_cols;
&self.data[start..start + self.orig_cols]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zero_len_tensor1d_uses_aligned_non_allocating_sentinel() {
let mut t = Tensor1D::zeros(0);
assert_eq!(t.len(), 0);
assert!(t.is_empty());
assert!(t.as_slice().is_empty());
assert!(t.as_mut_slice().is_empty());
assert_eq!((t.as_ptr() as usize) % ALIGNMENT, 0);
t.zero();
}
#[test]
fn zero_sized_tensor2d_is_safe() {
let mut t = Tensor2D::zeros(3, 0);
assert_eq!(t.rows(), 3);
assert_eq!(t.cols(), 0);
assert_eq!(t.stride(), 0);
assert_eq!((t.as_ptr() as usize) % ALIGNMENT, 0);
for row in 0..t.rows() {
assert!(t.row(row).is_empty());
assert!(t.row_mut(row).is_empty());
}
t.zero();
}
}