use std::convert::TryInto;
use lace::cc::state::Builder;
use lace::codebook::Codebook;
use lace::codebook::ColMetadata;
use lace::codebook::ColType;
use lace::data::Datum;
use lace::stats::prior::nix::NixHyper;
use lace::AppendStrategy;
use lace::Engine;
use lace::HasData;
use lace::HasStates;
use lace::WriteMode;
use rand::Rng;
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
fn assert_rows_eq(row_a: &[Datum], row_b: &[Datum]) {
assert_eq!(row_a.len(), row_b.len());
for (ix, (a, b)) in row_a.iter().zip(row_b.iter()).enumerate() {
let xa = a.to_f64_opt().unwrap();
let xb = b.to_f64_opt().unwrap();
if (xa - xb).abs() > 1E-14 {
let msg = format!(
"Rows were different at index {}: {:?} != {:?}",
ix, a, b
);
panic!("{}\n{:?} != {:?}", msg, row_a, row_b);
}
}
}
fn assert_rows_ne(row_a: &[Datum], row_b: &[Datum]) {
assert_eq!(row_a.len(), row_b.len());
let diff = row_a.iter().zip(row_b.iter()).fold(false, |acc, (a, b)| {
if acc {
acc
} else {
let xa = a.to_f64_opt().unwrap();
let xb = b.to_f64_opt().unwrap();
(xa - xb).abs() > 1E-14
}
});
if !diff {
panic!("Rows identical\n{:?} == {:?}", row_a, row_b);
}
}
fn gen_engine() -> Engine {
let states: Vec<_> = (0..4)
.map(|_| {
Builder::new()
.n_rows(10)
.column_configs(
14,
ColType::Continuous {
hyper: Some(NixHyper::default()),
prior: None,
},
)
.n_views(1)
.n_cats(2)
.build()
.unwrap()
})
.collect();
let codebook = Codebook {
table_name: "table".into(),
state_prior_process: None,
view_prior_process: None,
col_metadata: (0..14)
.map(|i| ColMetadata {
name: format!("{}", i),
notes: None,
coltype: ColType::Continuous {
hyper: Some(NixHyper::default()),
prior: None,
},
missing_not_at_random: false,
})
.collect::<Vec<ColMetadata>>()
.try_into()
.unwrap(),
comments: None,
row_names: (0..10)
.map(|i| format!("{}", i))
.collect::<Vec<String>>()
.try_into()
.unwrap(),
};
Engine {
states,
state_ids: vec![0, 1, 2, 3],
rng: Xoshiro256Plus::from_os_rng(),
codebook,
}
}
#[test]
fn stream_insert_all_data() {
let mut engine = gen_engine();
let mut rng = rand::rng();
let mode = WriteMode {
append_strategy: AppendStrategy::Window,
..WriteMode::unrestricted()
};
for i in 10..40 {
let row = (
format!("{}", i),
(0..14)
.map(|j| {
let x = Datum::Continuous(rng.random());
(format!("{}", j), x)
})
.collect::<Vec<(String, Datum)>>(),
);
let tasks = engine.insert_data(vec![row.into()], None, mode).unwrap();
assert_eq!(tasks.new_rows().unwrap().len(), 1);
engine.run(1).unwrap();
assert_eq!(engine.n_rows(), 10);
}
}
#[test]
fn trench_insert_all_data() {
let mut engine = gen_engine();
let mut rng = rand::rng();
let mode = WriteMode {
append_strategy: AppendStrategy::Trench {
max_n_rows: 15,
trench_ix: 10,
},
..WriteMode::unrestricted()
};
let ninth_row: Vec<_> =
(0..14).map(|col_ix| engine.cell(9, col_ix)).collect();
let mut last_tenth_row: Vec<_> =
(0..14).map(|col_ix| engine.cell(9, col_ix)).collect();
for (i, ix) in (10..40).enumerate() {
let row = (
format!("{}", ix),
(0..14)
.map(|j| {
let x = Datum::Continuous(rng.random());
(format!("{}", j), x)
})
.collect::<Vec<(String, Datum)>>(),
);
let tasks = engine.insert_data(vec![row.into()], None, mode).unwrap();
let this_ninth_row: Vec<_> =
(0..14).map(|col_ix| engine.cell(9, col_ix)).collect();
let this_tenth_row: Vec<_> =
(0..14).map(|col_ix| engine.cell(10, col_ix)).collect();
engine.run(1).unwrap();
dbg!(i);
assert_eq!(tasks.new_rows().unwrap().len(), 1);
assert_eq!(engine.n_rows(), 15_usize.min(10 + i + 1));
assert_rows_eq(&ninth_row, &this_ninth_row);
if ix > 14 {
assert_rows_ne(&last_tenth_row, &this_tenth_row);
}
last_tenth_row = this_tenth_row;
}
}