mixturs/params/
options.rs

1use crate::stats::{NormalConjugatePrior, PriorHyperParams};
2
3/// Outlier removal options
4#[derive(Debug, Clone, PartialEq)]
5pub struct OutlierRemoval<P: NormalConjugatePrior> {
6    /// Weight of the outlier prior
7    pub weight: f64,
8    /// Outlier prior
9    pub dist: P::HyperParams,
10}
11
12/// Options for the DPMMSC model
13#[derive(Debug, Clone, PartialEq)]
14pub struct ModelOptions<P: NormalConjugatePrior> {
15    /// Prior for the complete data distribution
16    pub data_dist: P::HyperParams,
17    /// Concentration parameter for the Dirichlet process
18    pub alpha: f64,
19    /// Dimensionality of the data
20    pub dim: usize,
21    /// Burnout period for the supercluster log likelihood history (in iterations)
22    pub burnout_period: usize,
23    /// Outlier removal options
24    pub outlier: Option<OutlierRemoval<P>>,
25    /// Whether to use hard assignment during expectation phase
26    pub hard_assignment: bool,
27}
28
29impl<P: NormalConjugatePrior> ModelOptions<P> {
30    pub fn default(dim: usize) -> Self {
31        Self {
32            data_dist: P::HyperParams::default(dim),
33            alpha: 10.0,
34            dim,
35            burnout_period: 20,
36            outlier: Some(OutlierRemoval {
37                weight: 0.05,
38                dist: P::HyperParams::default(dim),
39            }),
40            hard_assignment: false,
41        }
42    }
43}
44
45/// Options for the DPMMSC model fit method
46#[derive(Debug, Clone)]
47pub struct FitOptions {
48    /// Seed for the random number generator
49    pub seed: u64,
50    /// Whether to reuse the previous model parameters
51    pub reuse: bool,
52    /// Number of initial clusters
53    pub init_clusters: usize,
54    /// Maximum number of clusters
55    pub max_clusters: usize,
56    /// Maximum number of iterations
57    pub iters: usize,
58    /// Number of iterations before max iteration to start using argmax label sampling strategy
59    pub argmax_sample_stop: usize,
60    /// Number of iterations before max iteration to stop split/merge proposals
61    pub iter_split_stop: usize,
62    /// Number of workers (threads) for parallelization (-1 = number of CPUs)
63    pub workers: i32,
64}
65
66impl Default for FitOptions {
67    fn default() -> Self {
68        Self {
69            seed: 42,
70            reuse: false,
71            init_clusters: 1,
72            max_clusters: usize::MAX,
73            iters: 100,
74            argmax_sample_stop: 5,
75            iter_split_stop: 5,
76            workers: 1,
77        }
78    }
79}