#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LayoutOrder {
RowMajor, ColumnMajor, }
#[derive(Debug, Clone, PartialEq)]
pub struct Tensor<T> {
pub shape: Vec<usize>,
pub data: Vec<T>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Vector<T> {
pub data: Vec<T>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct StridedTensor<T> {
pub shape: Vec<usize>,
pub stride: Vec<usize>,
pub data: Vec<T>,
}
impl<T> Tensor<T> {
pub fn new(shape: Vec<usize>, data: Vec<T>) -> Self {
let expected_len: usize = shape.iter().product();
assert_eq!(
data.len(),
expected_len,
"Data length {} doesn't match shape {:?} (expected {})",
data.len(),
shape,
expected_len
);
Tensor { shape, data }
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn len(&self) -> usize {
self.shape.iter().product()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<T> StridedTensor<T> {
pub fn new(shape: Vec<usize>, stride: Vec<usize>, data: Vec<T>) -> Self {
assert_eq!(
shape.len(),
stride.len(),
"Shape and stride must have same number of dimensions"
);
StridedTensor {
shape,
stride,
data,
}
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn len(&self) -> usize {
self.shape.iter().product()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn is_contiguous(&self) -> bool {
let ndim = self.ndim();
let mut expected_stride = 1;
for i in (0..ndim).rev() {
if self.stride[i] != expected_stride {
return false;
}
expected_stride *= self.shape[i];
}
true
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct BitPackedTensor {
pub bit_depth: u8,
pub shape: Vec<usize>,
pub data: Vec<u8>,
}
pub trait PackableUnsigned: Copy {
fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor;
}
macro_rules! impl_bitpack {
($fn_name:ident, $t:ty, $work_t:ty) => {
impl BitPackedTensor {
#[doc = concat!("Pack ", stringify!($t), " samples into bitpacked tensor\n\n")]
#[doc = "# Arguments\n"]
#[doc = "* `bit_depth` - Bits per sample (1-128 supported, 0 reserved for future 256-bit)\n"]
#[doc = "* `shape` - Tensor dimensions\n"]
#[doc = "* `samples` - Sample values (only low `bit_depth` bits are packed, high bits ignored)\n\n"]
#[doc = "# Panics\n"]
#[doc = "* If bit_depth exceeds the type's bit width\n"]
#[doc = "* If bit_depth > 128 (256-bit not yet supported)\n"]
#[doc = "* If samples.len() doesn't match shape product\n"]
pub fn $fn_name(bit_depth: u8, shape: Vec<usize>, samples: &[$t]) -> Self {
let total_elements: usize = shape.iter().product();
assert_eq!(
samples.len(),
total_elements,
"Sample count {} doesn't match shape {:?} (expected {})",
samples.len(),
shape,
total_elements
);
let bits_per_sample = if bit_depth == 0 {
panic!("bit_depth=0 (256-bit) not yet supported - use 1-128");
} else {
bit_depth as usize
};
if bits_per_sample > 128 {
panic!("bit_depth > 128 not yet supported (waiting for native u256 support)");
}
if bits_per_sample > <$t>::BITS as usize {
panic!(
"Cannot pack {}-bit values into {}-bit type {}",
bits_per_sample,
<$t>::BITS,
std::any::type_name::<$t>()
);
}
let total_bits = total_elements * bits_per_sample;
let byte_count = (total_bits + 7) / 8;
let mut data = vec![0u8; byte_count];
let mut bit_offset = 0;
for &sample in samples {
let value = sample as $work_t;
for bit_idx in (0..bits_per_sample).rev() {
let bit = if (value >> bit_idx) & 1 == 1 { 1u8 } else { 0u8 };
let byte_idx = bit_offset / 8;
let bit_pos = 7 - (bit_offset % 8);
data[byte_idx] |= bit << bit_pos;
bit_offset += 1;
}
}
BitPackedTensor {
bit_depth,
shape,
data,
}
}
}
};
}
impl_bitpack!(pack_u8, u8, u64);
impl_bitpack!(pack_u16, u16, u64);
impl_bitpack!(pack_u32, u32, u64);
impl_bitpack!(pack_u64, u64, u64);
impl_bitpack!(pack_u128, u128, u128);
impl_bitpack!(pack_usize, usize, u64);
impl PackableUnsigned for u8 {
fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
BitPackedTensor::pack_u8(bit_depth, shape, samples)
}
}
impl PackableUnsigned for u16 {
fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
BitPackedTensor::pack_u16(bit_depth, shape, samples)
}
}
impl PackableUnsigned for u32 {
fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
BitPackedTensor::pack_u32(bit_depth, shape, samples)
}
}
impl PackableUnsigned for u64 {
fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
BitPackedTensor::pack_u64(bit_depth, shape, samples)
}
}
impl PackableUnsigned for u128 {
fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
BitPackedTensor::pack_u128(bit_depth, shape, samples)
}
}
impl PackableUnsigned for usize {
fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
BitPackedTensor::pack_usize(bit_depth, shape, samples)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum UnpackedSamples {
U8(Vec<u8>), U16(Vec<u16>), U32(Vec<u32>), U64(Vec<u64>), U128(Vec<u128>), }
impl UnpackedSamples {
pub fn into_u64(self) -> Vec<u64> {
match self {
UnpackedSamples::U8(v) => v.into_iter().map(|x| x as u64).collect(),
UnpackedSamples::U16(v) => v.into_iter().map(|x| x as u64).collect(),
UnpackedSamples::U32(v) => v.into_iter().map(|x| x as u64).collect(),
UnpackedSamples::U64(v) => v,
UnpackedSamples::U128(_) => {
panic!("Cannot convert >64 bit samples to u64 (would truncate)")
}
}
}
pub fn into_u128(self) -> Vec<u128> {
match self {
UnpackedSamples::U8(v) => v.into_iter().map(|x| x as u128).collect(),
UnpackedSamples::U16(v) => v.into_iter().map(|x| x as u128).collect(),
UnpackedSamples::U32(v) => v.into_iter().map(|x| x as u128).collect(),
UnpackedSamples::U64(v) => v.into_iter().map(|x| x as u128).collect(),
UnpackedSamples::U128(v) => v,
}
}
pub fn len(&self) -> usize {
match self {
UnpackedSamples::U8(v) => v.len(),
UnpackedSamples::U16(v) => v.len(),
UnpackedSamples::U32(v) => v.len(),
UnpackedSamples::U64(v) => v.len(),
UnpackedSamples::U128(v) => v.len(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl BitPackedTensor {
pub fn pack<T: PackableUnsigned>(bit_depth: u8, shape: Vec<usize>, samples: &[T]) -> Self {
T::pack_samples(bit_depth, shape, samples)
}
pub fn unpack(&self) -> UnpackedSamples {
let bits = self.bit_depth as usize;
match bits {
1..=8 => UnpackedSamples::U8(self.unpack_to_u8()),
9..=16 => UnpackedSamples::U16(self.unpack_to_u16()),
17..=32 => UnpackedSamples::U32(self.unpack_to_u32()),
33..=64 => UnpackedSamples::U64(self.unpack_to_u64()),
65..=128 => UnpackedSamples::U128(self.unpack_to_u128()),
_ => panic!("bit_depth {} not supported (max 128)", self.bit_depth),
}
}
pub fn unpack_u8(&self) -> Vec<u8> {
if self.bit_depth > 8 {
panic!(
"Cannot unpack {}-bit data into u8 (would truncate)",
self.bit_depth
);
}
self.unpack_to_u8()
}
pub fn unpack_u16(&self) -> Vec<u16> {
if self.bit_depth > 16 {
panic!(
"Cannot unpack {}-bit data into u16 (would truncate)",
self.bit_depth
);
}
self.unpack_to_u16()
}
pub fn unpack_u32(&self) -> Vec<u32> {
if self.bit_depth > 32 {
panic!(
"Cannot unpack {}-bit data into u32 (would truncate)",
self.bit_depth
);
}
self.unpack_to_u32()
}
pub fn unpack_u64(&self) -> Vec<u64> {
if self.bit_depth > 64 {
panic!(
"Cannot unpack {}-bit data into u64 (would truncate)",
self.bit_depth
);
}
self.unpack_to_u64()
}
pub fn unpack_u128(&self) -> Vec<u128> {
self.unpack_to_u128()
}
fn unpack_to_u8(&self) -> Vec<u8> {
let total_elements: usize = self.shape.iter().product();
let bits_per_sample = self.bit_depth as usize;
let mut samples = Vec::with_capacity(total_elements);
let mut bit_offset = 0;
for _ in 0..total_elements {
let mut sample = 0u8;
for _ in 0..bits_per_sample {
let byte_idx = bit_offset / 8;
let bit_pos = 7 - (bit_offset % 8);
let bit = (self.data[byte_idx] >> bit_pos) & 1;
sample = (sample << 1) | bit;
bit_offset += 1;
}
samples.push(sample);
}
samples
}
fn unpack_to_u16(&self) -> Vec<u16> {
let total_elements: usize = self.shape.iter().product();
let bits_per_sample = self.bit_depth as usize;
let mut samples = Vec::with_capacity(total_elements);
let mut bit_offset = 0;
for _ in 0..total_elements {
let mut sample = 0u16;
for _ in 0..bits_per_sample {
let byte_idx = bit_offset / 8;
let bit_pos = 7 - (bit_offset % 8);
let bit = (self.data[byte_idx] >> bit_pos) & 1;
sample = (sample << 1) | (bit as u16);
bit_offset += 1;
}
samples.push(sample);
}
samples
}
fn unpack_to_u32(&self) -> Vec<u32> {
let total_elements: usize = self.shape.iter().product();
let bits_per_sample = self.bit_depth as usize;
let mut samples = Vec::with_capacity(total_elements);
let mut bit_offset = 0;
for _ in 0..total_elements {
let mut sample = 0u32;
for _ in 0..bits_per_sample {
let byte_idx = bit_offset / 8;
let bit_pos = 7 - (bit_offset % 8);
let bit = (self.data[byte_idx] >> bit_pos) & 1;
sample = (sample << 1) | (bit as u32);
bit_offset += 1;
}
samples.push(sample);
}
samples
}
fn unpack_to_u64(&self) -> Vec<u64> {
let total_elements: usize = self.shape.iter().product();
let bits_per_sample = self.bit_depth as usize;
let mut samples = Vec::with_capacity(total_elements);
let mut bit_offset = 0;
for _ in 0..total_elements {
let mut sample = 0u64;
for _ in 0..bits_per_sample {
let byte_idx = bit_offset / 8;
let bit_pos = 7 - (bit_offset % 8);
let bit = (self.data[byte_idx] >> bit_pos) & 1;
sample = (sample << 1) | (bit as u64);
bit_offset += 1;
}
samples.push(sample);
}
samples
}
fn unpack_to_u128(&self) -> Vec<u128> {
let total_elements: usize = self.shape.iter().product();
let bits_per_sample = self.bit_depth as usize;
let mut samples = Vec::with_capacity(total_elements);
let mut bit_offset = 0;
for _ in 0..total_elements {
let mut sample = 0u128;
for _ in 0..bits_per_sample {
let byte_idx = bit_offset / 8;
let bit_pos = 7 - (bit_offset % 8);
let bit = (self.data[byte_idx] >> bit_pos) & 1;
sample = (sample << 1) | (bit as u128);
bit_offset += 1;
}
samples.push(sample);
}
samples
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn len(&self) -> usize {
self.shape.iter().product()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}