extern crate alloc;
use core::marker::PhantomData;
use core::ptr::NonNull;
use crate::tensor::{Allocator, Global, Tensor, TensorError, SIMD_ALIGNMENT};
use crate::types::{DimMut, DimRef, FloatConvertible, NumberLike, StorageElement};
mod private {
pub trait Sealed {}
}
pub trait VectorIndex: private::Sealed + Copy {
fn resolve(self, len: usize) -> Option<usize>;
}
macro_rules! impl_vec_index_unsigned {
($($t:ty),*) => {$(
impl private::Sealed for $t {}
impl VectorIndex for $t {
#[inline]
fn resolve(self, len: usize) -> Option<usize> {
let idx = self as usize;
if idx < len { Some(idx) } else { None }
}
}
)*};
}
macro_rules! impl_vec_index_signed {
($($t:ty),*) => {$(
impl private::Sealed for $t {}
impl VectorIndex for $t {
#[inline]
fn resolve(self, len: usize) -> Option<usize> {
let idx = if self >= 0 {
self as usize
} else {
let neg = (-(self as isize)) as usize;
if neg > len { return None; }
len - neg
};
if idx < len { Some(idx) } else { None }
}
}
)*};
}
impl_vec_index_unsigned!(usize, u8, u16, u32, u64);
impl_vec_index_signed!(isize, i8, i16, i32, i64);
pub struct NibbleRef<'a> {
byte: *const u8,
high: bool,
_marker: PhantomData<&'a u8>,
}
impl<'a> NibbleRef<'a> {
#[inline]
pub fn get_unsigned(&self) -> u8 {
let b = unsafe { *self.byte };
if self.high {
b >> 4
} else {
b & 0x0F
}
}
#[inline]
pub fn get_signed(&self) -> i8 {
let nibble = self.get_unsigned();
if nibble & 0x08 != 0 {
nibble as i8 | !0x0Fi8
} else {
nibble as i8
}
}
}
pub struct NibbleRefMut<'a> {
byte: *mut u8,
high: bool,
_marker: PhantomData<&'a mut u8>,
}
impl<'a> NibbleRefMut<'a> {
#[inline]
pub fn get_unsigned(&self) -> u8 {
let b = unsafe { *self.byte };
if self.high {
b >> 4
} else {
b & 0x0F
}
}
#[inline]
pub fn get_signed(&self) -> i8 {
let nibble = self.get_unsigned();
if nibble & 0x08 != 0 {
nibble as i8 | !0x0Fi8
} else {
nibble as i8
}
}
#[inline]
pub fn set_unsigned(&self, val: u8) {
unsafe {
let b = *self.byte;
if self.high {
*self.byte = (b & 0x0F) | ((val & 0x0F) << 4);
} else {
*self.byte = (b & 0xF0) | (val & 0x0F);
}
}
}
#[inline]
pub fn set_signed(&self, val: i8) { self.set_unsigned(val as u8); }
}
pub struct BitRef<'a> {
byte: *const u8,
mask: u8,
_marker: PhantomData<&'a u8>,
}
impl<'a> BitRef<'a> {
#[inline]
pub fn get(&self) -> bool {
(unsafe { *self.byte } & self.mask) != 0
}
}
pub struct BitRefMut<'a> {
byte: *mut u8,
mask: u8,
_marker: PhantomData<&'a mut u8>,
}
impl<'a> BitRefMut<'a> {
#[inline]
pub fn get(&self) -> bool {
(unsafe { *self.byte } & self.mask) != 0
}
#[inline]
pub fn set(&self, val: bool) {
unsafe {
if val {
*self.byte |= self.mask;
} else {
*self.byte &= !self.mask;
}
}
}
}
pub struct Vector<T: StorageElement, A: Allocator = Global> {
data: NonNull<T>,
dims: usize,
values: usize,
alloc: A,
}
unsafe impl<T: StorageElement + Send, A: Allocator + Send> Send for Vector<T, A> {}
unsafe impl<T: StorageElement + Sync, A: Allocator + Sync> Sync for Vector<T, A> {}
impl<T: StorageElement, A: Allocator> Drop for Vector<T, A> {
fn drop(&mut self) {
if self.values > 0 {
let layout = alloc::alloc::Layout::from_size_align(
self.values * core::mem::size_of::<T>(),
SIMD_ALIGNMENT,
)
.unwrap();
unsafe {
self.alloc.deallocate(
NonNull::new_unchecked(self.data.as_ptr() as *mut u8),
layout,
);
}
}
}
}
#[inline]
fn dims_to_values<T: StorageElement>(dims: usize) -> usize {
let dims_per_value = T::dimensions_per_value();
(dims + dims_per_value - 1) / dims_per_value
}
impl<T: StorageElement, A: Allocator> Vector<T, A> {
pub unsafe fn from_raw_parts(data: NonNull<T>, dims: usize, values: usize, alloc: A) -> Self {
Self {
data,
dims,
values,
alloc,
}
}
pub fn try_zeros_in(dims: usize, alloc: A) -> Result<Self, TensorError> {
let values = dims_to_values::<T>(dims);
if values == 0 {
return Ok(Self {
data: NonNull::dangling(),
dims: 0,
values: 0,
alloc,
});
}
let size = values * core::mem::size_of::<T>();
let layout = alloc::alloc::Layout::from_size_align(size, SIMD_ALIGNMENT)
.map_err(|_| TensorError::AllocationFailed)?;
let ptr = alloc
.allocate(layout)
.ok_or(TensorError::AllocationFailed)?;
unsafe { core::ptr::write_bytes(ptr.as_ptr(), 0, size) };
Ok(Self {
data: unsafe { NonNull::new_unchecked(ptr.as_ptr() as *mut T) },
dims,
values,
alloc,
})
}
pub fn try_full_in(dims: usize, value: T, alloc: A) -> Result<Self, TensorError> {
let v = Self::try_zeros_in(dims, alloc)?;
if v.values > 0 {
let ptr = v.data.as_ptr();
for i in 0..v.values {
unsafe { ptr.add(i).write(value) };
}
}
Ok(v)
}
pub fn try_ones_in(dims: usize, alloc: A) -> Result<Self, TensorError>
where
T: NumberLike,
{
Self::try_full_in(dims, T::one(), alloc)
}
pub unsafe fn try_empty_in(dims: usize, alloc: A) -> Result<Self, TensorError> {
let values = dims_to_values::<T>(dims);
if values == 0 {
return Ok(Self {
data: NonNull::dangling(),
dims: 0,
values: 0,
alloc,
});
}
let size = values * core::mem::size_of::<T>();
let layout = alloc::alloc::Layout::from_size_align(size, SIMD_ALIGNMENT)
.map_err(|_| TensorError::AllocationFailed)?;
let ptr = alloc
.allocate(layout)
.ok_or(TensorError::AllocationFailed)?;
Ok(Self {
data: unsafe { NonNull::new_unchecked(ptr.as_ptr() as *mut T) },
dims,
values,
alloc,
})
}
pub fn try_from_scalars_in(scalars: &[f32], alloc: A) -> Result<Self, TensorError>
where
T: FloatConvertible,
{
let n = scalars.len();
let mut v = Self::try_zeros_in(n, alloc)?;
for (i, &s) in scalars.iter().enumerate() {
v.try_set(i, T::DimScalar::from_f32(s))?;
}
Ok(v)
}
pub fn try_from_dims_in(dim_values: &[T::DimScalar], alloc: A) -> Result<Self, TensorError>
where
T: FloatConvertible,
{
let n = dim_values.len();
let mut v = Self::try_zeros_in(n, alloc)?;
for (i, &d) in dim_values.iter().enumerate() {
v.try_set(i, d)?;
}
Ok(v)
}
#[inline]
pub fn dims(&self) -> usize { self.dims }
#[inline]
pub fn size(&self) -> usize { self.dims }
#[inline]
pub fn size_values(&self) -> usize { self.values }
#[inline]
pub fn is_empty(&self) -> bool { self.dims == 0 }
#[inline]
pub fn as_ptr(&self) -> *const T { self.data.as_ptr() }
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut T { self.data.as_ptr() }
#[inline]
pub fn size_bytes(&self) -> usize { self.values * core::mem::size_of::<T>() }
#[inline]
pub fn view(&self) -> VectorView<'_, T> {
VectorView {
data: self.data.as_ptr() as *const T,
dims: self.dims,
stride_bytes: core::mem::size_of::<T>() as isize,
_marker: PhantomData,
}
}
#[inline]
pub fn span(&mut self) -> VectorSpan<'_, T> {
VectorSpan {
data: self.data.as_ptr(),
dims: self.dims,
stride_bytes: core::mem::size_of::<T>() as isize,
_marker: PhantomData,
}
}
#[inline]
pub fn try_get<I: VectorIndex>(&self, idx: I) -> Result<T::DimScalar, TensorError>
where
T: FloatConvertible,
{
let i = idx
.resolve(self.dims)
.ok_or(TensorError::IndexOutOfBounds {
index: 0,
size: self.dims,
})?;
let dims_per_value = T::dimensions_per_value();
let value_index = i / dims_per_value;
let sub_index = i % dims_per_value;
let packed = unsafe { *self.data.as_ptr().add(value_index) };
Ok(packed.unpack().as_ref()[sub_index])
}
#[inline]
pub fn try_set<I: VectorIndex>(&mut self, idx: I, val: T::DimScalar) -> Result<(), TensorError>
where
T: FloatConvertible,
{
let i = idx
.resolve(self.dims)
.ok_or(TensorError::IndexOutOfBounds {
index: 0,
size: self.dims,
})?;
let dims_per_value = T::dimensions_per_value();
let value_index = i / dims_per_value;
let sub_index = i % dims_per_value;
let ptr = unsafe { self.data.as_ptr().add(value_index) };
let mut unpacked = unsafe { *ptr }.unpack();
unpacked.as_mut()[sub_index] = val;
unsafe { ptr.write(T::pack(unpacked)) };
Ok(())
}
#[inline]
pub fn as_slice(&self) -> &[T] {
if T::dimensions_per_value() == 1 {
unsafe { core::slice::from_raw_parts(self.data.as_ptr(), self.values) }
} else {
unsafe { core::slice::from_raw_parts(self.data.as_ptr(), self.values) }
}
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [T] {
unsafe { core::slice::from_raw_parts_mut(self.data.as_ptr(), self.values) }
}
pub fn iter(&self) -> VectorViewIterator<'_, T>
where
T: FloatConvertible,
{
self.view().iter()
}
pub fn iter_mut(&mut self) -> VectorSpanIterator<'_, T>
where
T: FloatConvertible,
{
VectorSpanIterator {
data: self.data.as_ptr(),
stride_bytes: core::mem::size_of::<T>() as isize,
front: 0,
back: self.dims,
_marker: PhantomData,
}
}
}
impl<T: StorageElement, A: Allocator> Vector<T, A> {
pub fn try_into_tensor<const MAX_RANK: usize>(
self,
) -> Result<Tensor<T, A, MAX_RANK>, TensorError> {
if MAX_RANK == 0 {
return Err(TensorError::TooManyRanks { got: 1 });
}
let mut shape = [0usize; MAX_RANK];
shape[0] = self.dims;
let mut strides = [0isize; MAX_RANK];
strides[0] = core::mem::size_of::<T>() as isize;
let alloc_bytes = self.values * core::mem::size_of::<T>();
let data = self.data;
let alloc = unsafe { core::ptr::read(&self.alloc) };
core::mem::forget(self);
let tensor = unsafe { Tensor::from_raw_parts(data, alloc_bytes, shape, strides, 1, alloc) };
Ok(tensor)
}
}
impl<T: StorageElement> Vector<T, Global> {
pub fn try_zeros(dims: usize) -> Result<Self, TensorError> { Self::try_zeros_in(dims, Global) }
pub fn try_full(dims: usize, value: T) -> Result<Self, TensorError> {
Self::try_full_in(dims, value, Global)
}
pub fn try_ones(dims: usize) -> Result<Self, TensorError>
where
T: NumberLike,
{
Self::try_full(dims, T::one())
}
pub unsafe fn try_empty(dims: usize) -> Result<Self, TensorError> {
unsafe { Self::try_empty_in(dims, Global) }
}
pub fn try_from_scalars(scalars: &[f32]) -> Result<Self, TensorError>
where
T: FloatConvertible,
{
Self::try_from_scalars_in(scalars, Global)
}
pub fn try_from_dims(dims: &[T::DimScalar]) -> Result<Self, TensorError>
where
T: FloatConvertible,
{
Self::try_from_dims_in(dims, Global)
}
}
impl<I: VectorIndex, T: StorageElement, A: Allocator> core::ops::Index<I> for Vector<T, A> {
type Output = T;
#[inline]
fn index(&self, idx: I) -> &T {
let i = idx.resolve(self.dims).expect("vector index out of bounds");
debug_assert_eq!(
T::dimensions_per_value(),
1,
"Index trait not supported for sub-byte types"
);
unsafe { &*self.data.as_ptr().add(i) }
}
}
impl<I: VectorIndex, T: StorageElement, A: Allocator> core::ops::IndexMut<I> for Vector<T, A> {
#[inline]
fn index_mut(&mut self, idx: I) -> &mut T {
let i = idx.resolve(self.dims).expect("vector index out of bounds");
debug_assert_eq!(
T::dimensions_per_value(),
1,
"IndexMut trait not supported for sub-byte types"
);
unsafe { &mut *self.data.as_ptr().add(i) }
}
}
impl<T: StorageElement + Clone, A: Allocator + Clone> Vector<T, A> {
pub fn try_clone(&self) -> Result<Self, TensorError> {
if self.values == 0 {
return Ok(Self {
data: NonNull::dangling(),
dims: 0,
values: 0,
alloc: self.alloc.clone(),
});
}
let size = self.values * core::mem::size_of::<T>();
let layout = alloc::alloc::Layout::from_size_align(size, SIMD_ALIGNMENT)
.map_err(|_| TensorError::AllocationFailed)?;
let ptr = self
.alloc
.allocate(layout)
.ok_or(TensorError::AllocationFailed)?;
unsafe {
core::ptr::copy_nonoverlapping(self.data.as_ptr() as *const u8, ptr.as_ptr(), size);
}
Ok(Self {
data: unsafe { NonNull::new_unchecked(ptr.as_ptr() as *mut T) },
dims: self.dims,
values: self.values,
alloc: self.alloc.clone(),
})
}
}
impl<T: StorageElement + Clone, A: Allocator + Clone> Clone for Vector<T, A> {
fn clone(&self) -> Self { self.try_clone().expect("vector clone allocation failed") }
}
impl<T: StorageElement> Default for Vector<T, Global> {
fn default() -> Self {
Self {
data: NonNull::dangling(),
dims: 0,
values: 0,
alloc: Global,
}
}
}
pub struct VectorView<'a, T: StorageElement> {
data: *const T,
dims: usize,
stride_bytes: isize,
_marker: PhantomData<&'a T>,
}
unsafe impl<'a, T: StorageElement + Sync> Send for VectorView<'a, T> {}
unsafe impl<'a, T: StorageElement + Sync> Sync for VectorView<'a, T> {}
impl<'a, T: StorageElement> Clone for VectorView<'a, T> {
fn clone(&self) -> Self { *self }
}
impl<'a, T: StorageElement> Copy for VectorView<'a, T> {}
impl<'a, T: StorageElement> VectorView<'a, T> {
#[inline]
pub unsafe fn from_raw_parts(data: *const T, dims: usize, stride_bytes: isize) -> Self {
Self {
data,
dims,
stride_bytes,
_marker: PhantomData,
}
}
#[inline]
pub fn dims(&self) -> usize { self.dims }
#[inline]
pub fn size(&self) -> usize { self.dims }
#[inline]
pub fn is_empty(&self) -> bool { self.dims == 0 }
#[inline]
pub fn stride_bytes(&self) -> isize { self.stride_bytes }
#[inline]
pub fn is_contiguous(&self) -> bool { self.stride_bytes == core::mem::size_of::<T>() as isize }
#[inline]
pub fn as_ptr(&self) -> *const T { self.data }
#[inline]
pub fn as_contiguous_slice(&self) -> Option<&'a [T]> {
if self.is_contiguous() && T::dimensions_per_value() == 1 {
Some(unsafe { core::slice::from_raw_parts(self.data, self.dims) })
} else {
None
}
}
#[inline]
pub fn try_get<I: VectorIndex>(&self, idx: I) -> Result<T::DimScalar, TensorError>
where
T: FloatConvertible,
{
let i = idx
.resolve(self.dims)
.ok_or(TensorError::IndexOutOfBounds {
index: 0,
size: self.dims,
})?;
let dims_per_value = T::dimensions_per_value();
let value_index = i / dims_per_value;
let sub_index = i % dims_per_value;
let ptr = unsafe {
(self.data as *const u8).offset(self.stride_bytes * value_index as isize) as *const T
};
Ok(unsafe { *ptr }.unpack().as_ref()[sub_index])
}
pub fn rev(&self) -> Self {
if self.dims == 0 {
return *self;
}
let last_offset = self.stride_bytes * (self.dims as isize - 1);
Self {
data: unsafe { (self.data as *const u8).offset(last_offset) as *const T },
dims: self.dims,
stride_bytes: -self.stride_bytes,
_marker: PhantomData,
}
}
pub fn try_strided(&self, start: usize, end: usize, step: isize) -> Result<Self, TensorError> {
if start > self.dims || end > self.dims || step == 0 {
return Err(TensorError::IndexOutOfBounds {
index: start.max(end),
size: self.dims,
});
}
let count = if step > 0 {
if end > start {
(end - start + step as usize - 1) / step as usize
} else {
0
}
} else if start > end {
let abs_step = (-step) as usize;
(start - end + abs_step - 1) / abs_step
} else {
0
};
let new_data = unsafe {
(self.data as *const u8).offset(self.stride_bytes * start as isize) as *const T
};
Ok(Self {
data: new_data,
dims: count,
stride_bytes: self.stride_bytes * step,
_marker: PhantomData,
})
}
pub fn iter(&self) -> VectorViewIterator<'a, T>
where
T: FloatConvertible,
{
VectorViewIterator {
data: self.data,
stride_bytes: self.stride_bytes,
front: 0,
back: self.dims,
_marker: PhantomData,
}
}
}
impl<'a, I: VectorIndex, T: StorageElement> core::ops::Index<I> for VectorView<'a, T> {
type Output = T;
#[inline]
fn index(&self, idx: I) -> &T {
let i = idx.resolve(self.dims).expect("view index out of bounds");
debug_assert_eq!(
T::dimensions_per_value(),
1,
"Index trait not supported for sub-byte types"
);
unsafe { &*((self.data as *const u8).offset(self.stride_bytes * i as isize) as *const T) }
}
}
pub struct VectorSpan<'a, T: StorageElement> {
data: *mut T,
dims: usize,
stride_bytes: isize,
_marker: PhantomData<&'a mut T>,
}
unsafe impl<'a, T: StorageElement + Send> Send for VectorSpan<'a, T> {}
unsafe impl<'a, T: StorageElement + Sync> Sync for VectorSpan<'a, T> {}
impl<'a, T: StorageElement> VectorSpan<'a, T> {
#[inline]
pub unsafe fn from_raw_parts(data: *mut T, dims: usize, stride_bytes: isize) -> Self {
Self {
data,
dims,
stride_bytes,
_marker: PhantomData,
}
}
#[inline]
pub fn dims(&self) -> usize { self.dims }
#[inline]
pub fn size(&self) -> usize { self.dims }
#[inline]
pub fn is_empty(&self) -> bool { self.dims == 0 }
#[inline]
pub fn stride_bytes(&self) -> isize { self.stride_bytes }
#[inline]
pub fn is_contiguous(&self) -> bool { self.stride_bytes == core::mem::size_of::<T>() as isize }
#[inline]
pub fn as_ptr(&self) -> *const T { self.data }
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut T { self.data }
pub fn as_view(&self) -> VectorView<'_, T> {
VectorView {
data: self.data,
dims: self.dims,
stride_bytes: self.stride_bytes,
_marker: PhantomData,
}
}
#[inline]
pub fn try_get<I: VectorIndex>(&self, idx: I) -> Result<T::DimScalar, TensorError>
where
T: FloatConvertible,
{
self.as_view().try_get(idx)
}
#[inline]
pub fn try_set<I: VectorIndex>(&mut self, idx: I, val: T::DimScalar) -> Result<(), TensorError>
where
T: FloatConvertible,
{
let i = idx
.resolve(self.dims)
.ok_or(TensorError::IndexOutOfBounds {
index: 0,
size: self.dims,
})?;
let dims_per_value = T::dimensions_per_value();
let value_index = i / dims_per_value;
let sub_index = i % dims_per_value;
let ptr = unsafe {
(self.data as *mut u8).offset(self.stride_bytes * value_index as isize) as *mut T
};
let mut unpacked = unsafe { *ptr }.unpack();
unpacked.as_mut()[sub_index] = val;
unsafe { ptr.write(T::pack(unpacked)) };
Ok(())
}
pub fn fill(&mut self, val: T) {
let dims_per_value = T::dimensions_per_value();
let values = (self.dims + dims_per_value - 1) / dims_per_value;
for i in 0..values {
let ptr =
unsafe { (self.data as *mut u8).offset(self.stride_bytes * i as isize) as *mut T };
unsafe { ptr.write(val) };
}
}
#[inline]
pub fn as_contiguous_slice(&self) -> Option<&[T]> {
if self.is_contiguous() && T::dimensions_per_value() == 1 {
Some(unsafe { core::slice::from_raw_parts(self.data, self.dims) })
} else {
None
}
}
#[inline]
pub fn as_contiguous_slice_mut(&mut self) -> Option<&mut [T]> {
if self.is_contiguous() && T::dimensions_per_value() == 1 {
Some(unsafe { core::slice::from_raw_parts_mut(self.data, self.dims) })
} else {
None
}
}
pub fn iter(&self) -> VectorViewIterator<'_, T>
where
T: FloatConvertible,
{
VectorViewIterator {
data: self.data,
stride_bytes: self.stride_bytes,
front: 0,
back: self.dims,
_marker: PhantomData,
}
}
pub fn iter_mut(&mut self) -> VectorSpanIterator<'_, T>
where
T: FloatConvertible,
{
VectorSpanIterator {
data: self.data,
stride_bytes: self.stride_bytes,
front: 0,
back: self.dims,
_marker: PhantomData,
}
}
}
impl<'a, I: VectorIndex, T: StorageElement> core::ops::Index<I> for VectorSpan<'a, T> {
type Output = T;
#[inline]
fn index(&self, idx: I) -> &T {
let i = idx.resolve(self.dims).expect("span index out of bounds");
debug_assert_eq!(
T::dimensions_per_value(),
1,
"Index trait not supported for sub-byte types"
);
unsafe { &*((self.data as *const u8).offset(self.stride_bytes * i as isize) as *const T) }
}
}
impl<'a, I: VectorIndex, T: StorageElement> core::ops::IndexMut<I> for VectorSpan<'a, T> {
#[inline]
fn index_mut(&mut self, idx: I) -> &mut T {
let i = idx.resolve(self.dims).expect("span index out of bounds");
debug_assert_eq!(
T::dimensions_per_value(),
1,
"IndexMut trait not supported for sub-byte types"
);
unsafe { &mut *((self.data as *mut u8).offset(self.stride_bytes * i as isize) as *mut T) }
}
}
pub struct VectorViewIterator<'a, T: FloatConvertible> {
data: *const T,
stride_bytes: isize,
front: usize,
back: usize,
_marker: PhantomData<&'a T>,
}
pub type VectorIterator<'a, T> = VectorViewIterator<'a, T>;
impl<'a, T: FloatConvertible> Iterator for VectorViewIterator<'a, T> {
type Item = DimRef<'a, T>;
#[inline]
fn next(&mut self) -> Option<DimRef<'a, T>> {
if self.front >= self.back {
return None;
}
let dims_per_value = T::dimensions_per_value();
let value_index = self.front / dims_per_value;
let sub_index = self.front % dims_per_value;
let ptr = unsafe {
(self.data as *const u8).offset(self.stride_bytes * value_index as isize) as *const T
};
let scalar = unsafe { *ptr }.unpack().as_ref()[sub_index];
self.front += 1;
Some(DimRef::new(scalar))
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let n = self.back - self.front;
(n, Some(n))
}
}
impl<'a, T: FloatConvertible> ExactSizeIterator for VectorViewIterator<'a, T> {}
impl<'a, T: FloatConvertible> core::iter::FusedIterator for VectorViewIterator<'a, T> {}
impl<'a, T: FloatConvertible> DoubleEndedIterator for VectorViewIterator<'a, T> {
#[inline]
fn next_back(&mut self) -> Option<DimRef<'a, T>> {
if self.front >= self.back {
return None;
}
self.back -= 1;
let dims_per_value = T::dimensions_per_value();
let value_index = self.back / dims_per_value;
let sub_index = self.back % dims_per_value;
let ptr = unsafe {
(self.data as *const u8).offset(self.stride_bytes * value_index as isize) as *const T
};
Some(DimRef::new(unsafe { *ptr }.unpack().as_ref()[sub_index]))
}
}
pub struct VectorSpanIterator<'a, T: FloatConvertible> {
data: *mut T,
stride_bytes: isize,
front: usize,
back: usize,
_marker: PhantomData<&'a mut T>,
}
impl<'a, T: FloatConvertible> Iterator for VectorSpanIterator<'a, T> {
type Item = DimMut<'a, T>;
#[inline]
fn next(&mut self) -> Option<DimMut<'a, T>> {
if self.front >= self.back {
return None;
}
let dims_per_value = T::dimensions_per_value();
let value_index = self.front / dims_per_value;
let sub_index = self.front % dims_per_value;
let ptr = unsafe {
(self.data as *mut u8).offset(self.stride_bytes * value_index as isize) as *mut T
};
let scalar = unsafe { *ptr }.unpack().as_ref()[sub_index];
self.front += 1;
Some(unsafe { DimMut::new(ptr, sub_index, scalar) })
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let n = self.back - self.front;
(n, Some(n))
}
}
impl<'a, T: FloatConvertible> ExactSizeIterator for VectorSpanIterator<'a, T> {}
impl<'a, T: FloatConvertible> core::iter::FusedIterator for VectorSpanIterator<'a, T> {}
impl<'a, T: FloatConvertible> DoubleEndedIterator for VectorSpanIterator<'a, T> {
#[inline]
fn next_back(&mut self) -> Option<DimMut<'a, T>> {
if self.front >= self.back {
return None;
}
self.back -= 1;
let dims_per_value = T::dimensions_per_value();
let value_index = self.back / dims_per_value;
let sub_index = self.back % dims_per_value;
let ptr = unsafe {
(self.data as *mut u8).offset(self.stride_bytes * value_index as isize) as *mut T
};
let scalar = unsafe { *ptr }.unpack().as_ref()[sub_index];
Some(unsafe { DimMut::new(ptr, sub_index, scalar) })
}
}
impl<'a, T: FloatConvertible, A: Allocator> IntoIterator for &'a Vector<T, A> {
type Item = DimRef<'a, T>;
type IntoIter = VectorViewIterator<'a, T>;
fn into_iter(self) -> Self::IntoIter { self.iter() }
}
impl<'a, T: FloatConvertible> IntoIterator for &'a VectorView<'a, T> {
type Item = DimRef<'a, T>;
type IntoIter = VectorViewIterator<'a, T>;
fn into_iter(self) -> Self::IntoIter { self.iter() }
}
impl<'a, T: FloatConvertible> IntoIterator for &'a VectorSpan<'a, T> {
type Item = DimRef<'a, T>;
type IntoIter = VectorViewIterator<'a, T>;
fn into_iter(self) -> Self::IntoIter { self.iter() }
}
impl<'a, T: FloatConvertible, A: Allocator> IntoIterator for &'a mut Vector<T, A> {
type Item = DimMut<'a, T>;
type IntoIter = VectorSpanIterator<'a, T>;
fn into_iter(self) -> Self::IntoIter { self.iter_mut() }
}
impl<'a, T: FloatConvertible> IntoIterator for &'a mut VectorSpan<'a, T> {
type Item = DimMut<'a, T>;
type IntoIter = VectorSpanIterator<'a, T>;
fn into_iter(self) -> Self::IntoIter { self.iter_mut() }
}
impl<T: StorageElement, A: Allocator> AsRef<[T]> for Vector<T, A> {
fn as_ref(&self) -> &[T] { self.as_slice() }
}
impl<T: FloatConvertible, A: Allocator> PartialEq for Vector<T, A>
where
T::DimScalar: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.dims == other.dims && self.iter().zip(other.iter()).all(|(a, b)| a == b)
}
}
impl<T: FloatConvertible, A: Allocator> PartialEq<[T::DimScalar]> for Vector<T, A>
where
T::DimScalar: PartialEq,
{
fn eq(&self, other: &[T::DimScalar]) -> bool {
self.dims == other.len() && self.iter().zip(other.iter()).all(|(a, b)| *a == *b)
}
}
impl<'a, T: FloatConvertible> PartialEq for VectorView<'a, T>
where
T::DimScalar: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.dims == other.dims && self.iter().zip(other.iter()).all(|(a, b)| a == b)
}
}
impl<'a, T: FloatConvertible> PartialEq for VectorSpan<'a, T>
where
T::DimScalar: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.dims == other.dims && self.iter().zip(other.iter()).all(|(a, b)| a == b)
}
}
impl<T: FloatConvertible, A: Allocator> Vector<T, A>
where
T::DimScalar: NumberLike,
{
pub fn allclose<OA: Allocator>(&self, other: &Vector<T, OA>, atol: f64, rtol: f64) -> bool {
self.dims == other.dims
&& self
.iter()
.zip(other.iter())
.all(|(a, b)| crate::types::is_close(a.to_f64(), b.to_f64(), atol, rtol))
}
}
impl<'a, T: FloatConvertible> VectorView<'a, T>
where
T::DimScalar: NumberLike,
{
pub fn allclose(&self, other: &Self, atol: f64, rtol: f64) -> bool {
self.dims == other.dims
&& self
.iter()
.zip(other.iter())
.all(|(a, b)| crate::types::is_close(a.to_f64(), b.to_f64(), atol, rtol))
}
}
impl<'a, T: FloatConvertible> VectorSpan<'a, T>
where
T::DimScalar: NumberLike,
{
pub fn allclose(&self, other: &Self, atol: f64, rtol: f64) -> bool {
self.dims == other.dims
&& self
.iter()
.zip(other.iter())
.all(|(a, b)| crate::types::is_close(a.to_f64(), b.to_f64(), atol, rtol))
}
}
fn fmt_debug_list<I: Iterator>(
f: &mut core::fmt::Formatter<'_>,
name: &str,
dims: usize,
iter: I,
limit: usize,
) -> core::fmt::Result
where
I::Item: core::fmt::Debug,
{
write!(f, "{}(dims={}, [", name, dims)?;
for (i, val) in iter.enumerate() {
if i >= limit {
write!(f, ", ...")?;
break;
}
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{:?}", val)?;
}
write!(f, "])")
}
fn fmt_display_list<I: Iterator>(
f: &mut core::fmt::Formatter<'_>,
iter: I,
limit: usize,
) -> core::fmt::Result
where
I::Item: core::fmt::Display,
{
let prec = f.precision();
write!(f, "[")?;
for (i, val) in iter.enumerate() {
if i >= limit {
write!(f, ", ...")?;
break;
}
if i > 0 {
write!(f, ", ")?;
}
if let Some(p) = prec {
write!(f, "{:.p$}", val)?;
} else {
write!(f, "{}", val)?;
}
}
write!(f, "]")
}
impl<T: FloatConvertible, A: Allocator> core::fmt::Debug for Vector<T, A>
where
T::DimScalar: core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
fmt_debug_list(f, "Vector", self.dims, self.iter(), 8)
}
}
impl<'a, T: FloatConvertible> core::fmt::Debug for VectorView<'a, T>
where
T::DimScalar: core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
fmt_debug_list(f, "VectorView", self.dims, self.iter(), 8)
}
}
impl<'a, T: FloatConvertible> core::fmt::Debug for VectorSpan<'a, T>
where
T::DimScalar: core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
fmt_debug_list(f, "VectorSpan", self.dims, self.iter(), 8)
}
}
impl<T: FloatConvertible, A: Allocator> core::fmt::Display for Vector<T, A>
where
T::DimScalar: core::fmt::Display,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
fmt_display_list(f, self.iter(), 20)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{bf16, f16, i4x2, u1x8, u4x2};
fn check_vector_roundtrip<T: FloatConvertible>() {
let dims_per_value = T::dimensions_per_value();
let test_dims = 16 * dims_per_value;
let v = Vector::<T>::try_zeros(test_dims).unwrap();
assert_eq!(v.dims(), test_dims);
assert_eq!(v.size_values(), test_dims / dims_per_value);
let count = v.iter().count();
assert_eq!(count, test_dims);
}
fn check_vector_try_get_set<T: FloatConvertible>()
where
T::DimScalar: core::fmt::Debug,
{
let dims_per_value = T::dimensions_per_value();
let test_dims = 4 * dims_per_value;
let mut v = Vector::<T>::try_zeros(test_dims).unwrap();
let one = T::DimScalar::from_f32(1.0);
v.try_set(0_usize, one).unwrap();
v.try_set((test_dims - 1) as i32, one).unwrap();
let first = v.try_get(0_usize).unwrap();
let last = v.try_get(-1_i32).unwrap();
assert!(
first.to_f32() >= 0.5,
"first dim should be ~1.0, got {:?}",
first
);
assert!(
last.to_f32() >= 0.5,
"last dim should be ~1.0, got {:?}",
last
);
}
#[test]
fn vector_roundtrip_all_types() {
check_vector_roundtrip::<f32>();
check_vector_roundtrip::<f64>();
check_vector_roundtrip::<f16>();
check_vector_roundtrip::<bf16>();
check_vector_roundtrip::<i4x2>();
check_vector_roundtrip::<u4x2>();
check_vector_roundtrip::<u1x8>();
}
#[test]
fn vector_try_get_set_all_types() {
check_vector_try_get_set::<f32>();
check_vector_try_get_set::<f64>();
check_vector_try_get_set::<i4x2>();
check_vector_try_get_set::<u4x2>();
check_vector_try_get_set::<u1x8>();
}
#[test]
fn vec_index_signed() {
let v = Vector::<f32>::try_from_dims(&[10.0, 20.0, 30.0, 40.0, 50.0]).unwrap();
assert_eq!(v[0], 10.0);
assert_eq!(v[4], 50.0);
assert_eq!(v[-1_i32], 50.0);
assert_eq!(v[-2_i32], 40.0);
assert_eq!(v[-5_i32], 10.0);
}
#[test]
fn vector_view_stride() {
let v = Vector::<f32>::try_from_scalars(&[1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
let view = v.view();
assert!(view.is_contiguous());
assert_eq!(view.size(), 5);
let rev = view.rev();
assert_eq!(rev.try_get(0_usize).unwrap(), 5.0);
assert_eq!(rev.try_get(4_usize).unwrap(), 1.0);
}
#[test]
fn vector_span_fill() {
let mut v = Vector::<f32>::try_zeros(4).unwrap();
{
let mut span = v.span();
span.fill(42.0);
}
assert_eq!(v[0], 42.0);
assert_eq!(v[3], 42.0);
}
#[test]
fn vector_iter() {
let v = Vector::<f32>::try_from_scalars(&[1.0, 2.0, 3.0]).unwrap();
let vals: Vec<f32> = v.iter().map(|x| *x).collect();
assert_eq!(vals, vec![1.0, 2.0, 3.0]);
let rev_vals: Vec<f32> = v.iter().rev().map(|x| *x).collect();
assert_eq!(rev_vals, vec![3.0, 2.0, 1.0]);
}
#[test]
fn view_strided_iteration() {
let v = Vector::<f32>::try_from_scalars(&[1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
let view = v.view();
let strided = view.try_strided(0, 5, 2).unwrap();
assert_eq!(strided.size(), 3);
let vals: Vec<f32> = strided.iter().map(|x| *x).collect();
assert_eq!(vals, vec![1.0, 3.0, 5.0]);
}
#[test]
fn vector_filled() {
let v = Vector::<f32>::try_full(3, 7.5).unwrap();
assert_eq!(v[0], 7.5);
assert_eq!(v[1], 7.5);
assert_eq!(v[2], 7.5);
}
#[test]
fn empty_vector() {
let v = Vector::<f32>::try_zeros(0).unwrap();
assert!(v.is_empty());
assert_eq!(v.size(), 0);
}
#[test]
#[should_panic]
fn index_out_of_bounds() {
let v = Vector::<f32>::try_zeros(3).unwrap();
let _ = v[3_usize];
}
#[test]
fn vector_allclose_matching() {
let a = Vector::<f32>::try_full(4, 1.0).unwrap();
let b = Vector::<f32>::try_full(4, 1.0 + 1e-7).unwrap();
assert!(a.allclose(&b, 1e-6, 0.0));
}
#[test]
fn vector_allclose_mismatching() {
let a = Vector::<f32>::try_full(4, 1.0).unwrap();
let b = Vector::<f32>::try_full(4, 2.0).unwrap();
assert!(!a.allclose(&b, 1e-6, 0.0));
}
#[test]
fn vector_allclose_different_dims() {
let a = Vector::<f32>::try_full(3, 1.0).unwrap();
let b = Vector::<f32>::try_full(4, 1.0).unwrap();
assert!(!a.allclose(&b, 1e-6, 1e-6));
}
#[test]
fn display_precision_forwarding() {
let v = Vector::<f32>::try_full(3, 1.0).unwrap();
let s = format!("{:.2}", v);
assert_eq!(s, "[1.00, 1.00, 1.00]");
}
#[test]
fn vector_span_iter_mut_f32() {
let mut v = Vector::<f32>::try_from_scalars(&[1.0, 2.0, 3.0]).unwrap();
for mut val in &mut v {
*val += 10.0;
}
let vals: Vec<f32> = v.iter().map(|x| *x).collect();
assert_eq!(vals, vec![11.0, 12.0, 13.0]);
}
#[test]
fn vector_span_iter_mut_i4x2() {
let mut v = Vector::<i4x2>::try_zeros(4).unwrap();
{
let mut span = v.span();
for (i, mut val) in span.iter_mut().enumerate() {
*val = (i + 1) as i8;
}
}
assert_eq!(v.try_get(0_usize).unwrap(), 1);
assert_eq!(v.try_get(1_usize).unwrap(), 2);
assert_eq!(v.try_get(2_usize).unwrap(), 3);
assert_eq!(v.try_get(3_usize).unwrap(), 4);
}
#[test]
fn vector_span_iter_mut_u1x8() {
let mut v = Vector::<u1x8>::try_zeros(8).unwrap();
for (i, mut val) in v.iter_mut().enumerate() {
if i % 2 == 0 {
*val = 1;
}
}
assert_eq!(v.try_get(0_usize).unwrap(), 1);
assert_eq!(v.try_get(1_usize).unwrap(), 0);
assert_eq!(v.try_get(2_usize).unwrap(), 1);
assert_eq!(v.try_get(3_usize).unwrap(), 0);
}
#[test]
fn vector_span_iter_double_ended() {
let mut v = Vector::<f32>::try_from_scalars(&[1.0, 2.0, 3.0]).unwrap();
let mut span = v.span();
let mut it = span.iter_mut();
let mut first = it.next().unwrap();
*first = 10.0;
drop(first);
let mut last = it.next_back().unwrap();
*last = 30.0;
drop(last);
drop(it);
drop(span);
assert_eq!(v.try_get(0_usize).unwrap(), 10.0);
assert_eq!(v.try_get(1_usize).unwrap(), 2.0);
assert_eq!(v.try_get(2_usize).unwrap(), 30.0);
}
#[test]
fn vector_iterator_alias_compat() {
let v = Vector::<f32>::try_from_scalars(&[1.0]).unwrap();
let _it: VectorIterator<'_, f32> = v.iter();
}
}