datasynth_core/distributions/
line_item.rs1use rand::prelude::*;
13use rand_chacha::ChaCha8Rng;
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct LineItemDistributionConfig {
21 pub two_items: f64,
23 pub three_items: f64,
25 pub four_items: f64,
27 pub five_items: f64,
29 pub six_items: f64,
31 pub seven_items: f64,
33 pub eight_items: f64,
35 pub nine_items: f64,
37 pub ten_to_ninety_nine: f64,
39 pub hundred_to_nine_ninety_nine: f64,
41 pub thousand_plus: f64,
43 #[serde(default = "default_tail_decay")]
49 pub tail_decay: bool,
50}
51
52fn default_tail_decay() -> bool {
53 true
54}
55
56impl Default for LineItemDistributionConfig {
57 fn default() -> Self {
58 Self {
70 two_items: 0.601,
71 three_items: 0.11,
72 four_items: 0.16,
73 five_items: 0.04,
74 six_items: 0.03,
75 seven_items: 0.015,
76 eight_items: 0.015,
77 nine_items: 0.008,
78 ten_to_ninety_nine: 0.0195,
79 hundred_to_nine_ninety_nine: 0.0015,
80 thousand_plus: 0.0,
81 tail_decay: true,
82 }
83 }
84}
85
86impl LineItemDistributionConfig {
87 pub fn paper_reference() -> Self {
94 Self {
95 two_items: 0.6068,
96 three_items: 0.0577,
97 four_items: 0.1663,
98 five_items: 0.0306,
99 six_items: 0.0332,
100 seven_items: 0.0113,
101 eight_items: 0.0188,
102 nine_items: 0.0042,
103 ten_to_ninety_nine: 0.0633,
104 hundred_to_nine_ninety_nine: 0.0076,
105 thousand_plus: 0.0002,
106 tail_decay: false,
110 }
111 }
112
113 pub fn validate(&self) -> Result<(), String> {
115 let sum = self.two_items
116 + self.three_items
117 + self.four_items
118 + self.five_items
119 + self.six_items
120 + self.seven_items
121 + self.eight_items
122 + self.nine_items
123 + self.ten_to_ninety_nine
124 + self.hundred_to_nine_ninety_nine
125 + self.thousand_plus;
126
127 if (sum - 1.0).abs() > 0.01 {
128 return Err(format!(
129 "Line item distribution probabilities sum to {sum}, expected ~1.0"
130 ));
131 }
132 Ok(())
133 }
134
135 fn cumulative(&self) -> [f64; 11] {
137 let mut cum = [0.0; 11];
138 cum[0] = self.two_items;
139 cum[1] = cum[0] + self.three_items;
140 cum[2] = cum[1] + self.four_items;
141 cum[3] = cum[2] + self.five_items;
142 cum[4] = cum[3] + self.six_items;
143 cum[5] = cum[4] + self.seven_items;
144 cum[6] = cum[5] + self.eight_items;
145 cum[7] = cum[6] + self.nine_items;
146 cum[8] = cum[7] + self.ten_to_ninety_nine;
147 cum[9] = cum[8] + self.hundred_to_nine_ninety_nine;
148 cum[10] = cum[9] + self.thousand_plus;
149 cum
150 }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct EvenOddDistributionConfig {
156 pub even: f64,
158 pub odd: f64,
160}
161
162impl Default for EvenOddDistributionConfig {
163 fn default() -> Self {
164 Self {
166 even: 0.88,
167 odd: 0.12,
168 }
169 }
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct DebitCreditDistributionConfig {
175 pub equal: f64,
177 pub more_debit: f64,
179 pub more_credit: f64,
181}
182
183impl Default for DebitCreditDistributionConfig {
184 fn default() -> Self {
185 Self {
187 equal: 0.82,
188 more_debit: 0.07,
189 more_credit: 0.11,
190 }
191 }
192}
193
194pub struct LineItemSampler {
199 rng: ChaCha8Rng,
201 even_odd_config: EvenOddDistributionConfig,
203 debit_credit_config: DebitCreditDistributionConfig,
205 cumulative: [f64; 11],
207 tail_decay: bool,
209}
210
211impl LineItemSampler {
212 pub fn new(seed: u64) -> Self {
214 let line_config = LineItemDistributionConfig::default();
215 let cumulative = line_config.cumulative();
216
217 Self {
218 rng: ChaCha8Rng::seed_from_u64(seed),
219 even_odd_config: EvenOddDistributionConfig::default(),
220 debit_credit_config: DebitCreditDistributionConfig::default(),
221 cumulative,
222 tail_decay: line_config.tail_decay,
223 }
224 }
225
226 pub fn with_config(
228 seed: u64,
229 line_config: LineItemDistributionConfig,
230 even_odd_config: EvenOddDistributionConfig,
231 debit_credit_config: DebitCreditDistributionConfig,
232 ) -> Self {
233 let cumulative = line_config.cumulative();
234 let tail_decay = line_config.tail_decay;
235
236 Self {
237 rng: ChaCha8Rng::seed_from_u64(seed),
238 even_odd_config,
239 debit_credit_config,
240 cumulative,
241 tail_decay,
242 }
243 }
244
245 pub fn sample_count(&mut self) -> usize {
247 let p: f64 = self.rng.random();
248
249 if p < self.cumulative[0] {
251 2
252 } else if p < self.cumulative[1] {
253 3
254 } else if p < self.cumulative[2] {
255 4
256 } else if p < self.cumulative[3] {
257 5
258 } else if p < self.cumulative[4] {
259 6
260 } else if p < self.cumulative[5] {
261 7
262 } else if p < self.cumulative[6] {
263 8
264 } else if p < self.cumulative[7] {
265 9
266 } else if p < self.cumulative[8] {
267 if self.tail_decay {
273 let u: f64 = self.rng.random();
274 (10.0 + -(1.0 - u).ln() * 8.0).min(99.0) as usize
275 } else {
276 self.rng.random_range(10..100)
277 }
278 } else if p < self.cumulative[9] {
279 if self.tail_decay {
282 let u: f64 = self.rng.random();
283 (100.0 + -(1.0 - u).ln() * 120.0).min(600.0) as usize
284 } else {
285 self.rng.random_range(100..1000)
286 }
287 } else if self.tail_decay {
288 let u: f64 = self.rng.random();
291 (100.0 + -(1.0 - u).ln() * 120.0).min(600.0) as usize
292 } else {
293 self.rng.random_range(1000..10000)
295 }
296 }
297
298 pub fn sample_even(&mut self) -> bool {
300 self.rng.random::<f64>() < self.even_odd_config.even
301 }
302
303 pub fn sample_count_with_parity(&mut self) -> usize {
308 let base_count = self.sample_count();
309 let should_be_even = self.sample_even();
310
311 let is_even = base_count.is_multiple_of(2);
313 if should_be_even != is_even {
314 if base_count <= 2 {
316 base_count + 1
318 } else if self.rng.random::<bool>() {
319 base_count + 1
321 } else {
322 base_count - 1
324 }
325 } else {
326 base_count
327 }
328 }
329
330 pub fn sample_debit_credit_type(&mut self) -> DebitCreditSplit {
332 let p: f64 = self.rng.random();
333
334 if p < self.debit_credit_config.equal {
335 DebitCreditSplit::Equal
336 } else if p < self.debit_credit_config.equal + self.debit_credit_config.more_debit {
337 DebitCreditSplit::MoreDebit
338 } else {
339 DebitCreditSplit::MoreCredit
340 }
341 }
342
343 pub fn sample(&mut self) -> LineItemSpec {
345 let total_count = self.sample_count_with_parity();
346 let split_type = self.sample_debit_credit_type();
347
348 let (debit_count, credit_count) = match split_type {
349 DebitCreditSplit::Equal => {
350 let half = total_count / 2;
351 (half, total_count - half)
352 }
353 DebitCreditSplit::MoreDebit => {
354 let debit = (total_count as f64 * 0.6).round() as usize;
356 let debit = debit.max(1).min(total_count - 1);
357 (debit, total_count - debit)
358 }
359 DebitCreditSplit::MoreCredit => {
360 let credit = (total_count as f64 * 0.6).round() as usize;
362 let credit = credit.max(1).min(total_count - 1);
363 (total_count - credit, credit)
364 }
365 };
366
367 LineItemSpec {
368 total_count,
369 debit_count,
370 credit_count,
371 split_type,
372 }
373 }
374
375 pub fn reset(&mut self, seed: u64) {
377 self.rng = ChaCha8Rng::seed_from_u64(seed);
378 }
379}
380
381#[derive(Debug, Clone, Copy, PartialEq, Eq)]
383pub enum DebitCreditSplit {
384 Equal,
386 MoreDebit,
388 MoreCredit,
390}
391
392#[derive(Debug, Clone)]
394pub struct LineItemSpec {
395 pub total_count: usize,
397 pub debit_count: usize,
399 pub credit_count: usize,
401 pub split_type: DebitCreditSplit,
403}
404
405impl LineItemSpec {
406 pub fn is_valid(&self) -> bool {
408 self.total_count >= 2
409 && self.debit_count >= 1
410 && self.credit_count >= 1
411 && self.debit_count + self.credit_count == self.total_count
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418
419 #[test]
420 fn test_default_config_valid() {
421 let config = LineItemDistributionConfig::default();
422 assert!(config.validate().is_ok());
423 }
424
425 #[test]
426 fn test_sampler_determinism() {
427 let mut sampler1 = LineItemSampler::new(42);
428 let mut sampler2 = LineItemSampler::new(42);
429
430 for _ in 0..100 {
431 assert_eq!(sampler1.sample_count(), sampler2.sample_count());
432 }
433 }
434
435 #[test]
436 fn test_sampler_distribution() {
437 let mut sampler = LineItemSampler::new(42);
438 let sample_size = 100_000;
439
440 let mut counts = std::collections::HashMap::new();
441 for _ in 0..sample_size {
442 let count = sampler.sample_count();
443 *counts.entry(count).or_insert(0) += 1;
444 }
445
446 let two_count = *counts.get(&2).unwrap_or(&0) as f64 / sample_size as f64;
448 assert!(
449 two_count > 0.55 && two_count < 0.65,
450 "Expected ~60% 2-item entries, got {}%",
451 two_count * 100.0
452 );
453
454 let four_count = *counts.get(&4).unwrap_or(&0) as f64 / sample_size as f64;
456 assert!(
457 four_count > 0.13 && four_count < 0.20,
458 "Expected ~16% 4-item entries, got {}%",
459 four_count * 100.0
460 );
461 }
462
463 #[test]
464 fn default_line_count_mean_is_corpus_scale() {
465 let mut sampler = LineItemSampler::new(7);
466 let n = 500_000;
467 let mut counts: Vec<usize> = (0..n).map(|_| sampler.sample_count()).collect();
468 let mean = counts.iter().sum::<usize>() as f64 / n as f64;
469 counts.sort_unstable();
470 let p99 = counts[(n as f64 * 0.99) as usize];
471 let max = *counts.last().unwrap();
472 assert!(
477 (3.0..=5.0).contains(&mean),
478 "default line-count mean should be corpus-scale (~3.5), got {mean:.2}"
479 );
480 assert!(
481 p99 <= 30,
482 "p99 line count should be ~18 (corpus), got {p99}"
483 );
484 assert!(
485 max <= 600,
486 "max line count must be capped at the corpus max (~600), got {max}"
487 );
488 }
489
490 #[test]
491 fn test_line_item_spec_valid() {
492 let mut sampler = LineItemSampler::new(42);
493
494 for _ in 0..1000 {
495 let spec = sampler.sample();
496 assert!(spec.is_valid(), "Invalid spec: {:?}", spec);
497 }
498 }
499}