Skip to main content

rv/dist/
crp.rs

1//! Chinese Restaurant Process
2//!
3//! [The Chinese Restaurant Process](https://en.wikipedia.org/wiki/Chinese_restaurant_process) (CRP)
4//! is a distribution over partitions of items. The CRP defines a process by
5//! which entities are assigned to an unknown number of partition.
6//!
7//! The CRP is parameterized CRP(α) where α is the 'discount' parameter in
8//! (0, ∞). Higher α causes there to be more partitions, as it encourages new
9//! entries to create new partitions.
10#[cfg(feature = "serde1")]
11use serde::{Deserialize, Serialize};
12
13use crate::data::Partition;
14use crate::impl_display;
15use crate::misc::ln_gammafn;
16use crate::misc::pflip;
17use crate::traits::{HasDensity, Parameterized, Sampleable, Support};
18use rand::Rng;
19use std::fmt;
20
21/// Parameters for the Chinese Restaurant Process distribution
22#[derive(Debug, Clone, PartialEq)]
23#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
24#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
25pub struct CrpParameters {
26    /// Discount parameter
27    pub alpha: f64,
28    /// number of items in the partition
29    pub n: usize,
30}
31
32/// [Chinese Restaurant Process](https://en.wikipedia.org/wiki/Chinese_restaurant_process),
33/// a distribution over partitions.
34///
35/// # Example
36///
37/// ```
38/// use::rv::prelude::*;
39///
40/// let mut rng = rand::rng();
41///
42/// let crp = Crp::new(1.0, 10).expect("Invalid parameters");
43/// let partition = crp.draw(&mut rng);
44///
45/// assert_eq!(partition.len(), 10);
46/// ```
47#[derive(Debug, Clone, PartialEq)]
48#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
49#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
50pub struct Crp {
51    /// Discount parameter
52    alpha: f64,
53    /// number of items in the partition
54    n: usize,
55}
56
57#[derive(Debug, Clone, PartialEq)]
58#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
59#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
60pub enum CrpError {
61    /// n parameter is zero
62    NIsZero,
63    /// alpha parameter is less than or equal to zero
64    AlphaTooLow { alpha: f64 },
65    /// alpha parameter is infinite or NaN
66    AlphaNotFinite { alpha: f64 },
67}
68
69impl Crp {
70    /// Create an empty `Crp` with parameter alpha
71    ///
72    /// # Arguments
73    /// - alpha: Discount parameter in (0, Infinity)
74    /// - n: the number of items in the partition
75    pub fn new(alpha: f64, n: usize) -> Result<Self, CrpError> {
76        if n == 0 {
77            Err(CrpError::NIsZero)
78        } else if alpha <= 0.0 {
79            Err(CrpError::AlphaTooLow { alpha })
80        } else if !alpha.is_finite() {
81            Err(CrpError::AlphaNotFinite { alpha })
82        } else {
83            Ok(Crp { alpha, n })
84        }
85    }
86
87    /// Create a new Crp without checking whether the parameters are valid.
88    ///
89    /// ```rust
90    /// use rv::dist::Crp;
91    ///
92    /// let crp = Crp::new_unchecked(3.0, 10);
93    ///
94    /// assert_eq!(crp.alpha(), 3.0);
95    /// assert_eq!(crp.n(), 10);
96    /// ```
97    #[inline]
98    #[must_use]
99    pub fn new_unchecked(alpha: f64, n: usize) -> Self {
100        Crp { alpha, n }
101    }
102
103    /// Get the discount parameter, `alpha`.
104    ///
105    /// # Example
106    ///
107    /// ```rust
108    /// # use rv::dist::Crp;
109    /// let crp = Crp::new(1.0, 12).unwrap();
110    /// assert_eq!(crp.alpha(), 1.0);
111    /// ```
112    #[inline]
113    #[must_use]
114    pub fn alpha(&self) -> f64 {
115        self.alpha
116    }
117
118    /// Set the value of alpha
119    ///
120    /// # Example
121    /// ```rust
122    /// # use rv::dist::Crp;
123    /// let mut crp = Crp::new(1.1, 20).unwrap();
124    /// assert_eq!(crp.alpha(), 1.1);
125    ///
126    /// crp.set_alpha(2.3).unwrap();
127    /// assert_eq!(crp.alpha(), 2.3);
128    /// ```
129    ///
130    /// Will error for invalid parameters
131    ///
132    /// ```rust
133    /// # use rv::dist::Crp;
134    /// # let mut crp = Crp::new(1.1, 20).unwrap();
135    /// assert!(crp.set_alpha(0.5).is_ok());
136    /// assert!(crp.set_alpha(0.0).is_err());
137    /// assert!(crp.set_alpha(-1.0).is_err());
138    /// assert!(crp.set_alpha(f64::INFINITY).is_err());
139    /// assert!(crp.set_alpha(f64::NEG_INFINITY).is_err());
140    /// assert!(crp.set_alpha(f64::NAN).is_err());
141    /// ```
142    #[inline]
143    pub fn set_alpha(&mut self, alpha: f64) -> Result<(), CrpError> {
144        if alpha <= 0.0 {
145            Err(CrpError::AlphaTooLow { alpha })
146        } else if !alpha.is_finite() {
147            Err(CrpError::AlphaNotFinite { alpha })
148        } else {
149            self.set_alpha_unchecked(alpha);
150            Ok(())
151        }
152    }
153
154    /// Set the value of alpha without input validation
155    #[inline]
156    pub fn set_alpha_unchecked(&mut self, alpha: f64) {
157        self.alpha = alpha;
158    }
159
160    /// Get the number of entries in the partition, `n`.
161    ///
162    /// # Example
163    ///
164    /// ```rust
165    /// # use rv::dist::Crp;
166    /// let crp = Crp::new(1.0, 12).unwrap();
167    /// assert_eq!(crp.n(), 12);
168    /// ```
169    #[inline]
170    #[must_use]
171    pub fn n(&self) -> usize {
172        self.n
173    }
174
175    /// Set the value of n
176    ///
177    /// # Example
178    /// ```rust
179    /// # use rv::dist::Crp;
180    /// let mut crp = Crp::new(1.1, 20).unwrap();
181    /// assert_eq!(crp.n(), 20);
182    ///
183    /// crp.set_n(11).unwrap();
184    /// assert_eq!(crp.n(), 11);
185    /// ```
186    ///
187    /// Will error for invalid parameters
188    ///
189    /// ```rust
190    /// # use rv::dist::Crp;
191    /// # let mut crp = Crp::new(1.1, 20).unwrap();
192    /// assert!(crp.set_n(5).is_ok());
193    /// assert!(crp.set_n(1).is_ok());
194    /// assert!(crp.set_n(0).is_err());
195    /// ```
196    #[inline]
197    pub fn set_n(&mut self, n: usize) -> Result<(), CrpError> {
198        if n == 0 {
199            Err(CrpError::NIsZero)
200        } else {
201            self.set_n_unchecked(n);
202            Ok(())
203        }
204    }
205
206    /// Set the value of alpha without input validation
207    #[inline]
208    pub fn set_n_unchecked(&mut self, n: usize) {
209        self.n = n;
210    }
211}
212
213impl From<&Crp> for String {
214    fn from(crp: &Crp) -> String {
215        format!("CRP({}; α: {})", crp.n, crp.alpha)
216    }
217}
218
219impl_display!(Crp);
220
221impl HasDensity<Partition> for Crp {
222    fn ln_f(&self, x: &Partition) -> f64 {
223        let gsum = x
224            .counts()
225            .iter()
226            .fold(0.0, |acc, ct| acc + ln_gammafn(*ct as f64));
227
228        // TODO: could cache ln(alpha) and ln_gamma(alpha)
229        (x.k() as f64).mul_add(self.alpha.ln(), gsum) + ln_gammafn(self.alpha)
230            - ln_gammafn(x.len() as f64 + self.alpha)
231    }
232}
233
234impl Sampleable<Partition> for Crp {
235    fn draw<R: Rng>(&self, rng: &mut R) -> Partition {
236        let mut k = 1;
237        // TODO: Set capacity according to
238        // https://www.cs.princeton.edu/courses/archive/fall07/cos597C/scribe/20070921.pdf
239        let mut weights: Vec<f64> = vec![1.0];
240        let mut sum = 1.0 + self.alpha;
241        let mut z: Vec<usize> = Vec::with_capacity(self.n);
242        z.push(0);
243
244        for _ in 1..self.n {
245            weights.push(self.alpha);
246            let zi = pflip(&weights, Some(sum), rng);
247            z.push(zi);
248
249            if zi == k {
250                weights[zi] = 1.0;
251                k += 1;
252            } else {
253                weights.truncate(k);
254                weights[zi] += 1.0;
255            }
256            sum += 1.0;
257        }
258        // convert weights to counts, correcting for possible floating point
259        // errors
260        // TODO: Is this right? Wouldn't this be the _expected_ counts?
261        let counts: Vec<usize> =
262            weights.iter().map(|w| (w + 0.5) as usize).collect();
263
264        Partition::new_unchecked(z, counts)
265    }
266}
267
268impl Support<Partition> for Crp {
269    #[inline]
270    fn supports(&self, _x: &Partition) -> bool {
271        true
272    }
273}
274
275impl std::error::Error for CrpError {}
276
277#[cfg_attr(coverage_nightly, coverage(off))]
278impl fmt::Display for CrpError {
279    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280        match self {
281            Self::AlphaTooLow { alpha } => {
282                write!(f, "alpha ({alpha}) must be greater than zero")
283            }
284            Self::AlphaNotFinite { alpha } => {
285                write!(f, "alpha ({alpha}) was non-finite")
286            }
287            Self::NIsZero => write!(f, "n must be greater than zero"),
288        }
289    }
290}
291
292impl Parameterized for Crp {
293    type Parameters = CrpParameters;
294
295    fn emit_params(&self) -> Self::Parameters {
296        CrpParameters {
297            alpha: self.alpha,
298            n: self.n,
299        }
300    }
301
302    fn from_params(params: Self::Parameters) -> Self {
303        Self::new_unchecked(params.alpha, params.n)
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    const TOL: f64 = 1E-12;
312
313    #[test]
314    fn new() {
315        let crp = Crp::new(1.2, 808).unwrap();
316        assert::close(crp.alpha, 1.2, TOL);
317        assert_eq!(crp.n, 808);
318    }
319
320    #[test]
321    fn params() {
322        let crp = Crp::new(1.2, 808).unwrap();
323        let params = crp.emit_params();
324
325        let new_crp = Crp::from_params(params);
326        assert_eq!(crp, new_crp);
327    }
328
329    #[test]
330    fn emit_and_from_params_are_identity() {
331        let dist_a = Crp::new(1.5, 710).unwrap();
332        let dist_b = Crp::from_params(dist_a.emit_params());
333        assert_eq!(dist_a, dist_b);
334    }
335
336    // TODO: More tests!
337}