1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
pub mod cluster_switch;
pub mod gibbs;
pub mod gibbs_sampler;
pub use cluster_switch::*;
pub use gibbs::*;
pub use gibbs_sampler::*;
use crate::nonparametric::*;
use crate::RandomVariable;
use crate::*;
#[derive(thiserror::Error, Debug)]
pub enum PitmanYorProcessError {
#[error("'d' must be greater than or equal to 0 and less than 1")]
DMustBeGTE0AndLT1,
#[error("`remove_index` is out of range of `s`.")]
RemoveIndexOutOfRange,
#[error("Unknown error")]
Unknown,
}
#[derive(Clone, Debug)]
pub struct PitmanYorProcessParams<G0, TH>
where
G0: Distribution<Value = TH, Condition = ()>,
TH: RandomVariable,
{
alpha: f64,
d: f64,
g0: BaselineMeasure<G0, TH>,
}
impl<G0, TH> PitmanYorProcessParams<G0, TH>
where
G0: Distribution<Value = TH, Condition = ()>,
TH: RandomVariable,
{
pub fn new(alpha: f64, d: f64, g0: BaselineMeasure<G0, TH>) -> Result<Self, DistributionError> {
if alpha <= 0.0 {
return Err(DistributionError::InvalidParameters(
DirichletProcessError::AlphaMustBePositive.into(),
));
}
if d < 0.0 || 1.0 <= d {
return Err(DistributionError::InvalidParameters(
PitmanYorProcessError::DMustBeGTE0AndLT1.into(),
));
}
Ok(Self { alpha, d, g0 })
}
pub fn alpha(&self) -> f64 {
self.alpha
}
pub fn d(&self) -> f64 {
self.d
}
pub fn x_in_cluster<T>(x: &[T], s: &[u32], k: u32) -> Vec<T>
where
T: RandomVariable,
{
s.iter()
.enumerate()
.filter(|&(_, &si)| si == k)
.map(|(i, _)| x[i].clone())
.collect::<Vec<_>>()
}
}