mixturs/params/
options.rs1use crate::stats::{NormalConjugatePrior, PriorHyperParams};
2
3#[derive(Debug, Clone, PartialEq)]
5pub struct OutlierRemoval<P: NormalConjugatePrior> {
6 pub weight: f64,
8 pub dist: P::HyperParams,
10}
11
12#[derive(Debug, Clone, PartialEq)]
14pub struct ModelOptions<P: NormalConjugatePrior> {
15 pub data_dist: P::HyperParams,
17 pub alpha: f64,
19 pub dim: usize,
21 pub burnout_period: usize,
23 pub outlier: Option<OutlierRemoval<P>>,
25 pub hard_assignment: bool,
27}
28
29impl<P: NormalConjugatePrior> ModelOptions<P> {
30 pub fn default(dim: usize) -> Self {
31 Self {
32 data_dist: P::HyperParams::default(dim),
33 alpha: 10.0,
34 dim,
35 burnout_period: 20,
36 outlier: Some(OutlierRemoval {
37 weight: 0.05,
38 dist: P::HyperParams::default(dim),
39 }),
40 hard_assignment: false,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct FitOptions {
48 pub seed: u64,
50 pub reuse: bool,
52 pub init_clusters: usize,
54 pub max_clusters: usize,
56 pub iters: usize,
58 pub argmax_sample_stop: usize,
60 pub iter_split_stop: usize,
62 pub workers: i32,
64}
65
66impl Default for FitOptions {
67 fn default() -> Self {
68 Self {
69 seed: 42,
70 reuse: false,
71 init_clusters: 1,
72 max_clusters: usize::MAX,
73 iters: 100,
74 argmax_sample_stop: 5,
75 iter_split_stop: 5,
76 workers: 1,
77 }
78 }
79}