use bytemuck::{Pod, Zeroable};
use rkyv::{Archive, Deserialize, Serialize};
use std::sync::OnceLock;
#[repr(C)]
#[derive(Debug, Clone, Copy, Default, Pod, Zeroable, Archive, Serialize, Deserialize)]
pub struct BinEntry {
pub sum_gradients: f32,
pub sum_hessians: f32,
pub count: u32,
}
impl BinEntry {
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn accumulate(&mut self, gradient: f32, hessian: f32) {
self.sum_gradients += gradient;
self.sum_hessians += hessian;
self.count += 1;
}
#[inline]
pub fn accumulate_with_count(&mut self, gradient: f32, hessian: f32, count: u32) {
self.sum_gradients += gradient;
self.sum_hessians += hessian;
self.count += count;
}
#[inline]
pub fn merge(&mut self, other: &BinEntry) {
self.sum_gradients += other.sum_gradients;
self.sum_hessians += other.sum_hessians;
self.count += other.count;
}
#[inline]
pub fn subtract(&mut self, other: &BinEntry) {
self.sum_gradients -= other.sum_gradients;
self.sum_hessians -= other.sum_hessians;
self.count -= other.count;
}
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Archive,
Serialize,
Deserialize,
serde::Serialize,
serde::Deserialize,
)]
pub enum FeatureType {
Numeric,
Categorical,
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, serde::Serialize, serde::Deserialize)]
pub struct FeatureInfo {
pub name: String,
pub feature_type: FeatureType,
pub num_bins: u8,
pub bin_boundaries: Vec<f64>,
}
pub const SPARSITY_THRESHOLD: f32 = 0.9;
pub const DEFAULT_BIN: u8 = 0;
#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
pub struct SparseColumn {
pub indices: Vec<u32>,
pub values: Vec<u8>,
pub num_rows: usize,
}
impl SparseColumn {
pub fn from_dense(dense: &[u8], default_bin: u8) -> Self {
let mut indices = Vec::new();
let mut values = Vec::new();
for (i, &bin) in dense.iter().enumerate() {
if bin != default_bin {
indices.push(i as u32);
values.push(bin);
}
}
Self {
indices,
values,
num_rows: dense.len(),
}
}
#[inline]
pub fn nnz(&self) -> usize {
self.indices.len()
}
#[inline]
pub fn sparsity(&self) -> f32 {
if self.num_rows == 0 {
return 1.0;
}
1.0 - (self.nnz() as f32 / self.num_rows as f32)
}
#[inline]
pub fn is_sparse(&self) -> bool {
self.sparsity() >= SPARSITY_THRESHOLD
}
#[inline]
pub fn iter(&self) -> impl Iterator<Item = (usize, u8)> + '_ {
self.indices
.iter()
.zip(self.values.iter())
.map(|(&idx, &val)| (idx as usize, val))
}
}
#[derive(Archive, Serialize, Deserialize)]
pub struct BinnedDataset {
num_rows: usize,
features: Vec<u8>,
targets: Vec<f32>,
feature_info: Vec<FeatureInfo>,
sparse_columns: Vec<Option<SparseColumn>>,
era_indices: Option<Vec<u16>>,
num_eras: usize,
#[rkyv(with = rkyv::with::Skip)]
row_major_cache: OnceLock<Vec<u8>>,
#[rkyv(with = rkyv::with::Skip)]
row_major_4bit_cache: OnceLock<Vec<u8>>,
}
impl std::fmt::Debug for BinnedDataset {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BinnedDataset")
.field("num_rows", &self.num_rows)
.field("num_features", &self.num_features())
.field("features_len", &self.features.len())
.field("sparse_features", &self.num_sparse_features())
.field("max_bins", &self.max_bins())
.field("supports_4bit", &self.supports_4bit())
.field("has_eras", &self.era_indices.is_some())
.field("num_eras", &self.num_eras)
.field("row_major_cached", &self.row_major_cache.get().is_some())
.field(
"row_major_4bit_cached",
&self.row_major_4bit_cache.get().is_some(),
)
.finish()
}
}
impl Clone for BinnedDataset {
fn clone(&self) -> Self {
Self {
num_rows: self.num_rows,
features: self.features.clone(),
targets: self.targets.clone(),
feature_info: self.feature_info.clone(),
sparse_columns: self.sparse_columns.clone(),
era_indices: self.era_indices.clone(),
num_eras: self.num_eras,
row_major_cache: OnceLock::new(),
row_major_4bit_cache: OnceLock::new(),
}
}
}
impl BinnedDataset {
pub fn new(
num_rows: usize,
features: Vec<u8>,
targets: Vec<f32>,
feature_info: Vec<FeatureInfo>,
) -> Self {
debug_assert_eq!(features.len(), num_rows * feature_info.len());
debug_assert_eq!(targets.len(), num_rows);
let num_features = feature_info.len();
let sparse_columns: Vec<Option<SparseColumn>> = (0..num_features)
.map(|f| {
let start = f * num_rows;
let column = &features[start..start + num_rows];
let sparse = SparseColumn::from_dense(column, DEFAULT_BIN);
if sparse.is_sparse() {
Some(sparse)
} else {
None
}
})
.collect();
Self {
num_rows,
features,
targets,
feature_info,
sparse_columns,
era_indices: None,
num_eras: 0,
row_major_cache: OnceLock::new(),
row_major_4bit_cache: OnceLock::new(),
}
}
pub fn new_with_eras(
num_rows: usize,
features: Vec<u8>,
targets: Vec<f32>,
feature_info: Vec<FeatureInfo>,
era_indices: Vec<u16>,
) -> Self {
debug_assert_eq!(features.len(), num_rows * feature_info.len());
debug_assert_eq!(targets.len(), num_rows);
debug_assert_eq!(era_indices.len(), num_rows);
let num_features = feature_info.len();
let sparse_columns: Vec<Option<SparseColumn>> = (0..num_features)
.map(|f| {
let start = f * num_rows;
let column = &features[start..start + num_rows];
let sparse = SparseColumn::from_dense(column, DEFAULT_BIN);
if sparse.is_sparse() {
Some(sparse)
} else {
None
}
})
.collect();
let num_eras = era_indices
.iter()
.copied()
.max()
.map(|m| m as usize + 1)
.unwrap_or(0);
Self {
num_rows,
features,
targets,
feature_info,
sparse_columns,
era_indices: Some(era_indices),
num_eras,
row_major_cache: OnceLock::new(),
row_major_4bit_cache: OnceLock::new(),
}
}
pub fn set_era_indices(&mut self, era_indices: Vec<u16>) {
debug_assert_eq!(era_indices.len(), self.num_rows);
self.num_eras = era_indices
.iter()
.copied()
.max()
.map(|m| m as usize + 1)
.unwrap_or(0);
self.era_indices = Some(era_indices);
}
#[inline]
pub fn has_eras(&self) -> bool {
self.era_indices.is_some()
}
#[inline]
pub fn num_eras(&self) -> usize {
self.num_eras
}
#[inline]
pub fn era(&self, row_idx: usize) -> u16 {
self.era_indices.as_ref().expect("No era indices set")[row_idx]
}
#[inline]
pub fn era_indices(&self) -> Option<&[u16]> {
self.era_indices.as_deref()
}
#[inline]
pub fn num_rows(&self) -> usize {
self.num_rows
}
#[inline]
pub fn num_features(&self) -> usize {
self.feature_info.len()
}
#[inline]
pub fn feature_info(&self, feature_idx: usize) -> &FeatureInfo {
&self.feature_info[feature_idx]
}
#[inline]
pub fn all_feature_info(&self) -> &[FeatureInfo] {
&self.feature_info
}
#[inline]
pub fn get_bin(&self, row_idx: usize, feature_idx: usize) -> u8 {
self.features[feature_idx * self.num_rows + row_idx]
}
#[inline]
pub fn feature_column(&self, feature_idx: usize) -> &[u8] {
let start = feature_idx * self.num_rows;
&self.features[start..start + self.num_rows]
}
#[inline]
pub fn is_sparse(&self, feature_idx: usize) -> bool {
self.sparse_columns
.get(feature_idx)
.map(|s| s.is_some())
.unwrap_or(false)
}
#[inline]
pub fn sparse_column(&self, feature_idx: usize) -> Option<&SparseColumn> {
self.sparse_columns
.get(feature_idx)
.and_then(|s| s.as_ref())
}
pub fn num_sparse_features(&self) -> usize {
self.sparse_columns.iter().filter(|s| s.is_some()).count()
}
#[inline]
pub fn target(&self, row_idx: usize) -> f32 {
self.targets[row_idx]
}
#[inline]
pub fn targets(&self) -> &[f32] {
&self.targets
}
#[inline]
pub fn targets_mut(&mut self) -> &mut [f32] {
&mut self.targets
}
pub fn with_targets(&self, new_targets: Vec<f32>) -> Self {
assert_eq!(
new_targets.len(),
self.num_rows,
"new_targets length ({}) must match num_rows ({})",
new_targets.len(),
self.num_rows
);
Self {
num_rows: self.num_rows,
features: self.features.clone(),
targets: new_targets, feature_info: self.feature_info.clone(),
sparse_columns: self.sparse_columns.clone(),
era_indices: self.era_indices.clone(),
num_eras: self.num_eras,
row_major_cache: OnceLock::new(),
row_major_4bit_cache: OnceLock::new(),
}
}
pub fn bin_value(&self, feature_idx: usize, value: f64) -> u8 {
let info = &self.feature_info[feature_idx];
if info.bin_boundaries.is_empty() {
return 0;
}
match info
.bin_boundaries
.binary_search_by(|b| b.partial_cmp(&value).unwrap_or(std::cmp::Ordering::Less))
{
Ok(idx) => (idx + 1).min(info.num_bins as usize - 1) as u8,
Err(idx) => idx.min(info.num_bins as usize - 1) as u8,
}
}
#[inline]
pub fn get_split_value(&self, feature_idx: usize, bin_threshold: u8) -> f64 {
let info = &self.feature_info[feature_idx];
if info.bin_boundaries.is_empty() {
return 0.0;
}
let idx = bin_threshold as usize;
if idx < info.bin_boundaries.len() {
info.bin_boundaries[idx]
} else {
info.bin_boundaries.last().copied().unwrap_or(f64::MAX)
}
}
pub fn extract_raw_features_from_bins(&self) -> Vec<f32> {
let num_rows = self.num_rows();
let num_features = self.num_features();
let mut raw = vec![0.0f32; num_rows * num_features];
for row in 0..num_rows {
for feat in 0..num_features {
let bin = self.get_bin(row, feat);
raw[row * num_features + feat] = self.get_split_value(feat, bin) as f32;
}
}
raw
}
pub fn as_row_major(&self) -> &[u8] {
self.row_major_cache.get_or_init(|| {
let num_rows = self.num_rows;
let num_features = self.num_features();
let mut row_major = vec![0u8; num_rows * num_features];
for row in 0..num_rows {
for feature in 0..num_features {
row_major[row * num_features + feature] =
self.features[feature * num_rows + row];
}
}
row_major
})
}
#[inline]
pub fn max_bins(&self) -> u8 {
self.feature_info
.iter()
.map(|f| f.num_bins)
.max()
.unwrap_or(0)
}
#[inline]
pub fn supports_4bit(&self) -> bool {
self.max_bins() <= 16
}
pub fn as_row_major_4bit(&self) -> &[u8] {
self.row_major_4bit_cache.get_or_init(|| {
assert!(
self.supports_4bit(),
"4-bit packing requires all features to have ≤16 bins, max is {}",
self.max_bins()
);
let num_rows = self.num_rows;
let num_features = self.num_features();
let bytes_per_row = num_features.div_ceil(2);
let mut packed = vec![0u8; num_rows * bytes_per_row];
for row in 0..num_rows {
let row_offset = row * bytes_per_row;
for pair in 0..bytes_per_row {
let f0 = pair * 2;
let f1 = f0 + 1;
let bin0 = self.features[f0 * num_rows + row] & 0x0F;
let bin1 = if f1 < num_features {
self.features[f1 * num_rows + row] & 0x0F
} else {
0 };
packed[row_offset + pair] = bin0 | (bin1 << 4);
}
}
packed
})
}
#[inline]
pub fn bytes_per_row_4bit(&self) -> usize {
self.num_features().div_ceil(2)
}
pub fn subset_by_indices(&self, indices: &[usize]) -> Self {
let new_num_rows = indices.len();
let num_features = self.num_features();
for &idx in indices {
assert!(
idx < self.num_rows,
"Index {} out of bounds for dataset with {} rows",
idx,
self.num_rows
);
}
let mut new_features = Vec::with_capacity(new_num_rows * num_features);
for f in 0..num_features {
let col_start = f * self.num_rows;
for &idx in indices {
new_features.push(self.features[col_start + idx]);
}
}
let new_targets: Vec<f32> = indices.iter().map(|&idx| self.targets[idx]).collect();
let new_era_indices = self
.era_indices
.as_ref()
.map(|eras| indices.iter().map(|&idx| eras[idx]).collect::<Vec<u16>>());
let sparse_columns: Vec<Option<SparseColumn>> = (0..num_features)
.map(|f| {
let start = f * new_num_rows;
let column = &new_features[start..start + new_num_rows];
let sparse = SparseColumn::from_dense(column, DEFAULT_BIN);
if sparse.is_sparse() {
Some(sparse)
} else {
None
}
})
.collect();
let new_num_eras = new_era_indices
.as_ref()
.and_then(|eras| eras.iter().copied().max())
.map(|m| m as usize + 1)
.unwrap_or(0);
Self {
num_rows: new_num_rows,
features: new_features,
targets: new_targets,
feature_info: self.feature_info.clone(),
sparse_columns,
era_indices: new_era_indices,
num_eras: new_num_eras,
row_major_cache: OnceLock::new(),
row_major_4bit_cache: OnceLock::new(),
}
}
}
impl crate::backend::BinStorage for BinnedDataset {
fn get_bin(&self, row: usize, feature: usize) -> u8 {
self.features[feature * self.num_rows + row]
}
fn num_rows(&self) -> usize {
self.num_rows
}
fn num_features(&self) -> usize {
self.feature_info.len()
}
fn feature_column(&self, feature: usize) -> Option<&[u8]> {
let start = feature * self.num_rows;
Some(&self.features[start..start + self.num_rows])
}
fn sparse_column(&self, feature: usize) -> Option<&SparseColumn> {
self.sparse_columns.get(feature).and_then(|s| s.as_ref())
}
fn as_row_major(&self) -> Option<&[u8]> {
Some(BinnedDataset::as_row_major(self))
}
fn max_bins(&self) -> u8 {
BinnedDataset::max_bins(self)
}
fn supports_4bit(&self) -> bool {
BinnedDataset::supports_4bit(self)
}
fn as_row_major_4bit(&self) -> Option<&[u8]> {
if self.supports_4bit() {
Some(BinnedDataset::as_row_major_4bit(self))
} else {
None
}
}
fn bytes_per_row_4bit(&self) -> usize {
BinnedDataset::bytes_per_row_4bit(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bin_entry_accumulate() {
let mut entry = BinEntry::new();
entry.accumulate(1.0, 2.0);
entry.accumulate(0.5, 1.0);
assert_eq!(entry.sum_gradients, 1.5);
assert_eq!(entry.sum_hessians, 3.0);
assert_eq!(entry.count, 2);
}
#[test]
fn test_bin_entry_subtract() {
let mut parent = BinEntry {
sum_gradients: 10.0,
sum_hessians: 20.0,
count: 100,
};
let child = BinEntry {
sum_gradients: 3.0,
sum_hessians: 6.0,
count: 30,
};
parent.subtract(&child);
assert_eq!(parent.sum_gradients, 7.0);
assert_eq!(parent.sum_hessians, 14.0);
assert_eq!(parent.count, 70);
}
#[test]
fn test_binned_dataset_access() {
let num_rows = 4;
let features = vec![0u8, 1, 2, 3, 10, 11, 12, 13];
let targets = vec![1.0f32, 2.0, 3.0, 4.0];
let feature_info = vec![
FeatureInfo {
name: "f0".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 4,
bin_boundaries: vec![0.5, 1.5, 2.5],
},
FeatureInfo {
name: "f1".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 4,
bin_boundaries: vec![10.5, 11.5, 12.5],
},
];
let dataset = BinnedDataset::new(num_rows, features, targets, feature_info);
assert_eq!(dataset.num_rows(), 4);
assert_eq!(dataset.num_features(), 2);
assert_eq!(dataset.get_bin(0, 0), 0);
assert_eq!(dataset.get_bin(2, 0), 2);
assert_eq!(dataset.get_bin(1, 1), 11);
assert_eq!(dataset.feature_column(0), &[0, 1, 2, 3]);
assert_eq!(dataset.feature_column(1), &[10, 11, 12, 13]);
assert_eq!(dataset.target(0), 1.0);
assert_eq!(dataset.targets(), &[1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_row_major_conversion() {
let num_rows = 4;
let num_features = 2;
let features = vec![0u8, 1, 2, 3, 10, 11, 12, 13];
let targets = vec![1.0f32, 2.0, 3.0, 4.0];
let feature_info = vec![
FeatureInfo {
name: "f0".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 4,
bin_boundaries: vec![0.5, 1.5, 2.5],
},
FeatureInfo {
name: "f1".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 4,
bin_boundaries: vec![10.5, 11.5, 12.5],
},
];
let dataset = BinnedDataset::new(num_rows, features, targets, feature_info);
let row_major = dataset.as_row_major();
assert_eq!(row_major.len(), num_rows * num_features);
assert_eq!(row_major[0 * num_features + 0], 0); assert_eq!(row_major[0 * num_features + 1], 10); assert_eq!(row_major[1 * num_features + 0], 1); assert_eq!(row_major[1 * num_features + 1], 11); assert_eq!(row_major[3 * num_features + 0], 3); assert_eq!(row_major[3 * num_features + 1], 13);
let row_major2 = dataset.as_row_major();
assert_eq!(row_major.as_ptr(), row_major2.as_ptr());
}
#[test]
fn test_row_major_via_bin_storage_trait() {
use crate::backend::BinStorage;
let features = vec![0u8, 1, 2, 3, 10, 11, 12, 13];
let targets = vec![1.0f32, 2.0, 3.0, 4.0];
let feature_info = vec![
FeatureInfo {
name: "f0".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 4,
bin_boundaries: vec![],
},
FeatureInfo {
name: "f1".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 4,
bin_boundaries: vec![],
},
];
let dataset = BinnedDataset::new(4, features, targets, feature_info);
let storage: &dyn BinStorage = &dataset;
let row_major = storage.as_row_major();
assert!(row_major.is_some());
let data = row_major.unwrap();
assert_eq!(data.len(), 8);
assert_eq!(data[0], 0);
assert_eq!(data[1], 10);
}
#[test]
fn test_4bit_packing() {
let num_rows = 4;
let _num_features = 3;
let features = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
let targets = vec![1.0f32, 2.0, 3.0, 4.0];
let feature_info = vec![
FeatureInfo {
name: "f0".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 16, bin_boundaries: vec![],
},
FeatureInfo {
name: "f1".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 16,
bin_boundaries: vec![],
},
FeatureInfo {
name: "f2".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 16,
bin_boundaries: vec![],
},
];
let dataset = BinnedDataset::new(num_rows, features, targets, feature_info);
assert!(dataset.supports_4bit());
assert_eq!(dataset.max_bins(), 16);
assert_eq!(dataset.bytes_per_row_4bit(), 2);
let packed = dataset.as_row_major_4bit();
assert_eq!(packed.len(), num_rows * 2);
assert_eq!(packed[0], 0x51); assert_eq!(packed[1], 0x09);
assert_eq!(packed[2], 0x62); assert_eq!(packed[3], 0x0A);
assert_eq!(packed[6], 0x84); assert_eq!(packed[7], 0x0C);
for row in 0..num_rows {
let row_offset = row * 2;
let bin0 = packed[row_offset] & 0x0F;
assert_eq!(bin0, (row + 1) as u8);
let bin1 = (packed[row_offset] >> 4) & 0x0F;
assert_eq!(bin1, (row + 5) as u8);
let bin2 = packed[row_offset + 1] & 0x0F;
assert_eq!(bin2, (row + 9) as u8);
}
}
#[test]
fn test_4bit_not_supported_for_large_bins() {
let features = vec![0u8, 1, 2, 3, 100, 101, 102, 103]; let targets = vec![1.0f32, 2.0, 3.0, 4.0];
let feature_info = vec![
FeatureInfo {
name: "f0".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 16,
bin_boundaries: vec![],
},
FeatureInfo {
name: "f1".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 128, bin_boundaries: vec![],
},
];
let dataset = BinnedDataset::new(4, features, targets, feature_info);
assert!(!dataset.supports_4bit());
assert_eq!(dataset.max_bins(), 128);
}
#[test]
fn test_4bit_via_bin_storage_trait() {
use crate::backend::BinStorage;
let features = vec![1u8, 2, 5, 6];
let targets = vec![1.0f32, 2.0];
let feature_info = vec![
FeatureInfo {
name: "f0".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 8,
bin_boundaries: vec![],
},
FeatureInfo {
name: "f1".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 8,
bin_boundaries: vec![],
},
];
let dataset = BinnedDataset::new(2, features, targets, feature_info);
let storage: &dyn BinStorage = &dataset;
assert!(storage.supports_4bit());
assert_eq!(storage.max_bins(), 8);
let packed = storage.as_row_major_4bit();
assert!(packed.is_some());
let data = packed.unwrap();
assert_eq!(data.len(), 2); assert_eq!(data[0], 0x51);
assert_eq!(data[1], 0x62);
}
#[test]
fn test_subset_by_indices() {
let num_rows = 5;
let features = vec![0u8, 1, 2, 3, 4, 10, 11, 12, 13, 14];
let targets = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let feature_info = vec![
FeatureInfo {
name: "f0".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 5,
bin_boundaries: vec![0.5, 1.5, 2.5, 3.5],
},
FeatureInfo {
name: "f1".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 5,
bin_boundaries: vec![10.5, 11.5, 12.5, 13.5],
},
];
let dataset = BinnedDataset::new(num_rows, features, targets, feature_info);
let subset = dataset.subset_by_indices(&[1, 3, 4]);
assert_eq!(subset.num_rows(), 3);
assert_eq!(subset.num_features(), 2);
assert_eq!(subset.get_bin(0, 0), 1); assert_eq!(subset.get_bin(1, 0), 3); assert_eq!(subset.get_bin(2, 0), 4);
assert_eq!(subset.get_bin(0, 1), 11); assert_eq!(subset.get_bin(1, 1), 13); assert_eq!(subset.get_bin(2, 1), 14);
assert_eq!(subset.targets(), &[2.0, 4.0, 5.0]);
assert_eq!(subset.feature_column(0), &[1, 3, 4]);
assert_eq!(subset.feature_column(1), &[11, 13, 14]);
}
#[test]
fn test_subset_by_indices_with_eras() {
let num_rows = 4;
let features = vec![0u8, 1, 2, 3, 10, 11, 12, 13];
let targets = vec![1.0f32, 2.0, 3.0, 4.0];
let feature_info = vec![
FeatureInfo {
name: "f0".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 4,
bin_boundaries: vec![],
},
FeatureInfo {
name: "f1".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 14,
bin_boundaries: vec![],
},
];
let era_indices = vec![0u16, 1, 0, 2];
let dataset =
BinnedDataset::new_with_eras(num_rows, features, targets, feature_info, era_indices);
let subset = dataset.subset_by_indices(&[0, 2]);
assert_eq!(subset.num_rows(), 2);
assert!(subset.has_eras());
assert_eq!(subset.era_indices().unwrap(), &[0, 0]);
assert_eq!(subset.num_eras(), 1);
}
}