Skip to main content

rv/data/
partition.rs

1#[cfg(feature = "serde1")]
2use serde::{Deserialize, Serialize};
3
4use crate::impl_display;
5use crate::misc::vec_to_string;
6use std::fmt;
7
8/// A Partition of data by index.
9///
10/// # Example
11/// ```rust
12/// use rv::data::Partition;
13///
14/// let part = Partition::new();
15///
16/// // It starts off empty
17/// assert_eq!(part.z(), &[]);
18/// assert_eq!(part.counts(), &[]);
19/// assert!(part.is_empty());
20///
21/// // We can derive the partition from assignments
22/// let part = Partition::from_z(vec![0, 0, 1, 1, 2]).expect("Non-empty assignments are valid");
23/// assert_eq!(part.z(), &[0, 0, 1, 1, 2]);
24/// assert_eq!(part.counts(), &[2, 2, 1]);
25/// assert!(!part.is_empty());
26/// ```
27#[derive(Debug, Clone, Eq, PartialEq)]
28#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
29#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
30pub struct Partition {
31    /// The assignment of the n items to partitions 0, ..., k-1
32    z: Vec<usize>,
33    /// The number of items assigned to each partition
34    counts: Vec<usize>,
35}
36
37impl Default for Partition {
38    fn default() -> Self {
39        Partition::new()
40    }
41}
42
43impl From<&Partition> for String {
44    fn from(part: &Partition) -> String {
45        let mut out = String::new();
46        out.push_str(
47            format!("Partition (n: {}, k: {})\n", part.len(), part.k())
48                .as_str(),
49        );
50        out.push_str(
51            format!("  assignment: {}\n", vec_to_string(&part.z, 15)).as_str(),
52        );
53        out.push_str(
54            format!("  counts: {}\n", vec_to_string(&part.counts, part.k()))
55                .as_str(),
56        );
57        out
58    }
59}
60
61impl_display!(Partition);
62
63#[derive(Debug, Clone, Eq, PartialEq)]
64#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
65#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
66pub enum PartitionError {
67    /// The input partition is an empty vector
68    EmptyInputPartition,
69    /// One or more of the indicators exceeds the number of categories
70    IndicatorHigherThanNumberOfPartitions {
71        /// The indicator
72        zi: usize,
73        /// The number of partitions
74        nparts: usize,
75    },
76}
77
78impl std::error::Error for PartitionError {}
79
80#[cfg_attr(coverage_nightly, coverage(off))]
81impl fmt::Display for PartitionError {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        match self {
84            Self::EmptyInputPartition => {
85                write!(f, "the input partition was empty")
86            }
87            Self::IndicatorHigherThanNumberOfPartitions { zi, nparts } => {
88                write!(
89                    f,
90                    "tried to append z = {zi} to a partition with {nparts} \
91                     partitions.  z must be in 0..n_parts, (or 0..{nparts}),"
92                )
93            }
94        }
95    }
96}
97
98impl Partition {
99    /// Empty partition
100    #[must_use]
101    pub fn new() -> Partition {
102        Partition {
103            z: vec![],
104            counts: vec![],
105        }
106    }
107
108    #[must_use]
109    pub fn new_unchecked(z: Vec<usize>, counts: Vec<usize>) -> Self {
110        Partition { z, counts }
111    }
112
113    #[must_use]
114    pub fn z(&self) -> &Vec<usize> {
115        &self.z
116    }
117
118    pub fn z_mut(&mut self) -> &mut Vec<usize> {
119        &mut self.z
120    }
121
122    #[must_use]
123    pub fn counts(&self) -> &Vec<usize> {
124        &self.counts
125    }
126
127    pub fn counts_mut(&mut self) -> &mut Vec<usize> {
128        &mut self.counts
129    }
130
131    /// Create a `Partition` with a given assignment, `z`
132    ///
133    /// # Examples
134    ///
135    /// ```rust
136    /// # use rv::data::Partition;
137    /// let z1 = vec![0, 1, 2, 3, 1, 2];
138    /// let part = Partition::from_z(z1).unwrap();
139    ///
140    /// assert_eq!(*part.z(), vec![0, 1, 2, 3, 1, 2]);
141    /// assert_eq!(*part.counts(), vec![1, 2, 2, 1]);
142    ///
143    /// // Invalid z because k=4 is empty. All partitions must be occupied.
144    /// let z2 = vec![0, 1, 2, 3, 1, 5];
145    /// assert!(Partition::from_z(z2).is_err());
146    /// ```
147    pub fn from_z(z: Vec<usize>) -> Result<Self, PartitionError> {
148        if z.is_empty() {
149            return Err(PartitionError::EmptyInputPartition);
150        }
151
152        let k = *z.iter().max().expect("empty z") + 1;
153        let mut counts: Vec<usize> = vec![0; k];
154        z.iter().for_each(|&zi| counts[zi] += 1);
155
156        if counts.iter().all(|&ct| ct > 0) {
157            let part = Partition { z, counts };
158            Ok(part)
159        } else {
160            Err(PartitionError::EmptyInputPartition)
161        }
162    }
163
164    /// Remove the item at index `ix`
165    ///
166    /// # Example
167    ///
168    /// ```
169    /// # use rv::data::Partition;
170    /// let mut part = Partition::from_z(vec![0, 1, 0, 2]).unwrap();
171    /// part.remove(1).expect("Could not remove");
172    ///
173    /// assert_eq!(*part.z(), vec![0, 0, 1]);
174    /// assert_eq!(*part.counts(), vec![2, 1]);
175    /// ```
176    pub fn remove(&mut self, ix: usize) -> Result<(), PartitionError> {
177        // Panics  on index error panics.
178        let zi = self.z.remove(ix);
179        if self.counts[zi] == 1 {
180            let _ct = self.counts.remove(zi);
181            // ensure canonical order
182            self.z.iter_mut().for_each(|zj| {
183                if *zj > zi {
184                    *zj -= 1;
185                }
186            });
187            Ok(())
188        } else {
189            self.counts[zi] -= 1;
190            Ok(())
191        }
192    }
193
194    /// Append a new item assigned to partition `zi`
195    ///
196    /// # Example
197    ///
198    /// ``` rust
199    /// # use rv::data::Partition;
200    /// let mut part = Partition::from_z(vec![0, 1, 0, 2]).unwrap();
201    /// part.append(3).expect("Could not append");
202    ///
203    /// assert_eq!(*part.z(), vec![0, 1, 0, 2, 3]);
204    /// assert_eq!(*part.counts(), vec![2, 1, 1, 1]);
205    /// ```
206    pub fn append(&mut self, zi: usize) -> Result<(), PartitionError> {
207        let k = self.k();
208        if zi > k {
209            Err(PartitionError::IndicatorHigherThanNumberOfPartitions {
210                zi,
211                nparts: k,
212            })
213        } else {
214            self.z.push(zi);
215            if zi == k {
216                self.counts.push(1);
217            } else {
218                self.counts[zi] += 1;
219            }
220            Ok(())
221        }
222    }
223
224    /// Returns the number of partitions, k.
225    ///
226    /// # Example
227    ///
228    /// ``` rust
229    /// # use rv::data::Partition;
230    /// let part = Partition::from_z(vec![0, 1, 0, 2]).unwrap();
231    ///
232    /// assert_eq!(part.k(), 3);
233    /// assert_eq!(*part.counts(), vec![2, 1, 1]);
234    /// ```
235    #[must_use]
236    pub fn k(&self) -> usize {
237        self.counts.len()
238    }
239
240    /// Returns the number items
241    #[must_use]
242    pub fn len(&self) -> usize {
243        self.z.len()
244    }
245
246    #[must_use]
247    pub fn is_empty(&self) -> bool {
248        self.len() == 0
249    }
250
251    /// Return the partition weights (normalized counts)
252    ///
253    /// # Example
254    ///
255    /// ``` rust
256    /// # use rv::data::Partition;
257    /// let part = Partition::from_z(vec![0, 1, 0, 2]).unwrap();
258    /// let weights = part.weights();
259    ///
260    /// assert_eq!(weights, vec![0.5, 0.25, 0.25]);
261    /// ```
262    #[must_use]
263    pub fn weights(&self) -> Vec<f64> {
264        let n = self.len() as f64;
265        self.counts.iter().map(|&ct| (ct as f64) / n).collect()
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn new() {
275        let part = Partition::from_z(vec![0, 1, 0, 2]).unwrap();
276
277        assert_eq!(part.k(), 3);
278        assert_eq!(part.counts, vec![2, 1, 1]);
279    }
280}