datasynth_generators/sourcing/
spend_analysis_generator.rs1use datasynth_config::schema::SpendAnalysisConfig;
6use datasynth_core::models::sourcing::{SpendAnalysis, VendorSpendShare};
7use datasynth_core::utils::seeded_rng;
8use rand::prelude::*;
9use rand_chacha::ChaCha8Rng;
10use rust_decimal::Decimal;
11
12pub struct SpendAnalysisGenerator {
14 rng: ChaCha8Rng,
15 config: SpendAnalysisConfig,
16}
17
18impl SpendAnalysisGenerator {
19 pub fn new(seed: u64) -> Self {
21 Self {
22 rng: seeded_rng(seed, 0),
23 config: SpendAnalysisConfig::default(),
24 }
25 }
26
27 pub fn with_config(seed: u64, config: SpendAnalysisConfig) -> Self {
29 Self {
30 rng: seeded_rng(seed, 0),
31 config,
32 }
33 }
34
35 pub fn generate(
43 &mut self,
44 company_code: &str,
45 vendor_ids: &[String],
46 categories: &[(String, String)],
47 fiscal_year: u16,
48 ) -> Vec<SpendAnalysis> {
49 let mut analyses = Vec::new();
50
51 for (cat_id, cat_name) in categories {
52 let vendor_count = self.rng.random_range(3..=vendor_ids.len().min(15));
54 let mut cat_vendors: Vec<&String> = vendor_ids
55 .choose_multiple(&mut self.rng, vendor_count)
56 .collect();
57 cat_vendors.shuffle(&mut self.rng);
58
59 let mut raw_shares: Vec<f64> = (0..cat_vendors.len())
61 .map(|i| 1.0 / ((i as f64 + 1.0).powf(0.8)))
62 .collect();
63 let total: f64 = raw_shares.iter().sum();
64 for s in &mut raw_shares {
65 *s /= total;
66 }
67
68 let total_spend = Decimal::from(self.rng.random_range(100_000i64..=5_000_000));
69 let transaction_count = self.rng.random_range(50..=2000);
70
71 let hhi: f64 = raw_shares.iter().map(|s| (s * 100.0).powi(2)).sum();
73
74 let contract_coverage = self.rng.random_range(0.3..=0.95);
75 let preferred_coverage = contract_coverage * self.rng.random_range(0.7..=1.0);
76
77 let vendor_shares: Vec<VendorSpendShare> = cat_vendors
78 .iter()
79 .zip(raw_shares.iter())
80 .map(|(vid, share)| VendorSpendShare {
81 vendor_id: vid.to_string(),
82 vendor_name: format!("Vendor {}", vid),
83 spend_amount: Decimal::from_f64_retain(
84 total_spend.to_string().parse::<f64>().unwrap_or(0.0) * share,
85 )
86 .unwrap_or(Decimal::ZERO),
87 share: *share,
88 is_preferred: *share > 0.15 && self.rng.random_bool(preferred_coverage),
89 })
90 .collect();
91
92 analyses.push(SpendAnalysis {
93 category_id: cat_id.clone(),
94 category_name: cat_name.clone(),
95 company_code: company_code.to_string(),
96 total_spend,
97 vendor_count: cat_vendors.len() as u32,
98 transaction_count,
99 hhi_index: hhi,
100 vendor_shares,
101 contract_coverage,
102 preferred_vendor_coverage: preferred_coverage,
103 price_trend_pct: self.rng.random_range(-0.05..=0.10),
104 fiscal_year,
105 });
106 }
107
108 analyses
109 }
110
111 pub fn hhi_threshold(&self) -> f64 {
113 self.config.hhi_threshold
114 }
115}
116
117#[cfg(test)]
118#[allow(clippy::unwrap_used)]
119mod tests {
120 use super::*;
121
122 fn test_vendor_ids() -> Vec<String> {
123 (1..=10).map(|i| format!("V{:04}", i)).collect()
124 }
125
126 fn test_categories() -> Vec<(String, String)> {
127 vec![
128 ("CAT-001".to_string(), "Office Supplies".to_string()),
129 ("CAT-002".to_string(), "IT Equipment".to_string()),
130 ]
131 }
132
133 #[test]
134 fn test_basic_generation() {
135 let mut gen = SpendAnalysisGenerator::new(42);
136 let results = gen.generate("C001", &test_vendor_ids(), &test_categories(), 2024);
137
138 assert_eq!(results.len(), 2);
139 for analysis in &results {
140 assert_eq!(analysis.company_code, "C001");
141 assert_eq!(analysis.fiscal_year, 2024);
142 assert!(!analysis.category_id.is_empty());
143 assert!(!analysis.category_name.is_empty());
144 assert!(analysis.vendor_count > 0);
145 assert!(analysis.transaction_count > 0);
146 assert!(analysis.total_spend > Decimal::ZERO);
147 assert!(analysis.hhi_index > 0.0);
148 assert!(!analysis.vendor_shares.is_empty());
149 }
150 }
151
152 #[test]
153 fn test_deterministic() {
154 let mut gen1 = SpendAnalysisGenerator::new(42);
155 let mut gen2 = SpendAnalysisGenerator::new(42);
156 let vendors = test_vendor_ids();
157 let cats = test_categories();
158
159 let r1 = gen1.generate("C001", &vendors, &cats, 2024);
160 let r2 = gen2.generate("C001", &vendors, &cats, 2024);
161
162 assert_eq!(r1.len(), r2.len());
163 for (a, b) in r1.iter().zip(r2.iter()) {
164 assert_eq!(a.category_id, b.category_id);
165 assert_eq!(a.total_spend, b.total_spend);
166 assert_eq!(a.vendor_count, b.vendor_count);
167 assert_eq!(a.transaction_count, b.transaction_count);
168 }
169 }
170
171 #[test]
172 fn test_field_constraints() {
173 let mut gen = SpendAnalysisGenerator::new(99);
174 let results = gen.generate("C001", &test_vendor_ids(), &test_categories(), 2024);
175
176 for analysis in &results {
177 let share_sum: f64 = analysis.vendor_shares.iter().map(|s| s.share).sum();
179 assert!(
180 (share_sum - 1.0).abs() < 0.01,
181 "shares should sum to ~1.0, got {}",
182 share_sum
183 );
184
185 assert!(analysis.contract_coverage >= 0.0 && analysis.contract_coverage <= 1.0);
187 assert!(
188 analysis.preferred_vendor_coverage >= 0.0
189 && analysis.preferred_vendor_coverage <= 1.0
190 );
191 assert!(analysis.price_trend_pct >= -0.05 && analysis.price_trend_pct <= 0.10);
192
193 for vs in &analysis.vendor_shares {
195 assert!(!vs.vendor_id.is_empty());
196 }
197 }
198 }
199
200 #[test]
201 fn test_hhi_threshold() {
202 let gen = SpendAnalysisGenerator::new(42);
203 assert_eq!(gen.hhi_threshold(), 2500.0);
204 }
205}