#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use crate::impl_display;
use crate::misc::vec_to_string;
use std::fmt;
#[derive(Debug, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Partition {
/// The assignment of the n items to partitions 0, ..., k-1
z: Vec<usize>,
/// The number of items assigned to each partition
counts: Vec<usize>,
}
impl Default for Partition {
fn default() -> Self {
Partition::new()
}
}
impl From<&Partition> for String {
fn from(part: &Partition) -> String {
let mut out = String::new();
out.push_str(
format!("Partition (n: {}, k: {})\n", part.len(), part.k())
.as_str(),
);
out.push_str(
format!(" assignment: {}\n", vec_to_string(&part.z, 15)).as_str(),
);
out.push_str(
format!(" counts: {}\n", vec_to_string(&part.counts, part.k()))
.as_str(),
);
out
}
}
impl_display!(Partition);
#[derive(Debug, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum PartitionError {
/// The input partition is an empty vector
EmptyInputPartition,
/// One or more of the indicators exceeds the number of categories
IndicatorHigherThanNumberOfPartitions {
/// The indicator
zi: usize,
/// The number of partitions
nparts: usize,
},
}
impl std::error::Error for PartitionError {}
impl fmt::Display for PartitionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::EmptyInputPartition => {
write!(f, "the input partition was empty")
}
Self::IndicatorHigherThanNumberOfPartitions { zi, nparts } => {
write!(
f,
"tried to append z = {0} to a partition with {1} \
partitions. z must be in 0..n_parts, (or 0..{1}),",
zi, nparts
)
}
}
}
}
impl Partition {
/// Empty partition
pub fn new() -> Partition {
Partition {
z: vec![],
counts: vec![],
}
}
pub fn new_unchecked(z: Vec<usize>, counts: Vec<usize>) -> Self {
Partition { z, counts }
}
pub fn z(&self) -> &Vec<usize> {
&self.z
}
pub fn z_mut(&mut self) -> &mut Vec<usize> {
&mut self.z
}
pub fn counts(&self) -> &Vec<usize> {
&self.counts
}
pub fn counts_mut(&mut self) -> &mut Vec<usize> {
&mut self.counts
}
/// Create a `Partition` with a given assignment, `z`
///
/// # Examples
///
/// ```rust
/// # use rv::data::Partition;
/// let z1 = vec![0, 1, 2, 3, 1, 2];
/// let part = Partition::from_z(z1).unwrap();
///
/// assert_eq!(*part.z(), vec![0, 1, 2, 3, 1, 2]);
/// assert_eq!(*part.counts(), vec![1, 2, 2, 1]);
///
/// // Invalid z because k=4 is empty. All partitions must be occupied.
/// let z2 = vec![0, 1, 2, 3, 1, 5];
/// assert!(Partition::from_z(z2).is_err());
/// ```
pub fn from_z(z: Vec<usize>) -> Result<Self, PartitionError> {
if z.is_empty() {
return Err(PartitionError::EmptyInputPartition);
}
let k = *z.iter().max().expect("empty z") + 1;
let mut counts: Vec<usize> = vec![0; k];
z.iter().for_each(|&zi| counts[zi] += 1);
if counts.iter().all(|&ct| ct > 0) {
let part = Partition { z, counts };
Ok(part)
} else {
Err(PartitionError::EmptyInputPartition)
}
}
/// Remove the item at index `ix`
///
/// # Example
///
/// ```
/// # use rv::data::Partition;
/// let mut part = Partition::from_z(vec![0, 1, 0, 2]).unwrap();
/// part.remove(1).expect("Could not remove");
///
/// assert_eq!(*part.z(), vec![0, 0, 1]);
/// assert_eq!(*part.counts(), vec![2, 1]);
/// ```
pub fn remove(&mut self, ix: usize) -> Result<(), PartitionError> {
// Panics on index error panics.
let zi = self.z.remove(ix);
if self.counts[zi] == 1 {
let _ct = self.counts.remove(zi);
// ensure canonical order
self.z.iter_mut().for_each(|zj| {
if *zj > zi {
*zj -= 1
}
});
Ok(())
} else {
self.counts[zi] -= 1;
Ok(())
}
}
/// Append a new item assigned to partition `zi`
///
/// # Example
///
/// ``` rust
/// # use rv::data::Partition;
/// let mut part = Partition::from_z(vec![0, 1, 0, 2]).unwrap();
/// part.append(3).expect("Could not append");
///
/// assert_eq!(*part.z(), vec![0, 1, 0, 2, 3]);
/// assert_eq!(*part.counts(), vec![2, 1, 1, 1]);
/// ```
pub fn append(&mut self, zi: usize) -> Result<(), PartitionError> {
let k = self.k();
if zi > k {
Err(PartitionError::IndicatorHigherThanNumberOfPartitions {
zi,
nparts: k,
})
} else {
self.z.push(zi);
if zi == k {
self.counts.push(1);
} else {
self.counts[zi] += 1;
}
Ok(())
}
}
/// Returns the number of partitions, k.
///
/// # Example
///
/// ``` rust
/// # use rv::data::Partition;
/// let part = Partition::from_z(vec![0, 1, 0, 2]).unwrap();
///
/// assert_eq!(part.k(), 3);
/// assert_eq!(*part.counts(), vec![2, 1, 1]);
/// ```
pub fn k(&self) -> usize {
self.counts.len()
}
/// Returns the number items
pub fn len(&self) -> usize {
self.z.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Return the partition weights (normalized counts)
///
/// # Example
///
/// ``` rust
/// # use rv::data::Partition;
/// let part = Partition::from_z(vec![0, 1, 0, 2]).unwrap();
/// let weights = part.weights();
///
/// assert_eq!(weights, vec![0.5, 0.25, 0.25]);
/// ```
pub fn weights(&self) -> Vec<f64> {
let n = self.len() as f64;
self.counts.iter().map(|&ct| (ct as f64) / n).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new() {
let part = Partition::from_z(vec![0, 1, 0, 2]).unwrap();
assert_eq!(part.k(), 3);
assert_eq!(part.counts, vec![2, 1, 1]);
}
}