1use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11use rand_distr::{Beta, Distribution, LogNormal, Normal, Uniform};
12use rust_decimal::Decimal;
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Breakpoint {
18 pub threshold: f64,
20 pub distribution: ConditionalDistributionParams,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case", tag = "type")]
27pub enum ConditionalDistributionParams {
28 Fixed { value: f64 },
30 Normal { mu: f64, sigma: f64 },
32 LogNormal { mu: f64, sigma: f64 },
34 Uniform { min: f64, max: f64 },
36 Beta {
38 alpha: f64,
39 beta: f64,
40 min: f64,
41 max: f64,
42 },
43 Discrete { values: Vec<f64>, weights: Vec<f64> },
45}
46
47impl Default for ConditionalDistributionParams {
48 fn default() -> Self {
49 Self::Fixed { value: 0.0 }
50 }
51}
52
53impl ConditionalDistributionParams {
54 pub fn sample(&self, rng: &mut ChaCha8Rng) -> f64 {
56 match self {
57 Self::Fixed { value } => *value,
58 Self::Normal { mu, sigma } => {
59 let dist = Normal::new(*mu, *sigma).unwrap_or_else(|_| {
60 Normal::new(0.0, 1.0).expect("valid fallback distribution params")
61 });
62 dist.sample(rng)
63 }
64 Self::LogNormal { mu, sigma } => {
65 let dist = LogNormal::new(*mu, *sigma).unwrap_or_else(|_| {
66 LogNormal::new(0.0, 1.0).expect("valid fallback distribution params")
67 });
68 dist.sample(rng)
69 }
70 Self::Uniform { min, max } => {
71 let dist = Uniform::new(*min, *max).expect("valid uniform params");
72 dist.sample(rng)
73 }
74 Self::Beta {
75 alpha,
76 beta,
77 min,
78 max,
79 } => {
80 let dist = Beta::new(*alpha, *beta).unwrap_or_else(|_| {
81 Beta::new(2.0, 2.0).expect("valid fallback distribution params")
82 });
83 let u = dist.sample(rng);
84 min + u * (max - min)
85 }
86 Self::Discrete { values, weights } => {
87 if values.is_empty() {
88 return 0.0;
89 }
90 if weights.is_empty() || weights.len() != values.len() {
91 return *values.choose(rng).unwrap_or(&0.0);
93 }
94 let total: f64 = weights.iter().sum();
96 let mut p: f64 = rng.random::<f64>() * total;
97 for (i, w) in weights.iter().enumerate() {
98 p -= w;
99 if p <= 0.0 {
100 return values[i];
101 }
102 }
103 *values.last().unwrap_or(&0.0)
104 }
105 }
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ConditionalDistributionConfig {
112 pub output_field: String,
114 pub input_field: String,
116 pub breakpoints: Vec<Breakpoint>,
119 pub default_distribution: ConditionalDistributionParams,
121 #[serde(default)]
123 pub min_value: Option<f64>,
124 #[serde(default)]
126 pub max_value: Option<f64>,
127 #[serde(default = "default_decimal_places")]
129 pub decimal_places: u8,
130}
131
132fn default_decimal_places() -> u8 {
133 2
134}
135
136impl Default for ConditionalDistributionConfig {
137 fn default() -> Self {
138 Self {
139 output_field: "output".to_string(),
140 input_field: "input".to_string(),
141 breakpoints: vec![],
142 default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 },
143 min_value: None,
144 max_value: None,
145 decimal_places: 2,
146 }
147 }
148}
149
150impl ConditionalDistributionConfig {
151 pub fn new(
153 output_field: impl Into<String>,
154 input_field: impl Into<String>,
155 breakpoints: Vec<Breakpoint>,
156 default: ConditionalDistributionParams,
157 ) -> Self {
158 Self {
159 output_field: output_field.into(),
160 input_field: input_field.into(),
161 breakpoints,
162 default_distribution: default,
163 min_value: None,
164 max_value: None,
165 decimal_places: 2,
166 }
167 }
168
169 pub fn validate(&self) -> Result<(), String> {
171 for i in 1..self.breakpoints.len() {
173 if self.breakpoints[i].threshold <= self.breakpoints[i - 1].threshold {
174 return Err(format!(
175 "Breakpoints must be in ascending order: {} is not > {}",
176 self.breakpoints[i].threshold,
177 self.breakpoints[i - 1].threshold
178 ));
179 }
180 }
181
182 if let (Some(min), Some(max)) = (self.min_value, self.max_value) {
183 if max <= min {
184 return Err("max_value must be greater than min_value".to_string());
185 }
186 }
187
188 Ok(())
189 }
190
191 pub fn get_distribution(&self, input_value: f64) -> &ConditionalDistributionParams {
193 for breakpoint in self.breakpoints.iter().rev() {
195 if input_value >= breakpoint.threshold {
196 return &breakpoint.distribution;
197 }
198 }
199 &self.default_distribution
200 }
201}
202
203#[derive(Clone)]
205pub struct ConditionalSampler {
206 rng: ChaCha8Rng,
207 config: ConditionalDistributionConfig,
208 decimal_multiplier: f64,
209}
210
211impl ConditionalSampler {
212 pub fn new(seed: u64, config: ConditionalDistributionConfig) -> Result<Self, String> {
214 config.validate()?;
215 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
216 Ok(Self {
217 rng: ChaCha8Rng::seed_from_u64(seed),
218 config,
219 decimal_multiplier,
220 })
221 }
222
223 pub fn sample(&mut self, input_value: f64) -> f64 {
225 let dist = self.config.get_distribution(input_value);
226 let mut value = dist.sample(&mut self.rng);
227
228 if let Some(min) = self.config.min_value {
230 value = value.max(min);
231 }
232 if let Some(max) = self.config.max_value {
233 value = value.min(max);
234 }
235
236 (value * self.decimal_multiplier).round() / self.decimal_multiplier
238 }
239
240 pub fn sample_decimal(&mut self, input_value: f64) -> Decimal {
242 let value = self.sample(input_value);
243 Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
244 }
245
246 pub fn reset(&mut self, seed: u64) {
248 self.rng = ChaCha8Rng::seed_from_u64(seed);
249 }
250
251 pub fn config(&self) -> &ConditionalDistributionConfig {
253 &self.config
254 }
255}
256
257pub mod conditional_presets {
259 use super::*;
260
261 pub fn discount_by_amount() -> ConditionalDistributionConfig {
264 ConditionalDistributionConfig {
265 output_field: "discount_percent".to_string(),
266 input_field: "order_amount".to_string(),
267 breakpoints: vec![
268 Breakpoint {
269 threshold: 1000.0,
270 distribution: ConditionalDistributionParams::Beta {
271 alpha: 2.0,
272 beta: 8.0,
273 min: 0.01,
274 max: 0.05, },
276 },
277 Breakpoint {
278 threshold: 5000.0,
279 distribution: ConditionalDistributionParams::Beta {
280 alpha: 2.0,
281 beta: 5.0,
282 min: 0.02,
283 max: 0.08, },
285 },
286 Breakpoint {
287 threshold: 25000.0,
288 distribution: ConditionalDistributionParams::Beta {
289 alpha: 3.0,
290 beta: 3.0,
291 min: 0.05,
292 max: 0.12, },
294 },
295 Breakpoint {
296 threshold: 100000.0,
297 distribution: ConditionalDistributionParams::Beta {
298 alpha: 5.0,
299 beta: 2.0,
300 min: 0.08,
301 max: 0.15, },
303 },
304 ],
305 default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 },
306 min_value: Some(0.0),
307 max_value: Some(0.20),
308 decimal_places: 4,
309 }
310 }
311
312 pub fn approval_level_by_amount() -> ConditionalDistributionConfig {
314 ConditionalDistributionConfig {
315 output_field: "approval_level".to_string(),
316 input_field: "amount".to_string(),
317 breakpoints: vec![
318 Breakpoint {
319 threshold: 1000.0,
320 distribution: ConditionalDistributionParams::Discrete {
321 values: vec![1.0, 2.0],
322 weights: vec![0.9, 0.1],
323 },
324 },
325 Breakpoint {
326 threshold: 10000.0,
327 distribution: ConditionalDistributionParams::Discrete {
328 values: vec![2.0, 3.0],
329 weights: vec![0.7, 0.3],
330 },
331 },
332 Breakpoint {
333 threshold: 50000.0,
334 distribution: ConditionalDistributionParams::Discrete {
335 values: vec![3.0, 4.0],
336 weights: vec![0.6, 0.4],
337 },
338 },
339 Breakpoint {
340 threshold: 100000.0,
341 distribution: ConditionalDistributionParams::Fixed { value: 4.0 },
342 },
343 ],
344 default_distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
345 min_value: Some(1.0),
346 max_value: Some(4.0),
347 decimal_places: 0,
348 }
349 }
350
351 pub fn processing_time_by_complexity() -> ConditionalDistributionConfig {
353 ConditionalDistributionConfig {
354 output_field: "processing_days".to_string(),
355 input_field: "line_item_count".to_string(),
356 breakpoints: vec![
357 Breakpoint {
358 threshold: 5.0,
359 distribution: ConditionalDistributionParams::LogNormal {
360 mu: 0.5, sigma: 0.5,
362 },
363 },
364 Breakpoint {
365 threshold: 15.0,
366 distribution: ConditionalDistributionParams::LogNormal {
367 mu: 1.0, sigma: 0.5,
369 },
370 },
371 Breakpoint {
372 threshold: 30.0,
373 distribution: ConditionalDistributionParams::LogNormal {
374 mu: 1.5, sigma: 0.6,
376 },
377 },
378 ],
379 default_distribution: ConditionalDistributionParams::LogNormal {
380 mu: 0.0, sigma: 0.4,
382 },
383 min_value: Some(0.5),
384 max_value: Some(30.0),
385 decimal_places: 1,
386 }
387 }
388
389 pub fn payment_terms_by_credit_rating() -> ConditionalDistributionConfig {
391 ConditionalDistributionConfig {
392 output_field: "payment_terms_days".to_string(),
393 input_field: "credit_score".to_string(),
394 breakpoints: vec![
395 Breakpoint {
396 threshold: 300.0, distribution: ConditionalDistributionParams::Discrete {
398 values: vec![0.0, 15.0], weights: vec![0.7, 0.3],
400 },
401 },
402 Breakpoint {
403 threshold: 500.0, distribution: ConditionalDistributionParams::Discrete {
405 values: vec![15.0, 30.0],
406 weights: vec![0.5, 0.5],
407 },
408 },
409 Breakpoint {
410 threshold: 650.0, distribution: ConditionalDistributionParams::Discrete {
412 values: vec![30.0, 45.0, 60.0],
413 weights: vec![0.5, 0.3, 0.2],
414 },
415 },
416 Breakpoint {
417 threshold: 750.0, distribution: ConditionalDistributionParams::Discrete {
419 values: vec![30.0, 60.0, 90.0],
420 weights: vec![0.3, 0.4, 0.3],
421 },
422 },
423 ],
424 default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 }, min_value: Some(0.0),
426 max_value: Some(90.0),
427 decimal_places: 0,
428 }
429 }
430}
431
432#[cfg(test)]
433#[allow(clippy::unwrap_used)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_conditional_config_validation() {
439 let valid = ConditionalDistributionConfig::new(
440 "output",
441 "input",
442 vec![
443 Breakpoint {
444 threshold: 100.0,
445 distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
446 },
447 Breakpoint {
448 threshold: 200.0,
449 distribution: ConditionalDistributionParams::Fixed { value: 2.0 },
450 },
451 ],
452 ConditionalDistributionParams::Fixed { value: 0.0 },
453 );
454 assert!(valid.validate().is_ok());
455
456 let invalid = ConditionalDistributionConfig::new(
458 "output",
459 "input",
460 vec![
461 Breakpoint {
462 threshold: 200.0,
463 distribution: ConditionalDistributionParams::Fixed { value: 2.0 },
464 },
465 Breakpoint {
466 threshold: 100.0,
467 distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
468 },
469 ],
470 ConditionalDistributionParams::Fixed { value: 0.0 },
471 );
472 assert!(invalid.validate().is_err());
473 }
474
475 #[test]
476 fn test_conditional_sampling() {
477 let config = ConditionalDistributionConfig::new(
478 "output",
479 "input",
480 vec![
481 Breakpoint {
482 threshold: 100.0,
483 distribution: ConditionalDistributionParams::Fixed { value: 10.0 },
484 },
485 Breakpoint {
486 threshold: 200.0,
487 distribution: ConditionalDistributionParams::Fixed { value: 20.0 },
488 },
489 ],
490 ConditionalDistributionParams::Fixed { value: 0.0 },
491 );
492 let mut sampler = ConditionalSampler::new(42, config).unwrap();
493
494 assert_eq!(sampler.sample(50.0), 0.0);
496
497 assert_eq!(sampler.sample(150.0), 10.0);
499
500 assert_eq!(sampler.sample(250.0), 20.0);
502 }
503
504 #[test]
505 fn test_discount_by_amount_preset() {
506 let config = conditional_presets::discount_by_amount();
507 assert!(config.validate().is_ok());
508
509 let mut sampler = ConditionalSampler::new(42, config).unwrap();
510
511 let small_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(500.0)).collect();
513 let avg_small: f64 = small_discounts.iter().sum::<f64>() / 100.0;
514 assert!(avg_small < 0.01); sampler.reset(42);
518 let medium_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(3000.0)).collect();
519 let avg_medium: f64 = medium_discounts.iter().sum::<f64>() / 100.0;
520 assert!(avg_medium > 0.01 && avg_medium < 0.06);
521
522 sampler.reset(42);
524 let large_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(150000.0)).collect();
525 let avg_large: f64 = large_discounts.iter().sum::<f64>() / 100.0;
526 assert!(avg_large > 0.08);
527 }
528
529 #[test]
530 fn test_approval_level_preset() {
531 let config = conditional_presets::approval_level_by_amount();
532 assert!(config.validate().is_ok());
533
534 let mut sampler = ConditionalSampler::new(42, config).unwrap();
535
536 let level = sampler.sample(500.0);
538 assert_eq!(level, 1.0);
539
540 sampler.reset(42);
542 let levels: Vec<f64> = (0..100).map(|_| sampler.sample(75000.0)).collect();
543 let avg_level: f64 = levels.iter().sum::<f64>() / 100.0;
544 assert!(avg_level >= 3.0);
545 }
546
547 #[test]
548 fn test_distribution_params_sampling() {
549 let mut rng = ChaCha8Rng::seed_from_u64(42);
550
551 let normal = ConditionalDistributionParams::Normal {
553 mu: 10.0,
554 sigma: 1.0,
555 };
556 let samples: Vec<f64> = (0..1000).map(|_| normal.sample(&mut rng)).collect();
557 let mean: f64 = samples.iter().sum::<f64>() / 1000.0;
558 assert!((mean - 10.0).abs() < 0.5);
559
560 let beta = ConditionalDistributionParams::Beta {
562 alpha: 2.0,
563 beta: 5.0,
564 min: 0.0,
565 max: 1.0,
566 };
567 let samples: Vec<f64> = (0..1000).map(|_| beta.sample(&mut rng)).collect();
568 assert!(samples.iter().all(|&x| (0.0..=1.0).contains(&x)));
569
570 let discrete = ConditionalDistributionParams::Discrete {
572 values: vec![1.0, 2.0, 3.0],
573 weights: vec![0.5, 0.3, 0.2],
574 };
575 let samples: Vec<f64> = (0..1000).map(|_| discrete.sample(&mut rng)).collect();
576 let count_1 = samples.iter().filter(|&&x| x == 1.0).count();
577 assert!(count_1 > 400 && count_1 < 600); }
579
580 #[test]
581 fn test_conditional_determinism() {
582 let config = conditional_presets::discount_by_amount();
583
584 let mut sampler1 = ConditionalSampler::new(42, config.clone()).unwrap();
585 let mut sampler2 = ConditionalSampler::new(42, config).unwrap();
586
587 for amount in [100.0, 1000.0, 10000.0, 100000.0] {
588 assert_eq!(sampler1.sample(amount), sampler2.sample(amount));
589 }
590 }
591}