1use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11use rand_distr::{Beta, Distribution, LogNormal, Normal, Uniform};
12use rust_decimal::Decimal;
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Breakpoint {
18 pub threshold: f64,
20 pub distribution: ConditionalDistributionParams,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case", tag = "type")]
27pub enum ConditionalDistributionParams {
28 Fixed { value: f64 },
30 Normal { mu: f64, sigma: f64 },
32 LogNormal { mu: f64, sigma: f64 },
34 Uniform { min: f64, max: f64 },
36 Beta {
38 alpha: f64,
39 beta: f64,
40 min: f64,
41 max: f64,
42 },
43 Discrete { values: Vec<f64>, weights: Vec<f64> },
45}
46
47impl Default for ConditionalDistributionParams {
48 fn default() -> Self {
49 Self::Fixed { value: 0.0 }
50 }
51}
52
53impl ConditionalDistributionParams {
54 pub fn sample(&self, rng: &mut ChaCha8Rng) -> f64 {
56 match self {
57 Self::Fixed { value } => *value,
58 Self::Normal { mu, sigma } => {
59 let dist = Normal::new(*mu, *sigma).unwrap_or_else(|_| {
60 Normal::new(0.0, 1.0).expect("valid fallback distribution params")
61 });
62 dist.sample(rng)
63 }
64 Self::LogNormal { mu, sigma } => {
65 let dist = LogNormal::new(*mu, *sigma).unwrap_or_else(|_| {
66 LogNormal::new(0.0, 1.0).expect("valid fallback distribution params")
67 });
68 dist.sample(rng)
69 }
70 Self::Uniform { min, max } => {
71 let dist = Uniform::new(*min, *max).expect("valid uniform params");
72 dist.sample(rng)
73 }
74 Self::Beta {
75 alpha,
76 beta,
77 min,
78 max,
79 } => {
80 let dist = Beta::new(*alpha, *beta).unwrap_or_else(|_| {
81 Beta::new(2.0, 2.0).expect("valid fallback distribution params")
82 });
83 let u = dist.sample(rng);
84 min + u * (max - min)
85 }
86 Self::Discrete { values, weights } => {
87 if values.is_empty() {
88 return 0.0;
89 }
90 if weights.is_empty() || weights.len() != values.len() {
91 return *values.choose(rng).unwrap_or(&0.0);
93 }
94 let total: f64 = weights.iter().sum();
96 let mut p: f64 = rng.random::<f64>() * total;
97 for (i, w) in weights.iter().enumerate() {
98 p -= w;
99 if p <= 0.0 {
100 return values[i];
101 }
102 }
103 *values.last().unwrap_or(&0.0)
104 }
105 }
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ConditionalDistributionConfig {
112 pub output_field: String,
114 pub input_field: String,
116 pub breakpoints: Vec<Breakpoint>,
119 pub default_distribution: ConditionalDistributionParams,
121 #[serde(default)]
123 pub min_value: Option<f64>,
124 #[serde(default)]
126 pub max_value: Option<f64>,
127 #[serde(default = "default_decimal_places")]
129 pub decimal_places: u8,
130}
131
132fn default_decimal_places() -> u8 {
133 2
134}
135
136impl Default for ConditionalDistributionConfig {
137 fn default() -> Self {
138 Self {
139 output_field: "output".to_string(),
140 input_field: "input".to_string(),
141 breakpoints: vec![],
142 default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 },
143 min_value: None,
144 max_value: None,
145 decimal_places: 2,
146 }
147 }
148}
149
150impl ConditionalDistributionConfig {
151 pub fn new(
153 output_field: impl Into<String>,
154 input_field: impl Into<String>,
155 breakpoints: Vec<Breakpoint>,
156 default: ConditionalDistributionParams,
157 ) -> Self {
158 Self {
159 output_field: output_field.into(),
160 input_field: input_field.into(),
161 breakpoints,
162 default_distribution: default,
163 min_value: None,
164 max_value: None,
165 decimal_places: 2,
166 }
167 }
168
169 pub fn validate(&self) -> Result<(), String> {
171 for i in 1..self.breakpoints.len() {
173 if self.breakpoints[i].threshold <= self.breakpoints[i - 1].threshold {
174 return Err(format!(
175 "Breakpoints must be in ascending order: {} is not > {}",
176 self.breakpoints[i].threshold,
177 self.breakpoints[i - 1].threshold
178 ));
179 }
180 }
181
182 if let (Some(min), Some(max)) = (self.min_value, self.max_value) {
183 if max <= min {
184 return Err("max_value must be greater than min_value".to_string());
185 }
186 }
187
188 Ok(())
189 }
190
191 pub fn get_distribution(&self, input_value: f64) -> &ConditionalDistributionParams {
193 for breakpoint in self.breakpoints.iter().rev() {
195 if input_value >= breakpoint.threshold {
196 return &breakpoint.distribution;
197 }
198 }
199 &self.default_distribution
200 }
201}
202
203#[derive(Clone)]
205pub struct ConditionalSampler {
206 rng: ChaCha8Rng,
207 config: ConditionalDistributionConfig,
208 decimal_multiplier: f64,
209}
210
211impl ConditionalSampler {
212 pub fn new(seed: u64, config: ConditionalDistributionConfig) -> Result<Self, String> {
214 config.validate()?;
215 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
216 Ok(Self {
217 rng: ChaCha8Rng::seed_from_u64(seed),
218 config,
219 decimal_multiplier,
220 })
221 }
222
223 pub fn sample(&mut self, input_value: f64) -> f64 {
225 let dist = self.config.get_distribution(input_value);
226 let mut value = dist.sample(&mut self.rng);
227
228 if let Some(min) = self.config.min_value {
230 value = value.max(min);
231 }
232 if let Some(max) = self.config.max_value {
233 value = value.min(max);
234 }
235
236 (value * self.decimal_multiplier).round() / self.decimal_multiplier
238 }
239
240 pub fn sample_decimal(&mut self, input_value: f64) -> Decimal {
242 let value = self.sample(input_value);
243 Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
244 }
245
246 pub fn reset(&mut self, seed: u64) {
248 self.rng = ChaCha8Rng::seed_from_u64(seed);
249 }
250
251 pub fn config(&self) -> &ConditionalDistributionConfig {
253 &self.config
254 }
255}
256
257pub mod conditional_presets {
259 use super::*;
260
261 pub fn discount_by_amount() -> ConditionalDistributionConfig {
264 ConditionalDistributionConfig {
265 output_field: "discount_percent".to_string(),
266 input_field: "order_amount".to_string(),
267 breakpoints: vec![
268 Breakpoint {
269 threshold: 1000.0,
270 distribution: ConditionalDistributionParams::Beta {
271 alpha: 2.0,
272 beta: 8.0,
273 min: 0.01,
274 max: 0.05, },
276 },
277 Breakpoint {
278 threshold: 5000.0,
279 distribution: ConditionalDistributionParams::Beta {
280 alpha: 2.0,
281 beta: 5.0,
282 min: 0.02,
283 max: 0.08, },
285 },
286 Breakpoint {
287 threshold: 25000.0,
288 distribution: ConditionalDistributionParams::Beta {
289 alpha: 3.0,
290 beta: 3.0,
291 min: 0.05,
292 max: 0.12, },
294 },
295 Breakpoint {
296 threshold: 100000.0,
297 distribution: ConditionalDistributionParams::Beta {
298 alpha: 5.0,
299 beta: 2.0,
300 min: 0.08,
301 max: 0.15, },
303 },
304 ],
305 default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 },
306 min_value: Some(0.0),
307 max_value: Some(0.20),
308 decimal_places: 4,
309 }
310 }
311
312 pub fn approval_level_by_amount() -> ConditionalDistributionConfig {
314 ConditionalDistributionConfig {
315 output_field: "approval_level".to_string(),
316 input_field: "amount".to_string(),
317 breakpoints: vec![
318 Breakpoint {
319 threshold: 1000.0,
320 distribution: ConditionalDistributionParams::Discrete {
321 values: vec![1.0, 2.0],
322 weights: vec![0.9, 0.1],
323 },
324 },
325 Breakpoint {
326 threshold: 10000.0,
327 distribution: ConditionalDistributionParams::Discrete {
328 values: vec![2.0, 3.0],
329 weights: vec![0.7, 0.3],
330 },
331 },
332 Breakpoint {
333 threshold: 50000.0,
334 distribution: ConditionalDistributionParams::Discrete {
335 values: vec![3.0, 4.0],
336 weights: vec![0.6, 0.4],
337 },
338 },
339 Breakpoint {
340 threshold: 100000.0,
341 distribution: ConditionalDistributionParams::Fixed { value: 4.0 },
342 },
343 ],
344 default_distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
345 min_value: Some(1.0),
346 max_value: Some(4.0),
347 decimal_places: 0,
348 }
349 }
350
351 pub fn processing_time_by_complexity() -> ConditionalDistributionConfig {
353 ConditionalDistributionConfig {
354 output_field: "processing_days".to_string(),
355 input_field: "line_item_count".to_string(),
356 breakpoints: vec![
357 Breakpoint {
358 threshold: 5.0,
359 distribution: ConditionalDistributionParams::LogNormal {
360 mu: 0.5, sigma: 0.5,
362 },
363 },
364 Breakpoint {
365 threshold: 15.0,
366 distribution: ConditionalDistributionParams::LogNormal {
367 mu: 1.0, sigma: 0.5,
369 },
370 },
371 Breakpoint {
372 threshold: 30.0,
373 distribution: ConditionalDistributionParams::LogNormal {
374 mu: 1.5, sigma: 0.6,
376 },
377 },
378 ],
379 default_distribution: ConditionalDistributionParams::LogNormal {
380 mu: 0.0, sigma: 0.4,
382 },
383 min_value: Some(0.5),
384 max_value: Some(30.0),
385 decimal_places: 1,
386 }
387 }
388
389 pub fn payment_terms_by_credit_rating() -> ConditionalDistributionConfig {
391 ConditionalDistributionConfig {
392 output_field: "payment_terms_days".to_string(),
393 input_field: "credit_score".to_string(),
394 breakpoints: vec![
395 Breakpoint {
396 threshold: 300.0, distribution: ConditionalDistributionParams::Discrete {
398 values: vec![0.0, 15.0], weights: vec![0.7, 0.3],
400 },
401 },
402 Breakpoint {
403 threshold: 500.0, distribution: ConditionalDistributionParams::Discrete {
405 values: vec![15.0, 30.0],
406 weights: vec![0.5, 0.5],
407 },
408 },
409 Breakpoint {
410 threshold: 650.0, distribution: ConditionalDistributionParams::Discrete {
412 values: vec![30.0, 45.0, 60.0],
413 weights: vec![0.5, 0.3, 0.2],
414 },
415 },
416 Breakpoint {
417 threshold: 750.0, distribution: ConditionalDistributionParams::Discrete {
419 values: vec![30.0, 60.0, 90.0],
420 weights: vec![0.3, 0.4, 0.3],
421 },
422 },
423 ],
424 default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 }, min_value: Some(0.0),
426 max_value: Some(90.0),
427 decimal_places: 0,
428 }
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435
436 #[test]
437 fn test_conditional_config_validation() {
438 let valid = ConditionalDistributionConfig::new(
439 "output",
440 "input",
441 vec![
442 Breakpoint {
443 threshold: 100.0,
444 distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
445 },
446 Breakpoint {
447 threshold: 200.0,
448 distribution: ConditionalDistributionParams::Fixed { value: 2.0 },
449 },
450 ],
451 ConditionalDistributionParams::Fixed { value: 0.0 },
452 );
453 assert!(valid.validate().is_ok());
454
455 let invalid = ConditionalDistributionConfig::new(
457 "output",
458 "input",
459 vec![
460 Breakpoint {
461 threshold: 200.0,
462 distribution: ConditionalDistributionParams::Fixed { value: 2.0 },
463 },
464 Breakpoint {
465 threshold: 100.0,
466 distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
467 },
468 ],
469 ConditionalDistributionParams::Fixed { value: 0.0 },
470 );
471 assert!(invalid.validate().is_err());
472 }
473
474 #[test]
475 fn test_conditional_sampling() {
476 let config = ConditionalDistributionConfig::new(
477 "output",
478 "input",
479 vec![
480 Breakpoint {
481 threshold: 100.0,
482 distribution: ConditionalDistributionParams::Fixed { value: 10.0 },
483 },
484 Breakpoint {
485 threshold: 200.0,
486 distribution: ConditionalDistributionParams::Fixed { value: 20.0 },
487 },
488 ],
489 ConditionalDistributionParams::Fixed { value: 0.0 },
490 );
491 let mut sampler = ConditionalSampler::new(42, config).unwrap();
492
493 assert_eq!(sampler.sample(50.0), 0.0);
495
496 assert_eq!(sampler.sample(150.0), 10.0);
498
499 assert_eq!(sampler.sample(250.0), 20.0);
501 }
502
503 #[test]
504 fn test_discount_by_amount_preset() {
505 let config = conditional_presets::discount_by_amount();
506 assert!(config.validate().is_ok());
507
508 let mut sampler = ConditionalSampler::new(42, config).unwrap();
509
510 let small_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(500.0)).collect();
512 let avg_small: f64 = small_discounts.iter().sum::<f64>() / 100.0;
513 assert!(avg_small < 0.01); sampler.reset(42);
517 let medium_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(3000.0)).collect();
518 let avg_medium: f64 = medium_discounts.iter().sum::<f64>() / 100.0;
519 assert!(avg_medium > 0.01 && avg_medium < 0.06);
520
521 sampler.reset(42);
523 let large_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(150000.0)).collect();
524 let avg_large: f64 = large_discounts.iter().sum::<f64>() / 100.0;
525 assert!(avg_large > 0.08);
526 }
527
528 #[test]
529 fn test_approval_level_preset() {
530 let config = conditional_presets::approval_level_by_amount();
531 assert!(config.validate().is_ok());
532
533 let mut sampler = ConditionalSampler::new(42, config).unwrap();
534
535 let level = sampler.sample(500.0);
537 assert_eq!(level, 1.0);
538
539 sampler.reset(42);
541 let levels: Vec<f64> = (0..100).map(|_| sampler.sample(75000.0)).collect();
542 let avg_level: f64 = levels.iter().sum::<f64>() / 100.0;
543 assert!(avg_level >= 3.0);
544 }
545
546 #[test]
547 fn test_distribution_params_sampling() {
548 let mut rng = ChaCha8Rng::seed_from_u64(42);
549
550 let normal = ConditionalDistributionParams::Normal {
552 mu: 10.0,
553 sigma: 1.0,
554 };
555 let samples: Vec<f64> = (0..1000).map(|_| normal.sample(&mut rng)).collect();
556 let mean: f64 = samples.iter().sum::<f64>() / 1000.0;
557 assert!((mean - 10.0).abs() < 0.5);
558
559 let beta = ConditionalDistributionParams::Beta {
561 alpha: 2.0,
562 beta: 5.0,
563 min: 0.0,
564 max: 1.0,
565 };
566 let samples: Vec<f64> = (0..1000).map(|_| beta.sample(&mut rng)).collect();
567 assert!(samples.iter().all(|&x| (0.0..=1.0).contains(&x)));
568
569 let discrete = ConditionalDistributionParams::Discrete {
571 values: vec![1.0, 2.0, 3.0],
572 weights: vec![0.5, 0.3, 0.2],
573 };
574 let samples: Vec<f64> = (0..1000).map(|_| discrete.sample(&mut rng)).collect();
575 let count_1 = samples.iter().filter(|&&x| x == 1.0).count();
576 assert!(count_1 > 400 && count_1 < 600); }
578
579 #[test]
580 fn test_conditional_determinism() {
581 let config = conditional_presets::discount_by_amount();
582
583 let mut sampler1 = ConditionalSampler::new(42, config.clone()).unwrap();
584 let mut sampler2 = ConditionalSampler::new(42, config).unwrap();
585
586 for amount in [100.0, 1000.0, 10000.0, 100000.0] {
587 assert_eq!(sampler1.sample(amount), sampler2.sample(amount));
588 }
589 }
590}