Skip to main content

prodef/
multivariate.rs

1//! A module that implements multiple univariate PDFs and a multivariate distribution that is a product of independent univariate PDFs.
2
3use crate::{
4    Density, SamplingMode, UnivariateDensity, domain::Domain, macros::tval,
5    univariate::match_univariate,
6};
7use derive_more::IntoIterator;
8use nalgebra::{
9    DefaultAllocator, Dim, OVector, RealField, SVector, Scalar, U1, VectorView,
10    allocator::Allocator,
11};
12use rand::RngExt;
13use rand_distr::{Distribution, StandardNormal, uniform::SampleUniform};
14use serde::{Deserialize, Serialize};
15use std::{f64, fmt::Debug, iter::repeat_with};
16
17/// A `D`-dimensional distribution where each dimension is
18/// **independent** with potentially different univariate distributions. This is a **product distribution**:
19/// - Each marginal follows one of the available univariate distributions (Normal, Uniform, Cosine, etc.)
20/// - The joint density is the product of marginals: f(x₁, ..., xₐ) = f₁(x₁) × ... × fₐ(xₐ)
21///
22/// # Construction & Examples
23///
24/// Create a mixed 3D distribution (Normal × Uniform × Constant):
25/// ```
26/// # use nalgebra::{Const, SVector};
27/// # use prodef::{ConstantDensity, MultivariateDensity, NormalDensity, UniformDensity, Density};
28/// let marginals = SVector::from([
29///     NormalDensity::new(0.0, 1.0, None, None).unwrap().into(),
30///     UniformDensity::new(-1.0, 1.0).unwrap().into(),
31///     ConstantDensity::new(2.0).into(),
32/// ]);
33/// let _dist = MultivariateDensity::<f64, Const<3>>::new(marginals);
34/// ```
35///
36/// Create a 5D distribution with mixed univariates:
37/// ```
38/// # use nalgebra::{Const, SVector};
39/// # use prodef::{ConstantDensity, CosineDensity, LogUniformDensity, MultivariateDensity, NormalDensity, UniformDensity};
40/// let mvpdf = MultivariateDensity::<f64, Const<5>>::new(SVector::from([
41///    ConstantDensity::new(1.0).into(),
42///    CosineDensity::new(0.1, 0.2).unwrap().into(),
43///    LogUniformDensity::new(0.1, 0.5).unwrap().into(),
44///    NormalDensity::new(0.1, 0.25, Some(-0.5), Some(1.5)).unwrap().into(),
45///    UniformDensity::new(1.0, 2.0).unwrap().into(),
46/// ]));
47/// ```
48///
49/// Evaluate density at a point:
50/// ```
51/// # use nalgebra::{U1, U2, SVector};
52/// # use prodef::{ConstantDensity, MultivariateDensity, NormalDensity, UniformDensity, Density};
53/// let marginals = SVector::from([
54///     NormalDensity::new(0.0, 1.0, None, None).unwrap().into(),
55///     UniformDensity::new(-1.0, 1.0).unwrap().into(),
56/// ]);
57/// let dist = MultivariateDensity::<f64, U2>::new(marginals);
58/// let sample = SVector::from([0.0, 0.5]);
59/// // Use the Density trait to evaluate - see crate::Density for usage patterns
60/// if let Some(dens) = (&dist).density::<U1, U2>(&sample.as_view()) {
61///     println!("Joint density: {}", dens);
62/// }
63/// ```
64///
65/// Sample from the distribution:
66/// ```
67/// # use nalgebra::{U2, SVector};
68/// # use prodef::{ConstantDensity, MultivariateDensity, NormalDensity, UniformDensity, Density, SamplingMode};
69/// # use rand::{SeedableRng, rngs::StdRng};
70/// let marginals = SVector::from([
71///     NormalDensity::new(0.0, 1.0, None, None).unwrap().into(),
72///     UniformDensity::new(-1.0, 1.0).unwrap().into(),
73/// ]);
74/// let dist = MultivariateDensity::<f64, U2>::new(marginals);
75/// let mut rng = StdRng::seed_from_u64(42);
76/// if let Some(sample) = (&dist).sample(&mut rng, &SamplingMode::default()) {
77///     println!("Generated sample: {:?}", sample);
78/// }
79/// ```
80#[derive(Clone, Debug, Deserialize, IntoIterator, Serialize)]
81#[serde(bound(serialize = "OVector<UnivariateDensity<T>, D>: Serialize"))]
82#[serde(bound(deserialize = "OVector<UnivariateDensity<T>, D>: Deserialize<'de>"))]
83pub struct MultivariateDensity<T, D>(#[into_iterator(owned, ref)] OVector<UnivariateDensity<T>, D>)
84where
85    T: Scalar,
86    D: Dim,
87    DefaultAllocator: Allocator<D>;
88
89impl<T, D> MultivariateDensity<T, D>
90where
91    T: RealField,
92    D: Dim,
93    DefaultAllocator: Allocator<D>,
94{
95    /// Create a new [`MultivariateDensity`] from a vector of [`UnivariateDensity`]s.
96    pub fn new(domains: OVector<UnivariateDensity<T>, D>) -> Self {
97        Self(domains)
98    }
99
100    /// Return a reference to the underlying vector of [`UnivariateDensity`]s.
101    pub fn marginals(&self) -> &OVector<UnivariateDensity<T>, D> {
102        &self.0
103    }
104}
105
106impl<T, D> Density<T, D> for MultivariateDensity<T, D>
107where
108    T: RealField + SampleUniform,
109    D: Dim,
110    StandardNormal: Distribution<T>,
111    DefaultAllocator: Allocator<D>,
112{
113    fn density<RStride: Dim, CStride: Dim>(
114        &self,
115        sample: &VectorView<T, D, RStride, CStride>,
116    ) -> Option<T> {
117        if !self.domain().contains(sample) {
118            return None;
119        }
120
121        let mut rlh = T::one();
122
123        self.0.iter().zip(sample.iter()).for_each(|(uvpdf, value)| {
124            let vec = SVector::from([value.clone()]);
125
126            rlh *= match_univariate!(uvpdf, pdf, {
127                Density::<T, U1>::density::<U1, U1>(&pdf, &vec.as_view())
128            })
129            .unwrap_or(tval!(f64::NAN, f64));
130        });
131
132        Some(rlh)
133    }
134
135    fn domain(&self) -> Domain<T, D> {
136        Domain::new_mdomain(OVector::from_iterator_generic(
137            self.0.shape_generic().0,
138            U1,
139            self.0.iter().map(|uvpdf| {
140                let (a, b) = match uvpdf {
141                    UnivariateDensity::Constant(pdf) => {
142                        (Some(pdf.constant()), Some(pdf.constant()))
143                    }
144                    UnivariateDensity::Cosine(pdf) => (Some(pdf.minimum()), Some(pdf.maximum())),
145                    UnivariateDensity::Lognormal(pdf) => (Some(pdf.minimum()), Some(pdf.maximum())),
146                    UnivariateDensity::Loguniform(pdf) => {
147                        (Some(pdf.minimum()), Some(pdf.maximum()))
148                    }
149                    UnivariateDensity::Normal(pdf) => (pdf.minimum(), pdf.maximum()),
150                    UnivariateDensity::Uniform(pdf) => (Some(pdf.minimum()), Some(pdf.maximum())),
151                };
152
153                (a, b)
154            }),
155        ))
156    }
157
158    fn mean(&self) -> OVector<T, D> {
159        OVector::from_iterator_generic(
160            self.0.shape_generic().0,
161            U1,
162            self.0
163                .iter()
164                .map(|uvpdf| match_univariate!(uvpdf, pdf, { pdf.mean() })[0].clone()),
165        )
166    }
167
168    fn sample(&self, rng: &mut impl RngExt, mode: &SamplingMode) -> Option<OVector<T, D>> {
169        let mut draw = OVector::<T, D>::zeros_generic(self.0.shape_generic().0, U1);
170
171        for i in 0..self.0.shape_generic().0.value() {
172            draw[i] = match_univariate!(&self.0[i], pdf, {
173                match Density::<T, U1>::sample(&pdf, rng, mode) {
174                    Some(sample) => sample[0].clone(),
175                    None => return None,
176                }
177            });
178        }
179
180        Some(draw)
181    }
182
183    fn sample_iter(&self, rng: &mut impl RngExt) -> impl Iterator<Item = Option<OVector<T, D>>> {
184        let n_dim = self.0.shape_generic().0;
185
186        repeat_with(move || {
187            let draw_opts = OVector::<Option<SVector<T, 1>>, D>::from_iterator_generic(
188                n_dim,
189                U1,
190                self.into_iter()
191                    .map(|pdf| pdf.sample(rng, &SamplingMode::SingleAttempt)),
192            );
193
194            if draw_opts.iter().any(|draw| draw.is_none()) {
195                return None;
196            }
197
198            // All samples are guaranteed to be Some due to check above
199            let draw = OVector::<T, D>::from_iterator_generic(
200                n_dim,
201                U1,
202                draw_opts.iter().map(|opt_draw| {
203                    // Safe: we verified no None values exist above
204                    opt_draw.as_ref().unwrap()[0].clone()
205                }),
206            );
207
208            Some(draw)
209        })
210    }
211
212    fn variance(&self) -> OVector<T, D> {
213        OVector::from_iterator_generic(
214            self.0.shape_generic().0,
215            U1,
216            self.0
217                .iter()
218                .map(|uvpdf| match_univariate!(uvpdf, pdf, { pdf.variance() })[0].clone()),
219        )
220    }
221}