use rayon::prelude::*;
use rkyv::{Archive, Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Archive, Serialize, Deserialize, Default)]
pub enum StorageMode {
#[default]
U8,
Packed4Bit,
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
pub struct PackedColumn {
data: Vec<u8>,
num_rows: usize,
}
impl PackedColumn {
pub fn from_bins(bins: &[u8]) -> Self {
let num_rows = bins.len();
let packed_len = num_rows.div_ceil(2);
let mut data = Vec::with_capacity(packed_len);
for chunk in bins.chunks(2) {
debug_assert!(
chunk[0] <= 15,
"Bin value {} exceeds 4-bit max (15)",
chunk[0]
);
let high = chunk[0] << 4;
let low = if chunk.len() > 1 {
debug_assert!(
chunk[1] <= 15,
"Bin value {} exceeds 4-bit max (15)",
chunk[1]
);
chunk[1]
} else {
0 };
data.push(high | low);
}
Self { data, num_rows }
}
#[inline]
pub fn get(&self, row_idx: usize) -> u8 {
debug_assert!(row_idx < self.num_rows);
let byte_idx = row_idx / 2;
let byte = self.data[byte_idx];
if row_idx.is_multiple_of(2) {
byte >> 4
} else {
byte & 0x0F
}
}
#[inline]
pub fn num_rows(&self) -> usize {
self.num_rows
}
#[inline]
pub fn memory_bytes(&self) -> usize {
self.data.len()
}
pub fn unpack(&self) -> Vec<u8> {
let mut result = vec![0u8; self.num_rows];
self.unpack_to_buffer(&mut result);
result
}
pub fn unpack_to_buffer(&self, buffer: &mut [u8]) {
debug_assert!(buffer.len() >= self.num_rows);
let full_bytes = self.num_rows / 2;
if full_bytes > 0 {
crate::kernel::unpack_4bit(&self.data[..full_bytes], &mut buffer[..full_bytes * 2]);
}
if self.num_rows % 2 == 1 {
let last_byte = self.data[full_bytes];
buffer[self.num_rows - 1] = last_byte >> 4;
}
}
pub fn unpack_range(&self, start_row: usize, count: usize, buffer: &mut [u8]) {
debug_assert!(start_row + count <= self.num_rows);
debug_assert!(buffer.len() >= count);
if start_row.is_multiple_of(2) && count.is_multiple_of(2) {
let start_byte = start_row / 2;
let byte_count = count / 2;
crate::kernel::unpack_4bit(
&self.data[start_byte..start_byte + byte_count],
&mut buffer[..count],
);
return;
}
for (i, buf) in buffer[..count].iter_mut().enumerate() {
*buf = self.get(start_row + i);
}
}
#[inline]
pub fn packed_data(&self) -> &[u8] {
&self.data
}
#[inline]
pub fn iter(&self) -> PackedColumnIter<'_> {
PackedColumnIter {
column: self,
idx: 0,
}
}
}
pub struct PackedColumnIter<'a> {
column: &'a PackedColumn,
idx: usize,
}
impl Iterator for PackedColumnIter<'_> {
type Item = u8;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.idx < self.column.num_rows {
let val = self.column.get(self.idx);
self.idx += 1;
Some(val)
} else {
None
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.column.num_rows - self.idx;
(remaining, Some(remaining))
}
}
impl ExactSizeIterator for PackedColumnIter<'_> {}
pub fn can_pack(bins: &[u8]) -> bool {
bins.iter().all(|&b| b <= 15)
}
pub fn optimal_storage(bins: &[u8]) -> StorageMode {
if can_pack(bins) {
StorageMode::Packed4Bit
} else {
StorageMode::U8
}
}
use super::{BinnedDataset, FeatureInfo};
#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
pub enum FeatureStorage {
U8(Vec<u8>),
Packed(PackedColumn),
}
impl FeatureStorage {
#[inline]
pub fn get(&self, row_idx: usize) -> u8 {
match self {
FeatureStorage::U8(data) => data[row_idx],
FeatureStorage::Packed(packed) => packed.get(row_idx),
}
}
#[inline]
pub fn num_rows(&self) -> usize {
match self {
FeatureStorage::U8(data) => data.len(),
FeatureStorage::Packed(packed) => packed.num_rows(),
}
}
#[inline]
pub fn memory_bytes(&self) -> usize {
match self {
FeatureStorage::U8(data) => data.len(),
FeatureStorage::Packed(packed) => packed.memory_bytes(),
}
}
#[inline]
pub fn mode(&self) -> StorageMode {
match self {
FeatureStorage::U8(_) => StorageMode::U8,
FeatureStorage::Packed(_) => StorageMode::Packed4Bit,
}
}
pub fn from_bins(bins: Vec<u8>) -> Self {
if can_pack(&bins) {
FeatureStorage::Packed(PackedColumn::from_bins(&bins))
} else {
FeatureStorage::U8(bins)
}
}
pub fn from_bins_with_mode(bins: Vec<u8>, mode: StorageMode) -> Self {
match mode {
StorageMode::U8 => FeatureStorage::U8(bins),
StorageMode::Packed4Bit => {
debug_assert!(can_pack(&bins), "Cannot pack bins: values exceed 15");
FeatureStorage::Packed(PackedColumn::from_bins(&bins))
}
}
}
pub fn to_u8(&self) -> Vec<u8> {
match self {
FeatureStorage::U8(data) => data.clone(),
FeatureStorage::Packed(packed) => packed.unpack(),
}
}
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
pub struct PackedDataset {
num_rows: usize,
feature_data: Vec<FeatureStorage>,
targets: Vec<f32>,
feature_info: Vec<FeatureInfo>,
}
impl PackedDataset {
pub fn from_binned(dataset: &BinnedDataset) -> Self {
let num_rows = dataset.num_rows();
let num_features = dataset.num_features();
let feature_data: Vec<FeatureStorage> = (0..num_features)
.map(|f| {
let column = dataset.feature_column(f).to_vec();
FeatureStorage::from_bins(column)
})
.collect();
Self {
num_rows,
feature_data,
targets: dataset.targets().to_vec(),
feature_info: dataset.all_feature_info().to_vec(),
}
}
pub fn from_binned_with_modes(dataset: &BinnedDataset, modes: &[StorageMode]) -> Self {
let num_rows = dataset.num_rows();
let num_features = dataset.num_features();
debug_assert_eq!(modes.len(), num_features);
let feature_data: Vec<FeatureStorage> = (0..num_features)
.map(|f| {
let column = dataset.feature_column(f).to_vec();
FeatureStorage::from_bins_with_mode(column, modes[f])
})
.collect();
Self {
num_rows,
feature_data,
targets: dataset.targets().to_vec(),
feature_info: dataset.all_feature_info().to_vec(),
}
}
#[inline]
pub fn num_rows(&self) -> usize {
self.num_rows
}
#[inline]
pub fn num_features(&self) -> usize {
self.feature_data.len()
}
#[inline]
pub fn get_bin(&self, row_idx: usize, feature_idx: usize) -> u8 {
self.feature_data[feature_idx].get(row_idx)
}
#[inline]
pub fn feature_storage(&self, feature_idx: usize) -> &FeatureStorage {
&self.feature_data[feature_idx]
}
#[inline]
pub fn target(&self, row_idx: usize) -> f32 {
self.targets[row_idx]
}
#[inline]
pub fn targets(&self) -> &[f32] {
&self.targets
}
#[inline]
pub fn feature_info(&self, feature_idx: usize) -> &FeatureInfo {
&self.feature_info[feature_idx]
}
#[inline]
pub fn all_feature_info(&self) -> &[FeatureInfo] {
&self.feature_info
}
pub fn feature_memory_bytes(&self) -> usize {
self.feature_data.iter().map(|f| f.memory_bytes()).sum()
}
pub fn memory_savings(&self) -> f64 {
let packed_size = self.feature_memory_bytes();
let u8_size = self.num_rows * self.num_features();
if u8_size == 0 {
0.0
} else {
1.0 - (packed_size as f64 / u8_size as f64)
}
}
pub fn storage_modes(&self) -> Vec<StorageMode> {
self.feature_data.iter().map(|f| f.mode()).collect()
}
pub fn to_binned(&self) -> BinnedDataset {
let num_features = self.num_features();
let total_size = self.num_rows * num_features;
let mut features = vec![0u8; total_size];
if num_features >= 8 {
let chunk_size = self.num_rows;
features
.par_chunks_mut(chunk_size)
.zip(self.feature_data.par_iter())
.for_each(|(dest, storage)| match storage {
FeatureStorage::U8(data) => {
dest.copy_from_slice(data);
}
FeatureStorage::Packed(packed) => {
packed.unpack_to_buffer(dest);
}
});
} else {
for (f, storage) in self.feature_data.iter().enumerate() {
let start = f * self.num_rows;
let dest = &mut features[start..start + self.num_rows];
match storage {
FeatureStorage::U8(data) => {
dest.copy_from_slice(data);
}
FeatureStorage::Packed(packed) => {
packed.unpack_to_buffer(dest);
}
}
}
}
BinnedDataset::new(
self.num_rows,
features,
self.targets.clone(),
self.feature_info.clone(),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_packed_column_basic() {
let bins = vec![0, 1, 2, 3, 4, 5, 6, 7];
let packed = PackedColumn::from_bins(&bins);
assert_eq!(packed.num_rows(), 8);
assert_eq!(packed.memory_bytes(), 4);
for (i, &expected) in bins.iter().enumerate() {
assert_eq!(packed.get(i), expected, "Mismatch at index {}", i);
}
}
#[test]
fn test_packed_column_odd_rows() {
let bins = vec![15, 8, 3, 11, 7];
let packed = PackedColumn::from_bins(&bins);
assert_eq!(packed.num_rows(), 5);
assert_eq!(packed.memory_bytes(), 3);
for (i, &expected) in bins.iter().enumerate() {
assert_eq!(packed.get(i), expected, "Mismatch at index {}", i);
}
}
#[test]
fn test_packed_column_unpack() {
let bins = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let packed = PackedColumn::from_bins(&bins);
let unpacked = packed.unpack();
assert_eq!(unpacked, bins);
}
#[test]
fn test_packed_column_iterator() {
let bins = vec![0, 15, 7, 8, 3];
let packed = PackedColumn::from_bins(&bins);
let collected: Vec<u8> = packed.iter().collect();
assert_eq!(collected, bins);
}
#[test]
fn test_can_pack() {
assert!(can_pack(&[0, 1, 2, 3, 15]));
assert!(can_pack(&[0, 15, 8, 7]));
assert!(!can_pack(&[0, 1, 16])); assert!(!can_pack(&[255, 0]));
}
#[test]
fn test_optimal_storage() {
assert_eq!(optimal_storage(&[0, 1, 2, 3]), StorageMode::Packed4Bit);
assert_eq!(optimal_storage(&[0, 1, 2, 255]), StorageMode::U8);
}
#[test]
fn test_memory_savings() {
let bins: Vec<u8> = (0..1000).map(|i| (i % 16) as u8).collect();
let unpacked_size = bins.len(); let packed = PackedColumn::from_bins(&bins);
let packed_size = packed.memory_bytes();
assert_eq!(packed_size, 500);
assert_eq!(unpacked_size / packed_size, 2); }
#[test]
fn test_feature_storage_auto() {
let bins_low = vec![0, 1, 2, 3, 4, 5];
let storage_low = FeatureStorage::from_bins(bins_low.clone());
assert_eq!(storage_low.mode(), StorageMode::Packed4Bit);
for (i, &expected) in bins_low.iter().enumerate() {
assert_eq!(storage_low.get(i), expected);
}
let bins_high = vec![0, 1, 100, 200, 255];
let storage_high = FeatureStorage::from_bins(bins_high.clone());
assert_eq!(storage_high.mode(), StorageMode::U8);
for (i, &expected) in bins_high.iter().enumerate() {
assert_eq!(storage_high.get(i), expected);
}
}
#[test]
fn test_packed_dataset() {
use crate::dataset::FeatureType;
let num_rows = 100;
let mut features = Vec::with_capacity(num_rows * 2);
for r in 0..num_rows {
features.push((r % 16) as u8); }
for r in 0..num_rows {
features.push((r % 256) as u8); }
let targets: Vec<f32> = (0..num_rows).map(|i| i as f32).collect();
let feature_info = vec![
FeatureInfo {
name: "f0".to_string(),
feature_type: FeatureType::Categorical,
num_bins: 16,
bin_boundaries: vec![],
},
FeatureInfo {
name: "f1".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 255,
bin_boundaries: vec![],
},
];
let binned = BinnedDataset::new(num_rows, features, targets, feature_info);
let packed = PackedDataset::from_binned(&binned);
let modes = packed.storage_modes();
assert_eq!(modes[0], StorageMode::Packed4Bit); assert_eq!(modes[1], StorageMode::U8);
for r in 0..num_rows {
assert_eq!(packed.get_bin(r, 0), binned.get_bin(r, 0));
assert_eq!(packed.get_bin(r, 1), binned.get_bin(r, 1));
}
let savings = packed.memory_savings();
assert!(savings > 0.2 && savings < 0.3, "Savings: {}", savings);
let unpacked = packed.to_binned();
for r in 0..num_rows {
for f in 0..2 {
assert_eq!(unpacked.get_bin(r, f), binned.get_bin(r, f));
}
}
}
}