rv/dist/crp.rs
1//! Chinese Restaurant Process
2//!
3//! [The Chinese Restaurant Process](https://en.wikipedia.org/wiki/Chinese_restaurant_process) (CRP)
4//! is a distribution over partitions of items. The CRP defines a process by
5//! which entities are assigned to an unknown number of partition.
6//!
7//! The CRP is parameterized CRP(α) where α is the 'discount' parameter in
8//! (0, ∞). Higher α causes there to be more partitions, as it encourages new
9//! entries to create new partitions.
10#[cfg(feature = "serde1")]
11use serde::{Deserialize, Serialize};
12
13use crate::data::Partition;
14use crate::impl_display;
15use crate::misc::ln_gammafn;
16use crate::misc::pflip;
17use crate::traits::{HasDensity, Parameterized, Sampleable, Support};
18use rand::Rng;
19use std::fmt;
20
21/// Parameters for the Chinese Restaurant Process distribution
22#[derive(Debug, Clone, PartialEq)]
23#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
24#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
25pub struct CrpParameters {
26 /// Discount parameter
27 pub alpha: f64,
28 /// number of items in the partition
29 pub n: usize,
30}
31
32/// [Chinese Restaurant Process](https://en.wikipedia.org/wiki/Chinese_restaurant_process),
33/// a distribution over partitions.
34///
35/// # Example
36///
37/// ```
38/// use::rv::prelude::*;
39///
40/// let mut rng = rand::rng();
41///
42/// let crp = Crp::new(1.0, 10).expect("Invalid parameters");
43/// let partition = crp.draw(&mut rng);
44///
45/// assert_eq!(partition.len(), 10);
46/// ```
47#[derive(Debug, Clone, PartialEq)]
48#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
49#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
50pub struct Crp {
51 /// Discount parameter
52 alpha: f64,
53 /// number of items in the partition
54 n: usize,
55}
56
57#[derive(Debug, Clone, PartialEq)]
58#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
59#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
60pub enum CrpError {
61 /// n parameter is zero
62 NIsZero,
63 /// alpha parameter is less than or equal to zero
64 AlphaTooLow { alpha: f64 },
65 /// alpha parameter is infinite or NaN
66 AlphaNotFinite { alpha: f64 },
67}
68
69impl Crp {
70 /// Create an empty `Crp` with parameter alpha
71 ///
72 /// # Arguments
73 /// - alpha: Discount parameter in (0, Infinity)
74 /// - n: the number of items in the partition
75 pub fn new(alpha: f64, n: usize) -> Result<Self, CrpError> {
76 if n == 0 {
77 Err(CrpError::NIsZero)
78 } else if alpha <= 0.0 {
79 Err(CrpError::AlphaTooLow { alpha })
80 } else if !alpha.is_finite() {
81 Err(CrpError::AlphaNotFinite { alpha })
82 } else {
83 Ok(Crp { alpha, n })
84 }
85 }
86
87 /// Create a new Crp without checking whether the parameters are valid.
88 ///
89 /// ```rust
90 /// use rv::dist::Crp;
91 ///
92 /// let crp = Crp::new_unchecked(3.0, 10);
93 ///
94 /// assert_eq!(crp.alpha(), 3.0);
95 /// assert_eq!(crp.n(), 10);
96 /// ```
97 #[inline]
98 #[must_use]
99 pub fn new_unchecked(alpha: f64, n: usize) -> Self {
100 Crp { alpha, n }
101 }
102
103 /// Get the discount parameter, `alpha`.
104 ///
105 /// # Example
106 ///
107 /// ```rust
108 /// # use rv::dist::Crp;
109 /// let crp = Crp::new(1.0, 12).unwrap();
110 /// assert_eq!(crp.alpha(), 1.0);
111 /// ```
112 #[inline]
113 #[must_use]
114 pub fn alpha(&self) -> f64 {
115 self.alpha
116 }
117
118 /// Set the value of alpha
119 ///
120 /// # Example
121 /// ```rust
122 /// # use rv::dist::Crp;
123 /// let mut crp = Crp::new(1.1, 20).unwrap();
124 /// assert_eq!(crp.alpha(), 1.1);
125 ///
126 /// crp.set_alpha(2.3).unwrap();
127 /// assert_eq!(crp.alpha(), 2.3);
128 /// ```
129 ///
130 /// Will error for invalid parameters
131 ///
132 /// ```rust
133 /// # use rv::dist::Crp;
134 /// # let mut crp = Crp::new(1.1, 20).unwrap();
135 /// assert!(crp.set_alpha(0.5).is_ok());
136 /// assert!(crp.set_alpha(0.0).is_err());
137 /// assert!(crp.set_alpha(-1.0).is_err());
138 /// assert!(crp.set_alpha(f64::INFINITY).is_err());
139 /// assert!(crp.set_alpha(f64::NEG_INFINITY).is_err());
140 /// assert!(crp.set_alpha(f64::NAN).is_err());
141 /// ```
142 #[inline]
143 pub fn set_alpha(&mut self, alpha: f64) -> Result<(), CrpError> {
144 if alpha <= 0.0 {
145 Err(CrpError::AlphaTooLow { alpha })
146 } else if !alpha.is_finite() {
147 Err(CrpError::AlphaNotFinite { alpha })
148 } else {
149 self.set_alpha_unchecked(alpha);
150 Ok(())
151 }
152 }
153
154 /// Set the value of alpha without input validation
155 #[inline]
156 pub fn set_alpha_unchecked(&mut self, alpha: f64) {
157 self.alpha = alpha;
158 }
159
160 /// Get the number of entries in the partition, `n`.
161 ///
162 /// # Example
163 ///
164 /// ```rust
165 /// # use rv::dist::Crp;
166 /// let crp = Crp::new(1.0, 12).unwrap();
167 /// assert_eq!(crp.n(), 12);
168 /// ```
169 #[inline]
170 #[must_use]
171 pub fn n(&self) -> usize {
172 self.n
173 }
174
175 /// Set the value of n
176 ///
177 /// # Example
178 /// ```rust
179 /// # use rv::dist::Crp;
180 /// let mut crp = Crp::new(1.1, 20).unwrap();
181 /// assert_eq!(crp.n(), 20);
182 ///
183 /// crp.set_n(11).unwrap();
184 /// assert_eq!(crp.n(), 11);
185 /// ```
186 ///
187 /// Will error for invalid parameters
188 ///
189 /// ```rust
190 /// # use rv::dist::Crp;
191 /// # let mut crp = Crp::new(1.1, 20).unwrap();
192 /// assert!(crp.set_n(5).is_ok());
193 /// assert!(crp.set_n(1).is_ok());
194 /// assert!(crp.set_n(0).is_err());
195 /// ```
196 #[inline]
197 pub fn set_n(&mut self, n: usize) -> Result<(), CrpError> {
198 if n == 0 {
199 Err(CrpError::NIsZero)
200 } else {
201 self.set_n_unchecked(n);
202 Ok(())
203 }
204 }
205
206 /// Set the value of alpha without input validation
207 #[inline]
208 pub fn set_n_unchecked(&mut self, n: usize) {
209 self.n = n;
210 }
211}
212
213impl From<&Crp> for String {
214 fn from(crp: &Crp) -> String {
215 format!("CRP({}; α: {})", crp.n, crp.alpha)
216 }
217}
218
219impl_display!(Crp);
220
221impl HasDensity<Partition> for Crp {
222 fn ln_f(&self, x: &Partition) -> f64 {
223 let gsum = x
224 .counts()
225 .iter()
226 .fold(0.0, |acc, ct| acc + ln_gammafn(*ct as f64));
227
228 // TODO: could cache ln(alpha) and ln_gamma(alpha)
229 (x.k() as f64).mul_add(self.alpha.ln(), gsum) + ln_gammafn(self.alpha)
230 - ln_gammafn(x.len() as f64 + self.alpha)
231 }
232}
233
234impl Sampleable<Partition> for Crp {
235 fn draw<R: Rng>(&self, rng: &mut R) -> Partition {
236 let mut k = 1;
237 // TODO: Set capacity according to
238 // https://www.cs.princeton.edu/courses/archive/fall07/cos597C/scribe/20070921.pdf
239 let mut weights: Vec<f64> = vec![1.0];
240 let mut sum = 1.0 + self.alpha;
241 let mut z: Vec<usize> = Vec::with_capacity(self.n);
242 z.push(0);
243
244 for _ in 1..self.n {
245 weights.push(self.alpha);
246 let zi = pflip(&weights, Some(sum), rng);
247 z.push(zi);
248
249 if zi == k {
250 weights[zi] = 1.0;
251 k += 1;
252 } else {
253 weights.truncate(k);
254 weights[zi] += 1.0;
255 }
256 sum += 1.0;
257 }
258 // convert weights to counts, correcting for possible floating point
259 // errors
260 // TODO: Is this right? Wouldn't this be the _expected_ counts?
261 let counts: Vec<usize> =
262 weights.iter().map(|w| (w + 0.5) as usize).collect();
263
264 Partition::new_unchecked(z, counts)
265 }
266}
267
268impl Support<Partition> for Crp {
269 #[inline]
270 fn supports(&self, _x: &Partition) -> bool {
271 true
272 }
273}
274
275impl std::error::Error for CrpError {}
276
277#[cfg_attr(coverage_nightly, coverage(off))]
278impl fmt::Display for CrpError {
279 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280 match self {
281 Self::AlphaTooLow { alpha } => {
282 write!(f, "alpha ({alpha}) must be greater than zero")
283 }
284 Self::AlphaNotFinite { alpha } => {
285 write!(f, "alpha ({alpha}) was non-finite")
286 }
287 Self::NIsZero => write!(f, "n must be greater than zero"),
288 }
289 }
290}
291
292impl Parameterized for Crp {
293 type Parameters = CrpParameters;
294
295 fn emit_params(&self) -> Self::Parameters {
296 CrpParameters {
297 alpha: self.alpha,
298 n: self.n,
299 }
300 }
301
302 fn from_params(params: Self::Parameters) -> Self {
303 Self::new_unchecked(params.alpha, params.n)
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 const TOL: f64 = 1E-12;
312
313 #[test]
314 fn new() {
315 let crp = Crp::new(1.2, 808).unwrap();
316 assert::close(crp.alpha, 1.2, TOL);
317 assert_eq!(crp.n, 808);
318 }
319
320 #[test]
321 fn params() {
322 let crp = Crp::new(1.2, 808).unwrap();
323 let params = crp.emit_params();
324
325 let new_crp = Crp::from_params(params);
326 assert_eq!(crp, new_crp);
327 }
328
329 #[test]
330 fn emit_and_from_params_are_identity() {
331 let dist_a = Crp::new(1.5, 710).unwrap();
332 let dist_b = Crp::from_params(dist_a.emit_params());
333 assert_eq!(dist_a, dist_b);
334 }
335
336 // TODO: More tests!
337}