prodef/multivariate.rs
1//! A module that implements multiple univariate PDFs and a multivariate distribution that is a product of independent univariate PDFs.
2
3use crate::{
4 Density, SamplingMode, UnivariateDensity, domain::Domain, macros::tval,
5 univariate::match_univariate,
6};
7use derive_more::IntoIterator;
8use nalgebra::{
9 DefaultAllocator, Dim, OVector, RealField, SVector, Scalar, U1, VectorView,
10 allocator::Allocator,
11};
12use rand::RngExt;
13use rand_distr::{Distribution, StandardNormal, uniform::SampleUniform};
14use serde::{Deserialize, Serialize};
15use std::{f64, fmt::Debug, iter::repeat_with};
16
17/// A `D`-dimensional distribution where each dimension is
18/// **independent** with potentially different univariate distributions. This is a **product distribution**:
19/// - Each marginal follows one of the available univariate distributions (Normal, Uniform, Cosine, etc.)
20/// - The joint density is the product of marginals: f(x₁, ..., xₐ) = f₁(x₁) × ... × fₐ(xₐ)
21///
22/// # Construction & Examples
23///
24/// Create a mixed 3D distribution (Normal × Uniform × Constant):
25/// ```
26/// # use nalgebra::{Const, SVector};
27/// # use prodef::{ConstantDensity, MultivariateDensity, NormalDensity, UniformDensity, Density};
28/// let marginals = SVector::from([
29/// NormalDensity::new(0.0, 1.0, None, None).unwrap().into(),
30/// UniformDensity::new(-1.0, 1.0).unwrap().into(),
31/// ConstantDensity::new(2.0).into(),
32/// ]);
33/// let _dist = MultivariateDensity::<f64, Const<3>>::new(marginals);
34/// ```
35///
36/// Create a 5D distribution with mixed univariates:
37/// ```
38/// # use nalgebra::{Const, SVector};
39/// # use prodef::{ConstantDensity, CosineDensity, LogUniformDensity, MultivariateDensity, NormalDensity, UniformDensity};
40/// let mvpdf = MultivariateDensity::<f64, Const<5>>::new(SVector::from([
41/// ConstantDensity::new(1.0).into(),
42/// CosineDensity::new(0.1, 0.2).unwrap().into(),
43/// LogUniformDensity::new(0.1, 0.5).unwrap().into(),
44/// NormalDensity::new(0.1, 0.25, Some(-0.5), Some(1.5)).unwrap().into(),
45/// UniformDensity::new(1.0, 2.0).unwrap().into(),
46/// ]));
47/// ```
48///
49/// Evaluate density at a point:
50/// ```
51/// # use nalgebra::{U1, U2, SVector};
52/// # use prodef::{ConstantDensity, MultivariateDensity, NormalDensity, UniformDensity, Density};
53/// let marginals = SVector::from([
54/// NormalDensity::new(0.0, 1.0, None, None).unwrap().into(),
55/// UniformDensity::new(-1.0, 1.0).unwrap().into(),
56/// ]);
57/// let dist = MultivariateDensity::<f64, U2>::new(marginals);
58/// let sample = SVector::from([0.0, 0.5]);
59/// // Use the Density trait to evaluate - see crate::Density for usage patterns
60/// if let Some(dens) = (&dist).density::<U1, U2>(&sample.as_view()) {
61/// println!("Joint density: {}", dens);
62/// }
63/// ```
64///
65/// Sample from the distribution:
66/// ```
67/// # use nalgebra::{U2, SVector};
68/// # use prodef::{ConstantDensity, MultivariateDensity, NormalDensity, UniformDensity, Density, SamplingMode};
69/// # use rand::{SeedableRng, rngs::StdRng};
70/// let marginals = SVector::from([
71/// NormalDensity::new(0.0, 1.0, None, None).unwrap().into(),
72/// UniformDensity::new(-1.0, 1.0).unwrap().into(),
73/// ]);
74/// let dist = MultivariateDensity::<f64, U2>::new(marginals);
75/// let mut rng = StdRng::seed_from_u64(42);
76/// if let Some(sample) = (&dist).sample(&mut rng, &SamplingMode::default()) {
77/// println!("Generated sample: {:?}", sample);
78/// }
79/// ```
80#[derive(Clone, Debug, Deserialize, IntoIterator, Serialize)]
81#[serde(bound(serialize = "OVector<UnivariateDensity<T>, D>: Serialize"))]
82#[serde(bound(deserialize = "OVector<UnivariateDensity<T>, D>: Deserialize<'de>"))]
83pub struct MultivariateDensity<T, D>(#[into_iterator(owned, ref)] OVector<UnivariateDensity<T>, D>)
84where
85 T: Scalar,
86 D: Dim,
87 DefaultAllocator: Allocator<D>;
88
89impl<T, D> MultivariateDensity<T, D>
90where
91 T: RealField,
92 D: Dim,
93 DefaultAllocator: Allocator<D>,
94{
95 /// Create a new [`MultivariateDensity`] from a vector of [`UnivariateDensity`]s.
96 pub fn new(domains: OVector<UnivariateDensity<T>, D>) -> Self {
97 Self(domains)
98 }
99
100 /// Return a reference to the underlying vector of [`UnivariateDensity`]s.
101 pub fn marginals(&self) -> &OVector<UnivariateDensity<T>, D> {
102 &self.0
103 }
104}
105
106impl<T, D> Density<T, D> for MultivariateDensity<T, D>
107where
108 T: RealField + SampleUniform,
109 D: Dim,
110 StandardNormal: Distribution<T>,
111 DefaultAllocator: Allocator<D>,
112{
113 fn density<RStride: Dim, CStride: Dim>(
114 &self,
115 sample: &VectorView<T, D, RStride, CStride>,
116 ) -> Option<T> {
117 if !self.domain().contains(sample) {
118 return None;
119 }
120
121 let mut rlh = T::one();
122
123 self.0.iter().zip(sample.iter()).for_each(|(uvpdf, value)| {
124 let vec = SVector::from([value.clone()]);
125
126 rlh *= match_univariate!(uvpdf, pdf, {
127 Density::<T, U1>::density::<U1, U1>(&pdf, &vec.as_view())
128 })
129 .unwrap_or(tval!(f64::NAN, f64));
130 });
131
132 Some(rlh)
133 }
134
135 fn domain(&self) -> Domain<T, D> {
136 Domain::new_mdomain(OVector::from_iterator_generic(
137 self.0.shape_generic().0,
138 U1,
139 self.0.iter().map(|uvpdf| {
140 let (a, b) = match uvpdf {
141 UnivariateDensity::Constant(pdf) => {
142 (Some(pdf.constant()), Some(pdf.constant()))
143 }
144 UnivariateDensity::Cosine(pdf) => (Some(pdf.minimum()), Some(pdf.maximum())),
145 UnivariateDensity::Lognormal(pdf) => (Some(pdf.minimum()), Some(pdf.maximum())),
146 UnivariateDensity::Loguniform(pdf) => {
147 (Some(pdf.minimum()), Some(pdf.maximum()))
148 }
149 UnivariateDensity::Normal(pdf) => (pdf.minimum(), pdf.maximum()),
150 UnivariateDensity::Uniform(pdf) => (Some(pdf.minimum()), Some(pdf.maximum())),
151 };
152
153 (a, b)
154 }),
155 ))
156 }
157
158 fn mean(&self) -> OVector<T, D> {
159 OVector::from_iterator_generic(
160 self.0.shape_generic().0,
161 U1,
162 self.0
163 .iter()
164 .map(|uvpdf| match_univariate!(uvpdf, pdf, { pdf.mean() })[0].clone()),
165 )
166 }
167
168 fn sample(&self, rng: &mut impl RngExt, mode: &SamplingMode) -> Option<OVector<T, D>> {
169 let mut draw = OVector::<T, D>::zeros_generic(self.0.shape_generic().0, U1);
170
171 for i in 0..self.0.shape_generic().0.value() {
172 draw[i] = match_univariate!(&self.0[i], pdf, {
173 match Density::<T, U1>::sample(&pdf, rng, mode) {
174 Some(sample) => sample[0].clone(),
175 None => return None,
176 }
177 });
178 }
179
180 Some(draw)
181 }
182
183 fn sample_iter(&self, rng: &mut impl RngExt) -> impl Iterator<Item = Option<OVector<T, D>>> {
184 let n_dim = self.0.shape_generic().0;
185
186 repeat_with(move || {
187 let draw_opts = OVector::<Option<SVector<T, 1>>, D>::from_iterator_generic(
188 n_dim,
189 U1,
190 self.into_iter()
191 .map(|pdf| pdf.sample(rng, &SamplingMode::SingleAttempt)),
192 );
193
194 if draw_opts.iter().any(|draw| draw.is_none()) {
195 return None;
196 }
197
198 // All samples are guaranteed to be Some due to check above
199 let draw = OVector::<T, D>::from_iterator_generic(
200 n_dim,
201 U1,
202 draw_opts.iter().map(|opt_draw| {
203 // Safe: we verified no None values exist above
204 opt_draw.as_ref().unwrap()[0].clone()
205 }),
206 );
207
208 Some(draw)
209 })
210 }
211
212 fn variance(&self) -> OVector<T, D> {
213 OVector::from_iterator_generic(
214 self.0.shape_generic().0,
215 U1,
216 self.0
217 .iter()
218 .map(|uvpdf| match_univariate!(uvpdf, pdf, { pdf.variance() })[0].clone()),
219 )
220 }
221}