Skip to main content

datasynth_core/distributions/
conditional.rs

1//! Conditional distributions for dependent value generation.
2//!
3//! This module provides tools for generating values that depend on
4//! other values through breakpoint-based conditional logic, such as:
5//! - Discount percentage depends on order amount
6//! - Processing time depends on transaction complexity
7//! - Approval level depends on amount thresholds
8
9use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11use rand_distr::{Beta, Distribution, LogNormal, Normal, Uniform};
12use rust_decimal::Decimal;
13use serde::{Deserialize, Serialize};
14
15/// A breakpoint defining where distribution parameters change.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Breakpoint {
18    /// The threshold value where this breakpoint applies
19    pub threshold: f64,
20    /// Distribution parameters for values at or above this threshold
21    pub distribution: ConditionalDistributionParams,
22}
23
24/// Parameters for the conditional distribution at a given breakpoint.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case", tag = "type")]
27pub enum ConditionalDistributionParams {
28    /// Fixed value
29    Fixed { value: f64 },
30    /// Normal distribution
31    Normal { mu: f64, sigma: f64 },
32    /// Log-normal distribution
33    LogNormal { mu: f64, sigma: f64 },
34    /// Uniform distribution
35    Uniform { min: f64, max: f64 },
36    /// Beta distribution (scaled to min-max range)
37    Beta {
38        alpha: f64,
39        beta: f64,
40        min: f64,
41        max: f64,
42    },
43    /// Discrete choice from a set of values
44    Discrete { values: Vec<f64>, weights: Vec<f64> },
45}
46
47impl Default for ConditionalDistributionParams {
48    fn default() -> Self {
49        Self::Fixed { value: 0.0 }
50    }
51}
52
53impl ConditionalDistributionParams {
54    /// Sample from this distribution.
55    pub fn sample(&self, rng: &mut ChaCha8Rng) -> f64 {
56        match self {
57            Self::Fixed { value } => *value,
58            Self::Normal { mu, sigma } => {
59                let dist = Normal::new(*mu, *sigma).unwrap_or_else(|_| {
60                    Normal::new(0.0, 1.0).expect("valid fallback distribution params")
61                });
62                dist.sample(rng)
63            }
64            Self::LogNormal { mu, sigma } => {
65                let dist = LogNormal::new(*mu, *sigma).unwrap_or_else(|_| {
66                    LogNormal::new(0.0, 1.0).expect("valid fallback distribution params")
67                });
68                dist.sample(rng)
69            }
70            Self::Uniform { min, max } => {
71                let dist = Uniform::new(*min, *max).expect("valid uniform params");
72                dist.sample(rng)
73            }
74            Self::Beta {
75                alpha,
76                beta,
77                min,
78                max,
79            } => {
80                let dist = Beta::new(*alpha, *beta).unwrap_or_else(|_| {
81                    Beta::new(2.0, 2.0).expect("valid fallback distribution params")
82                });
83                let u = dist.sample(rng);
84                min + u * (max - min)
85            }
86            Self::Discrete { values, weights } => {
87                if values.is_empty() {
88                    return 0.0;
89                }
90                if weights.is_empty() || weights.len() != values.len() {
91                    // Equal weights
92                    return *values.choose(rng).unwrap_or(&0.0);
93                }
94                // Weighted selection
95                let total: f64 = weights.iter().sum();
96                let mut p: f64 = rng.random::<f64>() * total;
97                for (i, w) in weights.iter().enumerate() {
98                    p -= w;
99                    if p <= 0.0 {
100                        return values[i];
101                    }
102                }
103                *values.last().unwrap_or(&0.0)
104            }
105        }
106    }
107}
108
109/// Configuration for a conditional distribution.
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ConditionalDistributionConfig {
112    /// Name of the dependent field (the output)
113    pub output_field: String,
114    /// Name of the conditioning field (the input)
115    pub input_field: String,
116    /// Breakpoints defining the conditional distribution
117    /// Must be sorted by threshold in ascending order
118    pub breakpoints: Vec<Breakpoint>,
119    /// Distribution for values below the first breakpoint
120    pub default_distribution: ConditionalDistributionParams,
121    /// Minimum output value (clamps)
122    #[serde(default)]
123    pub min_value: Option<f64>,
124    /// Maximum output value (clamps)
125    #[serde(default)]
126    pub max_value: Option<f64>,
127    /// Number of decimal places for rounding
128    #[serde(default = "default_decimal_places")]
129    pub decimal_places: u8,
130}
131
132fn default_decimal_places() -> u8 {
133    2
134}
135
136impl Default for ConditionalDistributionConfig {
137    fn default() -> Self {
138        Self {
139            output_field: "output".to_string(),
140            input_field: "input".to_string(),
141            breakpoints: vec![],
142            default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 },
143            min_value: None,
144            max_value: None,
145            decimal_places: 2,
146        }
147    }
148}
149
150impl ConditionalDistributionConfig {
151    /// Create a new conditional distribution configuration.
152    pub fn new(
153        output_field: impl Into<String>,
154        input_field: impl Into<String>,
155        breakpoints: Vec<Breakpoint>,
156        default: ConditionalDistributionParams,
157    ) -> Self {
158        Self {
159            output_field: output_field.into(),
160            input_field: input_field.into(),
161            breakpoints,
162            default_distribution: default,
163            min_value: None,
164            max_value: None,
165            decimal_places: 2,
166        }
167    }
168
169    /// Validate the configuration.
170    pub fn validate(&self) -> Result<(), String> {
171        // Check breakpoints are in ascending order
172        for i in 1..self.breakpoints.len() {
173            if self.breakpoints[i].threshold <= self.breakpoints[i - 1].threshold {
174                return Err(format!(
175                    "Breakpoints must be in ascending order: {} is not > {}",
176                    self.breakpoints[i].threshold,
177                    self.breakpoints[i - 1].threshold
178                ));
179            }
180        }
181
182        if let (Some(min), Some(max)) = (self.min_value, self.max_value) {
183            if max <= min {
184                return Err("max_value must be greater than min_value".to_string());
185            }
186        }
187
188        Ok(())
189    }
190
191    /// Get the distribution parameters for a given input value.
192    pub fn get_distribution(&self, input_value: f64) -> &ConditionalDistributionParams {
193        // Find the highest breakpoint that the input exceeds
194        for breakpoint in self.breakpoints.iter().rev() {
195            if input_value >= breakpoint.threshold {
196                return &breakpoint.distribution;
197            }
198        }
199        &self.default_distribution
200    }
201}
202
203/// Sampler for conditional distributions.
204#[derive(Clone)]
205pub struct ConditionalSampler {
206    rng: ChaCha8Rng,
207    config: ConditionalDistributionConfig,
208    decimal_multiplier: f64,
209}
210
211impl ConditionalSampler {
212    /// Create a new conditional sampler.
213    pub fn new(seed: u64, config: ConditionalDistributionConfig) -> Result<Self, String> {
214        config.validate()?;
215        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
216        Ok(Self {
217            rng: ChaCha8Rng::seed_from_u64(seed),
218            config,
219            decimal_multiplier,
220        })
221    }
222
223    /// Sample a value given the conditioning input.
224    pub fn sample(&mut self, input_value: f64) -> f64 {
225        let dist = self.config.get_distribution(input_value);
226        let mut value = dist.sample(&mut self.rng);
227
228        // Apply constraints
229        if let Some(min) = self.config.min_value {
230            value = value.max(min);
231        }
232        if let Some(max) = self.config.max_value {
233            value = value.min(max);
234        }
235
236        // Round to decimal places
237        (value * self.decimal_multiplier).round() / self.decimal_multiplier
238    }
239
240    /// Sample a value as Decimal.
241    pub fn sample_decimal(&mut self, input_value: f64) -> Decimal {
242        let value = self.sample(input_value);
243        Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
244    }
245
246    /// Reset the sampler with a new seed.
247    pub fn reset(&mut self, seed: u64) {
248        self.rng = ChaCha8Rng::seed_from_u64(seed);
249    }
250
251    /// Get the configuration.
252    pub fn config(&self) -> &ConditionalDistributionConfig {
253        &self.config
254    }
255}
256
257/// Preset conditional distribution configurations.
258pub mod conditional_presets {
259    use super::*;
260
261    /// Discount percentage based on order amount.
262    /// Higher amounts get higher discount percentages.
263    pub fn discount_by_amount() -> ConditionalDistributionConfig {
264        ConditionalDistributionConfig {
265            output_field: "discount_percent".to_string(),
266            input_field: "order_amount".to_string(),
267            breakpoints: vec![
268                Breakpoint {
269                    threshold: 1000.0,
270                    distribution: ConditionalDistributionParams::Beta {
271                        alpha: 2.0,
272                        beta: 8.0,
273                        min: 0.01,
274                        max: 0.05, // 1-5%
275                    },
276                },
277                Breakpoint {
278                    threshold: 5000.0,
279                    distribution: ConditionalDistributionParams::Beta {
280                        alpha: 2.0,
281                        beta: 5.0,
282                        min: 0.02,
283                        max: 0.08, // 2-8%
284                    },
285                },
286                Breakpoint {
287                    threshold: 25000.0,
288                    distribution: ConditionalDistributionParams::Beta {
289                        alpha: 3.0,
290                        beta: 3.0,
291                        min: 0.05,
292                        max: 0.12, // 5-12%
293                    },
294                },
295                Breakpoint {
296                    threshold: 100000.0,
297                    distribution: ConditionalDistributionParams::Beta {
298                        alpha: 5.0,
299                        beta: 2.0,
300                        min: 0.08,
301                        max: 0.15, // 8-15%
302                    },
303                },
304            ],
305            default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 },
306            min_value: Some(0.0),
307            max_value: Some(0.20),
308            decimal_places: 4,
309        }
310    }
311
312    /// Approval level based on transaction amount.
313    pub fn approval_level_by_amount() -> ConditionalDistributionConfig {
314        ConditionalDistributionConfig {
315            output_field: "approval_level".to_string(),
316            input_field: "amount".to_string(),
317            breakpoints: vec![
318                Breakpoint {
319                    threshold: 1000.0,
320                    distribution: ConditionalDistributionParams::Discrete {
321                        values: vec![1.0, 2.0],
322                        weights: vec![0.9, 0.1],
323                    },
324                },
325                Breakpoint {
326                    threshold: 10000.0,
327                    distribution: ConditionalDistributionParams::Discrete {
328                        values: vec![2.0, 3.0],
329                        weights: vec![0.7, 0.3],
330                    },
331                },
332                Breakpoint {
333                    threshold: 50000.0,
334                    distribution: ConditionalDistributionParams::Discrete {
335                        values: vec![3.0, 4.0],
336                        weights: vec![0.6, 0.4],
337                    },
338                },
339                Breakpoint {
340                    threshold: 100000.0,
341                    distribution: ConditionalDistributionParams::Fixed { value: 4.0 },
342                },
343            ],
344            default_distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
345            min_value: Some(1.0),
346            max_value: Some(4.0),
347            decimal_places: 0,
348        }
349    }
350
351    /// Processing days based on order complexity (number of line items).
352    pub fn processing_time_by_complexity() -> ConditionalDistributionConfig {
353        ConditionalDistributionConfig {
354            output_field: "processing_days".to_string(),
355            input_field: "line_item_count".to_string(),
356            breakpoints: vec![
357                Breakpoint {
358                    threshold: 5.0,
359                    distribution: ConditionalDistributionParams::LogNormal {
360                        mu: 0.5, // ~1.6 days median
361                        sigma: 0.5,
362                    },
363                },
364                Breakpoint {
365                    threshold: 15.0,
366                    distribution: ConditionalDistributionParams::LogNormal {
367                        mu: 1.0, // ~2.7 days median
368                        sigma: 0.5,
369                    },
370                },
371                Breakpoint {
372                    threshold: 30.0,
373                    distribution: ConditionalDistributionParams::LogNormal {
374                        mu: 1.5, // ~4.5 days median
375                        sigma: 0.6,
376                    },
377                },
378            ],
379            default_distribution: ConditionalDistributionParams::LogNormal {
380                mu: 0.0, // ~1 day median
381                sigma: 0.4,
382            },
383            min_value: Some(0.5),
384            max_value: Some(30.0),
385            decimal_places: 1,
386        }
387    }
388
389    /// Payment terms (days) based on customer credit rating.
390    pub fn payment_terms_by_credit_rating() -> ConditionalDistributionConfig {
391        ConditionalDistributionConfig {
392            output_field: "payment_terms_days".to_string(),
393            input_field: "credit_score".to_string(),
394            breakpoints: vec![
395                Breakpoint {
396                    threshold: 300.0, // Poor credit
397                    distribution: ConditionalDistributionParams::Discrete {
398                        values: vec![0.0, 15.0], // Due on receipt or Net 15
399                        weights: vec![0.7, 0.3],
400                    },
401                },
402                Breakpoint {
403                    threshold: 500.0, // Fair credit
404                    distribution: ConditionalDistributionParams::Discrete {
405                        values: vec![15.0, 30.0],
406                        weights: vec![0.5, 0.5],
407                    },
408                },
409                Breakpoint {
410                    threshold: 650.0, // Good credit
411                    distribution: ConditionalDistributionParams::Discrete {
412                        values: vec![30.0, 45.0, 60.0],
413                        weights: vec![0.5, 0.3, 0.2],
414                    },
415                },
416                Breakpoint {
417                    threshold: 750.0, // Excellent credit
418                    distribution: ConditionalDistributionParams::Discrete {
419                        values: vec![30.0, 60.0, 90.0],
420                        weights: vec![0.3, 0.4, 0.3],
421                    },
422                },
423            ],
424            default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 }, // Due on receipt
425            min_value: Some(0.0),
426            max_value: Some(90.0),
427            decimal_places: 0,
428        }
429    }
430}
431
432#[cfg(test)]
433#[allow(clippy::unwrap_used)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_conditional_config_validation() {
439        let valid = ConditionalDistributionConfig::new(
440            "output",
441            "input",
442            vec![
443                Breakpoint {
444                    threshold: 100.0,
445                    distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
446                },
447                Breakpoint {
448                    threshold: 200.0,
449                    distribution: ConditionalDistributionParams::Fixed { value: 2.0 },
450                },
451            ],
452            ConditionalDistributionParams::Fixed { value: 0.0 },
453        );
454        assert!(valid.validate().is_ok());
455
456        // Invalid: breakpoints not in order
457        let invalid = ConditionalDistributionConfig::new(
458            "output",
459            "input",
460            vec![
461                Breakpoint {
462                    threshold: 200.0,
463                    distribution: ConditionalDistributionParams::Fixed { value: 2.0 },
464                },
465                Breakpoint {
466                    threshold: 100.0,
467                    distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
468                },
469            ],
470            ConditionalDistributionParams::Fixed { value: 0.0 },
471        );
472        assert!(invalid.validate().is_err());
473    }
474
475    #[test]
476    fn test_conditional_sampling() {
477        let config = ConditionalDistributionConfig::new(
478            "output",
479            "input",
480            vec![
481                Breakpoint {
482                    threshold: 100.0,
483                    distribution: ConditionalDistributionParams::Fixed { value: 10.0 },
484                },
485                Breakpoint {
486                    threshold: 200.0,
487                    distribution: ConditionalDistributionParams::Fixed { value: 20.0 },
488                },
489            ],
490            ConditionalDistributionParams::Fixed { value: 0.0 },
491        );
492        let mut sampler = ConditionalSampler::new(42, config).unwrap();
493
494        // Below first threshold
495        assert_eq!(sampler.sample(50.0), 0.0);
496
497        // Between first and second threshold
498        assert_eq!(sampler.sample(150.0), 10.0);
499
500        // Above second threshold
501        assert_eq!(sampler.sample(250.0), 20.0);
502    }
503
504    #[test]
505    fn test_discount_by_amount_preset() {
506        let config = conditional_presets::discount_by_amount();
507        assert!(config.validate().is_ok());
508
509        let mut sampler = ConditionalSampler::new(42, config).unwrap();
510
511        // Small orders: no discount or very small
512        let small_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(500.0)).collect();
513        let avg_small: f64 = small_discounts.iter().sum::<f64>() / 100.0;
514        assert!(avg_small < 0.01); // Should be 0 or very small
515
516        // Medium orders: small discount
517        sampler.reset(42);
518        let medium_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(3000.0)).collect();
519        let avg_medium: f64 = medium_discounts.iter().sum::<f64>() / 100.0;
520        assert!(avg_medium > 0.01 && avg_medium < 0.06);
521
522        // Large orders: higher discount
523        sampler.reset(42);
524        let large_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(150000.0)).collect();
525        let avg_large: f64 = large_discounts.iter().sum::<f64>() / 100.0;
526        assert!(avg_large > 0.08);
527    }
528
529    #[test]
530    fn test_approval_level_preset() {
531        let config = conditional_presets::approval_level_by_amount();
532        assert!(config.validate().is_ok());
533
534        let mut sampler = ConditionalSampler::new(42, config).unwrap();
535
536        // Small amounts: level 1
537        let level = sampler.sample(500.0);
538        assert_eq!(level, 1.0);
539
540        // Large amounts: level 3-4
541        sampler.reset(42);
542        let levels: Vec<f64> = (0..100).map(|_| sampler.sample(75000.0)).collect();
543        let avg_level: f64 = levels.iter().sum::<f64>() / 100.0;
544        assert!(avg_level >= 3.0);
545    }
546
547    #[test]
548    fn test_distribution_params_sampling() {
549        let mut rng = ChaCha8Rng::seed_from_u64(42);
550
551        // Test Normal
552        let normal = ConditionalDistributionParams::Normal {
553            mu: 10.0,
554            sigma: 1.0,
555        };
556        let samples: Vec<f64> = (0..1000).map(|_| normal.sample(&mut rng)).collect();
557        let mean: f64 = samples.iter().sum::<f64>() / 1000.0;
558        assert!((mean - 10.0).abs() < 0.5);
559
560        // Test Beta
561        let beta = ConditionalDistributionParams::Beta {
562            alpha: 2.0,
563            beta: 5.0,
564            min: 0.0,
565            max: 1.0,
566        };
567        let samples: Vec<f64> = (0..1000).map(|_| beta.sample(&mut rng)).collect();
568        assert!(samples.iter().all(|&x| (0.0..=1.0).contains(&x)));
569
570        // Test Discrete
571        let discrete = ConditionalDistributionParams::Discrete {
572            values: vec![1.0, 2.0, 3.0],
573            weights: vec![0.5, 0.3, 0.2],
574        };
575        let samples: Vec<f64> = (0..1000).map(|_| discrete.sample(&mut rng)).collect();
576        let count_1 = samples.iter().filter(|&&x| x == 1.0).count();
577        assert!(count_1 > 400 && count_1 < 600); // ~50%
578    }
579
580    #[test]
581    fn test_conditional_determinism() {
582        let config = conditional_presets::discount_by_amount();
583
584        let mut sampler1 = ConditionalSampler::new(42, config.clone()).unwrap();
585        let mut sampler2 = ConditionalSampler::new(42, config).unwrap();
586
587        for amount in [100.0, 1000.0, 10000.0, 100000.0] {
588            assert_eq!(sampler1.sample(amount), sampler2.sample(amount));
589        }
590    }
591}