Skip to main content

datasynth_core/distributions/
conditional.rs

1//! Conditional distributions for dependent value generation.
2//!
3//! This module provides tools for generating values that depend on
4//! other values through breakpoint-based conditional logic, such as:
5//! - Discount percentage depends on order amount
6//! - Processing time depends on transaction complexity
7//! - Approval level depends on amount thresholds
8
9use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11use rand_distr::{Beta, Distribution, LogNormal, Normal, Uniform};
12use rust_decimal::Decimal;
13use serde::{Deserialize, Serialize};
14
15/// A breakpoint defining where distribution parameters change.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Breakpoint {
18    /// The threshold value where this breakpoint applies
19    pub threshold: f64,
20    /// Distribution parameters for values at or above this threshold
21    pub distribution: ConditionalDistributionParams,
22}
23
24/// Parameters for the conditional distribution at a given breakpoint.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case", tag = "type")]
27pub enum ConditionalDistributionParams {
28    /// Fixed value
29    Fixed { value: f64 },
30    /// Normal distribution
31    Normal { mu: f64, sigma: f64 },
32    /// Log-normal distribution
33    LogNormal { mu: f64, sigma: f64 },
34    /// Uniform distribution
35    Uniform { min: f64, max: f64 },
36    /// Beta distribution (scaled to min-max range)
37    Beta {
38        alpha: f64,
39        beta: f64,
40        min: f64,
41        max: f64,
42    },
43    /// Discrete choice from a set of values
44    Discrete { values: Vec<f64>, weights: Vec<f64> },
45}
46
47impl Default for ConditionalDistributionParams {
48    fn default() -> Self {
49        Self::Fixed { value: 0.0 }
50    }
51}
52
53impl ConditionalDistributionParams {
54    /// Sample from this distribution.
55    pub fn sample(&self, rng: &mut ChaCha8Rng) -> f64 {
56        match self {
57            Self::Fixed { value } => *value,
58            Self::Normal { mu, sigma } => {
59                let dist =
60                    Normal::new(*mu, *sigma).unwrap_or_else(|_| Normal::new(0.0, 1.0).unwrap());
61                dist.sample(rng)
62            }
63            Self::LogNormal { mu, sigma } => {
64                let dist = LogNormal::new(*mu, *sigma)
65                    .unwrap_or_else(|_| LogNormal::new(0.0, 1.0).unwrap());
66                dist.sample(rng)
67            }
68            Self::Uniform { min, max } => {
69                let dist = Uniform::new(*min, *max);
70                dist.sample(rng)
71            }
72            Self::Beta {
73                alpha,
74                beta,
75                min,
76                max,
77            } => {
78                let dist =
79                    Beta::new(*alpha, *beta).unwrap_or_else(|_| Beta::new(2.0, 2.0).unwrap());
80                let u = dist.sample(rng);
81                min + u * (max - min)
82            }
83            Self::Discrete { values, weights } => {
84                if values.is_empty() {
85                    return 0.0;
86                }
87                if weights.is_empty() || weights.len() != values.len() {
88                    // Equal weights
89                    return *values.choose(rng).unwrap_or(&0.0);
90                }
91                // Weighted selection
92                let total: f64 = weights.iter().sum();
93                let mut p: f64 = rng.gen::<f64>() * total;
94                for (i, w) in weights.iter().enumerate() {
95                    p -= w;
96                    if p <= 0.0 {
97                        return values[i];
98                    }
99                }
100                *values.last().unwrap_or(&0.0)
101            }
102        }
103    }
104}
105
106/// Configuration for a conditional distribution.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ConditionalDistributionConfig {
109    /// Name of the dependent field (the output)
110    pub output_field: String,
111    /// Name of the conditioning field (the input)
112    pub input_field: String,
113    /// Breakpoints defining the conditional distribution
114    /// Must be sorted by threshold in ascending order
115    pub breakpoints: Vec<Breakpoint>,
116    /// Distribution for values below the first breakpoint
117    pub default_distribution: ConditionalDistributionParams,
118    /// Minimum output value (clamps)
119    #[serde(default)]
120    pub min_value: Option<f64>,
121    /// Maximum output value (clamps)
122    #[serde(default)]
123    pub max_value: Option<f64>,
124    /// Number of decimal places for rounding
125    #[serde(default = "default_decimal_places")]
126    pub decimal_places: u8,
127}
128
129fn default_decimal_places() -> u8 {
130    2
131}
132
133impl Default for ConditionalDistributionConfig {
134    fn default() -> Self {
135        Self {
136            output_field: "output".to_string(),
137            input_field: "input".to_string(),
138            breakpoints: vec![],
139            default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 },
140            min_value: None,
141            max_value: None,
142            decimal_places: 2,
143        }
144    }
145}
146
147impl ConditionalDistributionConfig {
148    /// Create a new conditional distribution configuration.
149    pub fn new(
150        output_field: impl Into<String>,
151        input_field: impl Into<String>,
152        breakpoints: Vec<Breakpoint>,
153        default: ConditionalDistributionParams,
154    ) -> Self {
155        Self {
156            output_field: output_field.into(),
157            input_field: input_field.into(),
158            breakpoints,
159            default_distribution: default,
160            min_value: None,
161            max_value: None,
162            decimal_places: 2,
163        }
164    }
165
166    /// Validate the configuration.
167    pub fn validate(&self) -> Result<(), String> {
168        // Check breakpoints are in ascending order
169        for i in 1..self.breakpoints.len() {
170            if self.breakpoints[i].threshold <= self.breakpoints[i - 1].threshold {
171                return Err(format!(
172                    "Breakpoints must be in ascending order: {} is not > {}",
173                    self.breakpoints[i].threshold,
174                    self.breakpoints[i - 1].threshold
175                ));
176            }
177        }
178
179        if let (Some(min), Some(max)) = (self.min_value, self.max_value) {
180            if max <= min {
181                return Err("max_value must be greater than min_value".to_string());
182            }
183        }
184
185        Ok(())
186    }
187
188    /// Get the distribution parameters for a given input value.
189    pub fn get_distribution(&self, input_value: f64) -> &ConditionalDistributionParams {
190        // Find the highest breakpoint that the input exceeds
191        for breakpoint in self.breakpoints.iter().rev() {
192            if input_value >= breakpoint.threshold {
193                return &breakpoint.distribution;
194            }
195        }
196        &self.default_distribution
197    }
198}
199
200/// Sampler for conditional distributions.
201pub struct ConditionalSampler {
202    rng: ChaCha8Rng,
203    config: ConditionalDistributionConfig,
204    decimal_multiplier: f64,
205}
206
207impl ConditionalSampler {
208    /// Create a new conditional sampler.
209    pub fn new(seed: u64, config: ConditionalDistributionConfig) -> Result<Self, String> {
210        config.validate()?;
211        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
212        Ok(Self {
213            rng: ChaCha8Rng::seed_from_u64(seed),
214            config,
215            decimal_multiplier,
216        })
217    }
218
219    /// Sample a value given the conditioning input.
220    pub fn sample(&mut self, input_value: f64) -> f64 {
221        let dist = self.config.get_distribution(input_value);
222        let mut value = dist.sample(&mut self.rng);
223
224        // Apply constraints
225        if let Some(min) = self.config.min_value {
226            value = value.max(min);
227        }
228        if let Some(max) = self.config.max_value {
229            value = value.min(max);
230        }
231
232        // Round to decimal places
233        (value * self.decimal_multiplier).round() / self.decimal_multiplier
234    }
235
236    /// Sample a value as Decimal.
237    pub fn sample_decimal(&mut self, input_value: f64) -> Decimal {
238        let value = self.sample(input_value);
239        Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
240    }
241
242    /// Reset the sampler with a new seed.
243    pub fn reset(&mut self, seed: u64) {
244        self.rng = ChaCha8Rng::seed_from_u64(seed);
245    }
246
247    /// Get the configuration.
248    pub fn config(&self) -> &ConditionalDistributionConfig {
249        &self.config
250    }
251}
252
253/// Preset conditional distribution configurations.
254pub mod conditional_presets {
255    use super::*;
256
257    /// Discount percentage based on order amount.
258    /// Higher amounts get higher discount percentages.
259    pub fn discount_by_amount() -> ConditionalDistributionConfig {
260        ConditionalDistributionConfig {
261            output_field: "discount_percent".to_string(),
262            input_field: "order_amount".to_string(),
263            breakpoints: vec![
264                Breakpoint {
265                    threshold: 1000.0,
266                    distribution: ConditionalDistributionParams::Beta {
267                        alpha: 2.0,
268                        beta: 8.0,
269                        min: 0.01,
270                        max: 0.05, // 1-5%
271                    },
272                },
273                Breakpoint {
274                    threshold: 5000.0,
275                    distribution: ConditionalDistributionParams::Beta {
276                        alpha: 2.0,
277                        beta: 5.0,
278                        min: 0.02,
279                        max: 0.08, // 2-8%
280                    },
281                },
282                Breakpoint {
283                    threshold: 25000.0,
284                    distribution: ConditionalDistributionParams::Beta {
285                        alpha: 3.0,
286                        beta: 3.0,
287                        min: 0.05,
288                        max: 0.12, // 5-12%
289                    },
290                },
291                Breakpoint {
292                    threshold: 100000.0,
293                    distribution: ConditionalDistributionParams::Beta {
294                        alpha: 5.0,
295                        beta: 2.0,
296                        min: 0.08,
297                        max: 0.15, // 8-15%
298                    },
299                },
300            ],
301            default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 },
302            min_value: Some(0.0),
303            max_value: Some(0.20),
304            decimal_places: 4,
305        }
306    }
307
308    /// Approval level based on transaction amount.
309    pub fn approval_level_by_amount() -> ConditionalDistributionConfig {
310        ConditionalDistributionConfig {
311            output_field: "approval_level".to_string(),
312            input_field: "amount".to_string(),
313            breakpoints: vec![
314                Breakpoint {
315                    threshold: 1000.0,
316                    distribution: ConditionalDistributionParams::Discrete {
317                        values: vec![1.0, 2.0],
318                        weights: vec![0.9, 0.1],
319                    },
320                },
321                Breakpoint {
322                    threshold: 10000.0,
323                    distribution: ConditionalDistributionParams::Discrete {
324                        values: vec![2.0, 3.0],
325                        weights: vec![0.7, 0.3],
326                    },
327                },
328                Breakpoint {
329                    threshold: 50000.0,
330                    distribution: ConditionalDistributionParams::Discrete {
331                        values: vec![3.0, 4.0],
332                        weights: vec![0.6, 0.4],
333                    },
334                },
335                Breakpoint {
336                    threshold: 100000.0,
337                    distribution: ConditionalDistributionParams::Fixed { value: 4.0 },
338                },
339            ],
340            default_distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
341            min_value: Some(1.0),
342            max_value: Some(4.0),
343            decimal_places: 0,
344        }
345    }
346
347    /// Processing days based on order complexity (number of line items).
348    pub fn processing_time_by_complexity() -> ConditionalDistributionConfig {
349        ConditionalDistributionConfig {
350            output_field: "processing_days".to_string(),
351            input_field: "line_item_count".to_string(),
352            breakpoints: vec![
353                Breakpoint {
354                    threshold: 5.0,
355                    distribution: ConditionalDistributionParams::LogNormal {
356                        mu: 0.5, // ~1.6 days median
357                        sigma: 0.5,
358                    },
359                },
360                Breakpoint {
361                    threshold: 15.0,
362                    distribution: ConditionalDistributionParams::LogNormal {
363                        mu: 1.0, // ~2.7 days median
364                        sigma: 0.5,
365                    },
366                },
367                Breakpoint {
368                    threshold: 30.0,
369                    distribution: ConditionalDistributionParams::LogNormal {
370                        mu: 1.5, // ~4.5 days median
371                        sigma: 0.6,
372                    },
373                },
374            ],
375            default_distribution: ConditionalDistributionParams::LogNormal {
376                mu: 0.0, // ~1 day median
377                sigma: 0.4,
378            },
379            min_value: Some(0.5),
380            max_value: Some(30.0),
381            decimal_places: 1,
382        }
383    }
384
385    /// Payment terms (days) based on customer credit rating.
386    pub fn payment_terms_by_credit_rating() -> ConditionalDistributionConfig {
387        ConditionalDistributionConfig {
388            output_field: "payment_terms_days".to_string(),
389            input_field: "credit_score".to_string(),
390            breakpoints: vec![
391                Breakpoint {
392                    threshold: 300.0, // Poor credit
393                    distribution: ConditionalDistributionParams::Discrete {
394                        values: vec![0.0, 15.0], // Due on receipt or Net 15
395                        weights: vec![0.7, 0.3],
396                    },
397                },
398                Breakpoint {
399                    threshold: 500.0, // Fair credit
400                    distribution: ConditionalDistributionParams::Discrete {
401                        values: vec![15.0, 30.0],
402                        weights: vec![0.5, 0.5],
403                    },
404                },
405                Breakpoint {
406                    threshold: 650.0, // Good credit
407                    distribution: ConditionalDistributionParams::Discrete {
408                        values: vec![30.0, 45.0, 60.0],
409                        weights: vec![0.5, 0.3, 0.2],
410                    },
411                },
412                Breakpoint {
413                    threshold: 750.0, // Excellent credit
414                    distribution: ConditionalDistributionParams::Discrete {
415                        values: vec![30.0, 60.0, 90.0],
416                        weights: vec![0.3, 0.4, 0.3],
417                    },
418                },
419            ],
420            default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 }, // Due on receipt
421            min_value: Some(0.0),
422            max_value: Some(90.0),
423            decimal_places: 0,
424        }
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    #[test]
433    fn test_conditional_config_validation() {
434        let valid = ConditionalDistributionConfig::new(
435            "output",
436            "input",
437            vec![
438                Breakpoint {
439                    threshold: 100.0,
440                    distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
441                },
442                Breakpoint {
443                    threshold: 200.0,
444                    distribution: ConditionalDistributionParams::Fixed { value: 2.0 },
445                },
446            ],
447            ConditionalDistributionParams::Fixed { value: 0.0 },
448        );
449        assert!(valid.validate().is_ok());
450
451        // Invalid: breakpoints not in order
452        let invalid = ConditionalDistributionConfig::new(
453            "output",
454            "input",
455            vec![
456                Breakpoint {
457                    threshold: 200.0,
458                    distribution: ConditionalDistributionParams::Fixed { value: 2.0 },
459                },
460                Breakpoint {
461                    threshold: 100.0,
462                    distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
463                },
464            ],
465            ConditionalDistributionParams::Fixed { value: 0.0 },
466        );
467        assert!(invalid.validate().is_err());
468    }
469
470    #[test]
471    fn test_conditional_sampling() {
472        let config = ConditionalDistributionConfig::new(
473            "output",
474            "input",
475            vec![
476                Breakpoint {
477                    threshold: 100.0,
478                    distribution: ConditionalDistributionParams::Fixed { value: 10.0 },
479                },
480                Breakpoint {
481                    threshold: 200.0,
482                    distribution: ConditionalDistributionParams::Fixed { value: 20.0 },
483                },
484            ],
485            ConditionalDistributionParams::Fixed { value: 0.0 },
486        );
487        let mut sampler = ConditionalSampler::new(42, config).unwrap();
488
489        // Below first threshold
490        assert_eq!(sampler.sample(50.0), 0.0);
491
492        // Between first and second threshold
493        assert_eq!(sampler.sample(150.0), 10.0);
494
495        // Above second threshold
496        assert_eq!(sampler.sample(250.0), 20.0);
497    }
498
499    #[test]
500    fn test_discount_by_amount_preset() {
501        let config = conditional_presets::discount_by_amount();
502        assert!(config.validate().is_ok());
503
504        let mut sampler = ConditionalSampler::new(42, config).unwrap();
505
506        // Small orders: no discount or very small
507        let small_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(500.0)).collect();
508        let avg_small: f64 = small_discounts.iter().sum::<f64>() / 100.0;
509        assert!(avg_small < 0.01); // Should be 0 or very small
510
511        // Medium orders: small discount
512        sampler.reset(42);
513        let medium_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(3000.0)).collect();
514        let avg_medium: f64 = medium_discounts.iter().sum::<f64>() / 100.0;
515        assert!(avg_medium > 0.01 && avg_medium < 0.06);
516
517        // Large orders: higher discount
518        sampler.reset(42);
519        let large_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(150000.0)).collect();
520        let avg_large: f64 = large_discounts.iter().sum::<f64>() / 100.0;
521        assert!(avg_large > 0.08);
522    }
523
524    #[test]
525    fn test_approval_level_preset() {
526        let config = conditional_presets::approval_level_by_amount();
527        assert!(config.validate().is_ok());
528
529        let mut sampler = ConditionalSampler::new(42, config).unwrap();
530
531        // Small amounts: level 1
532        let level = sampler.sample(500.0);
533        assert_eq!(level, 1.0);
534
535        // Large amounts: level 3-4
536        sampler.reset(42);
537        let levels: Vec<f64> = (0..100).map(|_| sampler.sample(75000.0)).collect();
538        let avg_level: f64 = levels.iter().sum::<f64>() / 100.0;
539        assert!(avg_level >= 3.0);
540    }
541
542    #[test]
543    fn test_distribution_params_sampling() {
544        let mut rng = ChaCha8Rng::seed_from_u64(42);
545
546        // Test Normal
547        let normal = ConditionalDistributionParams::Normal {
548            mu: 10.0,
549            sigma: 1.0,
550        };
551        let samples: Vec<f64> = (0..1000).map(|_| normal.sample(&mut rng)).collect();
552        let mean: f64 = samples.iter().sum::<f64>() / 1000.0;
553        assert!((mean - 10.0).abs() < 0.5);
554
555        // Test Beta
556        let beta = ConditionalDistributionParams::Beta {
557            alpha: 2.0,
558            beta: 5.0,
559            min: 0.0,
560            max: 1.0,
561        };
562        let samples: Vec<f64> = (0..1000).map(|_| beta.sample(&mut rng)).collect();
563        assert!(samples.iter().all(|&x| (0.0..=1.0).contains(&x)));
564
565        // Test Discrete
566        let discrete = ConditionalDistributionParams::Discrete {
567            values: vec![1.0, 2.0, 3.0],
568            weights: vec![0.5, 0.3, 0.2],
569        };
570        let samples: Vec<f64> = (0..1000).map(|_| discrete.sample(&mut rng)).collect();
571        let count_1 = samples.iter().filter(|&&x| x == 1.0).count();
572        assert!(count_1 > 400 && count_1 < 600); // ~50%
573    }
574
575    #[test]
576    fn test_conditional_determinism() {
577        let config = conditional_presets::discount_by_amount();
578
579        let mut sampler1 = ConditionalSampler::new(42, config.clone()).unwrap();
580        let mut sampler2 = ConditionalSampler::new(42, config).unwrap();
581
582        for amount in [100.0, 1000.0, 10000.0, 100000.0] {
583            assert_eq!(sampler1.sample(amount), sampler2.sample(amount));
584        }
585    }
586}