use lace_stats::mh::mh_prior;
use lace_stats::rv::dist::Gamma;
use lace_stats::rv::traits::Rv;
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::misc::crp_draw;
macro_rules! validate_assignment {
($asgn:expr) => {{
let validate_asgn: bool = match option_env!("LACE_NOCHECK") {
Some(value) => value != "1",
None => true,
};
if validate_asgn {
$asgn.validate().is_valid()
} else {
true
}
}};
}
#[allow(dead_code)]
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
pub struct Assignment {
pub alpha: f64,
pub asgn: Vec<usize>,
pub counts: Vec<usize>,
pub n_cats: usize,
pub prior: Gamma,
}
#[derive(Serialize, Deserialize, Eq, PartialEq, Debug, Clone)]
pub struct AssignmentDiagnostics {
asgn_min_is_zero: bool,
asgn_max_is_n_cats_minus_one: bool,
asgn_contains_0_through_n_cats_minus_1: bool,
no_zero_counts: bool,
n_cats_cmp_counts_len: bool,
sum_counts_cmp_n: bool,
asgn_agrees_with_counts: bool,
}
#[derive(Debug, Error, PartialEq)]
pub enum AssignmentError {
#[error("Minimum assignment index is not 0")]
MinAssignmentIndexNotZero,
#[error("Max assignment index is not n_cats - 1")]
MaxAssignmentIndexNotNCatsMinusOne,
#[error("The assignment is missing one or more indices")]
AssignmentDoesNotContainAllIndices,
#[error("One or more of the counts is zero")]
ZeroCounts,
#[error("The sum of counts does not equal the number of data")]
SumCountsNotEqualToAssignmentLength,
#[error("The counts do not agree with the assignment")]
AssignmentAndCountsDisagree,
#[error(
"The length of the counts does not equal the number of categories"
)]
NCatsIsNotCountsLength,
#[error("Attempting to set assignment with a different-length assignment")]
NewAssignmentLengthMismatch,
}
impl AssignmentDiagnostics {
pub fn new(asgn: &Assignment) -> Self {
AssignmentDiagnostics {
asgn_min_is_zero: { *asgn.asgn.iter().min().unwrap_or(&0) == 0 },
asgn_max_is_n_cats_minus_one: {
asgn.asgn
.iter()
.max()
.map(|&x| x == asgn.n_cats - 1)
.unwrap_or(true)
},
asgn_contains_0_through_n_cats_minus_1: {
let mut so_far = true;
for k in 0..asgn.n_cats {
so_far = so_far && asgn.asgn.iter().any(|&x| x == k)
}
so_far
},
no_zero_counts: { !asgn.counts.iter().any(|&ct| ct == 0) },
n_cats_cmp_counts_len: { asgn.n_cats == asgn.counts.len() },
sum_counts_cmp_n: {
let n: usize = asgn.counts.iter().sum();
n == asgn.asgn.len()
},
asgn_agrees_with_counts: {
let mut all = true;
for (k, &count) in asgn.counts.iter().enumerate() {
let k_count = asgn.asgn.iter().fold(0, |acc, &z| {
if z == k {
acc + 1
} else {
acc
}
});
all = all && (k_count == count)
}
all
},
}
}
pub fn is_valid(&self) -> bool {
self.asgn_min_is_zero
&& self.asgn_max_is_n_cats_minus_one
&& self.asgn_contains_0_through_n_cats_minus_1
&& self.no_zero_counts
&& self.n_cats_cmp_counts_len
&& self.sum_counts_cmp_n
&& self.asgn_agrees_with_counts
}
pub fn asgn_min_is_zero(&self) -> Result<(), AssignmentError> {
if self.asgn_min_is_zero {
Ok(())
} else {
Err(AssignmentError::MinAssignmentIndexNotZero)
}
}
fn asgn_max_is_n_cats_minus_one(&self) -> Result<(), AssignmentError> {
if self.asgn_max_is_n_cats_minus_one {
Ok(())
} else {
Err(AssignmentError::MaxAssignmentIndexNotNCatsMinusOne)
}
}
fn asgn_contains_0_through_n_cats_minus_1(
&self,
) -> Result<(), AssignmentError> {
if self.asgn_contains_0_through_n_cats_minus_1 {
Ok(())
} else {
Err(AssignmentError::AssignmentDoesNotContainAllIndices)
}
}
fn no_zero_counts(&self) -> Result<(), AssignmentError> {
if self.no_zero_counts {
Ok(())
} else {
Err(AssignmentError::ZeroCounts)
}
}
fn n_cats_cmp_counts_len(&self) -> Result<(), AssignmentError> {
if self.n_cats_cmp_counts_len {
Ok(())
} else {
Err(AssignmentError::NCatsIsNotCountsLength)
}
}
fn sum_counts_cmp_n(&self) -> Result<(), AssignmentError> {
if self.sum_counts_cmp_n {
Ok(())
} else {
Err(AssignmentError::SumCountsNotEqualToAssignmentLength)
}
}
fn asgn_agrees_with_counts(&self) -> Result<(), AssignmentError> {
if self.asgn_agrees_with_counts {
Ok(())
} else {
Err(AssignmentError::AssignmentAndCountsDisagree)
}
}
pub fn emit_error(&self) -> Result<(), AssignmentError> {
let mut results = vec![
self.asgn_min_is_zero(),
self.asgn_max_is_n_cats_minus_one(),
self.asgn_contains_0_through_n_cats_minus_1(),
self.no_zero_counts(),
self.n_cats_cmp_counts_len(),
self.sum_counts_cmp_n(),
self.asgn_agrees_with_counts(),
];
results.drain(..).collect()
}
}
#[derive(Clone, Debug)]
pub struct AssignmentBuilder {
n: usize,
asgn: Option<Vec<usize>>,
alpha: Option<f64>,
prior: Option<Gamma>,
seed: Option<u64>,
}
#[derive(Debug, Error, PartialEq)]
pub enum BuildAssignmentError {
#[error("alpha is zero")]
AlphaIsZero,
#[error("non-finite alpha: {alpha}")]
AlphaNotFinite { alpha: f64 },
#[error("assignment vector is empty")]
EmptyAssignmentVec,
#[error("there are {n_cats} categories but {n} data")]
NLessThanNCats { n: usize, n_cats: usize },
#[error("invalid assignment: {0}")]
AssignmentError(#[from] AssignmentError),
}
impl AssignmentBuilder {
pub fn new(n: usize) -> Self {
AssignmentBuilder {
n,
asgn: None,
prior: None,
alpha: None,
seed: None,
}
}
pub fn from_vec(asgn: Vec<usize>) -> Self {
AssignmentBuilder {
n: asgn.len(),
asgn: Some(asgn),
prior: None,
alpha: None,
seed: None,
}
}
#[must_use]
pub fn with_prior(mut self, prior: Gamma) -> Self {
self.prior = Some(prior);
self
}
#[must_use]
pub fn with_geweke_prior(mut self) -> Self {
self.prior = Some(lace_consts::geweke_alpha_prior());
self
}
#[must_use]
pub fn with_alpha(mut self, alpha: f64) -> Self {
self.alpha = Some(alpha);
self
}
#[must_use]
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
#[must_use]
pub fn seed_from_rng<R: rand::Rng>(mut self, rng: &mut R) -> Self {
self.seed = Some(rng.next_u64());
self
}
#[must_use]
pub fn flat(mut self) -> Self {
self.asgn = Some(vec![0; self.n]);
self
}
pub fn with_n_cats(
mut self,
n_cats: usize,
) -> Result<Self, BuildAssignmentError> {
if n_cats > self.n {
Err(BuildAssignmentError::NLessThanNCats { n: self.n, n_cats })
} else {
let asgn: Vec<usize> = (0..self.n).map(|i| i % n_cats).collect();
self.asgn = Some(asgn);
Ok(self)
}
}
pub fn build(self) -> Result<Assignment, BuildAssignmentError> {
let prior = self.prior.unwrap_or_else(lace_consts::general_alpha_prior);
let mut rng_opt = if self.alpha.is_none() || self.asgn.is_none() {
let rng = match self.seed {
Some(seed) => Xoshiro256Plus::seed_from_u64(seed),
None => Xoshiro256Plus::from_entropy(),
};
Some(rng)
} else {
None
};
let alpha = match self.alpha {
Some(alpha) => alpha,
None => prior.draw(&mut rng_opt.as_mut().unwrap()),
};
let n = self.n;
let asgn = self.asgn.unwrap_or_else(|| {
crp_draw(n, alpha, &mut rng_opt.as_mut().unwrap()).asgn
});
let n_cats: usize = asgn.iter().max().map(|&m| m + 1).unwrap_or(0);
let mut counts: Vec<usize> = vec![0; n_cats];
for z in &asgn {
counts[*z] += 1;
}
let asgn_out = Assignment {
alpha,
asgn,
counts,
n_cats,
prior,
};
if validate_assignment!(asgn_out) {
Ok(asgn_out)
} else {
asgn_out
.validate()
.emit_error()
.map_err(BuildAssignmentError::AssignmentError)
.map(|_| asgn_out)
}
}
}
impl Assignment {
pub fn set_asgn(
&mut self,
asgn: Vec<usize>,
) -> Result<(), AssignmentError> {
if asgn.len() != self.asgn.len() {
return Err(AssignmentError::NewAssignmentLengthMismatch);
}
let n_cats: usize = *asgn.iter().max().unwrap() + 1;
let mut counts: Vec<usize> = vec![0; n_cats];
for z in &asgn {
counts[*z] += 1;
}
self.asgn = asgn;
self.counts = counts;
self.n_cats = n_cats;
if validate_assignment!(self) {
Ok(())
} else {
self.validate().emit_error()
}
}
pub fn iter(&self) -> impl Iterator<Item = &usize> {
self.asgn.iter()
}
pub fn len(&self) -> usize {
self.asgn.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn dirvec(&self, append_alpha: bool) -> Vec<f64> {
let mut dv: Vec<f64> = self.counts.iter().map(|&x| x as f64).collect();
if append_alpha {
dv.push(self.alpha);
}
dv
}
pub fn log_dirvec(&self, append_alpha: bool) -> Vec<f64> {
let mut dv: Vec<f64> =
self.counts.iter().map(|&x| (x as f64).ln()).collect();
if append_alpha {
dv.push(self.alpha.ln());
}
dv
}
pub fn unassign(&mut self, ix: usize) {
if self.asgn[ix] == usize::max_value() {
return;
}
let k = self.asgn[ix];
if self.counts[k] == 1 {
self.asgn.iter_mut().for_each(|z| {
if *z > k {
*z -= 1
}
});
let _ct = self.counts.remove(k);
self.n_cats -= 1;
} else {
self.counts[k] -= 1;
}
self.asgn[ix] = usize::max_value();
}
pub fn reassign(&mut self, ix: usize, k: usize) {
if ix == self.len() {
self.asgn.push(usize::max_value());
}
if self.asgn[ix] != usize::max_value() {
panic!("Entry {} is assigned. Use assign instead", ix);
} else if k < self.n_cats {
self.asgn[ix] = k;
self.counts[k] += 1;
} else if k == self.n_cats {
self.asgn[ix] = k;
self.n_cats += 1;
self.counts.push(1);
} else {
panic!("k ({}) larger than n_cats ({})", k, self.n_cats);
}
}
pub fn push_unassigned(&mut self) {
self.asgn.push(usize::max_value())
}
pub fn weights(&self) -> Vec<f64> {
let z: f64 = self.len() as f64;
self.dirvec(false).iter().map(|&w| w / z).collect()
}
pub fn log_weights(&self) -> Vec<f64> {
self.weights().iter().map(|w| w.ln()).collect()
}
pub fn update_alpha<R: rand::Rng>(
&mut self,
n_iter: usize,
rng: &mut R,
) -> f64 {
let cts = &self.counts;
let n: usize = self.len();
let loglike = |alpha: &f64| lcrp(n, cts, *alpha);
let prior_ref = &self.prior;
let prior_draw = |rng: &mut R| prior_ref.draw(rng);
let mh_result = mh_prior(self.alpha, loglike, prior_draw, n_iter, rng);
self.alpha = mh_result.x;
mh_result.score_x
}
pub fn validate(&self) -> AssignmentDiagnostics {
AssignmentDiagnostics::new(self)
}
}
pub fn lcrp(n: usize, cts: &[usize], alpha: f64) -> f64 {
let k: f64 = cts.len() as f64;
let gsum = cts.iter().fold(0.0, |acc, ct| {
acc + ::special::Gamma::ln_gamma(*ct as f64).0
});
let cpnt_2 = ::special::Gamma::ln_gamma(alpha).0
- ::special::Gamma::ln_gamma(n as f64 + alpha).0;
gsum + k.mul_add(alpha.ln(), cpnt_2)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::*;
use lace_stats::rv::dist::Gamma;
#[test]
fn zero_count_fails_validation() {
let asgn = Assignment {
alpha: 1.0,
asgn: vec![0, 0, 0, 0],
counts: vec![0, 4],
n_cats: 1,
prior: Gamma::new(1.0, 1.0).unwrap(),
};
let diagnostic = asgn.validate();
assert!(!diagnostic.is_valid());
assert!(diagnostic.asgn_min_is_zero);
assert!(diagnostic.asgn_max_is_n_cats_minus_one);
assert!(diagnostic.asgn_contains_0_through_n_cats_minus_1);
assert!(diagnostic.sum_counts_cmp_n);
assert!(!diagnostic.n_cats_cmp_counts_len);
assert!(!diagnostic.no_zero_counts);
assert!(!diagnostic.asgn_agrees_with_counts);
}
#[test]
fn bad_counts_fails_validation() {
let asgn = Assignment {
alpha: 1.0,
asgn: vec![1, 1, 0, 0],
counts: vec![2, 3],
n_cats: 2,
prior: Gamma::new(1.0, 1.0).unwrap(),
};
let diagnostic = asgn.validate();
assert!(!diagnostic.is_valid());
assert!(diagnostic.asgn_min_is_zero);
assert!(diagnostic.asgn_max_is_n_cats_minus_one);
assert!(diagnostic.asgn_contains_0_through_n_cats_minus_1);
assert!(!diagnostic.sum_counts_cmp_n);
assert!(diagnostic.n_cats_cmp_counts_len);
assert!(diagnostic.no_zero_counts);
assert!(!diagnostic.asgn_agrees_with_counts);
}
#[test]
fn low_n_cats_fails_validation() {
let asgn = Assignment {
alpha: 1.0,
asgn: vec![1, 1, 0, 0],
counts: vec![2, 2],
n_cats: 1,
prior: Gamma::new(1.0, 1.0).unwrap(),
};
let diagnostic = asgn.validate();
assert!(!diagnostic.is_valid());
assert!(diagnostic.asgn_min_is_zero);
assert!(!diagnostic.asgn_max_is_n_cats_minus_one);
assert!(diagnostic.asgn_contains_0_through_n_cats_minus_1);
assert!(diagnostic.sum_counts_cmp_n);
assert!(!diagnostic.n_cats_cmp_counts_len);
assert!(diagnostic.no_zero_counts);
assert!(diagnostic.asgn_agrees_with_counts);
}
#[test]
fn high_n_cats_fails_validation() {
let asgn = Assignment {
alpha: 1.0,
asgn: vec![1, 1, 0, 0],
counts: vec![2, 2],
n_cats: 3,
prior: Gamma::new(1.0, 1.0).unwrap(),
};
let diagnostic = asgn.validate();
assert!(!diagnostic.is_valid());
assert!(diagnostic.asgn_min_is_zero);
assert!(!diagnostic.asgn_max_is_n_cats_minus_one);
assert!(!diagnostic.asgn_contains_0_through_n_cats_minus_1);
assert!(diagnostic.sum_counts_cmp_n);
assert!(!diagnostic.n_cats_cmp_counts_len);
assert!(diagnostic.no_zero_counts);
assert!(diagnostic.asgn_agrees_with_counts);
}
#[test]
fn no_zero_cat_fails_validation() {
let asgn = Assignment {
alpha: 1.0,
asgn: vec![1, 1, 2, 2],
counts: vec![2, 2],
n_cats: 2,
prior: Gamma::new(1.0, 1.0).unwrap(),
};
let diagnostic = asgn.validate();
assert!(!diagnostic.is_valid());
assert!(!diagnostic.asgn_min_is_zero);
assert!(!diagnostic.asgn_max_is_n_cats_minus_one);
assert!(!diagnostic.asgn_contains_0_through_n_cats_minus_1);
assert!(diagnostic.sum_counts_cmp_n);
assert!(diagnostic.n_cats_cmp_counts_len);
assert!(diagnostic.no_zero_counts);
assert!(!diagnostic.asgn_agrees_with_counts);
}
#[test]
fn drawn_assignment_should_have_valid_partition() {
let n: usize = 50;
for _ in 0..100 {
let asgn = AssignmentBuilder::new(n).build().unwrap();
assert!(asgn.validate().is_valid());
}
}
#[test]
fn from_prior_should_have_valid_alpha_and_proper_length() {
let n: usize = 50;
let asgn = AssignmentBuilder::new(n)
.with_prior(Gamma::new(1.0, 1.0).unwrap())
.build()
.unwrap();
assert!(!asgn.is_empty());
assert_eq!(asgn.len(), n);
assert!(asgn.validate().is_valid());
assert!(asgn.alpha > 0.0);
}
#[test]
fn flat_partition_validation() {
let n: usize = 50;
let asgn = AssignmentBuilder::new(n).flat().build().unwrap();
assert_eq!(asgn.n_cats, 1);
assert_eq!(asgn.counts.len(), 1);
assert_eq!(asgn.counts[0], n);
assert!(asgn.asgn.iter().all(|&z| z == 0));
}
#[test]
fn from_vec() {
let z = vec![0, 1, 2, 0, 1, 0];
let asgn = AssignmentBuilder::from_vec(z).build().unwrap();
assert_eq!(asgn.n_cats, 3);
assert_eq!(asgn.counts[0], 3);
assert_eq!(asgn.counts[1], 2);
assert_eq!(asgn.counts[2], 1);
}
#[test]
fn with_n_cats_n_cats_evenly_divides_n() {
let asgn = AssignmentBuilder::new(100)
.with_n_cats(5)
.expect("Whoops!")
.build()
.unwrap();
assert!(asgn.validate().is_valid());
assert_eq!(asgn.n_cats, 5);
assert_eq!(asgn.counts[0], 20);
assert_eq!(asgn.counts[1], 20);
assert_eq!(asgn.counts[2], 20);
assert_eq!(asgn.counts[3], 20);
assert_eq!(asgn.counts[4], 20);
}
#[test]
fn with_n_cats_n_cats_doesnt_divides_n() {
let asgn = AssignmentBuilder::new(103)
.with_n_cats(5)
.expect("Whoops!")
.build()
.unwrap();
assert!(asgn.validate().is_valid());
assert_eq!(asgn.n_cats, 5);
assert_eq!(asgn.counts[0], 21);
assert_eq!(asgn.counts[1], 21);
assert_eq!(asgn.counts[2], 21);
assert_eq!(asgn.counts[3], 20);
assert_eq!(asgn.counts[4], 20);
}
#[test]
fn dirvec_with_alpha_1() {
let asgn = AssignmentBuilder::from_vec(vec![0, 1, 2, 0, 1, 0])
.with_alpha(1.0)
.build()
.unwrap();
let dv = asgn.dirvec(false);
assert_eq!(dv.len(), 3);
assert_relative_eq!(dv[0], 3.0, epsilon = 10E-10);
assert_relative_eq!(dv[1], 2.0, epsilon = 10E-10);
assert_relative_eq!(dv[2], 1.0, epsilon = 10E-10);
}
#[test]
fn dirvec_with_alpha_15() {
let asgn = AssignmentBuilder::from_vec(vec![0, 1, 2, 0, 1, 0])
.with_alpha(1.5)
.build()
.unwrap();
let dv = asgn.dirvec(true);
assert_eq!(dv.len(), 4);
assert_relative_eq!(dv[0], 3.0, epsilon = 10E-10);
assert_relative_eq!(dv[1], 2.0, epsilon = 10E-10);
assert_relative_eq!(dv[2], 1.0, epsilon = 10E-10);
assert_relative_eq!(dv[3], 1.5, epsilon = 10E-10);
}
#[test]
fn log_dirvec_with_alpha_1() {
let asgn = AssignmentBuilder::from_vec(vec![0, 1, 2, 0, 1, 0])
.with_alpha(1.0)
.build()
.unwrap();
let ldv = asgn.log_dirvec(false);
assert_eq!(ldv.len(), 3);
assert_relative_eq!(ldv[0], 3.0_f64.ln(), epsilon = 10E-10);
assert_relative_eq!(ldv[1], 2.0_f64.ln(), epsilon = 10E-10);
assert_relative_eq!(ldv[2], 1.0_f64.ln(), epsilon = 10E-10);
}
#[test]
fn log_dirvec_with_alpha_15() {
let asgn = AssignmentBuilder::from_vec(vec![0, 1, 2, 0, 1, 0])
.with_alpha(1.5)
.build()
.unwrap();
let ldv = asgn.log_dirvec(true);
assert_eq!(ldv.len(), 4);
assert_relative_eq!(ldv[0], 3.0_f64.ln(), epsilon = 10E-10);
assert_relative_eq!(ldv[1], 2.0_f64.ln(), epsilon = 10E-10);
assert_relative_eq!(ldv[2], 1.0_f64.ln(), epsilon = 10E-10);
assert_relative_eq!(ldv[3], 1.5_f64.ln(), epsilon = 10E-10);
}
#[test]
fn weights() {
let asgn = AssignmentBuilder::from_vec(vec![0, 1, 2, 0, 1, 0])
.with_alpha(1.0)
.build()
.unwrap();
let weights = asgn.weights();
assert_eq!(weights.len(), 3);
assert_relative_eq!(weights[0], 3.0 / 6.0, epsilon = 10E-10);
assert_relative_eq!(weights[1], 2.0 / 6.0, epsilon = 10E-10);
assert_relative_eq!(weights[2], 1.0 / 6.0, epsilon = 10E-10);
}
#[test]
fn lcrp_all_ones() {
let lcrp_1 = lcrp(4, &[1, 1, 1, 1], 1.0);
assert_relative_eq!(lcrp_1, -3.178_053_830_347_95, epsilon = 10E-8);
let lcrp_2 = lcrp(4, &[1, 1, 1, 1], 2.1);
assert_relative_eq!(lcrp_2, -1.945_817_590_743_51, epsilon = 10E-8);
}
#[test]
fn unassign_non_singleton() {
let z: Vec<usize> = vec![0, 1, 1, 1, 2, 2];
let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap();
assert_eq!(asgn.n_cats, 3);
assert_eq!(asgn.counts, vec![1, 3, 2]);
asgn.unassign(1);
assert_eq!(asgn.n_cats, 3);
assert_eq!(asgn.counts, vec![1, 2, 2]);
assert_eq!(asgn.asgn, vec![0, usize::max_value(), 1, 1, 2, 2]);
}
#[test]
fn unassign_singleton_low() {
let z: Vec<usize> = vec![0, 1, 1, 1, 2, 2];
let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap();
assert_eq!(asgn.n_cats, 3);
assert_eq!(asgn.counts, vec![1, 3, 2]);
asgn.unassign(0);
assert_eq!(asgn.n_cats, 2);
assert_eq!(asgn.counts, vec![3, 2]);
assert_eq!(asgn.asgn, vec![usize::max_value(), 0, 0, 0, 1, 1]);
}
#[test]
fn unassign_singleton_high() {
let z: Vec<usize> = vec![0, 0, 1, 1, 1, 2];
let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap();
assert_eq!(asgn.n_cats, 3);
assert_eq!(asgn.counts, vec![2, 3, 1]);
asgn.unassign(5);
assert_eq!(asgn.n_cats, 2);
assert_eq!(asgn.counts, vec![2, 3]);
assert_eq!(asgn.asgn, vec![0, 0, 1, 1, 1, usize::max_value()]);
}
#[test]
fn unassign_singleton_middle() {
let z: Vec<usize> = vec![0, 0, 1, 2, 2, 2];
let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap();
assert_eq!(asgn.n_cats, 3);
assert_eq!(asgn.counts, vec![2, 1, 3]);
asgn.unassign(2);
assert_eq!(asgn.n_cats, 2);
assert_eq!(asgn.counts, vec![2, 3]);
assert_eq!(asgn.asgn, vec![0, 0, usize::max_value(), 1, 1, 1]);
}
#[test]
fn reassign_to_existing_cat() {
let z: Vec<usize> = vec![0, 1, 1, 1, 2, 2];
let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap();
assert_eq!(asgn.n_cats, 3);
assert_eq!(asgn.counts, vec![1, 3, 2]);
asgn.unassign(1);
assert_eq!(asgn.n_cats, 3);
assert_eq!(asgn.counts, vec![1, 2, 2]);
assert_eq!(asgn.asgn, vec![0, usize::max_value(), 1, 1, 2, 2]);
asgn.reassign(1, 1);
assert_eq!(asgn.n_cats, 3);
assert_eq!(asgn.counts, vec![1, 3, 2]);
assert_eq!(asgn.asgn, vec![0, 1, 1, 1, 2, 2]);
}
#[test]
fn reassign_to_new_cat() {
let z: Vec<usize> = vec![0, 1, 1, 1, 2, 2];
let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap();
assert_eq!(asgn.n_cats, 3);
assert_eq!(asgn.counts, vec![1, 3, 2]);
asgn.unassign(0);
assert_eq!(asgn.n_cats, 2);
assert_eq!(asgn.counts, vec![3, 2]);
assert_eq!(asgn.asgn, vec![usize::max_value(), 0, 0, 0, 1, 1]);
asgn.reassign(0, 2);
assert_eq!(asgn.n_cats, 3);
assert_eq!(asgn.counts, vec![3, 2, 1]);
assert_eq!(asgn.asgn, vec![2, 0, 0, 0, 1, 1]);
}
#[test]
fn dirvec_with_unassigned_entry() {
let z: Vec<usize> = vec![0, 1, 1, 1, 2, 2];
let mut asgn = AssignmentBuilder::from_vec(z)
.with_alpha(1.0)
.build()
.unwrap();
asgn.unassign(5);
let dv = asgn.dirvec(false);
assert_eq!(dv.len(), 3);
assert_relative_eq!(dv[0], 1.0, epsilon = 10e-10);
assert_relative_eq!(dv[1], 3.0, epsilon = 10e-10);
assert_relative_eq!(dv[2], 1.0, epsilon = 10e-10);
}
#[test]
fn manual_seed_control_works() {
let asgn_1 = AssignmentBuilder::new(25).with_seed(17_834_795).build();
let asgn_2 = AssignmentBuilder::new(25).with_seed(17_834_795).build();
let asgn_3 = AssignmentBuilder::new(25).build();
assert_eq!(asgn_1, asgn_2);
assert_ne!(asgn_1, asgn_3);
}
#[test]
fn from_rng_seed_control_works() {
let mut rng_1 = Xoshiro256Plus::seed_from_u64(17_834_795);
let mut rng_2 = Xoshiro256Plus::seed_from_u64(17_834_795);
let asgn_1 =
AssignmentBuilder::new(25).seed_from_rng(&mut rng_1).build();
let asgn_2 =
AssignmentBuilder::new(25).seed_from_rng(&mut rng_2).build();
let asgn_3 = AssignmentBuilder::new(25).build();
assert_eq!(asgn_1, asgn_2);
assert_ne!(asgn_1, asgn_3);
}
}