use super::stats::{compute_correlation_matrix, correlation};
use super::FeatureGenerator;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum InteractionType {
Multiply,
Add,
Subtract,
Min,
Max,
}
impl InteractionType {
pub fn suffix(&self) -> &'static str {
match self {
Self::Multiply => "mul",
Self::Add => "add",
Self::Subtract => "sub",
Self::Min => "min",
Self::Max => "max",
}
}
#[inline]
pub fn compute(&self, a: f32, b: f32) -> f32 {
match self {
Self::Multiply => a * b,
Self::Add => a + b,
Self::Subtract => (a - b).abs(),
Self::Min => a.min(b),
Self::Max => a.max(b),
}
}
pub fn all() -> Vec<Self> {
vec![
Self::Multiply,
Self::Add,
Self::Subtract,
Self::Min,
Self::Max,
]
}
pub fn default_types() -> Vec<Self> {
vec![Self::Multiply]
}
}
#[derive(Debug, Clone)]
pub enum PairSelection {
AllPairs,
Explicit(Vec<(usize, usize)>),
TopCorrelated { max_pairs: usize },
TargetBased { max_pairs: usize, targets: Vec<f32> },
}
#[derive(Debug, Clone)]
pub struct InteractionGenerator {
selection: PairSelection,
interaction_types: Vec<InteractionType>,
include_self: bool,
min_correlation: f32,
pairs: Option<Vec<(usize, usize)>>,
}
impl InteractionGenerator {
pub fn from_pairs(pairs: Vec<(usize, usize)>) -> Self {
Self {
selection: PairSelection::Explicit(pairs),
interaction_types: InteractionType::default_types(),
include_self: false,
min_correlation: 0.0,
pairs: None,
}
}
pub fn all_pairs() -> Self {
Self {
selection: PairSelection::AllPairs,
interaction_types: InteractionType::default_types(),
include_self: false,
min_correlation: 0.0,
pairs: None,
}
}
pub fn top_correlated(max_pairs: usize) -> Self {
Self {
selection: PairSelection::TopCorrelated { max_pairs },
interaction_types: InteractionType::default_types(),
include_self: false,
min_correlation: 0.1,
pairs: None,
}
}
pub fn target_based(max_pairs: usize, targets: Vec<f32>) -> Self {
Self {
selection: PairSelection::TargetBased { max_pairs, targets },
interaction_types: InteractionType::default_types(),
include_self: false,
min_correlation: 0.0,
pairs: None,
}
}
pub fn with_types(mut self, types: Vec<InteractionType>) -> Self {
self.interaction_types = types;
self
}
pub fn with_self_interactions(mut self, include: bool) -> Self {
self.include_self = include;
self
}
pub fn with_min_correlation(mut self, threshold: f32) -> Self {
self.min_correlation = threshold;
self
}
pub fn pairs(&self) -> Option<&[(usize, usize)]> {
self.pairs.as_deref()
}
pub fn interactions_per_pair(&self) -> usize {
self.interaction_types.len()
}
pub fn fit(&mut self, data: &[f32], num_features: usize) {
if num_features == 0 || data.is_empty() {
self.pairs = Some(Vec::new());
return;
}
let num_rows = data.len() / num_features;
let pairs = match &self.selection {
PairSelection::AllPairs => generate_all_pairs(num_features, self.include_self),
PairSelection::Explicit(p) => {
p.iter()
.filter(|(i, j)| *i < num_features && *j < num_features)
.cloned()
.collect()
}
PairSelection::TopCorrelated { max_pairs } => {
if num_rows < 2 {
generate_all_pairs(num_features, self.include_self)
.into_iter()
.take(*max_pairs)
.collect()
} else {
select_top_correlated(
data,
num_features,
num_rows,
*max_pairs,
self.min_correlation,
self.include_self,
)
}
}
PairSelection::TargetBased { max_pairs, targets } => {
if num_rows < 2 || targets.len() != num_rows {
generate_all_pairs(num_features, self.include_self)
.into_iter()
.take(*max_pairs)
.collect()
} else {
select_target_based(
data,
num_features,
num_rows,
targets,
*max_pairs,
self.include_self,
)
}
}
};
self.pairs = Some(pairs);
}
pub fn is_fitted(&self) -> bool {
self.pairs.is_some()
}
pub fn n_output_features(&self) -> usize {
self.pairs
.as_ref()
.map(|p| p.len() * self.interaction_types.len())
.unwrap_or(0)
}
}
impl FeatureGenerator for InteractionGenerator {
fn generate(
&self,
data: &[f32],
num_features: usize,
feature_names: &[String],
) -> (Vec<f32>, Vec<String>) {
let pairs = match &self.pairs {
Some(p) => p.clone(),
None => {
let mut temp = self.clone();
temp.fit(data, num_features);
temp.pairs.unwrap_or_default()
}
};
if pairs.is_empty() || num_features == 0 || data.is_empty() {
return (Vec::new(), Vec::new());
}
let num_rows = data.len() / num_features;
let n_interactions = self.interaction_types.len();
let total_features = pairs.len() * n_interactions;
let mut new_data = vec![0.0f32; num_rows * total_features];
let mut new_names = Vec::with_capacity(total_features);
for (pair_idx, &(i, j)) in pairs.iter().enumerate() {
let name_i = feature_names
.get(i)
.cloned()
.unwrap_or_else(|| format!("f{}", i));
let name_j = feature_names
.get(j)
.cloned()
.unwrap_or_else(|| format!("f{}", j));
for (type_idx, interaction_type) in self.interaction_types.iter().enumerate() {
let feature_idx = pair_idx * n_interactions + type_idx;
new_names.push(format!(
"{}_{}_{}",
name_i,
interaction_type.suffix(),
name_j
));
for r in 0..num_rows {
let val_i = data[r * num_features + i];
let val_j = data[r * num_features + j];
new_data[r * total_features + feature_idx] =
interaction_type.compute(val_i, val_j);
}
}
}
(new_data, new_names)
}
fn name(&self) -> &'static str {
"interaction"
}
}
impl Default for InteractionGenerator {
fn default() -> Self {
Self::top_correlated(20)
}
}
fn generate_all_pairs(num_features: usize, include_self: bool) -> Vec<(usize, usize)> {
let mut pairs = Vec::new();
for i in 0..num_features {
let start = if include_self { i } else { i + 1 };
for j in start..num_features {
pairs.push((i, j));
}
}
pairs
}
fn select_top_correlated(
data: &[f32],
num_features: usize,
num_rows: usize,
max_pairs: usize,
min_correlation: f32,
include_self: bool,
) -> Vec<(usize, usize)> {
let correlations = compute_correlation_matrix(data, num_features, num_rows);
let mut pair_scores: Vec<((usize, usize), f32)> = Vec::new();
for i in 0..num_features {
let start = if include_self { i } else { i + 1 };
for j in start..num_features {
let corr = correlations[i * num_features + j].abs();
if corr >= min_correlation {
pair_scores.push(((i, j), corr));
}
}
}
pair_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
pair_scores
.into_iter()
.take(max_pairs)
.map(|(pair, _)| pair)
.collect()
}
fn select_target_based(
data: &[f32],
num_features: usize,
num_rows: usize,
targets: &[f32],
max_pairs: usize,
include_self: bool,
) -> Vec<(usize, usize)> {
if targets.len() != num_rows {
return Vec::new();
}
let feature_target_corrs: Vec<f32> = (0..num_features)
.map(|f| {
let feature_vals: Vec<f32> =
(0..num_rows).map(|r| data[r * num_features + f]).collect();
correlation(&feature_vals, targets).abs()
})
.collect();
let mut pair_scores: Vec<((usize, usize), f32)> = Vec::new();
for i in 0..num_features {
let start = if include_self { i } else { i + 1 };
for j in start..num_features {
let interaction: Vec<f32> = (0..num_rows)
.map(|r| {
let vi = data[r * num_features + i];
let vj = data[r * num_features + j];
vi * vj
})
.collect();
let interaction_corr = correlation(&interaction, targets).abs();
let max_individual = feature_target_corrs[i].max(feature_target_corrs[j]);
let gain = interaction_corr - max_individual;
if gain > 0.0 {
pair_scores.push(((i, j), gain));
}
}
}
pair_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
pair_scores
.into_iter()
.take(max_pairs)
.map(|(pair, _)| pair)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_interaction_type_compute() {
assert!((InteractionType::Multiply.compute(3.0, 4.0) - 12.0).abs() < 1e-6);
assert!((InteractionType::Add.compute(3.0, 4.0) - 7.0).abs() < 1e-6);
assert!((InteractionType::Subtract.compute(3.0, 4.0) - 1.0).abs() < 1e-6);
assert!((InteractionType::Min.compute(3.0, 4.0) - 3.0).abs() < 1e-6);
assert!((InteractionType::Max.compute(3.0, 4.0) - 4.0).abs() < 1e-6);
}
#[test]
fn test_interaction_type_subtract_absolute() {
assert!((InteractionType::Subtract.compute(4.0, 3.0) - 1.0).abs() < 1e-6);
assert!((InteractionType::Subtract.compute(3.0, 4.0) - 1.0).abs() < 1e-6);
}
#[test]
fn test_interaction_type_suffixes() {
assert_eq!(InteractionType::Multiply.suffix(), "mul");
assert_eq!(InteractionType::Add.suffix(), "add");
assert_eq!(InteractionType::Subtract.suffix(), "sub");
assert_eq!(InteractionType::Min.suffix(), "min");
assert_eq!(InteractionType::Max.suffix(), "max");
}
#[test]
fn test_from_pairs_basic() {
let gen = InteractionGenerator::from_pairs(vec![(0, 1), (1, 2)]);
let data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, ];
let names = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let (new_data, new_names) = gen.generate(&data, 3, &names);
assert_eq!(new_names.len(), 2);
assert_eq!(new_names[0], "a_mul_b");
assert_eq!(new_names[1], "b_mul_c");
assert_eq!(new_data.len(), 6);
assert!((new_data[0] - 2.0).abs() < 1e-6);
assert!((new_data[1] - 6.0).abs() < 1e-6);
assert!((new_data[2] - 20.0).abs() < 1e-6);
assert!((new_data[3] - 30.0).abs() < 1e-6);
}
#[test]
fn test_multiple_interaction_types() {
let gen = InteractionGenerator::from_pairs(vec![(0, 1)]).with_types(vec![
InteractionType::Multiply,
InteractionType::Add,
InteractionType::Subtract,
]);
let data = vec![3.0, 5.0]; let names = vec!["a".to_string(), "b".to_string()];
let (new_data, new_names) = gen.generate(&data, 2, &names);
assert_eq!(new_names.len(), 3);
assert_eq!(new_names[0], "a_mul_b");
assert_eq!(new_names[1], "a_add_b");
assert_eq!(new_names[2], "a_sub_b");
assert!((new_data[0] - 15.0).abs() < 1e-6); assert!((new_data[1] - 8.0).abs() < 1e-6); assert!((new_data[2] - 2.0).abs() < 1e-6); }
#[test]
fn test_all_pairs_generator() {
let mut gen = InteractionGenerator::all_pairs();
let data = vec![1.0, 2.0, 3.0]; let names = vec!["a".to_string(), "b".to_string(), "c".to_string()];
gen.fit(&data, 3);
assert_eq!(gen.pairs().unwrap().len(), 3);
let (_, new_names) = gen.generate(&data, 3, &names);
assert_eq!(new_names.len(), 3);
}
#[test]
fn test_self_interactions() {
let mut gen = InteractionGenerator::all_pairs().with_self_interactions(true);
let data = vec![2.0, 3.0];
gen.fit(&data, 2);
assert_eq!(gen.pairs().unwrap().len(), 3);
let names = vec!["a".to_string(), "b".to_string()];
let (new_data, new_names) = gen.generate(&data, 2, &names);
assert_eq!(new_names.len(), 3);
assert!((new_data[0] - 4.0).abs() < 1e-6); }
#[test]
fn test_empty_data() {
let gen = InteractionGenerator::from_pairs(vec![(0, 1)]);
let (new_data, new_names) = gen.generate(&[], 0, &[]);
assert!(new_data.is_empty());
assert!(new_names.is_empty());
}
#[test]
fn test_top_correlated_selection() {
let data = vec![
1.0, 2.0, 10.0, 2.0, 4.0, 20.0, 3.0, 6.0, 30.0, 4.0, 8.0, 40.0, ];
let mut gen = InteractionGenerator::top_correlated(2).with_min_correlation(0.5);
gen.fit(&data, 3);
let pairs = gen.pairs().unwrap();
assert!(!pairs.is_empty());
assert!(pairs.len() <= 2);
}
#[test]
fn test_target_based_selection() {
let data = vec![
1.0, 1.0, 0.0, 2.0, 3.0, 0.0, 3.0, 2.0, 0.0, 4.0, 4.0, 0.0, ];
let targets = vec![1.0, 6.0, 6.0, 16.0];
let mut gen = InteractionGenerator::target_based(5, targets);
gen.fit(&data, 3);
let pairs = gen.pairs().unwrap();
assert!(pairs.len() <= 5);
}
#[test]
fn test_correlation_perfect() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let y = vec![2.0, 4.0, 6.0, 8.0];
let corr = correlation(&x, &y);
assert!((corr - 1.0).abs() < 1e-6);
}
#[test]
fn test_correlation_negative() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let y = vec![4.0, 3.0, 2.0, 1.0];
let corr = correlation(&x, &y);
assert!((corr + 1.0).abs() < 1e-6);
}
#[test]
fn test_correlation_zero() {
let x = vec![1.0, -1.0, 1.0, -1.0];
let y = vec![1.0, 1.0, -1.0, -1.0];
let corr = correlation(&x, &y);
assert!(corr.abs() < 1e-6);
}
#[test]
fn test_generate_all_pairs() {
let pairs = generate_all_pairs(4, false);
assert_eq!(pairs.len(), 6);
let pairs_with_self = generate_all_pairs(4, true);
assert_eq!(pairs_with_self.len(), 10);
}
#[test]
fn test_feature_generator_trait() {
let gen = InteractionGenerator::from_pairs(vec![(0, 1)]);
assert_eq!(gen.name(), "interaction");
}
#[test]
fn test_min_max_interactions() {
let gen = InteractionGenerator::from_pairs(vec![(0, 1)])
.with_types(vec![InteractionType::Min, InteractionType::Max]);
let data = vec![
3.0, 5.0, 7.0, 2.0, ];
let names = vec!["a".to_string(), "b".to_string()];
let (new_data, new_names) = gen.generate(&data, 2, &names);
assert_eq!(new_names.len(), 2);
assert_eq!(new_names[0], "a_min_b");
assert_eq!(new_names[1], "a_max_b");
assert!((new_data[0] - 3.0).abs() < 1e-6);
assert!((new_data[1] - 5.0).abs() < 1e-6);
assert!((new_data[2] - 2.0).abs() < 1e-6);
assert!((new_data[3] - 7.0).abs() < 1e-6);
}
#[test]
fn test_nan_handling() {
let gen = InteractionGenerator::from_pairs(vec![(0, 1)]);
let data = vec![f32::NAN, 5.0];
let names = vec!["a".to_string(), "b".to_string()];
let (new_data, _) = gen.generate(&data, 2, &names);
assert!(new_data[0].is_nan());
}
#[test]
fn test_large_dataset() {
let num_rows = 1000;
let num_features = 10;
let data: Vec<f32> = (0..num_rows * num_features)
.map(|i| (i % 100) as f32)
.collect();
let names: Vec<String> = (0..num_features).map(|i| format!("f{}", i)).collect();
let mut gen = InteractionGenerator::top_correlated(20);
gen.fit(&data, num_features);
let (new_data, new_names) = gen.generate(&data, num_features, &names);
assert_eq!(new_data.len(), num_rows * new_names.len());
assert!(new_names.len() <= 20);
}
}