extern crate alloc;
use core::marker::PhantomData;
use core::ptr::NonNull;
use crate::tensor::{Allocator, Global, Tensor, TensorError, SIMD_ALIGNMENT};
use crate::types::{DimMut, DimRef, FloatConvertible, NumberLike, StorageElement};
mod private {
pub trait Sealed {}
}
pub trait VectorIndex: private::Sealed + Copy {
fn resolve(self, len: usize) -> Option<usize>;
}
macro_rules! impl_vec_index_unsigned {
($($t:ty),*) => {$(
impl private::Sealed for $t {}
impl VectorIndex for $t {
#[inline]
fn resolve(self, len: usize) -> Option<usize> {
let index = self as usize;
if index < len { Some(index) } else { None }
}
}
)*};
}
macro_rules! impl_vec_index_signed {
($($t:ty),*) => {$(
impl private::Sealed for $t {}
impl VectorIndex for $t {
#[inline]
fn resolve(self, len: usize) -> Option<usize> {
let index = if self >= 0 {
self as usize
} else {
let neg = (-(self as isize)) as usize;
if neg > len { return None; }
len - neg
};
if index < len { Some(index) } else { None }
}
}
)*};
}
impl_vec_index_unsigned!(usize, u8, u16, u32, u64);
impl_vec_index_signed!(isize, i8, i16, i32, i64);
pub struct NibbleRef<'a> {
byte: *const u8,
high: bool,
_marker: PhantomData<&'a u8>,
}
impl<'a> NibbleRef<'a> {
#[inline]
pub fn get_unsigned(&self) -> u8 {
let byte_value = unsafe { *self.byte };
if self.high {
byte_value >> 4
} else {
byte_value & 0x0F
}
}
#[inline]
pub fn get_signed(&self) -> i8 {
let nibble = self.get_unsigned();
if nibble & 0x08 != 0 {
nibble as i8 | !0x0Fi8
} else {
nibble as i8
}
}
}
pub struct NibbleRefMut<'a> {
byte: *mut u8,
high: bool,
_marker: PhantomData<&'a mut u8>,
}
impl<'a> NibbleRefMut<'a> {
#[inline]
pub fn get_unsigned(&self) -> u8 {
let byte_value = unsafe { *self.byte };
if self.high {
byte_value >> 4
} else {
byte_value & 0x0F
}
}
#[inline]
pub fn get_signed(&self) -> i8 {
let nibble = self.get_unsigned();
if nibble & 0x08 != 0 {
nibble as i8 | !0x0Fi8
} else {
nibble as i8
}
}
#[inline]
pub fn set_unsigned(&self, value: u8) {
unsafe {
let byte_value = *self.byte;
if self.high {
*self.byte = (byte_value & 0x0F) | ((value & 0x0F) << 4);
} else {
*self.byte = (byte_value & 0xF0) | (value & 0x0F);
}
}
}
#[inline]
pub fn set_signed(&self, value: i8) {
self.set_unsigned(value as u8);
}
}
pub struct BitRef<'a> {
byte: *const u8,
mask: u8,
_marker: PhantomData<&'a u8>,
}
impl<'a> BitRef<'a> {
#[inline]
pub fn get(&self) -> bool {
(unsafe { *self.byte } & self.mask) != 0
}
}
pub struct BitRefMut<'a> {
byte: *mut u8,
mask: u8,
_marker: PhantomData<&'a mut u8>,
}
impl<'a> BitRefMut<'a> {
#[inline]
pub fn get(&self) -> bool {
(unsafe { *self.byte } & self.mask) != 0
}
#[inline]
pub fn set(&self, value: bool) {
unsafe {
if value {
*self.byte |= self.mask;
} else {
*self.byte &= !self.mask;
}
}
}
}
pub struct Vector<Scalar: StorageElement, Alloc: Allocator = Global> {
data: NonNull<Scalar>,
dims: usize,
values: usize,
alloc: Alloc,
}
unsafe impl<Scalar: StorageElement + Send, Alloc: Allocator + Send> Send for Vector<Scalar, Alloc> {}
unsafe impl<Scalar: StorageElement + Sync, Alloc: Allocator + Sync> Sync for Vector<Scalar, Alloc> {}
impl<Scalar: StorageElement, Alloc: Allocator> Drop for Vector<Scalar, Alloc> {
fn drop(&mut self) {
if self.values > 0 {
let layout = alloc::alloc::Layout::from_size_align(
self.values * core::mem::size_of::<Scalar>(),
SIMD_ALIGNMENT,
)
.unwrap();
unsafe {
self.alloc.deallocate(
NonNull::new_unchecked(self.data.as_ptr() as *mut u8),
layout,
);
}
}
}
}
#[inline]
fn dims_to_values<Scalar: StorageElement>(dims: usize) -> usize {
dims.div_ceil(Scalar::dimensions_per_value())
}
impl<Scalar: StorageElement, Alloc: Allocator> Vector<Scalar, Alloc> {
pub unsafe fn from_raw_parts(
data: NonNull<Scalar>,
dims: usize,
values: usize,
alloc: Alloc,
) -> Self {
Self {
data,
dims,
values,
alloc,
}
}
pub fn try_zeros_in(dims: usize, alloc: Alloc) -> Result<Self, TensorError> {
let values = dims_to_values::<Scalar>(dims);
if values == 0 {
return Ok(Self {
data: NonNull::dangling(),
dims: 0,
values: 0,
alloc,
});
}
let size = values * core::mem::size_of::<Scalar>();
let layout = alloc::alloc::Layout::from_size_align(size, SIMD_ALIGNMENT)
.map_err(|_| TensorError::AllocationFailed)?;
let ptr = alloc
.allocate(layout)
.ok_or(TensorError::AllocationFailed)?;
unsafe { core::ptr::write_bytes(ptr.as_ptr(), 0, size) };
Ok(Self {
data: unsafe { NonNull::new_unchecked(ptr.as_ptr() as *mut Scalar) },
dims,
values,
alloc,
})
}
pub fn try_full_in(dims: usize, value: Scalar, alloc: Alloc) -> Result<Self, TensorError> {
let v = Self::try_zeros_in(dims, alloc)?;
if v.values > 0 {
let ptr = v.data.as_ptr();
for i in 0..v.values {
unsafe { ptr.add(i).write(value) };
}
}
Ok(v)
}
pub fn try_ones_in(dims: usize, alloc: Alloc) -> Result<Self, TensorError>
where
Scalar: NumberLike,
{
Self::try_full_in(dims, Scalar::one(), alloc)
}
pub unsafe fn try_empty_in(dims: usize, alloc: Alloc) -> Result<Self, TensorError> {
let values = dims_to_values::<Scalar>(dims);
if values == 0 {
return Ok(Self {
data: NonNull::dangling(),
dims: 0,
values: 0,
alloc,
});
}
let size = values * core::mem::size_of::<Scalar>();
let layout = alloc::alloc::Layout::from_size_align(size, SIMD_ALIGNMENT)
.map_err(|_| TensorError::AllocationFailed)?;
let ptr = alloc
.allocate(layout)
.ok_or(TensorError::AllocationFailed)?;
Ok(Self {
data: unsafe { NonNull::new_unchecked(ptr.as_ptr() as *mut Scalar) },
dims,
values,
alloc,
})
}
pub fn try_from_scalars_in(scalars: &[f32], alloc: Alloc) -> Result<Self, TensorError>
where
Scalar: FloatConvertible,
{
let element_count = scalars.len();
let mut v = Self::try_zeros_in(element_count, alloc)?;
for (i, &s) in scalars.iter().enumerate() {
v.try_set(i, Scalar::DimScalar::from_f32(s))?;
}
Ok(v)
}
pub fn try_from_dims_in(
dim_values: &[Scalar::DimScalar],
alloc: Alloc,
) -> Result<Self, TensorError>
where
Scalar: FloatConvertible,
{
let element_count = dim_values.len();
let mut v = Self::try_zeros_in(element_count, alloc)?;
for (i, &d) in dim_values.iter().enumerate() {
v.try_set(i, d)?;
}
Ok(v)
}
#[inline]
pub fn dims(&self) -> usize {
self.dims
}
#[inline]
pub fn size(&self) -> usize {
self.dims
}
#[inline]
pub fn size_values(&self) -> usize {
self.values
}
#[inline]
pub fn is_empty(&self) -> bool {
self.dims == 0
}
#[inline]
pub fn as_ptr(&self) -> *const Scalar {
self.data.as_ptr()
}
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut Scalar {
self.data.as_ptr()
}
#[inline]
pub fn size_bytes(&self) -> usize {
self.values * core::mem::size_of::<Scalar>()
}
#[inline]
pub fn view(&self) -> VectorView<'_, Scalar> {
VectorView {
data: self.data.as_ptr() as *const Scalar,
dims: self.dims,
stride_bytes: core::mem::size_of::<Scalar>() as isize,
_marker: PhantomData,
}
}
#[inline]
pub fn span(&mut self) -> VectorSpan<'_, Scalar> {
VectorSpan {
data: self.data.as_ptr(),
dims: self.dims,
stride_bytes: core::mem::size_of::<Scalar>() as isize,
_marker: PhantomData,
}
}
#[inline]
pub fn try_get<AnyIndex: VectorIndex>(
&self,
index: AnyIndex,
) -> Result<Scalar::DimScalar, TensorError>
where
Scalar: FloatConvertible,
{
let i = index
.resolve(self.dims)
.ok_or(TensorError::IndexOutOfBounds {
index: 0,
size: self.dims,
})?;
let dims_per_value = Scalar::dimensions_per_value();
let value_index = i / dims_per_value;
let sub_index = i % dims_per_value;
let packed = unsafe { *self.data.as_ptr().add(value_index) };
Ok(packed.unpack().as_ref()[sub_index])
}
#[inline]
pub fn try_set<AnyIndex: VectorIndex>(
&mut self,
index: AnyIndex,
value: Scalar::DimScalar,
) -> Result<(), TensorError>
where
Scalar: FloatConvertible,
{
let i = index
.resolve(self.dims)
.ok_or(TensorError::IndexOutOfBounds {
index: 0,
size: self.dims,
})?;
let dims_per_value = Scalar::dimensions_per_value();
let value_index = i / dims_per_value;
let sub_index = i % dims_per_value;
let ptr = unsafe { self.data.as_ptr().add(value_index) };
let mut unpacked = unsafe { *ptr }.unpack();
unpacked.as_mut()[sub_index] = value;
unsafe { ptr.write(Scalar::pack(unpacked)) };
Ok(())
}
#[inline]
pub fn as_slice(&self) -> &[Scalar] {
if Scalar::dimensions_per_value() == 1 {
unsafe { core::slice::from_raw_parts(self.data.as_ptr(), self.values) }
} else {
unsafe { core::slice::from_raw_parts(self.data.as_ptr(), self.values) }
}
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [Scalar] {
unsafe { core::slice::from_raw_parts_mut(self.data.as_ptr(), self.values) }
}
pub fn iter(&self) -> VectorViewIterator<'_, Scalar>
where
Scalar: FloatConvertible,
{
self.view().iter()
}
pub fn iter_mut(&mut self) -> VectorSpanIterator<'_, Scalar>
where
Scalar: FloatConvertible,
{
VectorSpanIterator {
data: self.data.as_ptr(),
stride_bytes: core::mem::size_of::<Scalar>() as isize,
front: 0,
back: self.dims,
_marker: PhantomData,
}
}
}
impl<Scalar: StorageElement, Alloc: Allocator> Vector<Scalar, Alloc> {
pub fn try_into_tensor<const MAX_RANK: usize>(
self,
) -> Result<Tensor<Scalar, Alloc, MAX_RANK>, TensorError> {
if MAX_RANK == 0 {
return Err(TensorError::TooManyRanks { got: 1 });
}
let mut shape = [0usize; MAX_RANK];
shape[0] = self.dims;
let mut strides = [0isize; MAX_RANK];
strides[0] = core::mem::size_of::<Scalar>() as isize;
let alloc_bytes = self.values * core::mem::size_of::<Scalar>();
let data = self.data;
let alloc = unsafe { core::ptr::read(&self.alloc) };
core::mem::forget(self);
let tensor = unsafe { Tensor::from_raw_parts(data, alloc_bytes, shape, strides, 1, alloc) };
Ok(tensor)
}
}
impl<Scalar: StorageElement> Vector<Scalar, Global> {
pub fn try_zeros(dims: usize) -> Result<Self, TensorError> {
Self::try_zeros_in(dims, Global)
}
pub fn try_full(dims: usize, value: Scalar) -> Result<Self, TensorError> {
Self::try_full_in(dims, value, Global)
}
pub fn try_ones(dims: usize) -> Result<Self, TensorError>
where
Scalar: NumberLike,
{
Self::try_full(dims, Scalar::one())
}
pub unsafe fn try_empty(dims: usize) -> Result<Self, TensorError> {
unsafe { Self::try_empty_in(dims, Global) }
}
pub fn try_from_scalars(scalars: &[f32]) -> Result<Self, TensorError>
where
Scalar: FloatConvertible,
{
Self::try_from_scalars_in(scalars, Global)
}
pub fn try_from_dims(dims: &[Scalar::DimScalar]) -> Result<Self, TensorError>
where
Scalar: FloatConvertible,
{
Self::try_from_dims_in(dims, Global)
}
}
impl<AnyIndex: VectorIndex, Scalar: StorageElement, Alloc: Allocator> core::ops::Index<AnyIndex>
for Vector<Scalar, Alloc>
{
type Output = Scalar;
#[inline]
fn index(&self, index: AnyIndex) -> &Scalar {
let i = index
.resolve(self.dims)
.expect("vector index out of bounds");
debug_assert_eq!(
Scalar::dimensions_per_value(),
1,
"Index trait not supported for sub-byte types"
);
unsafe { &*self.data.as_ptr().add(i) }
}
}
impl<AnyIndex: VectorIndex, Scalar: StorageElement, Alloc: Allocator> core::ops::IndexMut<AnyIndex>
for Vector<Scalar, Alloc>
{
#[inline]
fn index_mut(&mut self, index: AnyIndex) -> &mut Scalar {
let i = index
.resolve(self.dims)
.expect("vector index out of bounds");
debug_assert_eq!(
Scalar::dimensions_per_value(),
1,
"IndexMut trait not supported for sub-byte types"
);
unsafe { &mut *self.data.as_ptr().add(i) }
}
}
impl<Scalar: StorageElement + Clone, Alloc: Allocator + Clone> Vector<Scalar, Alloc> {
pub fn try_clone(&self) -> Result<Self, TensorError> {
if self.values == 0 {
return Ok(Self {
data: NonNull::dangling(),
dims: 0,
values: 0,
alloc: self.alloc.clone(),
});
}
let size = self.values * core::mem::size_of::<Scalar>();
let layout = alloc::alloc::Layout::from_size_align(size, SIMD_ALIGNMENT)
.map_err(|_| TensorError::AllocationFailed)?;
let ptr = self
.alloc
.allocate(layout)
.ok_or(TensorError::AllocationFailed)?;
unsafe {
core::ptr::copy_nonoverlapping(self.data.as_ptr() as *const u8, ptr.as_ptr(), size);
}
Ok(Self {
data: unsafe { NonNull::new_unchecked(ptr.as_ptr() as *mut Scalar) },
dims: self.dims,
values: self.values,
alloc: self.alloc.clone(),
})
}
}
impl<Scalar: StorageElement + Clone, Alloc: Allocator + Clone> Clone for Vector<Scalar, Alloc> {
fn clone(&self) -> Self {
self.try_clone().expect("vector clone allocation failed")
}
}
impl<Scalar: StorageElement> Default for Vector<Scalar, Global> {
fn default() -> Self {
Self {
data: NonNull::dangling(),
dims: 0,
values: 0,
alloc: Global,
}
}
}
pub struct VectorView<'a, Scalar: StorageElement> {
data: *const Scalar,
dims: usize,
stride_bytes: isize,
_marker: PhantomData<&'a Scalar>,
}
unsafe impl<'a, Scalar: StorageElement + Sync> Send for VectorView<'a, Scalar> {}
unsafe impl<'a, Scalar: StorageElement + Sync> Sync for VectorView<'a, Scalar> {}
impl<'a, Scalar: StorageElement> Clone for VectorView<'a, Scalar> {
fn clone(&self) -> Self {
*self
}
}
impl<'a, Scalar: StorageElement> Copy for VectorView<'a, Scalar> {}
impl<'a, Scalar: StorageElement> VectorView<'a, Scalar> {
#[inline]
pub unsafe fn from_raw_parts(data: *const Scalar, dims: usize, stride_bytes: isize) -> Self {
Self {
data,
dims,
stride_bytes,
_marker: PhantomData,
}
}
#[inline]
pub fn dims(&self) -> usize {
self.dims
}
#[inline]
pub fn size(&self) -> usize {
self.dims
}
#[inline]
pub fn is_empty(&self) -> bool {
self.dims == 0
}
#[inline]
pub fn stride_bytes(&self) -> isize {
self.stride_bytes
}
#[inline]
pub fn is_contiguous(&self) -> bool {
self.stride_bytes == core::mem::size_of::<Scalar>() as isize
}
#[inline]
pub fn as_ptr(&self) -> *const Scalar {
self.data
}
#[inline]
pub fn as_contiguous_slice(&self) -> Option<&'a [Scalar]> {
if self.is_contiguous() && Scalar::dimensions_per_value() == 1 {
Some(unsafe { core::slice::from_raw_parts(self.data, self.dims) })
} else {
None
}
}
#[inline]
pub fn try_get<AnyIndex: VectorIndex>(
&self,
index: AnyIndex,
) -> Result<Scalar::DimScalar, TensorError>
where
Scalar: FloatConvertible,
{
let i = index
.resolve(self.dims)
.ok_or(TensorError::IndexOutOfBounds {
index: 0,
size: self.dims,
})?;
let dims_per_value = Scalar::dimensions_per_value();
let value_index = i / dims_per_value;
let sub_index = i % dims_per_value;
let ptr = unsafe {
(self.data as *const u8).offset(self.stride_bytes * value_index as isize)
as *const Scalar
};
Ok(unsafe { *ptr }.unpack().as_ref()[sub_index])
}
pub fn rev(&self) -> Self {
if self.dims == 0 {
return *self;
}
let last_offset = self.stride_bytes * (self.dims as isize - 1);
Self {
data: unsafe { (self.data as *const u8).offset(last_offset) as *const Scalar },
dims: self.dims,
stride_bytes: -self.stride_bytes,
_marker: PhantomData,
}
}
pub fn try_strided(&self, start: usize, end: usize, step: isize) -> Result<Self, TensorError> {
if start > self.dims || end > self.dims || step == 0 {
return Err(TensorError::IndexOutOfBounds {
index: start.max(end),
size: self.dims,
});
}
let count = if step > 0 {
if end > start {
(end - start + step as usize - 1) / step as usize
} else {
0
}
} else if start > end {
let abs_step = (-step) as usize;
(start - end + abs_step - 1) / abs_step
} else {
0
};
let new_data = unsafe {
(self.data as *const u8).offset(self.stride_bytes * start as isize) as *const Scalar
};
Ok(Self {
data: new_data,
dims: count,
stride_bytes: self.stride_bytes * step,
_marker: PhantomData,
})
}
pub fn iter(&self) -> VectorViewIterator<'a, Scalar>
where
Scalar: FloatConvertible,
{
VectorViewIterator {
data: self.data,
stride_bytes: self.stride_bytes,
front: 0,
back: self.dims,
_marker: PhantomData,
}
}
}
impl<'a, AnyIndex: VectorIndex, Scalar: StorageElement> core::ops::Index<AnyIndex>
for VectorView<'a, Scalar>
{
type Output = Scalar;
#[inline]
fn index(&self, index: AnyIndex) -> &Scalar {
let i = index.resolve(self.dims).expect("view index out of bounds");
debug_assert_eq!(
Scalar::dimensions_per_value(),
1,
"Index trait not supported for sub-byte types"
);
unsafe {
&*((self.data as *const u8).offset(self.stride_bytes * i as isize) as *const Scalar)
}
}
}
pub struct VectorSpan<'a, Scalar: StorageElement> {
data: *mut Scalar,
dims: usize,
stride_bytes: isize,
_marker: PhantomData<&'a mut Scalar>,
}
unsafe impl<'a, Scalar: StorageElement + Send> Send for VectorSpan<'a, Scalar> {}
unsafe impl<'a, Scalar: StorageElement + Sync> Sync for VectorSpan<'a, Scalar> {}
impl<'a, Scalar: StorageElement> VectorSpan<'a, Scalar> {
#[inline]
pub unsafe fn from_raw_parts(data: *mut Scalar, dims: usize, stride_bytes: isize) -> Self {
Self {
data,
dims,
stride_bytes,
_marker: PhantomData,
}
}
#[inline]
pub fn dims(&self) -> usize {
self.dims
}
#[inline]
pub fn size(&self) -> usize {
self.dims
}
#[inline]
pub fn is_empty(&self) -> bool {
self.dims == 0
}
#[inline]
pub fn stride_bytes(&self) -> isize {
self.stride_bytes
}
#[inline]
pub fn is_contiguous(&self) -> bool {
self.stride_bytes == core::mem::size_of::<Scalar>() as isize
}
#[inline]
pub fn as_ptr(&self) -> *const Scalar {
self.data
}
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut Scalar {
self.data
}
pub fn as_view(&self) -> VectorView<'_, Scalar> {
VectorView {
data: self.data,
dims: self.dims,
stride_bytes: self.stride_bytes,
_marker: PhantomData,
}
}
#[inline]
pub fn try_get<AnyIndex: VectorIndex>(
&self,
index: AnyIndex,
) -> Result<Scalar::DimScalar, TensorError>
where
Scalar: FloatConvertible,
{
self.as_view().try_get(index)
}
#[inline]
pub fn try_set<AnyIndex: VectorIndex>(
&mut self,
index: AnyIndex,
value: Scalar::DimScalar,
) -> Result<(), TensorError>
where
Scalar: FloatConvertible,
{
let i = index
.resolve(self.dims)
.ok_or(TensorError::IndexOutOfBounds {
index: 0,
size: self.dims,
})?;
let dims_per_value = Scalar::dimensions_per_value();
let value_index = i / dims_per_value;
let sub_index = i % dims_per_value;
let ptr = unsafe {
(self.data as *mut u8).offset(self.stride_bytes * value_index as isize) as *mut Scalar
};
let mut unpacked = unsafe { *ptr }.unpack();
unpacked.as_mut()[sub_index] = value;
unsafe { ptr.write(Scalar::pack(unpacked)) };
Ok(())
}
pub fn fill(&mut self, value: Scalar) {
let values = self.dims.div_ceil(Scalar::dimensions_per_value());
for i in 0..values {
let ptr = unsafe {
(self.data as *mut u8).offset(self.stride_bytes * i as isize) as *mut Scalar
};
unsafe { ptr.write(value) };
}
}
#[inline]
pub fn as_contiguous_slice(&self) -> Option<&[Scalar]> {
if self.is_contiguous() && Scalar::dimensions_per_value() == 1 {
Some(unsafe { core::slice::from_raw_parts(self.data, self.dims) })
} else {
None
}
}
#[inline]
pub fn as_contiguous_slice_mut(&mut self) -> Option<&mut [Scalar]> {
if self.is_contiguous() && Scalar::dimensions_per_value() == 1 {
Some(unsafe { core::slice::from_raw_parts_mut(self.data, self.dims) })
} else {
None
}
}
pub fn iter(&self) -> VectorViewIterator<'_, Scalar>
where
Scalar: FloatConvertible,
{
VectorViewIterator {
data: self.data,
stride_bytes: self.stride_bytes,
front: 0,
back: self.dims,
_marker: PhantomData,
}
}
pub fn iter_mut(&mut self) -> VectorSpanIterator<'_, Scalar>
where
Scalar: FloatConvertible,
{
VectorSpanIterator {
data: self.data,
stride_bytes: self.stride_bytes,
front: 0,
back: self.dims,
_marker: PhantomData,
}
}
}
impl<'a, AnyIndex: VectorIndex, Scalar: StorageElement> core::ops::Index<AnyIndex>
for VectorSpan<'a, Scalar>
{
type Output = Scalar;
#[inline]
fn index(&self, index: AnyIndex) -> &Scalar {
let i = index.resolve(self.dims).expect("span index out of bounds");
debug_assert_eq!(
Scalar::dimensions_per_value(),
1,
"Index trait not supported for sub-byte types"
);
unsafe {
&*((self.data as *const u8).offset(self.stride_bytes * i as isize) as *const Scalar)
}
}
}
impl<'a, AnyIndex: VectorIndex, Scalar: StorageElement> core::ops::IndexMut<AnyIndex>
for VectorSpan<'a, Scalar>
{
#[inline]
fn index_mut(&mut self, index: AnyIndex) -> &mut Scalar {
let i = index.resolve(self.dims).expect("span index out of bounds");
debug_assert_eq!(
Scalar::dimensions_per_value(),
1,
"IndexMut trait not supported for sub-byte types"
);
unsafe {
&mut *((self.data as *mut u8).offset(self.stride_bytes * i as isize) as *mut Scalar)
}
}
}
pub struct VectorViewIterator<'a, Scalar: FloatConvertible> {
data: *const Scalar,
stride_bytes: isize,
front: usize,
back: usize,
_marker: PhantomData<&'a Scalar>,
}
pub type VectorIterator<'a, Scalar> = VectorViewIterator<'a, Scalar>;
impl<'a, Scalar: FloatConvertible> Iterator for VectorViewIterator<'a, Scalar> {
type Item = DimRef<'a, Scalar>;
#[inline]
fn next(&mut self) -> Option<DimRef<'a, Scalar>> {
if self.front >= self.back {
return None;
}
let dims_per_value = Scalar::dimensions_per_value();
let value_index = self.front / dims_per_value;
let sub_index = self.front % dims_per_value;
let ptr = unsafe {
(self.data as *const u8).offset(self.stride_bytes * value_index as isize)
as *const Scalar
};
let scalar = unsafe { *ptr }.unpack().as_ref()[sub_index];
self.front += 1;
Some(DimRef::new(scalar))
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let element_count = self.back - self.front;
(element_count, Some(element_count))
}
}
impl<'a, Scalar: FloatConvertible> ExactSizeIterator for VectorViewIterator<'a, Scalar> {}
impl<'a, Scalar: FloatConvertible> core::iter::FusedIterator for VectorViewIterator<'a, Scalar> {}
impl<'a, Scalar: FloatConvertible> DoubleEndedIterator for VectorViewIterator<'a, Scalar> {
#[inline]
fn next_back(&mut self) -> Option<DimRef<'a, Scalar>> {
if self.front >= self.back {
return None;
}
self.back -= 1;
let dims_per_value = Scalar::dimensions_per_value();
let value_index = self.back / dims_per_value;
let sub_index = self.back % dims_per_value;
let ptr = unsafe {
(self.data as *const u8).offset(self.stride_bytes * value_index as isize)
as *const Scalar
};
Some(DimRef::new(unsafe { *ptr }.unpack().as_ref()[sub_index]))
}
}
pub struct VectorSpanIterator<'a, Scalar: FloatConvertible> {
data: *mut Scalar,
stride_bytes: isize,
front: usize,
back: usize,
_marker: PhantomData<&'a mut Scalar>,
}
impl<'a, Scalar: FloatConvertible> Iterator for VectorSpanIterator<'a, Scalar> {
type Item = DimMut<'a, Scalar>;
#[inline]
fn next(&mut self) -> Option<DimMut<'a, Scalar>> {
if self.front >= self.back {
return None;
}
let dims_per_value = Scalar::dimensions_per_value();
let value_index = self.front / dims_per_value;
let sub_index = self.front % dims_per_value;
let ptr = unsafe {
(self.data as *mut u8).offset(self.stride_bytes * value_index as isize) as *mut Scalar
};
let scalar = unsafe { *ptr }.unpack().as_ref()[sub_index];
self.front += 1;
Some(unsafe { DimMut::new(ptr, sub_index, scalar) })
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let element_count = self.back - self.front;
(element_count, Some(element_count))
}
}
impl<'a, Scalar: FloatConvertible> ExactSizeIterator for VectorSpanIterator<'a, Scalar> {}
impl<'a, Scalar: FloatConvertible> core::iter::FusedIterator for VectorSpanIterator<'a, Scalar> {}
impl<'a, Scalar: FloatConvertible> DoubleEndedIterator for VectorSpanIterator<'a, Scalar> {
#[inline]
fn next_back(&mut self) -> Option<DimMut<'a, Scalar>> {
if self.front >= self.back {
return None;
}
self.back -= 1;
let dims_per_value = Scalar::dimensions_per_value();
let value_index = self.back / dims_per_value;
let sub_index = self.back % dims_per_value;
let ptr = unsafe {
(self.data as *mut u8).offset(self.stride_bytes * value_index as isize) as *mut Scalar
};
let scalar = unsafe { *ptr }.unpack().as_ref()[sub_index];
Some(unsafe { DimMut::new(ptr, sub_index, scalar) })
}
}
impl<'a, Scalar: FloatConvertible, Alloc: Allocator> IntoIterator for &'a Vector<Scalar, Alloc> {
type Item = DimRef<'a, Scalar>;
type IntoIter = VectorViewIterator<'a, Scalar>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<'a, Scalar: FloatConvertible> IntoIterator for &'a VectorView<'a, Scalar> {
type Item = DimRef<'a, Scalar>;
type IntoIter = VectorViewIterator<'a, Scalar>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<'a, Scalar: FloatConvertible> IntoIterator for &'a VectorSpan<'a, Scalar> {
type Item = DimRef<'a, Scalar>;
type IntoIter = VectorViewIterator<'a, Scalar>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<'a, Scalar: FloatConvertible, Alloc: Allocator> IntoIterator
for &'a mut Vector<Scalar, Alloc>
{
type Item = DimMut<'a, Scalar>;
type IntoIter = VectorSpanIterator<'a, Scalar>;
fn into_iter(self) -> Self::IntoIter {
self.iter_mut()
}
}
impl<'a, Scalar: FloatConvertible> IntoIterator for &'a mut VectorSpan<'a, Scalar> {
type Item = DimMut<'a, Scalar>;
type IntoIter = VectorSpanIterator<'a, Scalar>;
fn into_iter(self) -> Self::IntoIter {
self.iter_mut()
}
}
impl<Scalar: StorageElement, Alloc: Allocator> AsRef<[Scalar]> for Vector<Scalar, Alloc> {
fn as_ref(&self) -> &[Scalar] {
self.as_slice()
}
}
impl<Scalar: FloatConvertible, Alloc: Allocator> PartialEq for Vector<Scalar, Alloc>
where
Scalar::DimScalar: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.dims == other.dims && self.iter().zip(other.iter()).all(|(a, b)| a == b)
}
}
impl<Scalar: FloatConvertible, Alloc: Allocator> PartialEq<[Scalar::DimScalar]>
for Vector<Scalar, Alloc>
where
Scalar::DimScalar: PartialEq,
{
fn eq(&self, other: &[Scalar::DimScalar]) -> bool {
self.dims == other.len() && self.iter().zip(other.iter()).all(|(a, b)| *a == *b)
}
}
impl<'a, Scalar: FloatConvertible> PartialEq for VectorView<'a, Scalar>
where
Scalar::DimScalar: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.dims == other.dims && self.iter().zip(other.iter()).all(|(a, b)| a == b)
}
}
impl<'a, Scalar: FloatConvertible> PartialEq for VectorSpan<'a, Scalar>
where
Scalar::DimScalar: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.dims == other.dims && self.iter().zip(other.iter()).all(|(a, b)| a == b)
}
}
impl<Scalar: FloatConvertible, Alloc: Allocator> Vector<Scalar, Alloc>
where
Scalar::DimScalar: NumberLike,
{
pub fn allclose<OtherAlloc: Allocator>(
&self,
other: &Vector<Scalar, OtherAlloc>,
atol: f64,
rtol: f64,
) -> bool {
self.dims == other.dims
&& self
.iter()
.zip(other.iter())
.all(|(a, b)| crate::types::is_close(a.to_f64(), b.to_f64(), atol, rtol))
}
}
impl<'a, Scalar: FloatConvertible> VectorView<'a, Scalar>
where
Scalar::DimScalar: NumberLike,
{
pub fn allclose(&self, other: &Self, atol: f64, rtol: f64) -> bool {
self.dims == other.dims
&& self
.iter()
.zip(other.iter())
.all(|(a, b)| crate::types::is_close(a.to_f64(), b.to_f64(), atol, rtol))
}
}
impl<'a, Scalar: FloatConvertible> VectorSpan<'a, Scalar>
where
Scalar::DimScalar: NumberLike,
{
pub fn allclose(&self, other: &Self, atol: f64, rtol: f64) -> bool {
self.dims == other.dims
&& self
.iter()
.zip(other.iter())
.all(|(a, b)| crate::types::is_close(a.to_f64(), b.to_f64(), atol, rtol))
}
}
fn fmt_debug_list<I: Iterator>(
f: &mut core::fmt::Formatter<'_>,
name: &str,
dims: usize,
iter: I,
limit: usize,
) -> core::fmt::Result
where
I::Item: core::fmt::Debug,
{
write!(f, "{}(dims={}, [", name, dims)?;
for (i, value) in iter.enumerate() {
if i >= limit {
write!(f, ", ...")?;
break;
}
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{:?}", value)?;
}
write!(f, "])")
}
fn fmt_display_list<I: Iterator>(
f: &mut core::fmt::Formatter<'_>,
iter: I,
limit: usize,
) -> core::fmt::Result
where
I::Item: core::fmt::Display,
{
let prec = f.precision();
write!(f, "[")?;
for (i, value) in iter.enumerate() {
if i >= limit {
write!(f, ", ...")?;
break;
}
if i > 0 {
write!(f, ", ")?;
}
if let Some(p) = prec {
write!(f, "{:.p$}", value)?;
} else {
write!(f, "{}", value)?;
}
}
write!(f, "]")
}
impl<Scalar: FloatConvertible, Alloc: Allocator> core::fmt::Debug for Vector<Scalar, Alloc>
where
Scalar::DimScalar: core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
fmt_debug_list(f, "Vector", self.dims, self.iter(), 8)
}
}
impl<'a, Scalar: FloatConvertible> core::fmt::Debug for VectorView<'a, Scalar>
where
Scalar::DimScalar: core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
fmt_debug_list(f, "VectorView", self.dims, self.iter(), 8)
}
}
impl<'a, Scalar: FloatConvertible> core::fmt::Debug for VectorSpan<'a, Scalar>
where
Scalar::DimScalar: core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
fmt_debug_list(f, "VectorSpan", self.dims, self.iter(), 8)
}
}
impl<Scalar: FloatConvertible, Alloc: Allocator> core::fmt::Display for Vector<Scalar, Alloc>
where
Scalar::DimScalar: core::fmt::Display,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
fmt_display_list(f, self.iter(), 20)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{bf16, f16, i4x2, u1x8, u4x2};
fn check_vector_roundtrip<Scalar: FloatConvertible>() {
let dims_per_value = Scalar::dimensions_per_value();
let test_dims = 16 * dims_per_value;
let v = Vector::<Scalar>::try_zeros(test_dims).unwrap();
assert_eq!(v.dims(), test_dims);
assert_eq!(v.size_values(), test_dims / dims_per_value);
let count = v.iter().count();
assert_eq!(count, test_dims);
}
fn check_vector_try_get_set<Scalar: FloatConvertible>()
where
Scalar::DimScalar: core::fmt::Debug,
{
let dims_per_value = Scalar::dimensions_per_value();
let test_dims = 4 * dims_per_value;
let mut v = Vector::<Scalar>::try_zeros(test_dims).unwrap();
let one = Scalar::DimScalar::from_f32(1.0);
v.try_set(0_usize, one).unwrap();
v.try_set((test_dims - 1) as i32, one).unwrap();
let first = v.try_get(0_usize).unwrap();
let last = v.try_get(-1_i32).unwrap();
assert!(
first.to_f32() >= 0.5,
"first dim should be ~1.0, got {:?}",
first
);
assert!(
last.to_f32() >= 0.5,
"last dim should be ~1.0, got {:?}",
last
);
}
#[test]
fn vector_roundtrip_all_types() {
check_vector_roundtrip::<f32>();
check_vector_roundtrip::<f64>();
check_vector_roundtrip::<f16>();
check_vector_roundtrip::<bf16>();
check_vector_roundtrip::<i4x2>();
check_vector_roundtrip::<u4x2>();
check_vector_roundtrip::<u1x8>();
}
#[test]
fn vector_try_get_set_all_types() {
check_vector_try_get_set::<f32>();
check_vector_try_get_set::<f64>();
check_vector_try_get_set::<i4x2>();
check_vector_try_get_set::<u4x2>();
check_vector_try_get_set::<u1x8>();
}
#[test]
fn vec_index_signed() {
let v = Vector::<f32>::try_from_dims(&[10.0, 20.0, 30.0, 40.0, 50.0]).unwrap();
assert_eq!(v[0], 10.0);
assert_eq!(v[4], 50.0);
assert_eq!(v[-1_i32], 50.0);
assert_eq!(v[-2_i32], 40.0);
assert_eq!(v[-5_i32], 10.0);
}
#[test]
fn vector_view_stride() {
let v = Vector::<f32>::try_from_scalars(&[1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
let view = v.view();
assert!(view.is_contiguous());
assert_eq!(view.size(), 5);
let rev = view.rev();
assert_eq!(rev.try_get(0_usize).unwrap(), 5.0);
assert_eq!(rev.try_get(4_usize).unwrap(), 1.0);
}
#[test]
fn vector_span_fill() {
let mut v = Vector::<f32>::try_zeros(4).unwrap();
{
let mut span = v.span();
span.fill(42.0);
}
assert_eq!(v[0], 42.0);
assert_eq!(v[3], 42.0);
}
#[test]
fn vector_iter() {
let v = Vector::<f32>::try_from_scalars(&[1.0, 2.0, 3.0]).unwrap();
let values: Vec<f32> = v.iter().map(|x| *x).collect();
assert_eq!(values, vec![1.0, 2.0, 3.0]);
let reversed_values: Vec<f32> = v.iter().rev().map(|x| *x).collect();
assert_eq!(reversed_values, vec![3.0, 2.0, 1.0]);
}
#[test]
fn view_strided_iteration() {
let v = Vector::<f32>::try_from_scalars(&[1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
let view = v.view();
let strided = view.try_strided(0, 5, 2).unwrap();
assert_eq!(strided.size(), 3);
let values: Vec<f32> = strided.iter().map(|x| *x).collect();
assert_eq!(values, vec![1.0, 3.0, 5.0]);
}
#[test]
fn vector_filled() {
let v = Vector::<f32>::try_full(3, 7.5).unwrap();
assert_eq!(v[0], 7.5);
assert_eq!(v[1], 7.5);
assert_eq!(v[2], 7.5);
}
#[test]
fn empty_vector() {
let v = Vector::<f32>::try_zeros(0).unwrap();
assert!(v.is_empty());
assert_eq!(v.size(), 0);
}
#[test]
#[should_panic]
fn index_out_of_bounds() {
let v = Vector::<f32>::try_zeros(3).unwrap();
let _ = v[3_usize];
}
#[test]
fn vector_allclose_matching() {
let a = Vector::<f32>::try_full(4, 1.0).unwrap();
let b = Vector::<f32>::try_full(4, 1.0 + 1e-7).unwrap();
assert!(a.allclose(&b, 1e-6, 0.0));
}
#[test]
fn vector_allclose_mismatching() {
let a = Vector::<f32>::try_full(4, 1.0).unwrap();
let b = Vector::<f32>::try_full(4, 2.0).unwrap();
assert!(!a.allclose(&b, 1e-6, 0.0));
}
#[test]
fn vector_allclose_different_dims() {
let a = Vector::<f32>::try_full(3, 1.0).unwrap();
let b = Vector::<f32>::try_full(4, 1.0).unwrap();
assert!(!a.allclose(&b, 1e-6, 1e-6));
}
#[test]
fn display_precision_forwarding() {
let v = Vector::<f32>::try_full(3, 1.0).unwrap();
let s = format!("{:.2}", v);
assert_eq!(s, "[1.00, 1.00, 1.00]");
}
#[test]
fn vector_span_iter_mut_f32() {
let mut v = Vector::<f32>::try_from_scalars(&[1.0, 2.0, 3.0]).unwrap();
for mut value in &mut v {
*value += 10.0;
}
let values: Vec<f32> = v.iter().map(|x| *x).collect();
assert_eq!(values, vec![11.0, 12.0, 13.0]);
}
#[test]
fn vector_span_iter_mut_i4x2() {
let mut v = Vector::<i4x2>::try_zeros(4).unwrap();
{
let mut span = v.span();
for (i, mut value) in span.iter_mut().enumerate() {
*value = (i + 1) as i8;
}
}
assert_eq!(v.try_get(0_usize).unwrap(), 1);
assert_eq!(v.try_get(1_usize).unwrap(), 2);
assert_eq!(v.try_get(2_usize).unwrap(), 3);
assert_eq!(v.try_get(3_usize).unwrap(), 4);
}
#[test]
fn vector_span_iter_mut_u1x8() {
let mut v = Vector::<u1x8>::try_zeros(8).unwrap();
for (i, mut value) in v.iter_mut().enumerate() {
if i % 2 == 0 {
*value = 1;
}
}
assert_eq!(v.try_get(0_usize).unwrap(), 1);
assert_eq!(v.try_get(1_usize).unwrap(), 0);
assert_eq!(v.try_get(2_usize).unwrap(), 1);
assert_eq!(v.try_get(3_usize).unwrap(), 0);
}
#[test]
fn vector_span_iter_double_ended() {
let mut v = Vector::<f32>::try_from_scalars(&[1.0, 2.0, 3.0]).unwrap();
let mut span = v.span();
let mut it = span.iter_mut();
let mut first = it.next().unwrap();
*first = 10.0;
drop(first);
let mut last = it.next_back().unwrap();
*last = 30.0;
drop(last);
drop(it);
drop(span);
assert_eq!(v.try_get(0_usize).unwrap(), 10.0);
assert_eq!(v.try_get(1_usize).unwrap(), 2.0);
assert_eq!(v.try_get(2_usize).unwrap(), 30.0);
}
#[test]
fn vector_iterator_alias_compat() {
let v = Vector::<f32>::try_from_scalars(&[1.0]).unwrap();
let _it: VectorIterator<'_, f32> = v.iter();
}
}