1use chrono::NaiveDate;
4use datasynth_config::schema::SocialConfig;
5use datasynth_core::models::{
6 DiversityDimension, GovernanceMetric, IncidentType, OrganizationLevel, PayEquityMetric,
7 SafetyIncident, SafetyMetric, WorkforceDiversityMetric,
8};
9use datasynth_core::utils::seeded_rng;
10use rand::prelude::*;
11use rand_chacha::ChaCha8Rng;
12use rust_decimal::Decimal;
13use rust_decimal_macros::dec;
14
15pub struct WorkforceGenerator {
17 rng: ChaCha8Rng,
18 config: SocialConfig,
19 counter: u64,
20}
21
22impl WorkforceGenerator {
23 pub fn new(config: SocialConfig, seed: u64) -> Self {
25 Self {
26 rng: seeded_rng(seed, 0),
27 config,
28 counter: 0,
29 }
30 }
31
32 pub fn generate_diversity(
38 &mut self,
39 entity_id: &str,
40 total_headcount: u32,
41 period: NaiveDate,
42 ) -> Vec<WorkforceDiversityMetric> {
43 if !self.config.diversity.enabled || total_headcount == 0 {
44 return Vec::new();
45 }
46
47 let mut metrics = Vec::new();
48 let levels = [
49 OrganizationLevel::Corporate,
50 OrganizationLevel::Executive,
51 OrganizationLevel::Board,
52 ];
53
54 for dimension in &[
55 DiversityDimension::Gender,
56 DiversityDimension::Ethnicity,
57 DiversityDimension::Age,
58 ] {
59 let categories = self.categories_for(*dimension);
60 for level in &levels {
62 let level_hc = match level {
63 OrganizationLevel::Corporate => total_headcount,
64 OrganizationLevel::Executive => (total_headcount / 50).max(5),
65 OrganizationLevel::Board => 11,
66 _ => total_headcount,
67 };
68
69 let shares = self.random_shares(categories.len());
70 for (i, cat) in categories.iter().enumerate() {
71 self.counter += 1;
72 let headcount = (Decimal::from(level_hc)
73 * Decimal::from_f64_retain(shares[i]).unwrap_or(Decimal::ZERO))
74 .round_dp(0);
75 let hc = headcount.to_string().parse::<u32>().unwrap_or(0);
76
77 let percentage = if level_hc > 0 {
78 (Decimal::from(hc) / Decimal::from(level_hc)).round_dp(4)
79 } else {
80 Decimal::ZERO
81 };
82
83 metrics.push(WorkforceDiversityMetric {
84 id: format!("DV-{:06}", self.counter),
85 entity_id: entity_id.to_string(),
86 period,
87 dimension: *dimension,
88 level: *level,
89 category: cat.to_string(),
90 headcount: hc,
91 total_headcount: level_hc,
92 percentage,
93 });
94 }
95 }
96 }
97
98 metrics
99 }
100
101 fn categories_for(&self, dimension: DiversityDimension) -> Vec<&'static str> {
102 match dimension {
103 DiversityDimension::Gender => vec!["Male", "Female", "Non-Binary"],
104 DiversityDimension::Ethnicity => {
105 vec!["White", "Asian", "Hispanic", "Black", "Other"]
106 }
107 DiversityDimension::Age => {
108 vec!["Under 30", "30-50", "Over 50"]
109 }
110 DiversityDimension::Disability => vec!["No Disability", "Has Disability"],
111 DiversityDimension::VeteranStatus => vec!["Non-Veteran", "Veteran"],
112 }
113 }
114
115 fn random_shares(&mut self, count: usize) -> Vec<f64> {
117 let mut raw: Vec<f64> = (0..count).map(|_| self.rng.random::<f64>()).collect();
118 let total: f64 = raw.iter().sum();
119 if total > 0.0 {
120 for v in &mut raw {
121 *v /= total;
122 }
123 }
124 raw
125 }
126
127 pub fn generate_pay_equity(
131 &mut self,
132 entity_id: &str,
133 period: NaiveDate,
134 ) -> Vec<PayEquityMetric> {
135 if !self.config.pay_equity.enabled {
136 return Vec::new();
137 }
138
139 let comparisons = [
140 (DiversityDimension::Gender, "Male", "Female"),
141 (DiversityDimension::Ethnicity, "White", "Asian"),
142 (DiversityDimension::Ethnicity, "White", "Hispanic"),
143 (DiversityDimension::Ethnicity, "White", "Black"),
144 ];
145
146 comparisons
147 .iter()
148 .map(|(dim, ref_group, cmp_group)| {
149 self.counter += 1;
150 let ref_salary: f64 = self.rng.random_range(70_000.0..120_000.0);
151 let gap_factor: f64 = self.rng.random_range(0.85..1.05);
153 let cmp_salary = ref_salary * gap_factor;
154
155 let ref_dec = Decimal::from_f64_retain(ref_salary)
156 .unwrap_or(dec!(90000))
157 .round_dp(2);
158 let cmp_dec = Decimal::from_f64_retain(cmp_salary)
159 .unwrap_or(dec!(85000))
160 .round_dp(2);
161 let ratio = if ref_dec.is_zero() {
162 dec!(1.00)
163 } else {
164 (cmp_dec / ref_dec).round_dp(4)
165 };
166
167 PayEquityMetric {
168 id: format!("PE-{:06}", self.counter),
169 entity_id: entity_id.to_string(),
170 period,
171 dimension: *dim,
172 reference_group: ref_group.to_string(),
173 comparison_group: cmp_group.to_string(),
174 reference_median_salary: ref_dec,
175 comparison_median_salary: cmp_dec,
176 pay_gap_ratio: ratio,
177 sample_size: self.rng.random_range(50..500),
178 }
179 })
180 .collect()
181 }
182
183 pub fn generate_safety_incidents(
187 &mut self,
188 entity_id: &str,
189 facility_count: u32,
190 start_date: NaiveDate,
191 end_date: NaiveDate,
192 ) -> Vec<SafetyIncident> {
193 if !self.config.safety.enabled {
194 return Vec::new();
195 }
196
197 let total_incidents = self.config.safety.incident_count;
198 let period_days = (end_date - start_date).num_days().max(1);
199
200 (0..total_incidents)
201 .map(|_| {
202 self.counter += 1;
203 let day_offset = self.rng.random_range(0..period_days);
204 let date = start_date + chrono::Duration::days(day_offset);
205 let fac = self.rng.random_range(1..=facility_count.max(1));
206
207 let incident_type = self.pick_incident_type();
208 let days_away = match incident_type {
209 IncidentType::Fatality => 0,
210 IncidentType::NearMiss | IncidentType::PropertyDamage => 0,
211 IncidentType::Injury => self.rng.random_range(0..30u32),
212 IncidentType::Illness => self.rng.random_range(1..15u32),
213 };
214 let is_recordable = !matches!(incident_type, IncidentType::NearMiss);
215
216 let description = match incident_type {
217 IncidentType::Injury => "Workplace injury incident".to_string(),
218 IncidentType::Illness => "Occupational illness reported".to_string(),
219 IncidentType::NearMiss => "Near miss event documented".to_string(),
220 IncidentType::Fatality => "Fatal workplace incident".to_string(),
221 IncidentType::PropertyDamage => "Property damage incident".to_string(),
222 };
223
224 SafetyIncident {
225 id: format!("SI-{:06}", self.counter),
226 entity_id: entity_id.to_string(),
227 facility_id: format!("FAC-{:03}", fac),
228 date,
229 incident_type,
230 days_away,
231 is_recordable,
232 description,
233 }
234 })
235 .collect()
236 }
237
238 fn pick_incident_type(&mut self) -> IncidentType {
239 let roll: f64 = self.rng.random::<f64>();
240 if roll < 0.35 {
241 IncidentType::NearMiss
242 } else if roll < 0.65 {
243 IncidentType::Injury
244 } else if roll < 0.80 {
245 IncidentType::Illness
246 } else if roll < 0.95 {
247 IncidentType::PropertyDamage
248 } else {
249 IncidentType::Fatality
250 }
251 }
252
253 pub fn compute_safety_metrics(
255 &mut self,
256 entity_id: &str,
257 incidents: &[SafetyIncident],
258 total_hours_worked: u64,
259 period: NaiveDate,
260 ) -> SafetyMetric {
261 self.counter += 1;
262
263 let recordable = incidents.iter().filter(|i| i.is_recordable).count() as u32;
264 let lost_time = incidents.iter().filter(|i| i.days_away > 0).count() as u32;
265 let days_away: u32 = incidents.iter().map(|i| i.days_away).sum();
266 let near_misses = incidents
267 .iter()
268 .filter(|i| i.incident_type == IncidentType::NearMiss)
269 .count() as u32;
270 let fatalities = incidents
271 .iter()
272 .filter(|i| i.incident_type == IncidentType::Fatality)
273 .count() as u32;
274
275 let hours_dec = Decimal::from(total_hours_worked);
276 let base = dec!(200000);
277
278 let trir = if total_hours_worked > 0 {
279 (Decimal::from(recordable) * base / hours_dec).round_dp(4)
280 } else {
281 Decimal::ZERO
282 };
283 let ltir = if total_hours_worked > 0 {
284 (Decimal::from(lost_time) * base / hours_dec).round_dp(4)
285 } else {
286 Decimal::ZERO
287 };
288 let dart_rate = if total_hours_worked > 0 {
289 (Decimal::from(days_away) * base / hours_dec).round_dp(4)
290 } else {
291 Decimal::ZERO
292 };
293
294 SafetyMetric {
295 id: format!("SM-{:06}", self.counter),
296 entity_id: entity_id.to_string(),
297 period,
298 total_hours_worked,
299 recordable_incidents: recordable,
300 lost_time_incidents: lost_time,
301 days_away,
302 near_misses,
303 fatalities,
304 trir,
305 ltir,
306 dart_rate,
307 }
308 }
309}
310
311pub struct GovernanceGenerator {
313 rng: ChaCha8Rng,
314 counter: u64,
315 board_size: u32,
316 independence_target: f64,
317}
318
319impl GovernanceGenerator {
320 pub fn new(seed: u64, board_size: u32, independence_target: f64) -> Self {
322 Self {
323 rng: seeded_rng(seed, 0),
324 counter: 0,
325 board_size: board_size.max(3),
326 independence_target,
327 }
328 }
329
330 pub fn generate(&mut self, entity_id: &str, period: NaiveDate) -> GovernanceMetric {
332 self.counter += 1;
333
334 let ind_frac: f64 = self.rng.random_range(
336 (self.independence_target - 0.10).max(0.0)..(self.independence_target + 0.10).min(1.0),
337 );
338 let independent = (self.board_size as f64 * ind_frac).round() as u32;
339 let independent = independent.min(self.board_size);
340
341 let fem_frac: f64 = self.rng.random_range(0.20..0.40);
343 let female = (self.board_size as f64 * fem_frac).round() as u32;
344 let female = female.min(self.board_size);
345
346 let independence_ratio = if self.board_size > 0 {
347 (Decimal::from(independent) / Decimal::from(self.board_size)).round_dp(4)
348 } else {
349 Decimal::ZERO
350 };
351 let gender_ratio = if self.board_size > 0 {
352 (Decimal::from(female) / Decimal::from(self.board_size)).round_dp(4)
353 } else {
354 Decimal::ZERO
355 };
356
357 let ethics_pct: f64 = self.rng.random_range(0.85..0.99);
358 let whistleblower: u32 = self.rng.random_range(0..5);
359 let anti_corruption: u32 = if self.rng.random::<f64>() < 0.10 {
360 1
361 } else {
362 0
363 };
364
365 GovernanceMetric {
366 id: format!("GV-{:06}", self.counter),
367 entity_id: entity_id.to_string(),
368 period,
369 board_size: self.board_size,
370 independent_directors: independent,
371 female_directors: female,
372 board_independence_ratio: independence_ratio,
373 board_gender_diversity_ratio: gender_ratio,
374 ethics_training_completion_pct: (ethics_pct * 100.0).round() / 100.0,
375 whistleblower_reports: whistleblower,
376 anti_corruption_violations: anti_corruption,
377 }
378 }
379}
380
381#[cfg(test)]
382#[allow(clippy::unwrap_used)]
383mod tests {
384 use super::*;
385
386 fn d(s: &str) -> NaiveDate {
387 NaiveDate::parse_from_str(s, "%Y-%m-%d").unwrap()
388 }
389
390 #[test]
391 fn test_diversity_percentages_sum_to_one() {
392 let config = SocialConfig::default();
393 let mut gen = WorkforceGenerator::new(config, 42);
394 let metrics = gen.generate_diversity("C001", 1000, d("2025-01-01"));
395
396 assert!(!metrics.is_empty());
397
398 let mut groups: std::collections::HashMap<
400 (String, String),
401 Vec<&WorkforceDiversityMetric>,
402 > = std::collections::HashMap::new();
403 for m in &metrics {
404 let key = (format!("{:?}", m.dimension), format!("{:?}", m.level));
405 groups.entry(key).or_default().push(m);
406 }
407
408 for (key, group) in &groups {
409 let total_hc: u32 = group.iter().map(|m| m.headcount).sum();
410 let expected = group[0].total_headcount;
411 assert!(
413 total_hc.abs_diff(expected) <= 1,
414 "Group {:?}: headcount sum {} != expected {}",
415 key,
416 total_hc,
417 expected
418 );
419 }
420 }
421
422 #[test]
423 fn test_pay_equity_ratios() {
424 let config = SocialConfig::default();
425 let mut gen = WorkforceGenerator::new(config, 42);
426 let metrics = gen.generate_pay_equity("C001", d("2025-01-01"));
427
428 assert_eq!(metrics.len(), 4, "Should have 4 comparison pairs");
429 for m in &metrics {
430 assert!(m.pay_gap_ratio > dec!(0.80) && m.pay_gap_ratio < dec!(1.10));
431 assert!(m.reference_median_salary > Decimal::ZERO);
432 assert!(m.comparison_median_salary > Decimal::ZERO);
433 assert!(m.sample_size > 0);
434 }
435 }
436
437 #[test]
438 fn test_safety_incidents() {
439 let config = SocialConfig {
440 safety: datasynth_config::schema::SafetySchemaConfig {
441 enabled: true,
442 target_trir: 2.5,
443 incident_count: 30,
444 },
445 ..Default::default()
446 };
447 let mut gen = WorkforceGenerator::new(config, 42);
448 let incidents = gen.generate_safety_incidents("C001", 3, d("2025-01-01"), d("2025-12-31"));
449
450 assert_eq!(incidents.len(), 30);
451
452 let recordable = incidents.iter().filter(|i| i.is_recordable).count();
453 let near_miss = incidents
454 .iter()
455 .filter(|i| i.incident_type == IncidentType::NearMiss)
456 .count();
457 assert_eq!(recordable + near_miss, 30);
459 }
460
461 #[test]
462 fn test_safety_metric_trir_computation() {
463 let config = SocialConfig::default();
464 let mut gen = WorkforceGenerator::new(config, 42);
465
466 let incidents = vec![
467 SafetyIncident {
468 id: "SI-001".into(),
469 entity_id: "C001".into(),
470 facility_id: "FAC-001".into(),
471 date: d("2025-03-15"),
472 incident_type: IncidentType::Injury,
473 days_away: 5,
474 is_recordable: true,
475 description: "Test".into(),
476 },
477 SafetyIncident {
478 id: "SI-002".into(),
479 entity_id: "C001".into(),
480 facility_id: "FAC-001".into(),
481 date: d("2025-06-20"),
482 incident_type: IncidentType::NearMiss,
483 days_away: 0,
484 is_recordable: false,
485 description: "Test".into(),
486 },
487 ];
488
489 let metric = gen.compute_safety_metrics("C001", &incidents, 500_000, d("2025-01-01"));
490
491 assert_eq!(metric.recordable_incidents, 1);
492 assert_eq!(metric.near_misses, 1);
493 assert_eq!(metric.lost_time_incidents, 1);
494 assert_eq!(metric.days_away, 5);
495 assert_eq!(metric.trir, dec!(0.4000));
497 assert_eq!(metric.computed_trir(), dec!(0.4000));
498 }
499
500 #[test]
501 fn test_governance_generation() {
502 let mut gen = GovernanceGenerator::new(42, 11, 0.67);
503 let metric = gen.generate("C001", d("2025-01-01"));
504
505 assert_eq!(metric.board_size, 11);
506 assert!(metric.independent_directors <= 11);
507 assert!(metric.female_directors <= 11);
508 assert!(metric.board_independence_ratio > Decimal::ZERO);
509 assert!(metric.ethics_training_completion_pct >= 0.85);
510 }
511
512 #[test]
513 fn test_disabled_diversity() {
514 let mut config = SocialConfig::default();
515 config.diversity.enabled = false;
516 let mut gen = WorkforceGenerator::new(config, 42);
517 let metrics = gen.generate_diversity("C001", 1000, d("2025-01-01"));
518 assert!(metrics.is_empty());
519 }
520}