rv/data/partition.rs
1#[cfg(feature = "serde1")]
2use serde::{Deserialize, Serialize};
3
4use crate::impl_display;
5use crate::misc::vec_to_string;
6use std::fmt;
7
8/// A Partition of data by index.
9///
10/// # Example
11/// ```rust
12/// use rv::data::Partition;
13///
14/// let part = Partition::new();
15///
16/// // It starts off empty
17/// assert_eq!(part.z(), &[]);
18/// assert_eq!(part.counts(), &[]);
19/// assert!(part.is_empty());
20///
21/// // We can derive the partition from assignments
22/// let part = Partition::from_z(vec![0, 0, 1, 1, 2]).expect("Non-empty assignments are valid");
23/// assert_eq!(part.z(), &[0, 0, 1, 1, 2]);
24/// assert_eq!(part.counts(), &[2, 2, 1]);
25/// assert!(!part.is_empty());
26/// ```
27#[derive(Debug, Clone, Eq, PartialEq)]
28#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
29#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
30pub struct Partition {
31 /// The assignment of the n items to partitions 0, ..., k-1
32 z: Vec<usize>,
33 /// The number of items assigned to each partition
34 counts: Vec<usize>,
35}
36
37impl Default for Partition {
38 fn default() -> Self {
39 Partition::new()
40 }
41}
42
43impl From<&Partition> for String {
44 fn from(part: &Partition) -> String {
45 let mut out = String::new();
46 out.push_str(
47 format!("Partition (n: {}, k: {})\n", part.len(), part.k())
48 .as_str(),
49 );
50 out.push_str(
51 format!(" assignment: {}\n", vec_to_string(&part.z, 15)).as_str(),
52 );
53 out.push_str(
54 format!(" counts: {}\n", vec_to_string(&part.counts, part.k()))
55 .as_str(),
56 );
57 out
58 }
59}
60
61impl_display!(Partition);
62
63#[derive(Debug, Clone, Eq, PartialEq)]
64#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
65#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
66pub enum PartitionError {
67 /// The input partition is an empty vector
68 EmptyInputPartition,
69 /// One or more of the indicators exceeds the number of categories
70 IndicatorHigherThanNumberOfPartitions {
71 /// The indicator
72 zi: usize,
73 /// The number of partitions
74 nparts: usize,
75 },
76}
77
78impl std::error::Error for PartitionError {}
79
80#[cfg_attr(coverage_nightly, coverage(off))]
81impl fmt::Display for PartitionError {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 match self {
84 Self::EmptyInputPartition => {
85 write!(f, "the input partition was empty")
86 }
87 Self::IndicatorHigherThanNumberOfPartitions { zi, nparts } => {
88 write!(
89 f,
90 "tried to append z = {zi} to a partition with {nparts} \
91 partitions. z must be in 0..n_parts, (or 0..{nparts}),"
92 )
93 }
94 }
95 }
96}
97
98impl Partition {
99 /// Empty partition
100 #[must_use]
101 pub fn new() -> Partition {
102 Partition {
103 z: vec![],
104 counts: vec![],
105 }
106 }
107
108 #[must_use]
109 pub fn new_unchecked(z: Vec<usize>, counts: Vec<usize>) -> Self {
110 Partition { z, counts }
111 }
112
113 #[must_use]
114 pub fn z(&self) -> &Vec<usize> {
115 &self.z
116 }
117
118 pub fn z_mut(&mut self) -> &mut Vec<usize> {
119 &mut self.z
120 }
121
122 #[must_use]
123 pub fn counts(&self) -> &Vec<usize> {
124 &self.counts
125 }
126
127 pub fn counts_mut(&mut self) -> &mut Vec<usize> {
128 &mut self.counts
129 }
130
131 /// Create a `Partition` with a given assignment, `z`
132 ///
133 /// # Examples
134 ///
135 /// ```rust
136 /// # use rv::data::Partition;
137 /// let z1 = vec![0, 1, 2, 3, 1, 2];
138 /// let part = Partition::from_z(z1).unwrap();
139 ///
140 /// assert_eq!(*part.z(), vec![0, 1, 2, 3, 1, 2]);
141 /// assert_eq!(*part.counts(), vec![1, 2, 2, 1]);
142 ///
143 /// // Invalid z because k=4 is empty. All partitions must be occupied.
144 /// let z2 = vec![0, 1, 2, 3, 1, 5];
145 /// assert!(Partition::from_z(z2).is_err());
146 /// ```
147 pub fn from_z(z: Vec<usize>) -> Result<Self, PartitionError> {
148 if z.is_empty() {
149 return Err(PartitionError::EmptyInputPartition);
150 }
151
152 let k = *z.iter().max().expect("empty z") + 1;
153 let mut counts: Vec<usize> = vec![0; k];
154 z.iter().for_each(|&zi| counts[zi] += 1);
155
156 if counts.iter().all(|&ct| ct > 0) {
157 let part = Partition { z, counts };
158 Ok(part)
159 } else {
160 Err(PartitionError::EmptyInputPartition)
161 }
162 }
163
164 /// Remove the item at index `ix`
165 ///
166 /// # Example
167 ///
168 /// ```
169 /// # use rv::data::Partition;
170 /// let mut part = Partition::from_z(vec![0, 1, 0, 2]).unwrap();
171 /// part.remove(1).expect("Could not remove");
172 ///
173 /// assert_eq!(*part.z(), vec![0, 0, 1]);
174 /// assert_eq!(*part.counts(), vec![2, 1]);
175 /// ```
176 pub fn remove(&mut self, ix: usize) -> Result<(), PartitionError> {
177 // Panics on index error panics.
178 let zi = self.z.remove(ix);
179 if self.counts[zi] == 1 {
180 let _ct = self.counts.remove(zi);
181 // ensure canonical order
182 self.z.iter_mut().for_each(|zj| {
183 if *zj > zi {
184 *zj -= 1;
185 }
186 });
187 Ok(())
188 } else {
189 self.counts[zi] -= 1;
190 Ok(())
191 }
192 }
193
194 /// Append a new item assigned to partition `zi`
195 ///
196 /// # Example
197 ///
198 /// ``` rust
199 /// # use rv::data::Partition;
200 /// let mut part = Partition::from_z(vec![0, 1, 0, 2]).unwrap();
201 /// part.append(3).expect("Could not append");
202 ///
203 /// assert_eq!(*part.z(), vec![0, 1, 0, 2, 3]);
204 /// assert_eq!(*part.counts(), vec![2, 1, 1, 1]);
205 /// ```
206 pub fn append(&mut self, zi: usize) -> Result<(), PartitionError> {
207 let k = self.k();
208 if zi > k {
209 Err(PartitionError::IndicatorHigherThanNumberOfPartitions {
210 zi,
211 nparts: k,
212 })
213 } else {
214 self.z.push(zi);
215 if zi == k {
216 self.counts.push(1);
217 } else {
218 self.counts[zi] += 1;
219 }
220 Ok(())
221 }
222 }
223
224 /// Returns the number of partitions, k.
225 ///
226 /// # Example
227 ///
228 /// ``` rust
229 /// # use rv::data::Partition;
230 /// let part = Partition::from_z(vec![0, 1, 0, 2]).unwrap();
231 ///
232 /// assert_eq!(part.k(), 3);
233 /// assert_eq!(*part.counts(), vec![2, 1, 1]);
234 /// ```
235 #[must_use]
236 pub fn k(&self) -> usize {
237 self.counts.len()
238 }
239
240 /// Returns the number items
241 #[must_use]
242 pub fn len(&self) -> usize {
243 self.z.len()
244 }
245
246 #[must_use]
247 pub fn is_empty(&self) -> bool {
248 self.len() == 0
249 }
250
251 /// Return the partition weights (normalized counts)
252 ///
253 /// # Example
254 ///
255 /// ``` rust
256 /// # use rv::data::Partition;
257 /// let part = Partition::from_z(vec![0, 1, 0, 2]).unwrap();
258 /// let weights = part.weights();
259 ///
260 /// assert_eq!(weights, vec![0.5, 0.25, 0.25]);
261 /// ```
262 #[must_use]
263 pub fn weights(&self) -> Vec<f64> {
264 let n = self.len() as f64;
265 self.counts.iter().map(|&ct| (ct as f64) / n).collect()
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn new() {
275 let part = Partition::from_z(vec![0, 1, 0, 2]).unwrap();
276
277 assert_eq!(part.k(), 3);
278 assert_eq!(part.counts, vec![2, 1, 1]);
279 }
280}