use super::{BinnedDataset, FeatureInfo};
use rkyv::{Archive, Deserialize, Serialize};
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Archive,
Serialize,
Deserialize,
serde::Serialize,
serde::Deserialize,
Default,
)]
pub enum OrderingStrategy {
#[default]
Original,
ByImportance,
ByAccessFrequency,
}
#[derive(Debug, Clone)]
pub struct AccessTracker {
access_counts: Vec<u64>,
}
impl AccessTracker {
pub fn new(num_features: usize) -> Self {
Self {
access_counts: vec![0; num_features],
}
}
#[inline]
pub fn record(&mut self, feature_idx: usize) {
if feature_idx < self.access_counts.len() {
self.access_counts[feature_idx] += 1;
}
}
pub fn counts(&self) -> &[u64] {
&self.access_counts
}
pub fn reset(&mut self) {
self.access_counts.fill(0);
}
pub fn optimal_order(&self) -> Vec<usize> {
let mut indices: Vec<usize> = (0..self.access_counts.len()).collect();
indices.sort_by(|&a, &b| self.access_counts[b].cmp(&self.access_counts[a]));
indices
}
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, serde::Serialize, serde::Deserialize)]
pub struct ColumnPermutation {
new_to_original: Vec<usize>,
original_to_new: Vec<usize>,
}
impl ColumnPermutation {
pub fn identity(num_features: usize) -> Self {
let indices: Vec<usize> = (0..num_features).collect();
Self {
new_to_original: indices.clone(),
original_to_new: indices,
}
}
pub fn from_importances(importances: &[f32]) -> Self {
let mut indices: Vec<usize> = (0..importances.len()).collect();
indices.sort_by(|&a, &b| {
importances[b]
.partial_cmp(&importances[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut original_to_new = vec![0; importances.len()];
for (new_idx, &orig_idx) in indices.iter().enumerate() {
original_to_new[orig_idx] = new_idx;
}
Self {
new_to_original: indices,
original_to_new,
}
}
pub fn from_access_tracker(tracker: &AccessTracker) -> Self {
let new_to_original = tracker.optimal_order();
let mut original_to_new = vec![0; new_to_original.len()];
for (new_idx, &orig_idx) in new_to_original.iter().enumerate() {
original_to_new[orig_idx] = new_idx;
}
Self {
new_to_original,
original_to_new,
}
}
#[inline]
pub fn to_original(&self, new_idx: usize) -> usize {
self.new_to_original[new_idx]
}
#[inline]
pub fn to_new(&self, original_idx: usize) -> usize {
self.original_to_new[original_idx]
}
pub fn len(&self) -> usize {
self.new_to_original.len()
}
pub fn is_empty(&self) -> bool {
self.new_to_original.is_empty()
}
pub fn is_identity(&self) -> bool {
self.new_to_original
.iter()
.enumerate()
.all(|(i, &orig)| i == orig)
}
pub fn new_to_original(&self) -> &[usize] {
&self.new_to_original
}
}
pub fn reorder_dataset(dataset: &BinnedDataset, permutation: &ColumnPermutation) -> BinnedDataset {
if permutation.is_identity() {
return dataset.clone();
}
let num_rows = dataset.num_rows();
let num_features = dataset.num_features();
let mut new_features = Vec::with_capacity(num_rows * num_features);
for new_idx in 0..num_features {
let orig_idx = permutation.to_original(new_idx);
let column = dataset.feature_column(orig_idx);
new_features.extend_from_slice(column);
}
let new_feature_info: Vec<FeatureInfo> = (0..num_features)
.map(|new_idx| {
let orig_idx = permutation.to_original(new_idx);
dataset.feature_info(orig_idx).clone()
})
.collect();
BinnedDataset::new(
num_rows,
new_features,
dataset.targets().to_vec(),
new_feature_info,
)
}
pub struct ReorderBuilder {
strategy: OrderingStrategy,
importances: Option<Vec<f32>>,
access_tracker: Option<AccessTracker>,
}
impl Default for ReorderBuilder {
fn default() -> Self {
Self::new()
}
}
impl ReorderBuilder {
pub fn new() -> Self {
Self {
strategy: OrderingStrategy::Original,
importances: None,
access_tracker: None,
}
}
pub fn with_strategy(mut self, strategy: OrderingStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn with_importances(mut self, importances: Vec<f32>) -> Self {
self.importances = Some(importances);
self
}
pub fn with_access_tracker(mut self, tracker: AccessTracker) -> Self {
self.access_tracker = Some(tracker);
self
}
pub fn build_permutation(&self, num_features: usize) -> ColumnPermutation {
match self.strategy {
OrderingStrategy::Original => ColumnPermutation::identity(num_features),
OrderingStrategy::ByImportance => {
if let Some(ref importances) = self.importances {
ColumnPermutation::from_importances(importances)
} else {
ColumnPermutation::identity(num_features)
}
}
OrderingStrategy::ByAccessFrequency => {
if let Some(ref tracker) = self.access_tracker {
ColumnPermutation::from_access_tracker(tracker)
} else {
ColumnPermutation::identity(num_features)
}
}
}
}
pub fn reorder(&self, dataset: &BinnedDataset) -> BinnedDataset {
let permutation = self.build_permutation(dataset.num_features());
reorder_dataset(dataset, &permutation)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::FeatureType;
fn create_test_dataset() -> BinnedDataset {
let num_rows = 4;
let features = vec![0, 1, 2, 3, 10, 11, 12, 13, 20, 21, 22, 23];
let targets = vec![1.0, 2.0, 3.0, 4.0];
let feature_info = vec![
FeatureInfo {
name: "f0".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 4,
bin_boundaries: vec![],
},
FeatureInfo {
name: "f1".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 14,
bin_boundaries: vec![],
},
FeatureInfo {
name: "f2".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 24,
bin_boundaries: vec![],
},
];
BinnedDataset::new(num_rows, features, targets, feature_info)
}
#[test]
fn test_identity_permutation() {
let perm = ColumnPermutation::identity(5);
assert!(perm.is_identity());
assert_eq!(perm.len(), 5);
for i in 0..5 {
assert_eq!(perm.to_original(i), i);
assert_eq!(perm.to_new(i), i);
}
}
#[test]
fn test_permutation_from_importances() {
let importances = vec![0.1, 0.5, 0.3, 0.05, 0.05];
let perm = ColumnPermutation::from_importances(&importances);
assert!(!perm.is_identity());
assert_eq!(perm.to_original(0), 1); assert_eq!(perm.to_original(1), 2); assert_eq!(perm.to_original(2), 0); }
#[test]
fn test_access_tracker() {
let mut tracker = AccessTracker::new(3);
tracker.record(0);
tracker.record(2);
tracker.record(2);
tracker.record(2);
tracker.record(1);
tracker.record(1);
assert_eq!(tracker.counts(), &[1, 2, 3]);
let order = tracker.optimal_order();
assert_eq!(order, vec![2, 1, 0]); }
#[test]
fn test_reorder_dataset() {
let dataset = create_test_dataset();
let importances = vec![0.1, 0.5, 0.3]; let perm = ColumnPermutation::from_importances(&importances);
let reordered = reorder_dataset(&dataset, &perm);
assert_eq!(reordered.feature_column(0), &[10, 11, 12, 13]); assert_eq!(reordered.feature_column(1), &[20, 21, 22, 23]); assert_eq!(reordered.feature_column(2), &[0, 1, 2, 3]);
assert_eq!(reordered.feature_info(0).name, "f1");
assert_eq!(reordered.feature_info(1).name, "f2");
assert_eq!(reordered.feature_info(2).name, "f0");
assert_eq!(reordered.targets(), &[1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_reorder_builder() {
let dataset = create_test_dataset();
let importances = vec![0.3, 0.1, 0.6];
let reordered = ReorderBuilder::new()
.with_strategy(OrderingStrategy::ByImportance)
.with_importances(importances)
.reorder(&dataset);
assert_eq!(reordered.feature_info(0).name, "f2");
}
#[test]
fn test_identity_reorder_no_copy() {
let dataset = create_test_dataset();
let perm = ColumnPermutation::identity(3);
let reordered = reorder_dataset(&dataset, &perm);
assert_eq!(reordered.feature_column(0), dataset.feature_column(0));
assert_eq!(reordered.feature_column(1), dataset.feature_column(1));
assert_eq!(reordered.feature_column(2), dataset.feature_column(2));
}
}