Skip to main content

numr/ops/cpu/
multivariate.rs

1//! CPU implementation of multivariate random distribution operations.
2//!
3//! This module provides thin wrappers that delegate to the generic implementations
4//! in `impl_generic/multivariate.rs` to ensure numerical parity across all backends.
5
6use crate::dispatch_dtype;
7use crate::dtype::DType;
8use crate::error::Result;
9use crate::ops::impl_generic::{
10    DTypeSupport, MultinomialSamplingOps, dirichlet_impl, multinomial_samples_impl,
11    multivariate_normal_impl, wishart_impl,
12};
13use crate::ops::traits::multivariate::MultivariateRandomOps;
14use crate::ops::{BinaryOps, CumulativeOps, RandomOps, ReduceOps};
15use crate::runtime::cpu::{CpuClient, CpuRuntime};
16use crate::tensor::Tensor;
17
18impl MultivariateRandomOps<CpuRuntime> for CpuClient {
19    fn multivariate_normal(
20        &self,
21        mean: &Tensor<CpuRuntime>,
22        cov: &Tensor<CpuRuntime>,
23        n_samples: usize,
24    ) -> Result<Tensor<CpuRuntime>> {
25        multivariate_normal_impl(self, mean, cov, n_samples, DTypeSupport::FULL)
26    }
27
28    fn wishart(
29        &self,
30        scale: &Tensor<CpuRuntime>,
31        df: usize,
32        n_samples: usize,
33    ) -> Result<Tensor<CpuRuntime>> {
34        wishart_impl(self, scale, df, n_samples, DTypeSupport::FULL)
35    }
36
37    fn dirichlet(
38        &self,
39        alpha: &Tensor<CpuRuntime>,
40        n_samples: usize,
41    ) -> Result<Tensor<CpuRuntime>> {
42        dirichlet_impl(self, alpha, n_samples)
43    }
44
45    fn multinomial_samples(
46        &self,
47        probs: &Tensor<CpuRuntime>,
48        n_trials: usize,
49        n_samples: usize,
50    ) -> Result<Tensor<CpuRuntime>> {
51        multinomial_samples_impl(self, probs, n_trials, n_samples)
52    }
53}
54
55/// CPU implementation of multinomial sampling kernel.
56///
57/// For CPU, we can efficiently implement this using native operations
58/// since CPU doesn't have the same kernel launch overhead as GPU.
59impl MultinomialSamplingOps<CpuRuntime> for CpuClient {
60    fn multinomial_sample_kernel(
61        &self,
62        probs: &Tensor<CpuRuntime>,
63        n_trials: usize,
64        n_samples: usize,
65    ) -> Result<Tensor<CpuRuntime>> {
66        let dtype = probs.dtype();
67        let k = probs.shape()[0];
68
69        // Step 1: Normalize probabilities (on CPU, this is just tensor ops)
70        let sum_probs = self.sum(probs, &[0], false)?;
71        let normalized = self.div(probs, &sum_probs)?;
72
73        // Step 2: Compute CDF using cumsum
74        let cdf = self.cumsum(&normalized, 0)?;
75
76        // Step 3: Generate uniform samples [n_samples, n_trials]
77        let uniforms = self.rand(&[n_samples, n_trials], dtype)?;
78
79        // Step 4: CDF lookup and counting
80        // For CPU, we implement this efficiently in native code
81        multinomial_count_kernel(&cdf, &uniforms, n_samples, n_trials, k, dtype, &self.device)
82    }
83}
84
85/// Native CPU kernel for multinomial counting.
86///
87/// Takes CDF tensor [k] and uniform samples [n_samples, n_trials],
88/// returns counts [n_samples, k].
89fn multinomial_count_kernel(
90    cdf: &Tensor<CpuRuntime>,
91    uniforms: &Tensor<CpuRuntime>,
92    n_samples: usize,
93    n_trials: usize,
94    k: usize,
95    dtype: DType,
96    device: &<CpuRuntime as crate::runtime::Runtime>::Device,
97) -> Result<Tensor<CpuRuntime>> {
98    dispatch_dtype!(dtype, T => {
99        multinomial_count_typed::<T>(cdf, uniforms, n_samples, n_trials, k, device)
100    }, "multinomial_count")
101}
102
103/// Type-specific multinomial counting implementation.
104///
105/// Generic over float types to eliminate code duplication.
106fn multinomial_count_typed<T>(
107    cdf: &Tensor<CpuRuntime>,
108    uniforms: &Tensor<CpuRuntime>,
109    n_samples: usize,
110    n_trials: usize,
111    k: usize,
112    device: &<CpuRuntime as crate::runtime::Runtime>::Device,
113) -> Result<Tensor<CpuRuntime>>
114where
115    T: crate::dtype::Element + PartialOrd,
116{
117    let cdf_data: Vec<T> = cdf.to_vec();
118    let uniform_data: Vec<T> = uniforms.to_vec();
119    let mut counts = vec![T::zero(); n_samples * k];
120
121    for s in 0..n_samples {
122        for t in 0..n_trials {
123            let u = uniform_data[s * n_trials + t];
124            // Binary search for category
125            let category = binary_search_cdf(&cdf_data, u);
126            counts[s * k + category] = counts[s * k + category] + T::one();
127        }
128    }
129
130    Ok(Tensor::<CpuRuntime>::from_slice(
131        &counts,
132        &[n_samples, k],
133        device,
134    ))
135}
136
137/// Binary search to find the category for a uniform sample.
138fn binary_search_cdf<T: PartialOrd>(cdf: &[T], u: T) -> usize {
139    let mut lo = 0;
140    let mut hi = cdf.len();
141    while lo < hi {
142        let mid = lo + (hi - lo) / 2;
143        if cdf[mid] <= u {
144            lo = mid + 1;
145        } else {
146            hi = mid;
147        }
148    }
149    lo.min(cdf.len() - 1)
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use crate::runtime::Runtime;
156
157    fn get_client() -> CpuClient {
158        let device = CpuRuntime::default_device();
159        CpuRuntime::default_client(&device)
160    }
161
162    #[test]
163    fn test_multivariate_normal_basic() {
164        let client = get_client();
165        let mean = Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0], &[2], &client.device);
166        let cov =
167            Tensor::<CpuRuntime>::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &client.device);
168
169        let samples = client
170            .multivariate_normal(&mean, &cov, 100)
171            .expect("multivariate_normal should succeed with valid inputs");
172        assert_eq!(samples.shape(), &[100, 2]);
173
174        // Verify samples have reasonable statistics
175        let sample_data: Vec<f32> = samples.to_vec();
176        let (mut mean_0, mut mean_1) = (0.0f64, 0.0f64);
177        for i in 0..100 {
178            mean_0 += sample_data[i * 2] as f64;
179            mean_1 += sample_data[i * 2 + 1] as f64;
180        }
181        mean_0 /= 100.0;
182        mean_1 /= 100.0;
183
184        // With 100 samples from N(0,1), means should be within ~0.5 of 0
185        assert!(mean_0.abs() < 0.5, "Mean 0 too far from 0: {}", mean_0);
186        assert!(mean_1.abs() < 0.5, "Mean 1 too far from 0: {}", mean_1);
187    }
188
189    #[test]
190    fn test_multivariate_normal_correlated() {
191        let client = get_client();
192        let mean = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0], &[2], &client.device);
193        let cov =
194            Tensor::<CpuRuntime>::from_slice(&[1.0f64, 0.8, 0.8, 1.0], &[2, 2], &client.device);
195
196        let samples = client
197            .multivariate_normal(&mean, &cov, 1000)
198            .expect("multivariate_normal should succeed with correlated covariance");
199        assert_eq!(samples.shape(), &[1000, 2]);
200    }
201
202    #[test]
203    fn test_multivariate_normal_invalid_cov() {
204        let client = get_client();
205        let mean = Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0], &[2], &client.device);
206        // Not positive definite
207        let cov =
208            Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 2.0, 1.0], &[2, 2], &client.device);
209
210        let result = client.multivariate_normal(&mean, &cov, 100);
211        assert!(
212            result.is_err(),
213            "Should fail with non-positive-definite cov"
214        );
215    }
216
217    #[test]
218    fn test_dirichlet_basic() {
219        let client = get_client();
220        let alpha = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0], &[3], &client.device);
221
222        let samples = client
223            .dirichlet(&alpha, 100)
224            .expect("dirichlet should succeed with valid inputs");
225        assert_eq!(samples.shape(), &[100, 3]);
226
227        // Verify each row sums to 1
228        let sample_data: Vec<f32> = samples.to_vec();
229        for i in 0..100 {
230            let row_sum: f32 = sample_data[i * 3..i * 3 + 3].iter().sum();
231            assert!(
232                (row_sum - 1.0).abs() < 1e-5,
233                "Row {} sum is {}, expected 1.0",
234                i,
235                row_sum
236            );
237        }
238    }
239
240    #[test]
241    fn test_dirichlet_concentrated() {
242        let client = get_client();
243        let alpha = Tensor::<CpuRuntime>::from_slice(&[100.0f64, 1.0, 1.0], &[3], &client.device);
244
245        let samples = client
246            .dirichlet(&alpha, 100)
247            .expect("dirichlet should succeed with concentrated alpha");
248        let sample_data: Vec<f64> = samples.to_vec();
249
250        // First category should have most mass
251        let mut mean_0 = 0.0;
252        for i in 0..100 {
253            mean_0 += sample_data[i * 3];
254        }
255        mean_0 /= 100.0;
256
257        // Expected: 100 / (100 + 1 + 1) ≈ 0.98
258        assert!(
259            mean_0 > 0.9,
260            "Expected first category mean > 0.9, got {}",
261            mean_0
262        );
263    }
264
265    #[test]
266    fn test_multinomial_samples_basic() {
267        let client = get_client();
268        let probs = Tensor::<CpuRuntime>::from_slice(&[1.0f32; 6], &[6], &client.device);
269
270        let samples = client
271            .multinomial_samples(&probs, 60, 100)
272            .expect("multinomial_samples should succeed with valid inputs");
273        assert_eq!(samples.shape(), &[100, 6]);
274
275        // Verify each row sums to n_trials
276        let sample_data: Vec<f32> = samples.to_vec();
277        for i in 0..100 {
278            let row_sum: f32 = sample_data[i * 6..i * 6 + 6].iter().sum();
279            assert!(
280                (row_sum - 60.0).abs() < 1e-5,
281                "Row {} sum is {}, expected 60.0",
282                i,
283                row_sum
284            );
285        }
286    }
287
288    #[test]
289    fn test_multinomial_samples_biased() {
290        let client = get_client();
291        let probs = Tensor::<CpuRuntime>::from_slice(&[0.99f64, 0.01], &[2], &client.device);
292
293        let samples = client
294            .multinomial_samples(&probs, 100, 50)
295            .expect("multinomial_samples should succeed with biased probs");
296        let sample_data: Vec<f64> = samples.to_vec();
297
298        // First category should have most counts
299        let mut mean_0 = 0.0;
300        for i in 0..50 {
301            mean_0 += sample_data[i * 2];
302        }
303        mean_0 /= 50.0;
304
305        // Expected: ~99 out of 100 trials
306        assert!(
307            mean_0 > 90.0,
308            "Expected first category mean > 90, got {}",
309            mean_0
310        );
311    }
312
313    #[test]
314    fn test_wishart_basic() {
315        let client = get_client();
316        let scale =
317            Tensor::<CpuRuntime>::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &client.device);
318
319        let samples = client
320            .wishart(&scale, 5, 10)
321            .expect("wishart should succeed with valid inputs");
322        assert_eq!(samples.shape(), &[10, 2, 2]);
323
324        // Verify samples are symmetric and positive definite
325        let sample_data: Vec<f32> = samples.to_vec();
326        for s in 0..10 {
327            let offset = s * 4;
328            let a00 = sample_data[offset];
329            let a01 = sample_data[offset + 1];
330            let a10 = sample_data[offset + 2];
331            let a11 = sample_data[offset + 3];
332
333            // Check symmetry
334            assert!(
335                (a01 - a10).abs() < 1e-4,
336                "Sample {} not symmetric: a01={}, a10={}",
337                s,
338                a01,
339                a10
340            );
341
342            // Check positive definiteness
343            assert!(a00 > 0.0, "Sample {} has non-positive a00: {}", s, a00);
344            assert!(a11 > 0.0, "Sample {} has non-positive a11: {}", s, a11);
345            let det = a00 * a11 - a01 * a10;
346            assert!(
347                det > 0.0,
348                "Sample {} has non-positive determinant: {}",
349                s,
350                det
351            );
352        }
353    }
354
355    #[test]
356    fn test_wishart_f64() {
357        let client = get_client();
358        let scale =
359            Tensor::<CpuRuntime>::from_slice(&[1.0f64, 0.0, 0.0, 1.0], &[2, 2], &client.device);
360
361        let samples = client
362            .wishart(&scale, 5, 5)
363            .expect("wishart should succeed with F64");
364        assert_eq!(samples.shape(), &[5, 2, 2]);
365        assert_eq!(samples.dtype(), crate::dtype::DType::F64);
366    }
367}