1use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11use rand_distr::{Beta, Distribution, LogNormal, Normal, Uniform};
12use rust_decimal::Decimal;
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Breakpoint {
18 pub threshold: f64,
20 pub distribution: ConditionalDistributionParams,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case", tag = "type")]
27pub enum ConditionalDistributionParams {
28 Fixed { value: f64 },
30 Normal { mu: f64, sigma: f64 },
32 LogNormal { mu: f64, sigma: f64 },
34 Uniform { min: f64, max: f64 },
36 Beta {
38 alpha: f64,
39 beta: f64,
40 min: f64,
41 max: f64,
42 },
43 Discrete { values: Vec<f64>, weights: Vec<f64> },
45}
46
47impl Default for ConditionalDistributionParams {
48 fn default() -> Self {
49 Self::Fixed { value: 0.0 }
50 }
51}
52
53impl ConditionalDistributionParams {
54 pub fn sample(&self, rng: &mut ChaCha8Rng) -> f64 {
56 match self {
57 Self::Fixed { value } => *value,
58 Self::Normal { mu, sigma } => {
59 let dist = Normal::new(*mu, *sigma).unwrap_or_else(|_| {
60 Normal::new(0.0, 1.0).expect("valid fallback distribution params")
61 });
62 dist.sample(rng)
63 }
64 Self::LogNormal { mu, sigma } => {
65 let dist = LogNormal::new(*mu, *sigma).unwrap_or_else(|_| {
66 LogNormal::new(0.0, 1.0).expect("valid fallback distribution params")
67 });
68 dist.sample(rng)
69 }
70 Self::Uniform { min, max } => {
71 let dist = Uniform::new(*min, *max);
72 dist.sample(rng)
73 }
74 Self::Beta {
75 alpha,
76 beta,
77 min,
78 max,
79 } => {
80 let dist = Beta::new(*alpha, *beta).unwrap_or_else(|_| {
81 Beta::new(2.0, 2.0).expect("valid fallback distribution params")
82 });
83 let u = dist.sample(rng);
84 min + u * (max - min)
85 }
86 Self::Discrete { values, weights } => {
87 if values.is_empty() {
88 return 0.0;
89 }
90 if weights.is_empty() || weights.len() != values.len() {
91 return *values.choose(rng).unwrap_or(&0.0);
93 }
94 let total: f64 = weights.iter().sum();
96 let mut p: f64 = rng.gen::<f64>() * total;
97 for (i, w) in weights.iter().enumerate() {
98 p -= w;
99 if p <= 0.0 {
100 return values[i];
101 }
102 }
103 *values.last().unwrap_or(&0.0)
104 }
105 }
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ConditionalDistributionConfig {
112 pub output_field: String,
114 pub input_field: String,
116 pub breakpoints: Vec<Breakpoint>,
119 pub default_distribution: ConditionalDistributionParams,
121 #[serde(default)]
123 pub min_value: Option<f64>,
124 #[serde(default)]
126 pub max_value: Option<f64>,
127 #[serde(default = "default_decimal_places")]
129 pub decimal_places: u8,
130}
131
132fn default_decimal_places() -> u8 {
133 2
134}
135
136impl Default for ConditionalDistributionConfig {
137 fn default() -> Self {
138 Self {
139 output_field: "output".to_string(),
140 input_field: "input".to_string(),
141 breakpoints: vec![],
142 default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 },
143 min_value: None,
144 max_value: None,
145 decimal_places: 2,
146 }
147 }
148}
149
150impl ConditionalDistributionConfig {
151 pub fn new(
153 output_field: impl Into<String>,
154 input_field: impl Into<String>,
155 breakpoints: Vec<Breakpoint>,
156 default: ConditionalDistributionParams,
157 ) -> Self {
158 Self {
159 output_field: output_field.into(),
160 input_field: input_field.into(),
161 breakpoints,
162 default_distribution: default,
163 min_value: None,
164 max_value: None,
165 decimal_places: 2,
166 }
167 }
168
169 pub fn validate(&self) -> Result<(), String> {
171 for i in 1..self.breakpoints.len() {
173 if self.breakpoints[i].threshold <= self.breakpoints[i - 1].threshold {
174 return Err(format!(
175 "Breakpoints must be in ascending order: {} is not > {}",
176 self.breakpoints[i].threshold,
177 self.breakpoints[i - 1].threshold
178 ));
179 }
180 }
181
182 if let (Some(min), Some(max)) = (self.min_value, self.max_value) {
183 if max <= min {
184 return Err("max_value must be greater than min_value".to_string());
185 }
186 }
187
188 Ok(())
189 }
190
191 pub fn get_distribution(&self, input_value: f64) -> &ConditionalDistributionParams {
193 for breakpoint in self.breakpoints.iter().rev() {
195 if input_value >= breakpoint.threshold {
196 return &breakpoint.distribution;
197 }
198 }
199 &self.default_distribution
200 }
201}
202
203pub struct ConditionalSampler {
205 rng: ChaCha8Rng,
206 config: ConditionalDistributionConfig,
207 decimal_multiplier: f64,
208}
209
210impl ConditionalSampler {
211 pub fn new(seed: u64, config: ConditionalDistributionConfig) -> Result<Self, String> {
213 config.validate()?;
214 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
215 Ok(Self {
216 rng: ChaCha8Rng::seed_from_u64(seed),
217 config,
218 decimal_multiplier,
219 })
220 }
221
222 pub fn sample(&mut self, input_value: f64) -> f64 {
224 let dist = self.config.get_distribution(input_value);
225 let mut value = dist.sample(&mut self.rng);
226
227 if let Some(min) = self.config.min_value {
229 value = value.max(min);
230 }
231 if let Some(max) = self.config.max_value {
232 value = value.min(max);
233 }
234
235 (value * self.decimal_multiplier).round() / self.decimal_multiplier
237 }
238
239 pub fn sample_decimal(&mut self, input_value: f64) -> Decimal {
241 let value = self.sample(input_value);
242 Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
243 }
244
245 pub fn reset(&mut self, seed: u64) {
247 self.rng = ChaCha8Rng::seed_from_u64(seed);
248 }
249
250 pub fn config(&self) -> &ConditionalDistributionConfig {
252 &self.config
253 }
254}
255
256pub mod conditional_presets {
258 use super::*;
259
260 pub fn discount_by_amount() -> ConditionalDistributionConfig {
263 ConditionalDistributionConfig {
264 output_field: "discount_percent".to_string(),
265 input_field: "order_amount".to_string(),
266 breakpoints: vec![
267 Breakpoint {
268 threshold: 1000.0,
269 distribution: ConditionalDistributionParams::Beta {
270 alpha: 2.0,
271 beta: 8.0,
272 min: 0.01,
273 max: 0.05, },
275 },
276 Breakpoint {
277 threshold: 5000.0,
278 distribution: ConditionalDistributionParams::Beta {
279 alpha: 2.0,
280 beta: 5.0,
281 min: 0.02,
282 max: 0.08, },
284 },
285 Breakpoint {
286 threshold: 25000.0,
287 distribution: ConditionalDistributionParams::Beta {
288 alpha: 3.0,
289 beta: 3.0,
290 min: 0.05,
291 max: 0.12, },
293 },
294 Breakpoint {
295 threshold: 100000.0,
296 distribution: ConditionalDistributionParams::Beta {
297 alpha: 5.0,
298 beta: 2.0,
299 min: 0.08,
300 max: 0.15, },
302 },
303 ],
304 default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 },
305 min_value: Some(0.0),
306 max_value: Some(0.20),
307 decimal_places: 4,
308 }
309 }
310
311 pub fn approval_level_by_amount() -> ConditionalDistributionConfig {
313 ConditionalDistributionConfig {
314 output_field: "approval_level".to_string(),
315 input_field: "amount".to_string(),
316 breakpoints: vec![
317 Breakpoint {
318 threshold: 1000.0,
319 distribution: ConditionalDistributionParams::Discrete {
320 values: vec![1.0, 2.0],
321 weights: vec![0.9, 0.1],
322 },
323 },
324 Breakpoint {
325 threshold: 10000.0,
326 distribution: ConditionalDistributionParams::Discrete {
327 values: vec![2.0, 3.0],
328 weights: vec![0.7, 0.3],
329 },
330 },
331 Breakpoint {
332 threshold: 50000.0,
333 distribution: ConditionalDistributionParams::Discrete {
334 values: vec![3.0, 4.0],
335 weights: vec![0.6, 0.4],
336 },
337 },
338 Breakpoint {
339 threshold: 100000.0,
340 distribution: ConditionalDistributionParams::Fixed { value: 4.0 },
341 },
342 ],
343 default_distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
344 min_value: Some(1.0),
345 max_value: Some(4.0),
346 decimal_places: 0,
347 }
348 }
349
350 pub fn processing_time_by_complexity() -> ConditionalDistributionConfig {
352 ConditionalDistributionConfig {
353 output_field: "processing_days".to_string(),
354 input_field: "line_item_count".to_string(),
355 breakpoints: vec![
356 Breakpoint {
357 threshold: 5.0,
358 distribution: ConditionalDistributionParams::LogNormal {
359 mu: 0.5, sigma: 0.5,
361 },
362 },
363 Breakpoint {
364 threshold: 15.0,
365 distribution: ConditionalDistributionParams::LogNormal {
366 mu: 1.0, sigma: 0.5,
368 },
369 },
370 Breakpoint {
371 threshold: 30.0,
372 distribution: ConditionalDistributionParams::LogNormal {
373 mu: 1.5, sigma: 0.6,
375 },
376 },
377 ],
378 default_distribution: ConditionalDistributionParams::LogNormal {
379 mu: 0.0, sigma: 0.4,
381 },
382 min_value: Some(0.5),
383 max_value: Some(30.0),
384 decimal_places: 1,
385 }
386 }
387
388 pub fn payment_terms_by_credit_rating() -> ConditionalDistributionConfig {
390 ConditionalDistributionConfig {
391 output_field: "payment_terms_days".to_string(),
392 input_field: "credit_score".to_string(),
393 breakpoints: vec![
394 Breakpoint {
395 threshold: 300.0, distribution: ConditionalDistributionParams::Discrete {
397 values: vec![0.0, 15.0], weights: vec![0.7, 0.3],
399 },
400 },
401 Breakpoint {
402 threshold: 500.0, distribution: ConditionalDistributionParams::Discrete {
404 values: vec![15.0, 30.0],
405 weights: vec![0.5, 0.5],
406 },
407 },
408 Breakpoint {
409 threshold: 650.0, distribution: ConditionalDistributionParams::Discrete {
411 values: vec![30.0, 45.0, 60.0],
412 weights: vec![0.5, 0.3, 0.2],
413 },
414 },
415 Breakpoint {
416 threshold: 750.0, distribution: ConditionalDistributionParams::Discrete {
418 values: vec![30.0, 60.0, 90.0],
419 weights: vec![0.3, 0.4, 0.3],
420 },
421 },
422 ],
423 default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 }, min_value: Some(0.0),
425 max_value: Some(90.0),
426 decimal_places: 0,
427 }
428 }
429}
430
431#[cfg(test)]
432#[allow(clippy::unwrap_used)]
433mod tests {
434 use super::*;
435
436 #[test]
437 fn test_conditional_config_validation() {
438 let valid = ConditionalDistributionConfig::new(
439 "output",
440 "input",
441 vec![
442 Breakpoint {
443 threshold: 100.0,
444 distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
445 },
446 Breakpoint {
447 threshold: 200.0,
448 distribution: ConditionalDistributionParams::Fixed { value: 2.0 },
449 },
450 ],
451 ConditionalDistributionParams::Fixed { value: 0.0 },
452 );
453 assert!(valid.validate().is_ok());
454
455 let invalid = ConditionalDistributionConfig::new(
457 "output",
458 "input",
459 vec![
460 Breakpoint {
461 threshold: 200.0,
462 distribution: ConditionalDistributionParams::Fixed { value: 2.0 },
463 },
464 Breakpoint {
465 threshold: 100.0,
466 distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
467 },
468 ],
469 ConditionalDistributionParams::Fixed { value: 0.0 },
470 );
471 assert!(invalid.validate().is_err());
472 }
473
474 #[test]
475 fn test_conditional_sampling() {
476 let config = ConditionalDistributionConfig::new(
477 "output",
478 "input",
479 vec![
480 Breakpoint {
481 threshold: 100.0,
482 distribution: ConditionalDistributionParams::Fixed { value: 10.0 },
483 },
484 Breakpoint {
485 threshold: 200.0,
486 distribution: ConditionalDistributionParams::Fixed { value: 20.0 },
487 },
488 ],
489 ConditionalDistributionParams::Fixed { value: 0.0 },
490 );
491 let mut sampler = ConditionalSampler::new(42, config).unwrap();
492
493 assert_eq!(sampler.sample(50.0), 0.0);
495
496 assert_eq!(sampler.sample(150.0), 10.0);
498
499 assert_eq!(sampler.sample(250.0), 20.0);
501 }
502
503 #[test]
504 fn test_discount_by_amount_preset() {
505 let config = conditional_presets::discount_by_amount();
506 assert!(config.validate().is_ok());
507
508 let mut sampler = ConditionalSampler::new(42, config).unwrap();
509
510 let small_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(500.0)).collect();
512 let avg_small: f64 = small_discounts.iter().sum::<f64>() / 100.0;
513 assert!(avg_small < 0.01); sampler.reset(42);
517 let medium_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(3000.0)).collect();
518 let avg_medium: f64 = medium_discounts.iter().sum::<f64>() / 100.0;
519 assert!(avg_medium > 0.01 && avg_medium < 0.06);
520
521 sampler.reset(42);
523 let large_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(150000.0)).collect();
524 let avg_large: f64 = large_discounts.iter().sum::<f64>() / 100.0;
525 assert!(avg_large > 0.08);
526 }
527
528 #[test]
529 fn test_approval_level_preset() {
530 let config = conditional_presets::approval_level_by_amount();
531 assert!(config.validate().is_ok());
532
533 let mut sampler = ConditionalSampler::new(42, config).unwrap();
534
535 let level = sampler.sample(500.0);
537 assert_eq!(level, 1.0);
538
539 sampler.reset(42);
541 let levels: Vec<f64> = (0..100).map(|_| sampler.sample(75000.0)).collect();
542 let avg_level: f64 = levels.iter().sum::<f64>() / 100.0;
543 assert!(avg_level >= 3.0);
544 }
545
546 #[test]
547 fn test_distribution_params_sampling() {
548 let mut rng = ChaCha8Rng::seed_from_u64(42);
549
550 let normal = ConditionalDistributionParams::Normal {
552 mu: 10.0,
553 sigma: 1.0,
554 };
555 let samples: Vec<f64> = (0..1000).map(|_| normal.sample(&mut rng)).collect();
556 let mean: f64 = samples.iter().sum::<f64>() / 1000.0;
557 assert!((mean - 10.0).abs() < 0.5);
558
559 let beta = ConditionalDistributionParams::Beta {
561 alpha: 2.0,
562 beta: 5.0,
563 min: 0.0,
564 max: 1.0,
565 };
566 let samples: Vec<f64> = (0..1000).map(|_| beta.sample(&mut rng)).collect();
567 assert!(samples.iter().all(|&x| (0.0..=1.0).contains(&x)));
568
569 let discrete = ConditionalDistributionParams::Discrete {
571 values: vec![1.0, 2.0, 3.0],
572 weights: vec![0.5, 0.3, 0.2],
573 };
574 let samples: Vec<f64> = (0..1000).map(|_| discrete.sample(&mut rng)).collect();
575 let count_1 = samples.iter().filter(|&&x| x == 1.0).count();
576 assert!(count_1 > 400 && count_1 < 600); }
578
579 #[test]
580 fn test_conditional_determinism() {
581 let config = conditional_presets::discount_by_amount();
582
583 let mut sampler1 = ConditionalSampler::new(42, config.clone()).unwrap();
584 let mut sampler2 = ConditionalSampler::new(42, config).unwrap();
585
586 for amount in [100.0, 1000.0, 10000.0, 100000.0] {
587 assert_eq!(sampler1.sample(amount), sampler2.sample(amount));
588 }
589 }
590}