Skip to main content

datasynth_core/distributions/
conditional.rs

1//! Conditional distributions for dependent value generation.
2//!
3//! This module provides tools for generating values that depend on
4//! other values through breakpoint-based conditional logic, such as:
5//! - Discount percentage depends on order amount
6//! - Processing time depends on transaction complexity
7//! - Approval level depends on amount thresholds
8
9use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11use rand_distr::{Beta, Distribution, LogNormal, Normal, Uniform};
12use rust_decimal::Decimal;
13use serde::{Deserialize, Serialize};
14
15/// A breakpoint defining where distribution parameters change.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Breakpoint {
18    /// The threshold value where this breakpoint applies
19    pub threshold: f64,
20    /// Distribution parameters for values at or above this threshold
21    pub distribution: ConditionalDistributionParams,
22}
23
24/// Parameters for the conditional distribution at a given breakpoint.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case", tag = "type")]
27pub enum ConditionalDistributionParams {
28    /// Fixed value
29    Fixed { value: f64 },
30    /// Normal distribution
31    Normal { mu: f64, sigma: f64 },
32    /// Log-normal distribution
33    LogNormal { mu: f64, sigma: f64 },
34    /// Uniform distribution
35    Uniform { min: f64, max: f64 },
36    /// Beta distribution (scaled to min-max range)
37    Beta {
38        alpha: f64,
39        beta: f64,
40        min: f64,
41        max: f64,
42    },
43    /// Discrete choice from a set of values
44    Discrete { values: Vec<f64>, weights: Vec<f64> },
45}
46
47impl Default for ConditionalDistributionParams {
48    fn default() -> Self {
49        Self::Fixed { value: 0.0 }
50    }
51}
52
53impl ConditionalDistributionParams {
54    /// Sample from this distribution.
55    pub fn sample(&self, rng: &mut ChaCha8Rng) -> f64 {
56        match self {
57            Self::Fixed { value } => *value,
58            Self::Normal { mu, sigma } => {
59                let dist = Normal::new(*mu, *sigma).unwrap_or_else(|_| {
60                    Normal::new(0.0, 1.0).expect("valid fallback distribution params")
61                });
62                dist.sample(rng)
63            }
64            Self::LogNormal { mu, sigma } => {
65                let dist = LogNormal::new(*mu, *sigma).unwrap_or_else(|_| {
66                    LogNormal::new(0.0, 1.0).expect("valid fallback distribution params")
67                });
68                dist.sample(rng)
69            }
70            Self::Uniform { min, max } => {
71                let dist = Uniform::new(*min, *max);
72                dist.sample(rng)
73            }
74            Self::Beta {
75                alpha,
76                beta,
77                min,
78                max,
79            } => {
80                let dist = Beta::new(*alpha, *beta).unwrap_or_else(|_| {
81                    Beta::new(2.0, 2.0).expect("valid fallback distribution params")
82                });
83                let u = dist.sample(rng);
84                min + u * (max - min)
85            }
86            Self::Discrete { values, weights } => {
87                if values.is_empty() {
88                    return 0.0;
89                }
90                if weights.is_empty() || weights.len() != values.len() {
91                    // Equal weights
92                    return *values.choose(rng).unwrap_or(&0.0);
93                }
94                // Weighted selection
95                let total: f64 = weights.iter().sum();
96                let mut p: f64 = rng.gen::<f64>() * total;
97                for (i, w) in weights.iter().enumerate() {
98                    p -= w;
99                    if p <= 0.0 {
100                        return values[i];
101                    }
102                }
103                *values.last().unwrap_or(&0.0)
104            }
105        }
106    }
107}
108
109/// Configuration for a conditional distribution.
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ConditionalDistributionConfig {
112    /// Name of the dependent field (the output)
113    pub output_field: String,
114    /// Name of the conditioning field (the input)
115    pub input_field: String,
116    /// Breakpoints defining the conditional distribution
117    /// Must be sorted by threshold in ascending order
118    pub breakpoints: Vec<Breakpoint>,
119    /// Distribution for values below the first breakpoint
120    pub default_distribution: ConditionalDistributionParams,
121    /// Minimum output value (clamps)
122    #[serde(default)]
123    pub min_value: Option<f64>,
124    /// Maximum output value (clamps)
125    #[serde(default)]
126    pub max_value: Option<f64>,
127    /// Number of decimal places for rounding
128    #[serde(default = "default_decimal_places")]
129    pub decimal_places: u8,
130}
131
132fn default_decimal_places() -> u8 {
133    2
134}
135
136impl Default for ConditionalDistributionConfig {
137    fn default() -> Self {
138        Self {
139            output_field: "output".to_string(),
140            input_field: "input".to_string(),
141            breakpoints: vec![],
142            default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 },
143            min_value: None,
144            max_value: None,
145            decimal_places: 2,
146        }
147    }
148}
149
150impl ConditionalDistributionConfig {
151    /// Create a new conditional distribution configuration.
152    pub fn new(
153        output_field: impl Into<String>,
154        input_field: impl Into<String>,
155        breakpoints: Vec<Breakpoint>,
156        default: ConditionalDistributionParams,
157    ) -> Self {
158        Self {
159            output_field: output_field.into(),
160            input_field: input_field.into(),
161            breakpoints,
162            default_distribution: default,
163            min_value: None,
164            max_value: None,
165            decimal_places: 2,
166        }
167    }
168
169    /// Validate the configuration.
170    pub fn validate(&self) -> Result<(), String> {
171        // Check breakpoints are in ascending order
172        for i in 1..self.breakpoints.len() {
173            if self.breakpoints[i].threshold <= self.breakpoints[i - 1].threshold {
174                return Err(format!(
175                    "Breakpoints must be in ascending order: {} is not > {}",
176                    self.breakpoints[i].threshold,
177                    self.breakpoints[i - 1].threshold
178                ));
179            }
180        }
181
182        if let (Some(min), Some(max)) = (self.min_value, self.max_value) {
183            if max <= min {
184                return Err("max_value must be greater than min_value".to_string());
185            }
186        }
187
188        Ok(())
189    }
190
191    /// Get the distribution parameters for a given input value.
192    pub fn get_distribution(&self, input_value: f64) -> &ConditionalDistributionParams {
193        // Find the highest breakpoint that the input exceeds
194        for breakpoint in self.breakpoints.iter().rev() {
195            if input_value >= breakpoint.threshold {
196                return &breakpoint.distribution;
197            }
198        }
199        &self.default_distribution
200    }
201}
202
203/// Sampler for conditional distributions.
204pub struct ConditionalSampler {
205    rng: ChaCha8Rng,
206    config: ConditionalDistributionConfig,
207    decimal_multiplier: f64,
208}
209
210impl ConditionalSampler {
211    /// Create a new conditional sampler.
212    pub fn new(seed: u64, config: ConditionalDistributionConfig) -> Result<Self, String> {
213        config.validate()?;
214        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
215        Ok(Self {
216            rng: ChaCha8Rng::seed_from_u64(seed),
217            config,
218            decimal_multiplier,
219        })
220    }
221
222    /// Sample a value given the conditioning input.
223    pub fn sample(&mut self, input_value: f64) -> f64 {
224        let dist = self.config.get_distribution(input_value);
225        let mut value = dist.sample(&mut self.rng);
226
227        // Apply constraints
228        if let Some(min) = self.config.min_value {
229            value = value.max(min);
230        }
231        if let Some(max) = self.config.max_value {
232            value = value.min(max);
233        }
234
235        // Round to decimal places
236        (value * self.decimal_multiplier).round() / self.decimal_multiplier
237    }
238
239    /// Sample a value as Decimal.
240    pub fn sample_decimal(&mut self, input_value: f64) -> Decimal {
241        let value = self.sample(input_value);
242        Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
243    }
244
245    /// Reset the sampler with a new seed.
246    pub fn reset(&mut self, seed: u64) {
247        self.rng = ChaCha8Rng::seed_from_u64(seed);
248    }
249
250    /// Get the configuration.
251    pub fn config(&self) -> &ConditionalDistributionConfig {
252        &self.config
253    }
254}
255
256/// Preset conditional distribution configurations.
257pub mod conditional_presets {
258    use super::*;
259
260    /// Discount percentage based on order amount.
261    /// Higher amounts get higher discount percentages.
262    pub fn discount_by_amount() -> ConditionalDistributionConfig {
263        ConditionalDistributionConfig {
264            output_field: "discount_percent".to_string(),
265            input_field: "order_amount".to_string(),
266            breakpoints: vec![
267                Breakpoint {
268                    threshold: 1000.0,
269                    distribution: ConditionalDistributionParams::Beta {
270                        alpha: 2.0,
271                        beta: 8.0,
272                        min: 0.01,
273                        max: 0.05, // 1-5%
274                    },
275                },
276                Breakpoint {
277                    threshold: 5000.0,
278                    distribution: ConditionalDistributionParams::Beta {
279                        alpha: 2.0,
280                        beta: 5.0,
281                        min: 0.02,
282                        max: 0.08, // 2-8%
283                    },
284                },
285                Breakpoint {
286                    threshold: 25000.0,
287                    distribution: ConditionalDistributionParams::Beta {
288                        alpha: 3.0,
289                        beta: 3.0,
290                        min: 0.05,
291                        max: 0.12, // 5-12%
292                    },
293                },
294                Breakpoint {
295                    threshold: 100000.0,
296                    distribution: ConditionalDistributionParams::Beta {
297                        alpha: 5.0,
298                        beta: 2.0,
299                        min: 0.08,
300                        max: 0.15, // 8-15%
301                    },
302                },
303            ],
304            default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 },
305            min_value: Some(0.0),
306            max_value: Some(0.20),
307            decimal_places: 4,
308        }
309    }
310
311    /// Approval level based on transaction amount.
312    pub fn approval_level_by_amount() -> ConditionalDistributionConfig {
313        ConditionalDistributionConfig {
314            output_field: "approval_level".to_string(),
315            input_field: "amount".to_string(),
316            breakpoints: vec![
317                Breakpoint {
318                    threshold: 1000.0,
319                    distribution: ConditionalDistributionParams::Discrete {
320                        values: vec![1.0, 2.0],
321                        weights: vec![0.9, 0.1],
322                    },
323                },
324                Breakpoint {
325                    threshold: 10000.0,
326                    distribution: ConditionalDistributionParams::Discrete {
327                        values: vec![2.0, 3.0],
328                        weights: vec![0.7, 0.3],
329                    },
330                },
331                Breakpoint {
332                    threshold: 50000.0,
333                    distribution: ConditionalDistributionParams::Discrete {
334                        values: vec![3.0, 4.0],
335                        weights: vec![0.6, 0.4],
336                    },
337                },
338                Breakpoint {
339                    threshold: 100000.0,
340                    distribution: ConditionalDistributionParams::Fixed { value: 4.0 },
341                },
342            ],
343            default_distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
344            min_value: Some(1.0),
345            max_value: Some(4.0),
346            decimal_places: 0,
347        }
348    }
349
350    /// Processing days based on order complexity (number of line items).
351    pub fn processing_time_by_complexity() -> ConditionalDistributionConfig {
352        ConditionalDistributionConfig {
353            output_field: "processing_days".to_string(),
354            input_field: "line_item_count".to_string(),
355            breakpoints: vec![
356                Breakpoint {
357                    threshold: 5.0,
358                    distribution: ConditionalDistributionParams::LogNormal {
359                        mu: 0.5, // ~1.6 days median
360                        sigma: 0.5,
361                    },
362                },
363                Breakpoint {
364                    threshold: 15.0,
365                    distribution: ConditionalDistributionParams::LogNormal {
366                        mu: 1.0, // ~2.7 days median
367                        sigma: 0.5,
368                    },
369                },
370                Breakpoint {
371                    threshold: 30.0,
372                    distribution: ConditionalDistributionParams::LogNormal {
373                        mu: 1.5, // ~4.5 days median
374                        sigma: 0.6,
375                    },
376                },
377            ],
378            default_distribution: ConditionalDistributionParams::LogNormal {
379                mu: 0.0, // ~1 day median
380                sigma: 0.4,
381            },
382            min_value: Some(0.5),
383            max_value: Some(30.0),
384            decimal_places: 1,
385        }
386    }
387
388    /// Payment terms (days) based on customer credit rating.
389    pub fn payment_terms_by_credit_rating() -> ConditionalDistributionConfig {
390        ConditionalDistributionConfig {
391            output_field: "payment_terms_days".to_string(),
392            input_field: "credit_score".to_string(),
393            breakpoints: vec![
394                Breakpoint {
395                    threshold: 300.0, // Poor credit
396                    distribution: ConditionalDistributionParams::Discrete {
397                        values: vec![0.0, 15.0], // Due on receipt or Net 15
398                        weights: vec![0.7, 0.3],
399                    },
400                },
401                Breakpoint {
402                    threshold: 500.0, // Fair credit
403                    distribution: ConditionalDistributionParams::Discrete {
404                        values: vec![15.0, 30.0],
405                        weights: vec![0.5, 0.5],
406                    },
407                },
408                Breakpoint {
409                    threshold: 650.0, // Good credit
410                    distribution: ConditionalDistributionParams::Discrete {
411                        values: vec![30.0, 45.0, 60.0],
412                        weights: vec![0.5, 0.3, 0.2],
413                    },
414                },
415                Breakpoint {
416                    threshold: 750.0, // Excellent credit
417                    distribution: ConditionalDistributionParams::Discrete {
418                        values: vec![30.0, 60.0, 90.0],
419                        weights: vec![0.3, 0.4, 0.3],
420                    },
421                },
422            ],
423            default_distribution: ConditionalDistributionParams::Fixed { value: 0.0 }, // Due on receipt
424            min_value: Some(0.0),
425            max_value: Some(90.0),
426            decimal_places: 0,
427        }
428    }
429}
430
431#[cfg(test)]
432#[allow(clippy::unwrap_used)]
433mod tests {
434    use super::*;
435
436    #[test]
437    fn test_conditional_config_validation() {
438        let valid = ConditionalDistributionConfig::new(
439            "output",
440            "input",
441            vec![
442                Breakpoint {
443                    threshold: 100.0,
444                    distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
445                },
446                Breakpoint {
447                    threshold: 200.0,
448                    distribution: ConditionalDistributionParams::Fixed { value: 2.0 },
449                },
450            ],
451            ConditionalDistributionParams::Fixed { value: 0.0 },
452        );
453        assert!(valid.validate().is_ok());
454
455        // Invalid: breakpoints not in order
456        let invalid = ConditionalDistributionConfig::new(
457            "output",
458            "input",
459            vec![
460                Breakpoint {
461                    threshold: 200.0,
462                    distribution: ConditionalDistributionParams::Fixed { value: 2.0 },
463                },
464                Breakpoint {
465                    threshold: 100.0,
466                    distribution: ConditionalDistributionParams::Fixed { value: 1.0 },
467                },
468            ],
469            ConditionalDistributionParams::Fixed { value: 0.0 },
470        );
471        assert!(invalid.validate().is_err());
472    }
473
474    #[test]
475    fn test_conditional_sampling() {
476        let config = ConditionalDistributionConfig::new(
477            "output",
478            "input",
479            vec![
480                Breakpoint {
481                    threshold: 100.0,
482                    distribution: ConditionalDistributionParams::Fixed { value: 10.0 },
483                },
484                Breakpoint {
485                    threshold: 200.0,
486                    distribution: ConditionalDistributionParams::Fixed { value: 20.0 },
487                },
488            ],
489            ConditionalDistributionParams::Fixed { value: 0.0 },
490        );
491        let mut sampler = ConditionalSampler::new(42, config).unwrap();
492
493        // Below first threshold
494        assert_eq!(sampler.sample(50.0), 0.0);
495
496        // Between first and second threshold
497        assert_eq!(sampler.sample(150.0), 10.0);
498
499        // Above second threshold
500        assert_eq!(sampler.sample(250.0), 20.0);
501    }
502
503    #[test]
504    fn test_discount_by_amount_preset() {
505        let config = conditional_presets::discount_by_amount();
506        assert!(config.validate().is_ok());
507
508        let mut sampler = ConditionalSampler::new(42, config).unwrap();
509
510        // Small orders: no discount or very small
511        let small_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(500.0)).collect();
512        let avg_small: f64 = small_discounts.iter().sum::<f64>() / 100.0;
513        assert!(avg_small < 0.01); // Should be 0 or very small
514
515        // Medium orders: small discount
516        sampler.reset(42);
517        let medium_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(3000.0)).collect();
518        let avg_medium: f64 = medium_discounts.iter().sum::<f64>() / 100.0;
519        assert!(avg_medium > 0.01 && avg_medium < 0.06);
520
521        // Large orders: higher discount
522        sampler.reset(42);
523        let large_discounts: Vec<f64> = (0..100).map(|_| sampler.sample(150000.0)).collect();
524        let avg_large: f64 = large_discounts.iter().sum::<f64>() / 100.0;
525        assert!(avg_large > 0.08);
526    }
527
528    #[test]
529    fn test_approval_level_preset() {
530        let config = conditional_presets::approval_level_by_amount();
531        assert!(config.validate().is_ok());
532
533        let mut sampler = ConditionalSampler::new(42, config).unwrap();
534
535        // Small amounts: level 1
536        let level = sampler.sample(500.0);
537        assert_eq!(level, 1.0);
538
539        // Large amounts: level 3-4
540        sampler.reset(42);
541        let levels: Vec<f64> = (0..100).map(|_| sampler.sample(75000.0)).collect();
542        let avg_level: f64 = levels.iter().sum::<f64>() / 100.0;
543        assert!(avg_level >= 3.0);
544    }
545
546    #[test]
547    fn test_distribution_params_sampling() {
548        let mut rng = ChaCha8Rng::seed_from_u64(42);
549
550        // Test Normal
551        let normal = ConditionalDistributionParams::Normal {
552            mu: 10.0,
553            sigma: 1.0,
554        };
555        let samples: Vec<f64> = (0..1000).map(|_| normal.sample(&mut rng)).collect();
556        let mean: f64 = samples.iter().sum::<f64>() / 1000.0;
557        assert!((mean - 10.0).abs() < 0.5);
558
559        // Test Beta
560        let beta = ConditionalDistributionParams::Beta {
561            alpha: 2.0,
562            beta: 5.0,
563            min: 0.0,
564            max: 1.0,
565        };
566        let samples: Vec<f64> = (0..1000).map(|_| beta.sample(&mut rng)).collect();
567        assert!(samples.iter().all(|&x| (0.0..=1.0).contains(&x)));
568
569        // Test Discrete
570        let discrete = ConditionalDistributionParams::Discrete {
571            values: vec![1.0, 2.0, 3.0],
572            weights: vec![0.5, 0.3, 0.2],
573        };
574        let samples: Vec<f64> = (0..1000).map(|_| discrete.sample(&mut rng)).collect();
575        let count_1 = samples.iter().filter(|&&x| x == 1.0).count();
576        assert!(count_1 > 400 && count_1 < 600); // ~50%
577    }
578
579    #[test]
580    fn test_conditional_determinism() {
581        let config = conditional_presets::discount_by_amount();
582
583        let mut sampler1 = ConditionalSampler::new(42, config.clone()).unwrap();
584        let mut sampler2 = ConditionalSampler::new(42, config).unwrap();
585
586        for amount in [100.0, 1000.0, 10000.0, 100000.0] {
587            assert_eq!(sampler1.sample(amount), sampler2.sample(amount));
588        }
589    }
590}