datasynth_core/distributions/
line_item.rs1use rand::prelude::*;
13use rand_chacha::ChaCha8Rng;
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct LineItemDistributionConfig {
21 pub two_items: f64,
23 pub three_items: f64,
25 pub four_items: f64,
27 pub five_items: f64,
29 pub six_items: f64,
31 pub seven_items: f64,
33 pub eight_items: f64,
35 pub nine_items: f64,
37 pub ten_to_ninety_nine: f64,
39 pub hundred_to_nine_ninety_nine: f64,
41 pub thousand_plus: f64,
43}
44
45impl Default for LineItemDistributionConfig {
46 fn default() -> Self {
47 Self {
56 two_items: 0.62,
57 three_items: 0.10,
58 four_items: 0.16,
59 five_items: 0.04,
60 six_items: 0.025,
61 seven_items: 0.012,
62 eight_items: 0.012,
63 nine_items: 0.006,
64 ten_to_ninety_nine: 0.024,
65 hundred_to_nine_ninety_nine: 0.0008,
66 thousand_plus: 0.00002,
67 }
68 }
69}
70
71impl LineItemDistributionConfig {
72 pub fn paper_reference() -> Self {
79 Self {
80 two_items: 0.6068,
81 three_items: 0.0577,
82 four_items: 0.1663,
83 five_items: 0.0306,
84 six_items: 0.0332,
85 seven_items: 0.0113,
86 eight_items: 0.0188,
87 nine_items: 0.0042,
88 ten_to_ninety_nine: 0.0633,
89 hundred_to_nine_ninety_nine: 0.0076,
90 thousand_plus: 0.0002,
91 }
92 }
93
94 pub fn validate(&self) -> Result<(), String> {
96 let sum = self.two_items
97 + self.three_items
98 + self.four_items
99 + self.five_items
100 + self.six_items
101 + self.seven_items
102 + self.eight_items
103 + self.nine_items
104 + self.ten_to_ninety_nine
105 + self.hundred_to_nine_ninety_nine
106 + self.thousand_plus;
107
108 if (sum - 1.0).abs() > 0.01 {
109 return Err(format!(
110 "Line item distribution probabilities sum to {sum}, expected ~1.0"
111 ));
112 }
113 Ok(())
114 }
115
116 fn cumulative(&self) -> [f64; 11] {
118 let mut cum = [0.0; 11];
119 cum[0] = self.two_items;
120 cum[1] = cum[0] + self.three_items;
121 cum[2] = cum[1] + self.four_items;
122 cum[3] = cum[2] + self.five_items;
123 cum[4] = cum[3] + self.six_items;
124 cum[5] = cum[4] + self.seven_items;
125 cum[6] = cum[5] + self.eight_items;
126 cum[7] = cum[6] + self.nine_items;
127 cum[8] = cum[7] + self.ten_to_ninety_nine;
128 cum[9] = cum[8] + self.hundred_to_nine_ninety_nine;
129 cum[10] = cum[9] + self.thousand_plus;
130 cum
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct EvenOddDistributionConfig {
137 pub even: f64,
139 pub odd: f64,
141}
142
143impl Default for EvenOddDistributionConfig {
144 fn default() -> Self {
145 Self {
147 even: 0.88,
148 odd: 0.12,
149 }
150 }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct DebitCreditDistributionConfig {
156 pub equal: f64,
158 pub more_debit: f64,
160 pub more_credit: f64,
162}
163
164impl Default for DebitCreditDistributionConfig {
165 fn default() -> Self {
166 Self {
168 equal: 0.82,
169 more_debit: 0.07,
170 more_credit: 0.11,
171 }
172 }
173}
174
175pub struct LineItemSampler {
180 rng: ChaCha8Rng,
182 even_odd_config: EvenOddDistributionConfig,
184 debit_credit_config: DebitCreditDistributionConfig,
186 cumulative: [f64; 11],
188}
189
190impl LineItemSampler {
191 pub fn new(seed: u64) -> Self {
193 let line_config = LineItemDistributionConfig::default();
194 let cumulative = line_config.cumulative();
195
196 Self {
197 rng: ChaCha8Rng::seed_from_u64(seed),
198 even_odd_config: EvenOddDistributionConfig::default(),
199 debit_credit_config: DebitCreditDistributionConfig::default(),
200 cumulative,
201 }
202 }
203
204 pub fn with_config(
206 seed: u64,
207 line_config: LineItemDistributionConfig,
208 even_odd_config: EvenOddDistributionConfig,
209 debit_credit_config: DebitCreditDistributionConfig,
210 ) -> Self {
211 let cumulative = line_config.cumulative();
212
213 Self {
214 rng: ChaCha8Rng::seed_from_u64(seed),
215 even_odd_config,
216 debit_credit_config,
217 cumulative,
218 }
219 }
220
221 pub fn sample_count(&mut self) -> usize {
223 let p: f64 = self.rng.random();
224
225 if p < self.cumulative[0] {
227 2
228 } else if p < self.cumulative[1] {
229 3
230 } else if p < self.cumulative[2] {
231 4
232 } else if p < self.cumulative[3] {
233 5
234 } else if p < self.cumulative[4] {
235 6
236 } else if p < self.cumulative[5] {
237 7
238 } else if p < self.cumulative[6] {
239 8
240 } else if p < self.cumulative[7] {
241 9
242 } else if p < self.cumulative[8] {
243 self.rng.random_range(10..100)
245 } else if p < self.cumulative[9] {
246 self.rng.random_range(100..1000)
248 } else {
249 self.rng.random_range(1000..10000)
251 }
252 }
253
254 pub fn sample_even(&mut self) -> bool {
256 self.rng.random::<f64>() < self.even_odd_config.even
257 }
258
259 pub fn sample_count_with_parity(&mut self) -> usize {
264 let base_count = self.sample_count();
265 let should_be_even = self.sample_even();
266
267 let is_even = base_count.is_multiple_of(2);
269 if should_be_even != is_even {
270 if base_count <= 2 {
272 base_count + 1
274 } else if self.rng.random::<bool>() {
275 base_count + 1
277 } else {
278 base_count - 1
280 }
281 } else {
282 base_count
283 }
284 }
285
286 pub fn sample_debit_credit_type(&mut self) -> DebitCreditSplit {
288 let p: f64 = self.rng.random();
289
290 if p < self.debit_credit_config.equal {
291 DebitCreditSplit::Equal
292 } else if p < self.debit_credit_config.equal + self.debit_credit_config.more_debit {
293 DebitCreditSplit::MoreDebit
294 } else {
295 DebitCreditSplit::MoreCredit
296 }
297 }
298
299 pub fn sample(&mut self) -> LineItemSpec {
301 let total_count = self.sample_count_with_parity();
302 let split_type = self.sample_debit_credit_type();
303
304 let (debit_count, credit_count) = match split_type {
305 DebitCreditSplit::Equal => {
306 let half = total_count / 2;
307 (half, total_count - half)
308 }
309 DebitCreditSplit::MoreDebit => {
310 let debit = (total_count as f64 * 0.6).round() as usize;
312 let debit = debit.max(1).min(total_count - 1);
313 (debit, total_count - debit)
314 }
315 DebitCreditSplit::MoreCredit => {
316 let credit = (total_count as f64 * 0.6).round() as usize;
318 let credit = credit.max(1).min(total_count - 1);
319 (total_count - credit, credit)
320 }
321 };
322
323 LineItemSpec {
324 total_count,
325 debit_count,
326 credit_count,
327 split_type,
328 }
329 }
330
331 pub fn reset(&mut self, seed: u64) {
333 self.rng = ChaCha8Rng::seed_from_u64(seed);
334 }
335}
336
337#[derive(Debug, Clone, Copy, PartialEq, Eq)]
339pub enum DebitCreditSplit {
340 Equal,
342 MoreDebit,
344 MoreCredit,
346}
347
348#[derive(Debug, Clone)]
350pub struct LineItemSpec {
351 pub total_count: usize,
353 pub debit_count: usize,
355 pub credit_count: usize,
357 pub split_type: DebitCreditSplit,
359}
360
361impl LineItemSpec {
362 pub fn is_valid(&self) -> bool {
364 self.total_count >= 2
365 && self.debit_count >= 1
366 && self.credit_count >= 1
367 && self.debit_count + self.credit_count == self.total_count
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_default_config_valid() {
377 let config = LineItemDistributionConfig::default();
378 assert!(config.validate().is_ok());
379 }
380
381 #[test]
382 fn test_sampler_determinism() {
383 let mut sampler1 = LineItemSampler::new(42);
384 let mut sampler2 = LineItemSampler::new(42);
385
386 for _ in 0..100 {
387 assert_eq!(sampler1.sample_count(), sampler2.sample_count());
388 }
389 }
390
391 #[test]
392 fn test_sampler_distribution() {
393 let mut sampler = LineItemSampler::new(42);
394 let sample_size = 100_000;
395
396 let mut counts = std::collections::HashMap::new();
397 for _ in 0..sample_size {
398 let count = sampler.sample_count();
399 *counts.entry(count).or_insert(0) += 1;
400 }
401
402 let two_count = *counts.get(&2).unwrap_or(&0) as f64 / sample_size as f64;
404 assert!(
405 two_count > 0.55 && two_count < 0.65,
406 "Expected ~60% 2-item entries, got {}%",
407 two_count * 100.0
408 );
409
410 let four_count = *counts.get(&4).unwrap_or(&0) as f64 / sample_size as f64;
412 assert!(
413 four_count > 0.13 && four_count < 0.20,
414 "Expected ~16% 4-item entries, got {}%",
415 four_count * 100.0
416 );
417 }
418
419 #[test]
420 fn default_line_count_mean_is_corpus_scale() {
421 let mut sampler = LineItemSampler::new(7);
422 let n = 200_000;
423 let sum: usize = (0..n).map(|_| sampler.sample_count()).sum();
424 let mean = sum as f64 / n as f64;
425 assert!(
429 (3.0..=6.5).contains(&mean),
430 "default line-count mean should be corpus-scale (~4.5), got {mean:.2}"
431 );
432 }
433
434 #[test]
435 fn test_line_item_spec_valid() {
436 let mut sampler = LineItemSampler::new(42);
437
438 for _ in 0..1000 {
439 let spec = sampler.sample();
440 assert!(spec.is_valid(), "Invalid spec: {:?}", spec);
441 }
442 }
443}