Skip to main content

datasynth_generators/esg/
workforce_generator.rs

1//! Workforce ESG generator — derives diversity metrics, pay equity ratios,
2//! safety incidents, and aggregate safety metrics from employee data.
3use chrono::NaiveDate;
4use datasynth_config::schema::SocialConfig;
5use datasynth_core::models::{
6    DiversityDimension, Employee, GovernanceMetric, IncidentType, OrganizationLevel,
7    PayEquityMetric, PayrollLineItem, SafetyIncident, SafetyMetric, WorkforceDiversityMetric,
8};
9use datasynth_core::utils::seeded_rng;
10use rand::prelude::*;
11use rand_chacha::ChaCha8Rng;
12use rust_decimal::Decimal;
13use rust_decimal_macros::dec;
14
15/// Generates workforce diversity, pay equity, and safety metrics.
16pub struct WorkforceGenerator {
17    rng: ChaCha8Rng,
18    config: SocialConfig,
19    counter: u64,
20}
21
22impl WorkforceGenerator {
23    /// Create a new workforce generator.
24    pub fn new(config: SocialConfig, seed: u64) -> Self {
25        Self {
26            rng: seeded_rng(seed, 0),
27            config,
28            counter: 0,
29        }
30    }
31
32    // ----- Diversity -----
33
34    /// Generate workforce diversity metrics for a reporting period.
35    ///
36    /// Produces one record per (dimension × category × level) combination.
37    pub fn generate_diversity(
38        &mut self,
39        entity_id: &str,
40        total_headcount: u32,
41        period: NaiveDate,
42    ) -> Vec<WorkforceDiversityMetric> {
43        if !self.config.diversity.enabled || total_headcount == 0 {
44            return Vec::new();
45        }
46
47        let mut metrics = Vec::new();
48        let levels = [
49            OrganizationLevel::Corporate,
50            OrganizationLevel::Executive,
51            OrganizationLevel::Board,
52        ];
53
54        for dimension in &[
55            DiversityDimension::Gender,
56            DiversityDimension::Ethnicity,
57            DiversityDimension::Age,
58        ] {
59            let categories = self.categories_for(*dimension);
60            // Distribute headcount per level
61            for level in &levels {
62                let level_hc = match level {
63                    OrganizationLevel::Corporate => total_headcount,
64                    OrganizationLevel::Executive => (total_headcount / 50).max(5),
65                    OrganizationLevel::Board => 11,
66                    _ => total_headcount,
67                };
68
69                let shares = self.random_shares(categories.len());
70                for (i, cat) in categories.iter().enumerate() {
71                    self.counter += 1;
72                    let headcount = (Decimal::from(level_hc)
73                        * Decimal::from_f64_retain(shares[i]).unwrap_or(Decimal::ZERO))
74                    .round_dp(0);
75                    let hc = headcount.to_string().parse::<u32>().unwrap_or(0);
76
77                    let percentage = if level_hc > 0 {
78                        (Decimal::from(hc) / Decimal::from(level_hc)).round_dp(4)
79                    } else {
80                        Decimal::ZERO
81                    };
82
83                    metrics.push(WorkforceDiversityMetric {
84                        id: format!("DV-{:06}", self.counter),
85                        entity_id: entity_id.to_string(),
86                        period,
87                        dimension: *dimension,
88                        level: *level,
89                        category: cat.to_string(),
90                        headcount: hc,
91                        total_headcount: level_hc,
92                        percentage,
93                    });
94                }
95            }
96        }
97
98        metrics
99    }
100
101    fn categories_for(&self, dimension: DiversityDimension) -> Vec<&'static str> {
102        match dimension {
103            DiversityDimension::Gender => vec!["Male", "Female", "Non-Binary"],
104            DiversityDimension::Ethnicity => {
105                vec!["White", "Asian", "Hispanic", "Black", "Other"]
106            }
107            DiversityDimension::Age => {
108                vec!["Under 30", "30-50", "Over 50"]
109            }
110            DiversityDimension::Disability => vec!["No Disability", "Has Disability"],
111            DiversityDimension::VeteranStatus => vec!["Non-Veteran", "Veteran"],
112        }
113    }
114
115    /// Generate random shares that sum to 1.0.
116    fn random_shares(&mut self, count: usize) -> Vec<f64> {
117        let mut raw: Vec<f64> = (0..count).map(|_| self.rng.random::<f64>()).collect();
118        let total: f64 = raw.iter().sum();
119        if total > 0.0 {
120            for v in &mut raw {
121                *v /= total;
122            }
123        }
124        raw
125    }
126
127    // ----- Pay Equity -----
128
129    /// Generate pay equity metrics for common group comparisons.
130    pub fn generate_pay_equity(
131        &mut self,
132        entity_id: &str,
133        period: NaiveDate,
134    ) -> Vec<PayEquityMetric> {
135        if !self.config.pay_equity.enabled {
136            return Vec::new();
137        }
138
139        let comparisons = [
140            (DiversityDimension::Gender, "Male", "Female"),
141            (DiversityDimension::Ethnicity, "White", "Asian"),
142            (DiversityDimension::Ethnicity, "White", "Hispanic"),
143            (DiversityDimension::Ethnicity, "White", "Black"),
144        ];
145
146        comparisons
147            .iter()
148            .map(|(dim, ref_group, cmp_group)| {
149                self.counter += 1;
150                let ref_salary: f64 = self.rng.random_range(70_000.0..120_000.0);
151                // Pay gap: comparison group earns 85-105% of reference
152                let gap_factor: f64 = self.rng.random_range(0.85..1.05);
153                let cmp_salary = ref_salary * gap_factor;
154
155                let ref_dec = Decimal::from_f64_retain(ref_salary)
156                    .unwrap_or(dec!(90000))
157                    .round_dp(2);
158                let cmp_dec = Decimal::from_f64_retain(cmp_salary)
159                    .unwrap_or(dec!(85000))
160                    .round_dp(2);
161                let ratio = if ref_dec.is_zero() {
162                    dec!(1.00)
163                } else {
164                    (cmp_dec / ref_dec).round_dp(4)
165                };
166
167                PayEquityMetric {
168                    id: format!("PE-{:06}", self.counter),
169                    entity_id: entity_id.to_string(),
170                    period,
171                    dimension: *dim,
172                    reference_group: ref_group.to_string(),
173                    comparison_group: cmp_group.to_string(),
174                    reference_median_salary: ref_dec,
175                    comparison_median_salary: cmp_dec,
176                    pay_gap_ratio: ratio,
177                    sample_size: self.rng.random_range(50..500),
178                }
179            })
180            .collect()
181    }
182
183    // ----- Safety -----
184
185    /// Generate safety incidents for a period.
186    pub fn generate_safety_incidents(
187        &mut self,
188        entity_id: &str,
189        facility_count: u32,
190        start_date: NaiveDate,
191        end_date: NaiveDate,
192    ) -> Vec<SafetyIncident> {
193        if !self.config.safety.enabled {
194            return Vec::new();
195        }
196
197        let total_incidents = self.config.safety.incident_count;
198        let period_days = (end_date - start_date).num_days().max(1);
199
200        (0..total_incidents)
201            .map(|_| {
202                self.counter += 1;
203                let day_offset = self.rng.random_range(0..period_days);
204                let date = start_date + chrono::Duration::days(day_offset);
205                let fac = self.rng.random_range(1..=facility_count.max(1));
206
207                let incident_type = self.pick_incident_type();
208                let days_away = match incident_type {
209                    IncidentType::Fatality => 0,
210                    IncidentType::NearMiss | IncidentType::PropertyDamage => 0,
211                    IncidentType::Injury => self.rng.random_range(0..30u32),
212                    IncidentType::Illness => self.rng.random_range(1..15u32),
213                };
214                let is_recordable = !matches!(incident_type, IncidentType::NearMiss);
215
216                let description = match incident_type {
217                    IncidentType::Injury => "Workplace injury incident".to_string(),
218                    IncidentType::Illness => "Occupational illness reported".to_string(),
219                    IncidentType::NearMiss => "Near miss event documented".to_string(),
220                    IncidentType::Fatality => "Fatal workplace incident".to_string(),
221                    IncidentType::PropertyDamage => "Property damage incident".to_string(),
222                };
223
224                SafetyIncident {
225                    id: format!("SI-{:06}", self.counter),
226                    entity_id: entity_id.to_string(),
227                    facility_id: format!("FAC-{fac:03}"),
228                    date,
229                    incident_type,
230                    days_away,
231                    is_recordable,
232                    description,
233                }
234            })
235            .collect()
236    }
237
238    fn pick_incident_type(&mut self) -> IncidentType {
239        let roll: f64 = self.rng.random::<f64>();
240        if roll < 0.35 {
241            IncidentType::NearMiss
242        } else if roll < 0.65 {
243            IncidentType::Injury
244        } else if roll < 0.80 {
245            IncidentType::Illness
246        } else if roll < 0.95 {
247            IncidentType::PropertyDamage
248        } else {
249            IncidentType::Fatality
250        }
251    }
252
253    /// Compute aggregate safety metrics from incidents.
254    pub fn compute_safety_metrics(
255        &mut self,
256        entity_id: &str,
257        incidents: &[SafetyIncident],
258        total_hours_worked: u64,
259        period: NaiveDate,
260    ) -> SafetyMetric {
261        self.counter += 1;
262
263        let recordable = incidents.iter().filter(|i| i.is_recordable).count() as u32;
264        let lost_time = incidents.iter().filter(|i| i.days_away > 0).count() as u32;
265        let days_away: u32 = incidents.iter().map(|i| i.days_away).sum();
266        let near_misses = incidents
267            .iter()
268            .filter(|i| i.incident_type == IncidentType::NearMiss)
269            .count() as u32;
270        let fatalities = incidents
271            .iter()
272            .filter(|i| i.incident_type == IncidentType::Fatality)
273            .count() as u32;
274
275        let hours_dec = Decimal::from(total_hours_worked);
276        let base = dec!(200000);
277
278        let trir = if total_hours_worked > 0 {
279            (Decimal::from(recordable) * base / hours_dec).round_dp(4)
280        } else {
281            Decimal::ZERO
282        };
283        let ltir = if total_hours_worked > 0 {
284            (Decimal::from(lost_time) * base / hours_dec).round_dp(4)
285        } else {
286            Decimal::ZERO
287        };
288        let dart_rate = if total_hours_worked > 0 {
289            (Decimal::from(days_away) * base / hours_dec).round_dp(4)
290        } else {
291            Decimal::ZERO
292        };
293
294        SafetyMetric {
295            id: format!("SM-{:06}", self.counter),
296            entity_id: entity_id.to_string(),
297            period,
298            total_hours_worked,
299            recordable_incidents: recordable,
300            lost_time_incidents: lost_time,
301            days_away,
302            near_misses,
303            fatalities,
304            trir,
305            ltir,
306            dart_rate,
307        }
308    }
309
310    // ----- HR bridge methods -----
311
312    /// Derive workforce diversity metrics from actual `Employee` records.
313    ///
314    /// Groups employees by `department_id` (falling back to `"Unknown"` when
315    /// the field is absent) and produces one [`WorkforceDiversityMetric`] per
316    /// department showing that department's share of the total headcount.
317    ///
318    /// The `Employee` model does not carry an explicit gender field, so
319    /// `DiversityDimension::Gender` is used as the primary dimension while the
320    /// `category` field stores the department identifier — preserving the ESG
321    /// schema while surfacing real organisational distribution data.
322    ///
323    /// An additional "total" record at [`OrganizationLevel::Corporate`] is
324    /// emitted so downstream consumers can verify the headcount roll-up.
325    pub fn generate_diversity_from_employees(
326        &mut self,
327        entity_id: &str,
328        employees: &[Employee],
329        period: NaiveDate,
330    ) -> Vec<WorkforceDiversityMetric> {
331        if employees.is_empty() {
332            return Vec::new();
333        }
334
335        let total_headcount = employees.len() as u32;
336
337        // --- Count employees per department ---
338        let mut dept_counts: std::collections::HashMap<String, u32> =
339            std::collections::HashMap::new();
340        for emp in employees {
341            let dept = emp
342                .department_id
343                .clone()
344                .unwrap_or_else(|| "Unknown".to_string());
345            *dept_counts.entry(dept).or_insert(0) += 1;
346        }
347
348        let mut metrics = Vec::new();
349
350        // One record per department at Department level
351        let mut sorted_depts: Vec<(String, u32)> = dept_counts.into_iter().collect();
352        sorted_depts.sort_by(|a, b| a.0.cmp(&b.0));
353
354        for (dept, count) in &sorted_depts {
355            self.counter += 1;
356            let percentage = (Decimal::from(*count) / Decimal::from(total_headcount)).round_dp(4);
357            metrics.push(WorkforceDiversityMetric {
358                id: format!("DV-HR-{:06}", self.counter),
359                entity_id: entity_id.to_string(),
360                period,
361                dimension: DiversityDimension::Gender,
362                level: OrganizationLevel::Department,
363                category: dept.clone(),
364                headcount: *count,
365                total_headcount,
366                percentage,
367            });
368        }
369
370        // Corporate-level total record (all employees, one bucket)
371        self.counter += 1;
372        metrics.push(WorkforceDiversityMetric {
373            id: format!("DV-HR-{:06}", self.counter),
374            entity_id: entity_id.to_string(),
375            period,
376            dimension: DiversityDimension::Gender,
377            level: OrganizationLevel::Corporate,
378            category: "All".to_string(),
379            headcount: total_headcount,
380            total_headcount,
381            percentage: dec!(1.0000),
382        });
383
384        metrics
385    }
386
387    /// Derive pay equity metrics from actual `PayrollLineItem` records.
388    ///
389    /// Groups line items by `department` (falling back to `cost_center`, then
390    /// `"Unknown"`) and computes the average `gross_pay` for each group.
391    /// Produces one [`PayEquityMetric`] per non-baseline group, comparing its
392    /// average pay against the group with the highest average pay (the
393    /// reference group).
394    ///
395    /// Returns an empty `Vec` when fewer than two distinct groups are found
396    /// (no meaningful comparison is possible).
397    pub fn generate_pay_equity_from_payroll(
398        &mut self,
399        entity_id: &str,
400        payroll_items: &[PayrollLineItem],
401        period: NaiveDate,
402    ) -> Vec<PayEquityMetric> {
403        if payroll_items.is_empty() {
404            return Vec::new();
405        }
406
407        // --- Group gross_pay by department / cost_center ---
408        let mut group_totals: std::collections::HashMap<String, (Decimal, u32)> =
409            std::collections::HashMap::new();
410        for item in payroll_items {
411            let group = item
412                .department
413                .clone()
414                .or_else(|| item.cost_center.clone())
415                .unwrap_or_else(|| "Unknown".to_string());
416            let entry = group_totals.entry(group).or_insert((Decimal::ZERO, 0));
417            entry.0 += item.gross_pay;
418            entry.1 += 1;
419        }
420
421        if group_totals.len() < 2 {
422            return Vec::new();
423        }
424
425        // Compute average gross_pay per group
426        let mut averages: Vec<(String, Decimal, u32)> = group_totals
427            .into_iter()
428            .map(|(g, (total, count))| {
429                let avg = if count > 0 {
430                    (total / Decimal::from(count)).round_dp(2)
431                } else {
432                    Decimal::ZERO
433                };
434                (g, avg, count)
435            })
436            .collect();
437
438        // Sort deterministically and pick highest-average group as reference
439        averages.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
440
441        let (ref_group, ref_avg, ref_count) = averages[0].clone();
442
443        let mut metrics = Vec::new();
444        for (cmp_group, cmp_avg, cmp_count) in averages.iter().skip(1) {
445            self.counter += 1;
446            let ratio = if ref_avg.is_zero() {
447                dec!(1.0000)
448            } else {
449                (cmp_avg / ref_avg).round_dp(4)
450            };
451            let sample = ref_count + cmp_count;
452            metrics.push(PayEquityMetric {
453                id: format!("PE-HR-{:06}", self.counter),
454                entity_id: entity_id.to_string(),
455                period,
456                dimension: DiversityDimension::Gender,
457                reference_group: ref_group.clone(),
458                comparison_group: cmp_group.clone(),
459                reference_median_salary: ref_avg,
460                comparison_median_salary: *cmp_avg,
461                pay_gap_ratio: ratio,
462                sample_size: sample,
463            });
464        }
465
466        metrics
467    }
468}
469
470/// Generates [`GovernanceMetric`] records.
471pub struct GovernanceGenerator {
472    rng: ChaCha8Rng,
473    counter: u64,
474    board_size: u32,
475    independence_target: f64,
476}
477
478impl GovernanceGenerator {
479    /// Create a new governance generator.
480    pub fn new(seed: u64, board_size: u32, independence_target: f64) -> Self {
481        Self {
482            rng: seeded_rng(seed, 0),
483            counter: 0,
484            board_size: board_size.max(3),
485            independence_target,
486        }
487    }
488
489    /// Generate a governance metric for a period.
490    pub fn generate(&mut self, entity_id: &str, period: NaiveDate) -> GovernanceMetric {
491        self.counter += 1;
492
493        // Independent directors: aim near target with some noise
494        let ind_frac: f64 = self.rng.random_range(
495            (self.independence_target - 0.10).max(0.0)..(self.independence_target + 0.10).min(1.0),
496        );
497        let independent = (self.board_size as f64 * ind_frac).round() as u32;
498        let independent = independent.min(self.board_size);
499
500        // Female directors: 20-40% range
501        let fem_frac: f64 = self.rng.random_range(0.20..0.40);
502        let female = (self.board_size as f64 * fem_frac).round() as u32;
503        let female = female.min(self.board_size);
504
505        let independence_ratio = if self.board_size > 0 {
506            (Decimal::from(independent) / Decimal::from(self.board_size)).round_dp(4)
507        } else {
508            Decimal::ZERO
509        };
510        let gender_ratio = if self.board_size > 0 {
511            (Decimal::from(female) / Decimal::from(self.board_size)).round_dp(4)
512        } else {
513            Decimal::ZERO
514        };
515
516        let ethics_pct: f64 = self.rng.random_range(0.85..0.99);
517        let whistleblower: u32 = self.rng.random_range(0..5);
518        let anti_corruption: u32 = if self.rng.random::<f64>() < 0.10 {
519            1
520        } else {
521            0
522        };
523
524        GovernanceMetric {
525            id: format!("GV-{:06}", self.counter),
526            entity_id: entity_id.to_string(),
527            period,
528            board_size: self.board_size,
529            independent_directors: independent,
530            female_directors: female,
531            board_independence_ratio: independence_ratio,
532            board_gender_diversity_ratio: gender_ratio,
533            ethics_training_completion_pct: (ethics_pct * 100.0).round() / 100.0,
534            whistleblower_reports: whistleblower,
535            anti_corruption_violations: anti_corruption,
536        }
537    }
538}
539
540#[cfg(test)]
541#[allow(clippy::unwrap_used)]
542mod tests {
543    use super::*;
544
545    fn d(s: &str) -> NaiveDate {
546        NaiveDate::parse_from_str(s, "%Y-%m-%d").unwrap()
547    }
548
549    #[test]
550    fn test_diversity_percentages_sum_to_one() {
551        let config = SocialConfig::default();
552        let mut gen = WorkforceGenerator::new(config, 42);
553        let metrics = gen.generate_diversity("C001", 1000, d("2025-01-01"));
554
555        assert!(!metrics.is_empty());
556
557        // Group by (dimension, level) and check percentages sum ≈ 1.0
558        let mut groups: std::collections::HashMap<
559            (String, String),
560            Vec<&WorkforceDiversityMetric>,
561        > = std::collections::HashMap::new();
562        for m in &metrics {
563            let key = (format!("{:?}", m.dimension), format!("{:?}", m.level));
564            groups.entry(key).or_default().push(m);
565        }
566
567        for (key, group) in &groups {
568            let total_hc: u32 = group.iter().map(|m| m.headcount).sum();
569            let expected = group[0].total_headcount;
570            // Allow rounding tolerance of ±1
571            assert!(
572                total_hc.abs_diff(expected) <= 1,
573                "Group {:?}: headcount sum {} != expected {}",
574                key,
575                total_hc,
576                expected
577            );
578        }
579    }
580
581    #[test]
582    fn test_pay_equity_ratios() {
583        let config = SocialConfig::default();
584        let mut gen = WorkforceGenerator::new(config, 42);
585        let metrics = gen.generate_pay_equity("C001", d("2025-01-01"));
586
587        assert_eq!(metrics.len(), 4, "Should have 4 comparison pairs");
588        for m in &metrics {
589            assert!(m.pay_gap_ratio > dec!(0.80) && m.pay_gap_ratio < dec!(1.10));
590            assert!(m.reference_median_salary > Decimal::ZERO);
591            assert!(m.comparison_median_salary > Decimal::ZERO);
592            assert!(m.sample_size > 0);
593        }
594    }
595
596    #[test]
597    fn test_safety_incidents() {
598        let config = SocialConfig {
599            safety: datasynth_config::schema::SafetySchemaConfig {
600                enabled: true,
601                target_trir: 2.5,
602                incident_count: 30,
603            },
604            ..Default::default()
605        };
606        let mut gen = WorkforceGenerator::new(config, 42);
607        let incidents = gen.generate_safety_incidents("C001", 3, d("2025-01-01"), d("2025-12-31"));
608
609        assert_eq!(incidents.len(), 30);
610
611        let recordable = incidents.iter().filter(|i| i.is_recordable).count();
612        let near_miss = incidents
613            .iter()
614            .filter(|i| i.incident_type == IncidentType::NearMiss)
615            .count();
616        // Near misses are not recordable
617        assert_eq!(recordable + near_miss, 30);
618    }
619
620    #[test]
621    fn test_safety_metric_trir_computation() {
622        let config = SocialConfig::default();
623        let mut gen = WorkforceGenerator::new(config, 42);
624
625        let incidents = vec![
626            SafetyIncident {
627                id: "SI-001".into(),
628                entity_id: "C001".into(),
629                facility_id: "FAC-001".into(),
630                date: d("2025-03-15"),
631                incident_type: IncidentType::Injury,
632                days_away: 5,
633                is_recordable: true,
634                description: "Test".into(),
635            },
636            SafetyIncident {
637                id: "SI-002".into(),
638                entity_id: "C001".into(),
639                facility_id: "FAC-001".into(),
640                date: d("2025-06-20"),
641                incident_type: IncidentType::NearMiss,
642                days_away: 0,
643                is_recordable: false,
644                description: "Test".into(),
645            },
646        ];
647
648        let metric = gen.compute_safety_metrics("C001", &incidents, 500_000, d("2025-01-01"));
649
650        assert_eq!(metric.recordable_incidents, 1);
651        assert_eq!(metric.near_misses, 1);
652        assert_eq!(metric.lost_time_incidents, 1);
653        assert_eq!(metric.days_away, 5);
654        // TRIR = 1 × 200,000 / 500,000 = 0.4
655        assert_eq!(metric.trir, dec!(0.4000));
656        assert_eq!(metric.computed_trir(), dec!(0.4000));
657    }
658
659    #[test]
660    fn test_governance_generation() {
661        let mut gen = GovernanceGenerator::new(42, 11, 0.67);
662        let metric = gen.generate("C001", d("2025-01-01"));
663
664        assert_eq!(metric.board_size, 11);
665        assert!(metric.independent_directors <= 11);
666        assert!(metric.female_directors <= 11);
667        assert!(metric.board_independence_ratio > Decimal::ZERO);
668        assert!(metric.ethics_training_completion_pct >= 0.85);
669    }
670
671    #[test]
672    fn test_disabled_diversity() {
673        let mut config = SocialConfig::default();
674        config.diversity.enabled = false;
675        let mut gen = WorkforceGenerator::new(config, 42);
676        let metrics = gen.generate_diversity("C001", 1000, d("2025-01-01"));
677        assert!(metrics.is_empty());
678    }
679}