#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use crate::impl_display;
use crate::misc::vec_to_string;
use std::fmt;
#[derive(Debug, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
pub struct Partition {
z: Vec<usize>,
counts: Vec<usize>,
}
impl Default for Partition {
fn default() -> Self {
Partition::new()
}
}
impl From<&Partition> for String {
fn from(part: &Partition) -> String {
let mut out = String::new();
out.push_str(
format!("Partition (n: {}, k: {})\n", part.len(), part.k())
.as_str(),
);
out.push_str(
format!(" assignment: {}\n", vec_to_string(&part.z, 15)).as_str(),
);
out.push_str(
format!(" counts: {}\n", vec_to_string(&part.counts, part.k()))
.as_str(),
);
out
}
}
impl_display!(Partition);
#[derive(Debug, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
pub enum PartitionError {
EmptyInputPartition,
IndicatorHigherThanNumberOfPartitions {
zi: usize,
nparts: usize,
},
}
impl std::error::Error for PartitionError {}
#[cfg_attr(coverage_nightly, coverage(off))]
impl fmt::Display for PartitionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::EmptyInputPartition => {
write!(f, "the input partition was empty")
}
Self::IndicatorHigherThanNumberOfPartitions { zi, nparts } => {
write!(
f,
"tried to append z = {zi} to a partition with {nparts} \
partitions. z must be in 0..n_parts, (or 0..{nparts}),"
)
}
}
}
}
impl Partition {
#[must_use]
pub fn new() -> Partition {
Partition {
z: vec![],
counts: vec![],
}
}
#[must_use]
pub fn new_unchecked(z: Vec<usize>, counts: Vec<usize>) -> Self {
Partition { z, counts }
}
#[must_use]
pub fn z(&self) -> &Vec<usize> {
&self.z
}
pub fn z_mut(&mut self) -> &mut Vec<usize> {
&mut self.z
}
#[must_use]
pub fn counts(&self) -> &Vec<usize> {
&self.counts
}
pub fn counts_mut(&mut self) -> &mut Vec<usize> {
&mut self.counts
}
pub fn from_z(z: Vec<usize>) -> Result<Self, PartitionError> {
if z.is_empty() {
return Err(PartitionError::EmptyInputPartition);
}
let k = *z.iter().max().expect("empty z") + 1;
let mut counts: Vec<usize> = vec![0; k];
z.iter().for_each(|&zi| counts[zi] += 1);
if counts.iter().all(|&ct| ct > 0) {
let part = Partition { z, counts };
Ok(part)
} else {
Err(PartitionError::EmptyInputPartition)
}
}
pub fn remove(&mut self, ix: usize) -> Result<(), PartitionError> {
let zi = self.z.remove(ix);
if self.counts[zi] == 1 {
let _ct = self.counts.remove(zi);
self.z.iter_mut().for_each(|zj| {
if *zj > zi {
*zj -= 1;
}
});
Ok(())
} else {
self.counts[zi] -= 1;
Ok(())
}
}
pub fn append(&mut self, zi: usize) -> Result<(), PartitionError> {
let k = self.k();
if zi > k {
Err(PartitionError::IndicatorHigherThanNumberOfPartitions {
zi,
nparts: k,
})
} else {
self.z.push(zi);
if zi == k {
self.counts.push(1);
} else {
self.counts[zi] += 1;
}
Ok(())
}
}
#[must_use]
pub fn k(&self) -> usize {
self.counts.len()
}
#[must_use]
pub fn len(&self) -> usize {
self.z.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn weights(&self) -> Vec<f64> {
let n = self.len() as f64;
self.counts.iter().map(|&ct| (ct as f64) / n).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new() {
let part = Partition::from_z(vec![0, 1, 0, 2]).unwrap();
assert_eq!(part.k(), 3);
assert_eq!(part.counts, vec![2, 1, 1]);
}
}