1use chrono::NaiveDate;
4use datasynth_config::schema::SocialConfig;
5use datasynth_core::models::{
6 DiversityDimension, Employee, GovernanceMetric, IncidentType, OrganizationLevel,
7 PayEquityMetric, PayrollLineItem, SafetyIncident, SafetyMetric, WorkforceDiversityMetric,
8};
9use datasynth_core::utils::seeded_rng;
10use rand::prelude::*;
11use rand_chacha::ChaCha8Rng;
12use rust_decimal::Decimal;
13use rust_decimal_macros::dec;
14
15pub struct WorkforceGenerator {
17 rng: ChaCha8Rng,
18 config: SocialConfig,
19 counter: u64,
20}
21
22impl WorkforceGenerator {
23 pub fn new(config: SocialConfig, seed: u64) -> Self {
25 Self {
26 rng: seeded_rng(seed, 0),
27 config,
28 counter: 0,
29 }
30 }
31
32 pub fn generate_diversity(
38 &mut self,
39 entity_id: &str,
40 total_headcount: u32,
41 period: NaiveDate,
42 ) -> Vec<WorkforceDiversityMetric> {
43 if !self.config.diversity.enabled || total_headcount == 0 {
44 return Vec::new();
45 }
46
47 let mut metrics = Vec::new();
48 let levels = [
49 OrganizationLevel::Corporate,
50 OrganizationLevel::Executive,
51 OrganizationLevel::Board,
52 ];
53
54 for dimension in &[
55 DiversityDimension::Gender,
56 DiversityDimension::Ethnicity,
57 DiversityDimension::Age,
58 ] {
59 let categories = self.categories_for(*dimension);
60 for level in &levels {
62 let level_hc = match level {
63 OrganizationLevel::Corporate => total_headcount,
64 OrganizationLevel::Executive => (total_headcount / 50).max(5),
65 OrganizationLevel::Board => 11,
66 _ => total_headcount,
67 };
68
69 let shares = self.random_shares(categories.len());
70 for (i, cat) in categories.iter().enumerate() {
71 self.counter += 1;
72 let headcount = (Decimal::from(level_hc)
73 * Decimal::from_f64_retain(shares[i]).unwrap_or(Decimal::ZERO))
74 .round_dp(0);
75 let hc = headcount.to_string().parse::<u32>().unwrap_or(0);
76
77 let percentage = if level_hc > 0 {
78 (Decimal::from(hc) / Decimal::from(level_hc)).round_dp(4)
79 } else {
80 Decimal::ZERO
81 };
82
83 metrics.push(WorkforceDiversityMetric {
84 id: format!("DV-{:06}", self.counter),
85 entity_id: entity_id.to_string(),
86 period,
87 dimension: *dimension,
88 level: *level,
89 category: cat.to_string(),
90 headcount: hc,
91 total_headcount: level_hc,
92 percentage,
93 });
94 }
95 }
96 }
97
98 metrics
99 }
100
101 fn categories_for(&self, dimension: DiversityDimension) -> Vec<&'static str> {
102 match dimension {
103 DiversityDimension::Gender => vec!["Male", "Female", "Non-Binary"],
104 DiversityDimension::Ethnicity => {
105 vec!["White", "Asian", "Hispanic", "Black", "Other"]
106 }
107 DiversityDimension::Age => {
108 vec!["Under 30", "30-50", "Over 50"]
109 }
110 DiversityDimension::Disability => vec!["No Disability", "Has Disability"],
111 DiversityDimension::VeteranStatus => vec!["Non-Veteran", "Veteran"],
112 }
113 }
114
115 fn random_shares(&mut self, count: usize) -> Vec<f64> {
117 let mut raw: Vec<f64> = (0..count).map(|_| self.rng.random::<f64>()).collect();
118 let total: f64 = raw.iter().sum();
119 if total > 0.0 {
120 for v in &mut raw {
121 *v /= total;
122 }
123 }
124 raw
125 }
126
127 pub fn generate_pay_equity(
131 &mut self,
132 entity_id: &str,
133 period: NaiveDate,
134 ) -> Vec<PayEquityMetric> {
135 if !self.config.pay_equity.enabled {
136 return Vec::new();
137 }
138
139 let comparisons = [
140 (DiversityDimension::Gender, "Male", "Female"),
141 (DiversityDimension::Ethnicity, "White", "Asian"),
142 (DiversityDimension::Ethnicity, "White", "Hispanic"),
143 (DiversityDimension::Ethnicity, "White", "Black"),
144 ];
145
146 comparisons
147 .iter()
148 .map(|(dim, ref_group, cmp_group)| {
149 self.counter += 1;
150 let ref_salary: f64 = self.rng.random_range(70_000.0..120_000.0);
151 let gap_factor: f64 = self.rng.random_range(0.85..1.05);
153 let cmp_salary = ref_salary * gap_factor;
154
155 let ref_dec = Decimal::from_f64_retain(ref_salary)
156 .unwrap_or(dec!(90000))
157 .round_dp(2);
158 let cmp_dec = Decimal::from_f64_retain(cmp_salary)
159 .unwrap_or(dec!(85000))
160 .round_dp(2);
161 let ratio = if ref_dec.is_zero() {
162 dec!(1.00)
163 } else {
164 (cmp_dec / ref_dec).round_dp(4)
165 };
166
167 PayEquityMetric {
168 id: format!("PE-{:06}", self.counter),
169 entity_id: entity_id.to_string(),
170 period,
171 dimension: *dim,
172 reference_group: ref_group.to_string(),
173 comparison_group: cmp_group.to_string(),
174 reference_median_salary: ref_dec,
175 comparison_median_salary: cmp_dec,
176 pay_gap_ratio: ratio,
177 sample_size: self.rng.random_range(50..500),
178 }
179 })
180 .collect()
181 }
182
183 pub fn generate_safety_incidents(
187 &mut self,
188 entity_id: &str,
189 facility_count: u32,
190 start_date: NaiveDate,
191 end_date: NaiveDate,
192 ) -> Vec<SafetyIncident> {
193 if !self.config.safety.enabled {
194 return Vec::new();
195 }
196
197 let total_incidents = self.config.safety.incident_count;
198 let period_days = (end_date - start_date).num_days().max(1);
199
200 (0..total_incidents)
201 .map(|_| {
202 self.counter += 1;
203 let day_offset = self.rng.random_range(0..period_days);
204 let date = start_date + chrono::Duration::days(day_offset);
205 let fac = self.rng.random_range(1..=facility_count.max(1));
206
207 let incident_type = self.pick_incident_type();
208 let days_away = match incident_type {
209 IncidentType::Fatality => 0,
210 IncidentType::NearMiss | IncidentType::PropertyDamage => 0,
211 IncidentType::Injury => self.rng.random_range(0..30u32),
212 IncidentType::Illness => self.rng.random_range(1..15u32),
213 };
214 let is_recordable = !matches!(incident_type, IncidentType::NearMiss);
215
216 let description = match incident_type {
217 IncidentType::Injury => "Workplace injury incident".to_string(),
218 IncidentType::Illness => "Occupational illness reported".to_string(),
219 IncidentType::NearMiss => "Near miss event documented".to_string(),
220 IncidentType::Fatality => "Fatal workplace incident".to_string(),
221 IncidentType::PropertyDamage => "Property damage incident".to_string(),
222 };
223
224 SafetyIncident {
225 id: format!("SI-{:06}", self.counter),
226 entity_id: entity_id.to_string(),
227 facility_id: format!("FAC-{fac:03}"),
228 date,
229 incident_type,
230 days_away,
231 is_recordable,
232 description,
233 }
234 })
235 .collect()
236 }
237
238 fn pick_incident_type(&mut self) -> IncidentType {
239 let roll: f64 = self.rng.random::<f64>();
240 if roll < 0.35 {
241 IncidentType::NearMiss
242 } else if roll < 0.65 {
243 IncidentType::Injury
244 } else if roll < 0.80 {
245 IncidentType::Illness
246 } else if roll < 0.95 {
247 IncidentType::PropertyDamage
248 } else {
249 IncidentType::Fatality
250 }
251 }
252
253 pub fn compute_safety_metrics(
255 &mut self,
256 entity_id: &str,
257 incidents: &[SafetyIncident],
258 total_hours_worked: u64,
259 period: NaiveDate,
260 ) -> SafetyMetric {
261 self.counter += 1;
262
263 let recordable = incidents.iter().filter(|i| i.is_recordable).count() as u32;
264 let lost_time = incidents.iter().filter(|i| i.days_away > 0).count() as u32;
265 let days_away: u32 = incidents.iter().map(|i| i.days_away).sum();
266 let near_misses = incidents
267 .iter()
268 .filter(|i| i.incident_type == IncidentType::NearMiss)
269 .count() as u32;
270 let fatalities = incidents
271 .iter()
272 .filter(|i| i.incident_type == IncidentType::Fatality)
273 .count() as u32;
274
275 let hours_dec = Decimal::from(total_hours_worked);
276 let base = dec!(200000);
277
278 let trir = if total_hours_worked > 0 {
279 (Decimal::from(recordable) * base / hours_dec).round_dp(4)
280 } else {
281 Decimal::ZERO
282 };
283 let ltir = if total_hours_worked > 0 {
284 (Decimal::from(lost_time) * base / hours_dec).round_dp(4)
285 } else {
286 Decimal::ZERO
287 };
288 let dart_rate = if total_hours_worked > 0 {
289 (Decimal::from(days_away) * base / hours_dec).round_dp(4)
290 } else {
291 Decimal::ZERO
292 };
293
294 SafetyMetric {
295 id: format!("SM-{:06}", self.counter),
296 entity_id: entity_id.to_string(),
297 period,
298 total_hours_worked,
299 recordable_incidents: recordable,
300 lost_time_incidents: lost_time,
301 days_away,
302 near_misses,
303 fatalities,
304 trir,
305 ltir,
306 dart_rate,
307 }
308 }
309
310 pub fn generate_diversity_from_employees(
326 &mut self,
327 entity_id: &str,
328 employees: &[Employee],
329 period: NaiveDate,
330 ) -> Vec<WorkforceDiversityMetric> {
331 if employees.is_empty() {
332 return Vec::new();
333 }
334
335 let total_headcount = employees.len() as u32;
336
337 let mut dept_counts: std::collections::HashMap<String, u32> =
339 std::collections::HashMap::new();
340 for emp in employees {
341 let dept = emp
342 .department_id
343 .clone()
344 .unwrap_or_else(|| "Unknown".to_string());
345 *dept_counts.entry(dept).or_insert(0) += 1;
346 }
347
348 let mut metrics = Vec::new();
349
350 let mut sorted_depts: Vec<(String, u32)> = dept_counts.into_iter().collect();
352 sorted_depts.sort_by(|a, b| a.0.cmp(&b.0));
353
354 for (dept, count) in &sorted_depts {
355 self.counter += 1;
356 let percentage = (Decimal::from(*count) / Decimal::from(total_headcount)).round_dp(4);
357 metrics.push(WorkforceDiversityMetric {
358 id: format!("DV-HR-{:06}", self.counter),
359 entity_id: entity_id.to_string(),
360 period,
361 dimension: DiversityDimension::Gender,
362 level: OrganizationLevel::Department,
363 category: dept.clone(),
364 headcount: *count,
365 total_headcount,
366 percentage,
367 });
368 }
369
370 self.counter += 1;
372 metrics.push(WorkforceDiversityMetric {
373 id: format!("DV-HR-{:06}", self.counter),
374 entity_id: entity_id.to_string(),
375 period,
376 dimension: DiversityDimension::Gender,
377 level: OrganizationLevel::Corporate,
378 category: "All".to_string(),
379 headcount: total_headcount,
380 total_headcount,
381 percentage: dec!(1.0000),
382 });
383
384 metrics
385 }
386
387 pub fn generate_pay_equity_from_payroll(
398 &mut self,
399 entity_id: &str,
400 payroll_items: &[PayrollLineItem],
401 period: NaiveDate,
402 ) -> Vec<PayEquityMetric> {
403 if payroll_items.is_empty() {
404 return Vec::new();
405 }
406
407 let mut group_totals: std::collections::HashMap<String, (Decimal, u32)> =
409 std::collections::HashMap::new();
410 for item in payroll_items {
411 let group = item
412 .department
413 .clone()
414 .or_else(|| item.cost_center.clone())
415 .unwrap_or_else(|| "Unknown".to_string());
416 let entry = group_totals.entry(group).or_insert((Decimal::ZERO, 0));
417 entry.0 += item.gross_pay;
418 entry.1 += 1;
419 }
420
421 if group_totals.len() < 2 {
422 return Vec::new();
423 }
424
425 let mut averages: Vec<(String, Decimal, u32)> = group_totals
427 .into_iter()
428 .map(|(g, (total, count))| {
429 let avg = if count > 0 {
430 (total / Decimal::from(count)).round_dp(2)
431 } else {
432 Decimal::ZERO
433 };
434 (g, avg, count)
435 })
436 .collect();
437
438 averages.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
440
441 let (ref_group, ref_avg, ref_count) = averages[0].clone();
442
443 let mut metrics = Vec::new();
444 for (cmp_group, cmp_avg, cmp_count) in averages.iter().skip(1) {
445 self.counter += 1;
446 let ratio = if ref_avg.is_zero() {
447 dec!(1.0000)
448 } else {
449 (cmp_avg / ref_avg).round_dp(4)
450 };
451 let sample = ref_count + cmp_count;
452 metrics.push(PayEquityMetric {
453 id: format!("PE-HR-{:06}", self.counter),
454 entity_id: entity_id.to_string(),
455 period,
456 dimension: DiversityDimension::Gender,
457 reference_group: ref_group.clone(),
458 comparison_group: cmp_group.clone(),
459 reference_median_salary: ref_avg,
460 comparison_median_salary: *cmp_avg,
461 pay_gap_ratio: ratio,
462 sample_size: sample,
463 });
464 }
465
466 metrics
467 }
468}
469
470pub struct GovernanceGenerator {
472 rng: ChaCha8Rng,
473 counter: u64,
474 board_size: u32,
475 independence_target: f64,
476}
477
478impl GovernanceGenerator {
479 pub fn new(seed: u64, board_size: u32, independence_target: f64) -> Self {
481 Self {
482 rng: seeded_rng(seed, 0),
483 counter: 0,
484 board_size: board_size.max(3),
485 independence_target,
486 }
487 }
488
489 pub fn generate(&mut self, entity_id: &str, period: NaiveDate) -> GovernanceMetric {
491 self.counter += 1;
492
493 let ind_frac: f64 = self.rng.random_range(
495 (self.independence_target - 0.10).max(0.0)..(self.independence_target + 0.10).min(1.0),
496 );
497 let independent = (self.board_size as f64 * ind_frac).round() as u32;
498 let independent = independent.min(self.board_size);
499
500 let fem_frac: f64 = self.rng.random_range(0.20..0.40);
502 let female = (self.board_size as f64 * fem_frac).round() as u32;
503 let female = female.min(self.board_size);
504
505 let independence_ratio = if self.board_size > 0 {
506 (Decimal::from(independent) / Decimal::from(self.board_size)).round_dp(4)
507 } else {
508 Decimal::ZERO
509 };
510 let gender_ratio = if self.board_size > 0 {
511 (Decimal::from(female) / Decimal::from(self.board_size)).round_dp(4)
512 } else {
513 Decimal::ZERO
514 };
515
516 let ethics_pct: f64 = self.rng.random_range(0.85..0.99);
517 let whistleblower: u32 = self.rng.random_range(0..5);
518 let anti_corruption: u32 = if self.rng.random::<f64>() < 0.10 {
519 1
520 } else {
521 0
522 };
523
524 GovernanceMetric {
525 id: format!("GV-{:06}", self.counter),
526 entity_id: entity_id.to_string(),
527 period,
528 board_size: self.board_size,
529 independent_directors: independent,
530 female_directors: female,
531 board_independence_ratio: independence_ratio,
532 board_gender_diversity_ratio: gender_ratio,
533 ethics_training_completion_pct: (ethics_pct * 100.0).round() / 100.0,
534 whistleblower_reports: whistleblower,
535 anti_corruption_violations: anti_corruption,
536 }
537 }
538}
539
540#[cfg(test)]
541#[allow(clippy::unwrap_used)]
542mod tests {
543 use super::*;
544
545 fn d(s: &str) -> NaiveDate {
546 NaiveDate::parse_from_str(s, "%Y-%m-%d").unwrap()
547 }
548
549 #[test]
550 fn test_diversity_percentages_sum_to_one() {
551 let config = SocialConfig::default();
552 let mut gen = WorkforceGenerator::new(config, 42);
553 let metrics = gen.generate_diversity("C001", 1000, d("2025-01-01"));
554
555 assert!(!metrics.is_empty());
556
557 let mut groups: std::collections::HashMap<
559 (String, String),
560 Vec<&WorkforceDiversityMetric>,
561 > = std::collections::HashMap::new();
562 for m in &metrics {
563 let key = (format!("{:?}", m.dimension), format!("{:?}", m.level));
564 groups.entry(key).or_default().push(m);
565 }
566
567 for (key, group) in &groups {
568 let total_hc: u32 = group.iter().map(|m| m.headcount).sum();
569 let expected = group[0].total_headcount;
570 assert!(
572 total_hc.abs_diff(expected) <= 1,
573 "Group {:?}: headcount sum {} != expected {}",
574 key,
575 total_hc,
576 expected
577 );
578 }
579 }
580
581 #[test]
582 fn test_pay_equity_ratios() {
583 let config = SocialConfig::default();
584 let mut gen = WorkforceGenerator::new(config, 42);
585 let metrics = gen.generate_pay_equity("C001", d("2025-01-01"));
586
587 assert_eq!(metrics.len(), 4, "Should have 4 comparison pairs");
588 for m in &metrics {
589 assert!(m.pay_gap_ratio > dec!(0.80) && m.pay_gap_ratio < dec!(1.10));
590 assert!(m.reference_median_salary > Decimal::ZERO);
591 assert!(m.comparison_median_salary > Decimal::ZERO);
592 assert!(m.sample_size > 0);
593 }
594 }
595
596 #[test]
597 fn test_safety_incidents() {
598 let config = SocialConfig {
599 safety: datasynth_config::schema::SafetySchemaConfig {
600 enabled: true,
601 target_trir: 2.5,
602 incident_count: 30,
603 },
604 ..Default::default()
605 };
606 let mut gen = WorkforceGenerator::new(config, 42);
607 let incidents = gen.generate_safety_incidents("C001", 3, d("2025-01-01"), d("2025-12-31"));
608
609 assert_eq!(incidents.len(), 30);
610
611 let recordable = incidents.iter().filter(|i| i.is_recordable).count();
612 let near_miss = incidents
613 .iter()
614 .filter(|i| i.incident_type == IncidentType::NearMiss)
615 .count();
616 assert_eq!(recordable + near_miss, 30);
618 }
619
620 #[test]
621 fn test_safety_metric_trir_computation() {
622 let config = SocialConfig::default();
623 let mut gen = WorkforceGenerator::new(config, 42);
624
625 let incidents = vec![
626 SafetyIncident {
627 id: "SI-001".into(),
628 entity_id: "C001".into(),
629 facility_id: "FAC-001".into(),
630 date: d("2025-03-15"),
631 incident_type: IncidentType::Injury,
632 days_away: 5,
633 is_recordable: true,
634 description: "Test".into(),
635 },
636 SafetyIncident {
637 id: "SI-002".into(),
638 entity_id: "C001".into(),
639 facility_id: "FAC-001".into(),
640 date: d("2025-06-20"),
641 incident_type: IncidentType::NearMiss,
642 days_away: 0,
643 is_recordable: false,
644 description: "Test".into(),
645 },
646 ];
647
648 let metric = gen.compute_safety_metrics("C001", &incidents, 500_000, d("2025-01-01"));
649
650 assert_eq!(metric.recordable_incidents, 1);
651 assert_eq!(metric.near_misses, 1);
652 assert_eq!(metric.lost_time_incidents, 1);
653 assert_eq!(metric.days_away, 5);
654 assert_eq!(metric.trir, dec!(0.4000));
656 assert_eq!(metric.computed_trir(), dec!(0.4000));
657 }
658
659 #[test]
660 fn test_governance_generation() {
661 let mut gen = GovernanceGenerator::new(42, 11, 0.67);
662 let metric = gen.generate("C001", d("2025-01-01"));
663
664 assert_eq!(metric.board_size, 11);
665 assert!(metric.independent_directors <= 11);
666 assert!(metric.female_directors <= 11);
667 assert!(metric.board_independence_ratio > Decimal::ZERO);
668 assert!(metric.ethics_training_completion_pct >= 0.85);
669 }
670
671 #[test]
672 fn test_disabled_diversity() {
673 let mut config = SocialConfig::default();
674 config.diversity.enabled = false;
675 let mut gen = WorkforceGenerator::new(config, 42);
676 let metrics = gen.generate_diversity("C001", 1000, d("2025-01-01"));
677 assert!(metrics.is_empty());
678 }
679}