1use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11use rand_distr::{Beta, Distribution, LogNormal, Normal, Uniform};
12use rust_decimal::Decimal;
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Breakpoint {
18 pub threshold: f64,
20 pub distribution: ConditionalDistributionParams,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case", tag = "type")]
27pub enum ConditionalDistributionParams {
28 Fixed { value: f64 },
30 Normal { mu: f64, sigma: f64 },
32 LogNormal { mu: f64, sigma: f64 },
34 Uniform { min: f64, max: f64 },
36 Beta {
38 alpha: f64,
39 beta: f64,
40 min: f64,
41 max: f64,
42 },
43 Discrete { values: Vec<f64>, weights: Vec<f64> },
45}
46
47impl Default for ConditionalDistributionParams {
48 fn default() -> Self {
49 Self::Fixed { value: 0.0 }
50 }
51}
52
53impl ConditionalDistributionParams {
54 pub fn sample(&self, rng: &mut ChaCha8Rng) -> f64 {
56 match self {
57 Self::Fixed { value } => *value,
58 Self::Normal { mu, sigma } => {
59 let dist =
60 Normal::new(*mu, *sigma).unwrap_or_else(|_| Normal::new(0.0, 1.0).unwrap());
61 dist.sample(rng)
62 }
63 Self::LogNormal { mu, sigma } => {
64 let dist = LogNormal::new(*mu, *sigma)
65 .unwrap_or_else(|_| LogNormal::new(0.0, 1.0).unwrap());
66 dist.sample(rng)
67 }
68 Self::Uniform { min, max } => {
69 let dist = Uniform::new(*min, *max);
70 dist.sample(rng)
71 }
72 Self::Beta {
73 alpha,
74 beta,
75 min,
76 max,
77 } => {
78 let dist =
79 Beta::new(*alpha, *beta).unwrap_or_else(|_| Beta::new(2.0, 2.0).unwrap());
80 let u = dist.sample(rng);
81 min + u * (max - min)
82 }
83 Self::Discrete { values, weights } => {
84 if values.is_empty() {
85 return 0.0;
86 }
87 if weights.is_empty() || weights.len() != values.len() {
88 return *values.choose(rng).unwrap_or(&0.0);
90 }
91 let total: f64 = weights.iter().sum();
93 let mut p: f64 = rng.gen::<f64>() * total;
94 for (i, w) in weights.iter().enumerate() {
95 p -= w;
96 if p <= 0.0 {
97 return values[i];
98 }
99 }
100 *values.last().unwrap_or(&0.0)
101 }
102 }
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ConditionalDistributionConfig {
109 pub output_field: String,
111 pub input_field: String,
113 pub breakpoints: Vec<Breakpoint>,
116 pub default_distribution: ConditionalDistributionParams,
118 #[serde(default)]
120 pub min_value: Option<f64>,
121 #[serde(default)]
123 pub max_value: Option<f64>,
124 #[serde(default = "default_decimal_places")]
126 pub decimal_places: u8,
127}
128
129fn default_decimal_places() -> u8 {
130 2
131}
132
133impl Default for ConditionalDistributionConfig {
134 fn default() -> Self {
135 Self {
136 output_field: "output".to_string(),
137 input_field: "input".to_string(),
138 breakpoints: vec![],
139 default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 },
140 min_value: None,
141 max_value: None,
142 decimal_places: 2,
143 }
144 }
145}
146
147impl ConditionalDistributionConfig {
148 pub fn new(
150 output_field: impl Into<String>,
151 input_field: impl Into<String>,
152 breakpoints: Vec<Breakpoint>,
153 default: ConditionalDistributionParams,
154 ) -> Self {
155 Self {
156 output_field: output_field.into(),
157 input_field: input_field.into(),
158 breakpoints,
159 default_distribution: default,
160 min_value: None,
161 max_value: None,
162 decimal_places: 2,
163 }
164 }
165
166 pub fn validate(&self) -> Result<(), String> {
168 for i in 1..self.breakpoints.len() {
170 if self.breakpoints[i].threshold <= self.breakpoints[i - 1].threshold {
171 return Err(format!(
172 "Breakpoints must be in ascending order: {} is not > {}",
173 self.breakpoints[i].threshold,
174 self.breakpoints[i - 1].threshold
175 ));
176 }
177 }
178
179 if let (Some(min), Some(max)) = (self.min_value, self.max_value) {
180 if max <= min {
181 return Err("max_value must be greater than min_value".to_string());
182 }
183 }
184
185 Ok(())
186 }
187
188 pub fn get_distribution(&self, input_value: f64) -> &ConditionalDistributionParams {
190 for breakpoint in self.breakpoints.iter().rev() {
192 if input_value >= breakpoint.threshold {
193 return &breakpoint.distribution;
194 }
195 }
196 &self.default_distribution
197 }
198}
199
200pub struct ConditionalSampler {
202 rng: ChaCha8Rng,
203 config: ConditionalDistributionConfig,
204 decimal_multiplier: f64,
205}
206
207impl ConditionalSampler {
208 pub fn new(seed: u64, config: ConditionalDistributionConfig) -> Result<Self, String> {
210 config.validate()?;
211 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
212 Ok(Self {
213 rng: ChaCha8Rng::seed_from_u64(seed),
214 config,
215 decimal_multiplier,
216 })
217 }
218
219 pub fn sample(&mut self, input_value: f64) -> f64 {
221 let dist = self.config.get_distribution(input_value);
222 let mut value = dist.sample(&mut self.rng);
223
224 if let Some(min) = self.config.min_value {
226 value = value.max(min);
227 }
228 if let Some(max) = self.config.max_value {
229 value = value.min(max);
230 }
231
232 (value * self.decimal_multiplier).round() / self.decimal_multiplier
234 }
235
236 pub fn sample_decimal(&mut self, input_value: f64) -> Decimal {
238 let value = self.sample(input_value);
239 Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
240 }
241
242 pub fn reset(&mut self, seed: u64) {
244 self.rng = ChaCha8Rng::seed_from_u64(seed);
245 }
246
247 pub fn config(&self) -> &ConditionalDistributionConfig {
249 &self.config
250 }
251}
252
253pub mod conditional_presets {
255 use super::*;
256
257 pub fn discount_by_amount() -> ConditionalDistributionConfig {
260 ConditionalDistributionConfig {
261 output_field: "discount_percent".to_string(),
262 input_field: "order_amount".to_string(),
263 breakpoints: vec![
264 Breakpoint {
265 threshold: 1000.0,
266 distribution: ConditionalDistributionParams::Beta {
267 alpha: 2.0,
268 beta: 8.0,
269 min: 0.01,
270 max: 0.05, },
272 },
273 Breakpoint {
274 threshold: 5000.0,
275 distribution: ConditionalDistributionParams::Beta {
276 alpha: 2.0,
277 beta: 5.0,
278 min: 0.02,
279 max: 0.08, },
281 },
282 Breakpoint {
283 threshold: 25000.0,
284 distribution: ConditionalDistributionParams::Beta {
285 alpha: 3.0,
286 beta: 3.0,
287 min: 0.05,
288 max: 0.12, },
290 },
291 Breakpoint {
292 threshold: 100000.0,
293 distribution: ConditionalDistributionParams::Beta {
294 alpha: 5.0,
295 beta: 2.0,
296 min: 0.08,
297 max: 0.15, },
299 },
300 ],
301 default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 },
302 min_value: Some(0.0),
303 max_value: Some(0.20),
304 decimal_places: 4,
305 }
306 }
307
308 pub fn approval_level_by_amount() -> ConditionalDistributionConfig {
310 ConditionalDistributionConfig {
311 output_field: "approval_level".to_string(),
312 input_field: "amount".to_string(),
313 breakpoints: vec![
314 Breakpoint {
315 threshold: 1000.0,
316 distribution: ConditionalDistributionParams::Discrete {
317 values: vec![1.0, 2.0],
318 weights: vec![0.9, 0.1],
319 },
320 },
321 Breakpoint {
322 threshold: 10000.0,
323 distribution: ConditionalDistributionParams::Discrete {
324 values: vec![2.0, 3.0],
325 weights: vec![0.7, 0.3],
326 },
327 },
328 Breakpoint {
329 threshold: 50000.0,
330 distribution: ConditionalDistributionParams::Discrete {
331 values: vec![3.0, 4.0],
332 weights: vec![0.6, 0.4],
333 },
334 },
335 Breakpoint {
336 threshold: 100000.0,
337 distribution: ConditionalDistributionParams::Fixed { value: 4.0 },
338 },
339 ],
340 default_distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
341 min_value: Some(1.0),
342 max_value: Some(4.0),
343 decimal_places: 0,
344 }
345 }
346
347 pub fn processing_time_by_complexity() -> ConditionalDistributionConfig {
349 ConditionalDistributionConfig {
350 output_field: "processing_days".to_string(),
351 input_field: "line_item_count".to_string(),
352 breakpoints: vec![
353 Breakpoint {
354 threshold: 5.0,
355 distribution: ConditionalDistributionParams::LogNormal {
356 mu: 0.5, sigma: 0.5,
358 },
359 },
360 Breakpoint {
361 threshold: 15.0,
362 distribution: ConditionalDistributionParams::LogNormal {
363 mu: 1.0, sigma: 0.5,
365 },
366 },
367 Breakpoint {
368 threshold: 30.0,
369 distribution: ConditionalDistributionParams::LogNormal {
370 mu: 1.5, sigma: 0.6,
372 },
373 },
374 ],
375 default_distribution: ConditionalDistributionParams::LogNormal {
376 mu: 0.0, sigma: 0.4,
378 },
379 min_value: Some(0.5),
380 max_value: Some(30.0),
381 decimal_places: 1,
382 }
383 }
384
385 pub fn payment_terms_by_credit_rating() -> ConditionalDistributionConfig {
387 ConditionalDistributionConfig {
388 output_field: "payment_terms_days".to_string(),
389 input_field: "credit_score".to_string(),
390 breakpoints: vec![
391 Breakpoint {
392 threshold: 300.0, distribution: ConditionalDistributionParams::Discrete {
394 values: vec![0.0, 15.0], weights: vec![0.7, 0.3],
396 },
397 },
398 Breakpoint {
399 threshold: 500.0, distribution: ConditionalDistributionParams::Discrete {
401 values: vec![15.0, 30.0],
402 weights: vec![0.5, 0.5],
403 },
404 },
405 Breakpoint {
406 threshold: 650.0, distribution: ConditionalDistributionParams::Discrete {
408 values: vec![30.0, 45.0, 60.0],
409 weights: vec![0.5, 0.3, 0.2],
410 },
411 },
412 Breakpoint {
413 threshold: 750.0, distribution: ConditionalDistributionParams::Discrete {
415 values: vec![30.0, 60.0, 90.0],
416 weights: vec![0.3, 0.4, 0.3],
417 },
418 },
419 ],
420 default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 }, min_value: Some(0.0),
422 max_value: Some(90.0),
423 decimal_places: 0,
424 }
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431
432 #[test]
433 fn test_conditional_config_validation() {
434 let valid = ConditionalDistributionConfig::new(
435 "output",
436 "input",
437 vec![
438 Breakpoint {
439 threshold: 100.0,
440 distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
441 },
442 Breakpoint {
443 threshold: 200.0,
444 distribution: ConditionalDistributionParams::Fixed { value: 2.0 },
445 },
446 ],
447 ConditionalDistributionParams::Fixed { value: 0.0 },
448 );
449 assert!(valid.validate().is_ok());
450
451 let invalid = ConditionalDistributionConfig::new(
453 "output",
454 "input",
455 vec![
456 Breakpoint {
457 threshold: 200.0,
458 distribution: ConditionalDistributionParams::Fixed { value: 2.0 },
459 },
460 Breakpoint {
461 threshold: 100.0,
462 distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
463 },
464 ],
465 ConditionalDistributionParams::Fixed { value: 0.0 },
466 );
467 assert!(invalid.validate().is_err());
468 }
469
470 #[test]
471 fn test_conditional_sampling() {
472 let config = ConditionalDistributionConfig::new(
473 "output",
474 "input",
475 vec![
476 Breakpoint {
477 threshold: 100.0,
478 distribution: ConditionalDistributionParams::Fixed { value: 10.0 },
479 },
480 Breakpoint {
481 threshold: 200.0,
482 distribution: ConditionalDistributionParams::Fixed { value: 20.0 },
483 },
484 ],
485 ConditionalDistributionParams::Fixed { value: 0.0 },
486 );
487 let mut sampler = ConditionalSampler::new(42, config).unwrap();
488
489 assert_eq!(sampler.sample(50.0), 0.0);
491
492 assert_eq!(sampler.sample(150.0), 10.0);
494
495 assert_eq!(sampler.sample(250.0), 20.0);
497 }
498
499 #[test]
500 fn test_discount_by_amount_preset() {
501 let config = conditional_presets::discount_by_amount();
502 assert!(config.validate().is_ok());
503
504 let mut sampler = ConditionalSampler::new(42, config).unwrap();
505
506 let small_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(500.0)).collect();
508 let avg_small: f64 = small_discounts.iter().sum::<f64>() / 100.0;
509 assert!(avg_small < 0.01); sampler.reset(42);
513 let medium_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(3000.0)).collect();
514 let avg_medium: f64 = medium_discounts.iter().sum::<f64>() / 100.0;
515 assert!(avg_medium > 0.01 && avg_medium < 0.06);
516
517 sampler.reset(42);
519 let large_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(150000.0)).collect();
520 let avg_large: f64 = large_discounts.iter().sum::<f64>() / 100.0;
521 assert!(avg_large > 0.08);
522 }
523
524 #[test]
525 fn test_approval_level_preset() {
526 let config = conditional_presets::approval_level_by_amount();
527 assert!(config.validate().is_ok());
528
529 let mut sampler = ConditionalSampler::new(42, config).unwrap();
530
531 let level = sampler.sample(500.0);
533 assert_eq!(level, 1.0);
534
535 sampler.reset(42);
537 let levels: Vec<f64> = (0..100).map(|_| sampler.sample(75000.0)).collect();
538 let avg_level: f64 = levels.iter().sum::<f64>() / 100.0;
539 assert!(avg_level >= 3.0);
540 }
541
542 #[test]
543 fn test_distribution_params_sampling() {
544 let mut rng = ChaCha8Rng::seed_from_u64(42);
545
546 let normal = ConditionalDistributionParams::Normal {
548 mu: 10.0,
549 sigma: 1.0,
550 };
551 let samples: Vec<f64> = (0..1000).map(|_| normal.sample(&mut rng)).collect();
552 let mean: f64 = samples.iter().sum::<f64>() / 1000.0;
553 assert!((mean - 10.0).abs() < 0.5);
554
555 let beta = ConditionalDistributionParams::Beta {
557 alpha: 2.0,
558 beta: 5.0,
559 min: 0.0,
560 max: 1.0,
561 };
562 let samples: Vec<f64> = (0..1000).map(|_| beta.sample(&mut rng)).collect();
563 assert!(samples.iter().all(|&x| (0.0..=1.0).contains(&x)));
564
565 let discrete = ConditionalDistributionParams::Discrete {
567 values: vec![1.0, 2.0, 3.0],
568 weights: vec![0.5, 0.3, 0.2],
569 };
570 let samples: Vec<f64> = (0..1000).map(|_| discrete.sample(&mut rng)).collect();
571 let count_1 = samples.iter().filter(|&&x| x == 1.0).count();
572 assert!(count_1 > 400 && count_1 < 600); }
574
575 #[test]
576 fn test_conditional_determinism() {
577 let config = conditional_presets::discount_by_amount();
578
579 let mut sampler1 = ConditionalSampler::new(42, config.clone()).unwrap();
580 let mut sampler2 = ConditionalSampler::new(42, config).unwrap();
581
582 for amount in [100.0, 1000.0, 10000.0, 100000.0] {
583 assert_eq!(sampler1.sample(amount), sampler2.sample(amount));
584 }
585 }
586}