1use crate::dispatch_dtype;
7use crate::dtype::DType;
8use crate::error::Result;
9use crate::ops::impl_generic::{
10 DTypeSupport, MultinomialSamplingOps, dirichlet_impl, multinomial_samples_impl,
11 multivariate_normal_impl, wishart_impl,
12};
13use crate::ops::traits::multivariate::MultivariateRandomOps;
14use crate::ops::{BinaryOps, CumulativeOps, RandomOps, ReduceOps};
15use crate::runtime::cpu::{CpuClient, CpuRuntime};
16use crate::tensor::Tensor;
17
18impl MultivariateRandomOps<CpuRuntime> for CpuClient {
19 fn multivariate_normal(
20 &self,
21 mean: &Tensor<CpuRuntime>,
22 cov: &Tensor<CpuRuntime>,
23 n_samples: usize,
24 ) -> Result<Tensor<CpuRuntime>> {
25 multivariate_normal_impl(self, mean, cov, n_samples, DTypeSupport::FULL)
26 }
27
28 fn wishart(
29 &self,
30 scale: &Tensor<CpuRuntime>,
31 df: usize,
32 n_samples: usize,
33 ) -> Result<Tensor<CpuRuntime>> {
34 wishart_impl(self, scale, df, n_samples, DTypeSupport::FULL)
35 }
36
37 fn dirichlet(
38 &self,
39 alpha: &Tensor<CpuRuntime>,
40 n_samples: usize,
41 ) -> Result<Tensor<CpuRuntime>> {
42 dirichlet_impl(self, alpha, n_samples)
43 }
44
45 fn multinomial_samples(
46 &self,
47 probs: &Tensor<CpuRuntime>,
48 n_trials: usize,
49 n_samples: usize,
50 ) -> Result<Tensor<CpuRuntime>> {
51 multinomial_samples_impl(self, probs, n_trials, n_samples)
52 }
53}
54
55impl MultinomialSamplingOps<CpuRuntime> for CpuClient {
60 fn multinomial_sample_kernel(
61 &self,
62 probs: &Tensor<CpuRuntime>,
63 n_trials: usize,
64 n_samples: usize,
65 ) -> Result<Tensor<CpuRuntime>> {
66 let dtype = probs.dtype();
67 let k = probs.shape()[0];
68
69 let sum_probs = self.sum(probs, &[0], false)?;
71 let normalized = self.div(probs, &sum_probs)?;
72
73 let cdf = self.cumsum(&normalized, 0)?;
75
76 let uniforms = self.rand(&[n_samples, n_trials], dtype)?;
78
79 multinomial_count_kernel(&cdf, &uniforms, n_samples, n_trials, k, dtype, &self.device)
82 }
83}
84
85fn multinomial_count_kernel(
90 cdf: &Tensor<CpuRuntime>,
91 uniforms: &Tensor<CpuRuntime>,
92 n_samples: usize,
93 n_trials: usize,
94 k: usize,
95 dtype: DType,
96 device: &<CpuRuntime as crate::runtime::Runtime>::Device,
97) -> Result<Tensor<CpuRuntime>> {
98 dispatch_dtype!(dtype, T => {
99 multinomial_count_typed::<T>(cdf, uniforms, n_samples, n_trials, k, device)
100 }, "multinomial_count")
101}
102
103fn multinomial_count_typed<T>(
107 cdf: &Tensor<CpuRuntime>,
108 uniforms: &Tensor<CpuRuntime>,
109 n_samples: usize,
110 n_trials: usize,
111 k: usize,
112 device: &<CpuRuntime as crate::runtime::Runtime>::Device,
113) -> Result<Tensor<CpuRuntime>>
114where
115 T: crate::dtype::Element + PartialOrd,
116{
117 let cdf_data: Vec<T> = cdf.to_vec();
118 let uniform_data: Vec<T> = uniforms.to_vec();
119 let mut counts = vec![T::zero(); n_samples * k];
120
121 for s in 0..n_samples {
122 for t in 0..n_trials {
123 let u = uniform_data[s * n_trials + t];
124 let category = binary_search_cdf(&cdf_data, u);
126 counts[s * k + category] = counts[s * k + category] + T::one();
127 }
128 }
129
130 Ok(Tensor::<CpuRuntime>::from_slice(
131 &counts,
132 &[n_samples, k],
133 device,
134 ))
135}
136
137fn binary_search_cdf<T: PartialOrd>(cdf: &[T], u: T) -> usize {
139 let mut lo = 0;
140 let mut hi = cdf.len();
141 while lo < hi {
142 let mid = lo + (hi - lo) / 2;
143 if cdf[mid] <= u {
144 lo = mid + 1;
145 } else {
146 hi = mid;
147 }
148 }
149 lo.min(cdf.len() - 1)
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use crate::runtime::Runtime;
156
157 fn get_client() -> CpuClient {
158 let device = CpuRuntime::default_device();
159 CpuRuntime::default_client(&device)
160 }
161
162 #[test]
163 fn test_multivariate_normal_basic() {
164 let client = get_client();
165 let mean = Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0], &[2], &client.device);
166 let cov =
167 Tensor::<CpuRuntime>::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &client.device);
168
169 let samples = client
170 .multivariate_normal(&mean, &cov, 100)
171 .expect("multivariate_normal should succeed with valid inputs");
172 assert_eq!(samples.shape(), &[100, 2]);
173
174 let sample_data: Vec<f32> = samples.to_vec();
176 let (mut mean_0, mut mean_1) = (0.0f64, 0.0f64);
177 for i in 0..100 {
178 mean_0 += sample_data[i * 2] as f64;
179 mean_1 += sample_data[i * 2 + 1] as f64;
180 }
181 mean_0 /= 100.0;
182 mean_1 /= 100.0;
183
184 assert!(mean_0.abs() < 0.5, "Mean 0 too far from 0: {}", mean_0);
186 assert!(mean_1.abs() < 0.5, "Mean 1 too far from 0: {}", mean_1);
187 }
188
189 #[test]
190 fn test_multivariate_normal_correlated() {
191 let client = get_client();
192 let mean = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0], &[2], &client.device);
193 let cov =
194 Tensor::<CpuRuntime>::from_slice(&[1.0f64, 0.8, 0.8, 1.0], &[2, 2], &client.device);
195
196 let samples = client
197 .multivariate_normal(&mean, &cov, 1000)
198 .expect("multivariate_normal should succeed with correlated covariance");
199 assert_eq!(samples.shape(), &[1000, 2]);
200 }
201
202 #[test]
203 fn test_multivariate_normal_invalid_cov() {
204 let client = get_client();
205 let mean = Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0], &[2], &client.device);
206 let cov =
208 Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 2.0, 1.0], &[2, 2], &client.device);
209
210 let result = client.multivariate_normal(&mean, &cov, 100);
211 assert!(
212 result.is_err(),
213 "Should fail with non-positive-definite cov"
214 );
215 }
216
217 #[test]
218 fn test_dirichlet_basic() {
219 let client = get_client();
220 let alpha = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0], &[3], &client.device);
221
222 let samples = client
223 .dirichlet(&alpha, 100)
224 .expect("dirichlet should succeed with valid inputs");
225 assert_eq!(samples.shape(), &[100, 3]);
226
227 let sample_data: Vec<f32> = samples.to_vec();
229 for i in 0..100 {
230 let row_sum: f32 = sample_data[i * 3..i * 3 + 3].iter().sum();
231 assert!(
232 (row_sum - 1.0).abs() < 1e-5,
233 "Row {} sum is {}, expected 1.0",
234 i,
235 row_sum
236 );
237 }
238 }
239
240 #[test]
241 fn test_dirichlet_concentrated() {
242 let client = get_client();
243 let alpha = Tensor::<CpuRuntime>::from_slice(&[100.0f64, 1.0, 1.0], &[3], &client.device);
244
245 let samples = client
246 .dirichlet(&alpha, 100)
247 .expect("dirichlet should succeed with concentrated alpha");
248 let sample_data: Vec<f64> = samples.to_vec();
249
250 let mut mean_0 = 0.0;
252 for i in 0..100 {
253 mean_0 += sample_data[i * 3];
254 }
255 mean_0 /= 100.0;
256
257 assert!(
259 mean_0 > 0.9,
260 "Expected first category mean > 0.9, got {}",
261 mean_0
262 );
263 }
264
265 #[test]
266 fn test_multinomial_samples_basic() {
267 let client = get_client();
268 let probs = Tensor::<CpuRuntime>::from_slice(&[1.0f32; 6], &[6], &client.device);
269
270 let samples = client
271 .multinomial_samples(&probs, 60, 100)
272 .expect("multinomial_samples should succeed with valid inputs");
273 assert_eq!(samples.shape(), &[100, 6]);
274
275 let sample_data: Vec<f32> = samples.to_vec();
277 for i in 0..100 {
278 let row_sum: f32 = sample_data[i * 6..i * 6 + 6].iter().sum();
279 assert!(
280 (row_sum - 60.0).abs() < 1e-5,
281 "Row {} sum is {}, expected 60.0",
282 i,
283 row_sum
284 );
285 }
286 }
287
288 #[test]
289 fn test_multinomial_samples_biased() {
290 let client = get_client();
291 let probs = Tensor::<CpuRuntime>::from_slice(&[0.99f64, 0.01], &[2], &client.device);
292
293 let samples = client
294 .multinomial_samples(&probs, 100, 50)
295 .expect("multinomial_samples should succeed with biased probs");
296 let sample_data: Vec<f64> = samples.to_vec();
297
298 let mut mean_0 = 0.0;
300 for i in 0..50 {
301 mean_0 += sample_data[i * 2];
302 }
303 mean_0 /= 50.0;
304
305 assert!(
307 mean_0 > 90.0,
308 "Expected first category mean > 90, got {}",
309 mean_0
310 );
311 }
312
313 #[test]
314 fn test_wishart_basic() {
315 let client = get_client();
316 let scale =
317 Tensor::<CpuRuntime>::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &client.device);
318
319 let samples = client
320 .wishart(&scale, 5, 10)
321 .expect("wishart should succeed with valid inputs");
322 assert_eq!(samples.shape(), &[10, 2, 2]);
323
324 let sample_data: Vec<f32> = samples.to_vec();
326 for s in 0..10 {
327 let offset = s * 4;
328 let a00 = sample_data[offset];
329 let a01 = sample_data[offset + 1];
330 let a10 = sample_data[offset + 2];
331 let a11 = sample_data[offset + 3];
332
333 assert!(
335 (a01 - a10).abs() < 1e-4,
336 "Sample {} not symmetric: a01={}, a10={}",
337 s,
338 a01,
339 a10
340 );
341
342 assert!(a00 > 0.0, "Sample {} has non-positive a00: {}", s, a00);
344 assert!(a11 > 0.0, "Sample {} has non-positive a11: {}", s, a11);
345 let det = a00 * a11 - a01 * a10;
346 assert!(
347 det > 0.0,
348 "Sample {} has non-positive determinant: {}",
349 s,
350 det
351 );
352 }
353 }
354
355 #[test]
356 fn test_wishart_f64() {
357 let client = get_client();
358 let scale =
359 Tensor::<CpuRuntime>::from_slice(&[1.0f64, 0.0, 0.0, 1.0], &[2, 2], &client.device);
360
361 let samples = client
362 .wishart(&scale, 5, 5)
363 .expect("wishart should succeed with F64");
364 assert_eq!(samples.shape(), &[5, 2, 2]);
365 assert_eq!(samples.dtype(), crate::dtype::DType::F64);
366 }
367}