use lace::cc::alg::RowAssignAlg;
use lace::cc::feature::ColModel;
use lace::cc::feature::Column;
use lace::cc::feature::Feature;
use lace::cc::view::Builder;
use lace::cc::view::View;
use lace::data::SparseContainer;
use lace::stats::prior::nix::NixHyper;
use rand::Rng;
use rv::dist::Gaussian;
use rv::dist::NormalInvChiSquared;
use rv::traits::Sampleable;
fn gen_col<R: Rng>(id: usize, n: usize, mut rng: &mut R) -> ColModel {
let gauss = Gaussian::new(0.0, 1.0).unwrap();
let data_vec: Vec<f64> = (0..n).map(|_| gauss.draw(&mut rng)).collect();
let data = SparseContainer::from(data_vec);
let hyper = NixHyper::default();
let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 1.0, 1.0);
let ftr = Column::new(id, data, prior, hyper);
ColModel::Continuous(ftr)
}
fn gen_gauss_view<R: Rng>(n: usize, mut rng: &mut R) -> View {
let ftrs: Vec<ColModel> = vec![
gen_col(0, n, &mut rng),
gen_col(1, n, &mut rng),
gen_col(2, n, &mut rng),
gen_col(3, n, &mut rng),
];
Builder::new(n)
.features(ftrs)
.seed_from_rng(&mut rng)
.build()
}
#[test]
fn create_view_smoke() {
let mut rng = rand::rng();
let view = gen_gauss_view(10, &mut rng);
assert_eq!(view.n_rows(), 10);
assert_eq!(view.n_cols(), 4);
}
#[test]
fn finite_reassign_direct_call() {
let mut rng = rand::rng();
let mut view = gen_gauss_view(10, &mut rng);
view.reassign_rows_finite_cpu(&mut rng);
assert!(view.asgn().validate().is_valid());
}
#[test]
fn finite_reassign_from_reassign() {
let mut rng = rand::rng();
let mut view = gen_gauss_view(10, &mut rng);
view.reassign(RowAssignAlg::FiniteCpu, &mut rng);
assert!(view.asgn().validate().is_valid());
}
#[test]
fn insert_feature() {
let mut rng = rand::rng();
let mut view = gen_gauss_view(10, &mut rng);
assert_eq!(view.n_cols(), 4);
let new_ftr = gen_col(4, 10, &mut rng);
view.insert_feature(new_ftr, &mut rng);
assert_eq!(view.n_cols(), 5);
}
#[test]
#[should_panic]
fn insert_feature_with_existing_id_panics() {
let mut rng = rand::rng();
let mut view = gen_gauss_view(10, &mut rng);
assert_eq!(view.n_cols(), 4);
let new_ftr = gen_col(2, 10, &mut rng);
view.insert_feature(new_ftr, &mut rng);
}
#[test]
fn remove_feature() {
let mut rng = rand::rng();
let mut view = gen_gauss_view(10, &mut rng);
assert_eq!(view.n_cols(), 4);
let ftr_opt = view.remove_feature(2);
assert!(ftr_opt.is_some());
assert_eq!(view.n_cols(), 3);
assert_eq!(ftr_opt.unwrap().id(), 2);
}
#[test]
fn remove_non_existent_feature_returns_none() {
let mut rng = rand::rng();
let mut view = gen_gauss_view(10, &mut rng);
assert_eq!(view.n_cols(), 4);
let ftr_opt = view.remove_feature(14);
assert!(ftr_opt.is_none());
assert_eq!(view.n_cols(), 4);
}