Skip to main content

datasynth_generators/esg/
workforce_generator.rs

1//! Workforce ESG generator — derives diversity metrics, pay equity ratios,
2//! safety incidents, and aggregate safety metrics from employee data.
3use chrono::NaiveDate;
4use datasynth_config::schema::SocialConfig;
5use datasynth_core::models::{
6    DiversityDimension, GovernanceMetric, IncidentType, OrganizationLevel, PayEquityMetric,
7    SafetyIncident, SafetyMetric, WorkforceDiversityMetric,
8};
9use datasynth_core::utils::seeded_rng;
10use rand::prelude::*;
11use rand_chacha::ChaCha8Rng;
12use rust_decimal::Decimal;
13use rust_decimal_macros::dec;
14
15/// Generates workforce diversity, pay equity, and safety metrics.
16pub struct WorkforceGenerator {
17    rng: ChaCha8Rng,
18    config: SocialConfig,
19    counter: u64,
20}
21
22impl WorkforceGenerator {
23    /// Create a new workforce generator.
24    pub fn new(config: SocialConfig, seed: u64) -> Self {
25        Self {
26            rng: seeded_rng(seed, 0),
27            config,
28            counter: 0,
29        }
30    }
31
32    // ----- Diversity -----
33
34    /// Generate workforce diversity metrics for a reporting period.
35    ///
36    /// Produces one record per (dimension × category × level) combination.
37    pub fn generate_diversity(
38        &mut self,
39        entity_id: &str,
40        total_headcount: u32,
41        period: NaiveDate,
42    ) -> Vec<WorkforceDiversityMetric> {
43        if !self.config.diversity.enabled || total_headcount == 0 {
44            return Vec::new();
45        }
46
47        let mut metrics = Vec::new();
48        let levels = [
49            OrganizationLevel::Corporate,
50            OrganizationLevel::Executive,
51            OrganizationLevel::Board,
52        ];
53
54        for dimension in &[
55            DiversityDimension::Gender,
56            DiversityDimension::Ethnicity,
57            DiversityDimension::Age,
58        ] {
59            let categories = self.categories_for(*dimension);
60            // Distribute headcount per level
61            for level in &levels {
62                let level_hc = match level {
63                    OrganizationLevel::Corporate => total_headcount,
64                    OrganizationLevel::Executive => (total_headcount / 50).max(5),
65                    OrganizationLevel::Board => 11,
66                    _ => total_headcount,
67                };
68
69                let shares = self.random_shares(categories.len());
70                for (i, cat) in categories.iter().enumerate() {
71                    self.counter += 1;
72                    let headcount = (Decimal::from(level_hc)
73                        * Decimal::from_f64_retain(shares[i]).unwrap_or(Decimal::ZERO))
74                    .round_dp(0);
75                    let hc = headcount.to_string().parse::<u32>().unwrap_or(0);
76
77                    let percentage = if level_hc > 0 {
78                        (Decimal::from(hc) / Decimal::from(level_hc)).round_dp(4)
79                    } else {
80                        Decimal::ZERO
81                    };
82
83                    metrics.push(WorkforceDiversityMetric {
84                        id: format!("DV-{:06}", self.counter),
85                        entity_id: entity_id.to_string(),
86                        period,
87                        dimension: *dimension,
88                        level: *level,
89                        category: cat.to_string(),
90                        headcount: hc,
91                        total_headcount: level_hc,
92                        percentage,
93                    });
94                }
95            }
96        }
97
98        metrics
99    }
100
101    fn categories_for(&self, dimension: DiversityDimension) -> Vec<&'static str> {
102        match dimension {
103            DiversityDimension::Gender => vec!["Male", "Female", "Non-Binary"],
104            DiversityDimension::Ethnicity => {
105                vec!["White", "Asian", "Hispanic", "Black", "Other"]
106            }
107            DiversityDimension::Age => {
108                vec!["Under 30", "30-50", "Over 50"]
109            }
110            DiversityDimension::Disability => vec!["No Disability", "Has Disability"],
111            DiversityDimension::VeteranStatus => vec!["Non-Veteran", "Veteran"],
112        }
113    }
114
115    /// Generate random shares that sum to 1.0.
116    fn random_shares(&mut self, count: usize) -> Vec<f64> {
117        let mut raw: Vec<f64> = (0..count).map(|_| self.rng.gen::<f64>()).collect();
118        let total: f64 = raw.iter().sum();
119        if total > 0.0 {
120            for v in &mut raw {
121                *v /= total;
122            }
123        }
124        raw
125    }
126
127    // ----- Pay Equity -----
128
129    /// Generate pay equity metrics for common group comparisons.
130    pub fn generate_pay_equity(
131        &mut self,
132        entity_id: &str,
133        period: NaiveDate,
134    ) -> Vec<PayEquityMetric> {
135        if !self.config.pay_equity.enabled {
136            return Vec::new();
137        }
138
139        let comparisons = [
140            (DiversityDimension::Gender, "Male", "Female"),
141            (DiversityDimension::Ethnicity, "White", "Asian"),
142            (DiversityDimension::Ethnicity, "White", "Hispanic"),
143            (DiversityDimension::Ethnicity, "White", "Black"),
144        ];
145
146        comparisons
147            .iter()
148            .map(|(dim, ref_group, cmp_group)| {
149                self.counter += 1;
150                let ref_salary: f64 = self.rng.gen_range(70_000.0..120_000.0);
151                // Pay gap: comparison group earns 85-105% of reference
152                let gap_factor: f64 = self.rng.gen_range(0.85..1.05);
153                let cmp_salary = ref_salary * gap_factor;
154
155                let ref_dec = Decimal::from_f64_retain(ref_salary)
156                    .unwrap_or(dec!(90000))
157                    .round_dp(2);
158                let cmp_dec = Decimal::from_f64_retain(cmp_salary)
159                    .unwrap_or(dec!(85000))
160                    .round_dp(2);
161                let ratio = if ref_dec.is_zero() {
162                    dec!(1.00)
163                } else {
164                    (cmp_dec / ref_dec).round_dp(4)
165                };
166
167                PayEquityMetric {
168                    id: format!("PE-{:06}", self.counter),
169                    entity_id: entity_id.to_string(),
170                    period,
171                    dimension: *dim,
172                    reference_group: ref_group.to_string(),
173                    comparison_group: cmp_group.to_string(),
174                    reference_median_salary: ref_dec,
175                    comparison_median_salary: cmp_dec,
176                    pay_gap_ratio: ratio,
177                    sample_size: self.rng.gen_range(50..500),
178                }
179            })
180            .collect()
181    }
182
183    // ----- Safety -----
184
185    /// Generate safety incidents for a period.
186    pub fn generate_safety_incidents(
187        &mut self,
188        entity_id: &str,
189        facility_count: u32,
190        start_date: NaiveDate,
191        end_date: NaiveDate,
192    ) -> Vec<SafetyIncident> {
193        if !self.config.safety.enabled {
194            return Vec::new();
195        }
196
197        let total_incidents = self.config.safety.incident_count;
198        let period_days = (end_date - start_date).num_days().max(1);
199
200        (0..total_incidents)
201            .map(|_| {
202                self.counter += 1;
203                let day_offset = self.rng.gen_range(0..period_days);
204                let date = start_date + chrono::Duration::days(day_offset);
205                let fac = self.rng.gen_range(1..=facility_count.max(1));
206
207                let incident_type = self.pick_incident_type();
208                let days_away = match incident_type {
209                    IncidentType::Fatality => 0,
210                    IncidentType::NearMiss | IncidentType::PropertyDamage => 0,
211                    IncidentType::Injury => self.rng.gen_range(0..30u32),
212                    IncidentType::Illness => self.rng.gen_range(1..15u32),
213                };
214                let is_recordable = !matches!(incident_type, IncidentType::NearMiss);
215
216                let description = match incident_type {
217                    IncidentType::Injury => "Workplace injury incident".to_string(),
218                    IncidentType::Illness => "Occupational illness reported".to_string(),
219                    IncidentType::NearMiss => "Near miss event documented".to_string(),
220                    IncidentType::Fatality => "Fatal workplace incident".to_string(),
221                    IncidentType::PropertyDamage => "Property damage incident".to_string(),
222                };
223
224                SafetyIncident {
225                    id: format!("SI-{:06}", self.counter),
226                    entity_id: entity_id.to_string(),
227                    facility_id: format!("FAC-{:03}", fac),
228                    date,
229                    incident_type,
230                    days_away,
231                    is_recordable,
232                    description,
233                }
234            })
235            .collect()
236    }
237
238    fn pick_incident_type(&mut self) -> IncidentType {
239        let roll: f64 = self.rng.gen::<f64>();
240        if roll < 0.35 {
241            IncidentType::NearMiss
242        } else if roll < 0.65 {
243            IncidentType::Injury
244        } else if roll < 0.80 {
245            IncidentType::Illness
246        } else if roll < 0.95 {
247            IncidentType::PropertyDamage
248        } else {
249            IncidentType::Fatality
250        }
251    }
252
253    /// Compute aggregate safety metrics from incidents.
254    pub fn compute_safety_metrics(
255        &mut self,
256        entity_id: &str,
257        incidents: &[SafetyIncident],
258        total_hours_worked: u64,
259        period: NaiveDate,
260    ) -> SafetyMetric {
261        self.counter += 1;
262
263        let recordable = incidents.iter().filter(|i| i.is_recordable).count() as u32;
264        let lost_time = incidents.iter().filter(|i| i.days_away > 0).count() as u32;
265        let days_away: u32 = incidents.iter().map(|i| i.days_away).sum();
266        let near_misses = incidents
267            .iter()
268            .filter(|i| i.incident_type == IncidentType::NearMiss)
269            .count() as u32;
270        let fatalities = incidents
271            .iter()
272            .filter(|i| i.incident_type == IncidentType::Fatality)
273            .count() as u32;
274
275        let hours_dec = Decimal::from(total_hours_worked);
276        let base = dec!(200000);
277
278        let trir = if total_hours_worked > 0 {
279            (Decimal::from(recordable) * base / hours_dec).round_dp(4)
280        } else {
281            Decimal::ZERO
282        };
283        let ltir = if total_hours_worked > 0 {
284            (Decimal::from(lost_time) * base / hours_dec).round_dp(4)
285        } else {
286            Decimal::ZERO
287        };
288        let dart_rate = if total_hours_worked > 0 {
289            (Decimal::from(days_away) * base / hours_dec).round_dp(4)
290        } else {
291            Decimal::ZERO
292        };
293
294        SafetyMetric {
295            id: format!("SM-{:06}", self.counter),
296            entity_id: entity_id.to_string(),
297            period,
298            total_hours_worked,
299            recordable_incidents: recordable,
300            lost_time_incidents: lost_time,
301            days_away,
302            near_misses,
303            fatalities,
304            trir,
305            ltir,
306            dart_rate,
307        }
308    }
309}
310
311/// Generates [`GovernanceMetric`] records.
312pub struct GovernanceGenerator {
313    rng: ChaCha8Rng,
314    counter: u64,
315    board_size: u32,
316    independence_target: f64,
317}
318
319impl GovernanceGenerator {
320    /// Create a new governance generator.
321    pub fn new(seed: u64, board_size: u32, independence_target: f64) -> Self {
322        Self {
323            rng: seeded_rng(seed, 0),
324            counter: 0,
325            board_size: board_size.max(3),
326            independence_target,
327        }
328    }
329
330    /// Generate a governance metric for a period.
331    pub fn generate(&mut self, entity_id: &str, period: NaiveDate) -> GovernanceMetric {
332        self.counter += 1;
333
334        // Independent directors: aim near target with some noise
335        let ind_frac: f64 = self.rng.gen_range(
336            (self.independence_target - 0.10).max(0.0)..(self.independence_target + 0.10).min(1.0),
337        );
338        let independent = (self.board_size as f64 * ind_frac).round() as u32;
339        let independent = independent.min(self.board_size);
340
341        // Female directors: 20-40% range
342        let fem_frac: f64 = self.rng.gen_range(0.20..0.40);
343        let female = (self.board_size as f64 * fem_frac).round() as u32;
344        let female = female.min(self.board_size);
345
346        let independence_ratio = if self.board_size > 0 {
347            (Decimal::from(independent) / Decimal::from(self.board_size)).round_dp(4)
348        } else {
349            Decimal::ZERO
350        };
351        let gender_ratio = if self.board_size > 0 {
352            (Decimal::from(female) / Decimal::from(self.board_size)).round_dp(4)
353        } else {
354            Decimal::ZERO
355        };
356
357        let ethics_pct: f64 = self.rng.gen_range(0.85..0.99);
358        let whistleblower: u32 = self.rng.gen_range(0..5);
359        let anti_corruption: u32 = if self.rng.gen::<f64>() < 0.10 { 1 } else { 0 };
360
361        GovernanceMetric {
362            id: format!("GV-{:06}", self.counter),
363            entity_id: entity_id.to_string(),
364            period,
365            board_size: self.board_size,
366            independent_directors: independent,
367            female_directors: female,
368            board_independence_ratio: independence_ratio,
369            board_gender_diversity_ratio: gender_ratio,
370            ethics_training_completion_pct: (ethics_pct * 100.0).round() / 100.0,
371            whistleblower_reports: whistleblower,
372            anti_corruption_violations: anti_corruption,
373        }
374    }
375}
376
377#[cfg(test)]
378#[allow(clippy::unwrap_used)]
379mod tests {
380    use super::*;
381
382    fn d(s: &str) -> NaiveDate {
383        NaiveDate::parse_from_str(s, "%Y-%m-%d").unwrap()
384    }
385
386    #[test]
387    fn test_diversity_percentages_sum_to_one() {
388        let config = SocialConfig::default();
389        let mut gen = WorkforceGenerator::new(config, 42);
390        let metrics = gen.generate_diversity("C001", 1000, d("2025-01-01"));
391
392        assert!(!metrics.is_empty());
393
394        // Group by (dimension, level) and check percentages sum ≈ 1.0
395        let mut groups: std::collections::HashMap<
396            (String, String),
397            Vec<&WorkforceDiversityMetric>,
398        > = std::collections::HashMap::new();
399        for m in &metrics {
400            let key = (format!("{:?}", m.dimension), format!("{:?}", m.level));
401            groups.entry(key).or_default().push(m);
402        }
403
404        for (key, group) in &groups {
405            let total_hc: u32 = group.iter().map(|m| m.headcount).sum();
406            let expected = group[0].total_headcount;
407            // Allow rounding tolerance of ±1
408            assert!(
409                total_hc.abs_diff(expected) <= 1,
410                "Group {:?}: headcount sum {} != expected {}",
411                key,
412                total_hc,
413                expected
414            );
415        }
416    }
417
418    #[test]
419    fn test_pay_equity_ratios() {
420        let config = SocialConfig::default();
421        let mut gen = WorkforceGenerator::new(config, 42);
422        let metrics = gen.generate_pay_equity("C001", d("2025-01-01"));
423
424        assert_eq!(metrics.len(), 4, "Should have 4 comparison pairs");
425        for m in &metrics {
426            assert!(m.pay_gap_ratio > dec!(0.80) && m.pay_gap_ratio < dec!(1.10));
427            assert!(m.reference_median_salary > Decimal::ZERO);
428            assert!(m.comparison_median_salary > Decimal::ZERO);
429            assert!(m.sample_size > 0);
430        }
431    }
432
433    #[test]
434    fn test_safety_incidents() {
435        let config = SocialConfig {
436            safety: datasynth_config::schema::SafetySchemaConfig {
437                enabled: true,
438                target_trir: 2.5,
439                incident_count: 30,
440            },
441            ..Default::default()
442        };
443        let mut gen = WorkforceGenerator::new(config, 42);
444        let incidents = gen.generate_safety_incidents("C001", 3, d("2025-01-01"), d("2025-12-31"));
445
446        assert_eq!(incidents.len(), 30);
447
448        let recordable = incidents.iter().filter(|i| i.is_recordable).count();
449        let near_miss = incidents
450            .iter()
451            .filter(|i| i.incident_type == IncidentType::NearMiss)
452            .count();
453        // Near misses are not recordable
454        assert_eq!(recordable + near_miss, 30);
455    }
456
457    #[test]
458    fn test_safety_metric_trir_computation() {
459        let config = SocialConfig::default();
460        let mut gen = WorkforceGenerator::new(config, 42);
461
462        let incidents = vec![
463            SafetyIncident {
464                id: "SI-001".into(),
465                entity_id: "C001".into(),
466                facility_id: "FAC-001".into(),
467                date: d("2025-03-15"),
468                incident_type: IncidentType::Injury,
469                days_away: 5,
470                is_recordable: true,
471                description: "Test".into(),
472            },
473            SafetyIncident {
474                id: "SI-002".into(),
475                entity_id: "C001".into(),
476                facility_id: "FAC-001".into(),
477                date: d("2025-06-20"),
478                incident_type: IncidentType::NearMiss,
479                days_away: 0,
480                is_recordable: false,
481                description: "Test".into(),
482            },
483        ];
484
485        let metric = gen.compute_safety_metrics("C001", &incidents, 500_000, d("2025-01-01"));
486
487        assert_eq!(metric.recordable_incidents, 1);
488        assert_eq!(metric.near_misses, 1);
489        assert_eq!(metric.lost_time_incidents, 1);
490        assert_eq!(metric.days_away, 5);
491        // TRIR = 1 × 200,000 / 500,000 = 0.4
492        assert_eq!(metric.trir, dec!(0.4000));
493        assert_eq!(metric.computed_trir(), dec!(0.4000));
494    }
495
496    #[test]
497    fn test_governance_generation() {
498        let mut gen = GovernanceGenerator::new(42, 11, 0.67);
499        let metric = gen.generate("C001", d("2025-01-01"));
500
501        assert_eq!(metric.board_size, 11);
502        assert!(metric.independent_directors <= 11);
503        assert!(metric.female_directors <= 11);
504        assert!(metric.board_independence_ratio > Decimal::ZERO);
505        assert!(metric.ethics_training_completion_pct >= 0.85);
506    }
507
508    #[test]
509    fn test_disabled_diversity() {
510        let mut config = SocialConfig::default();
511        config.diversity.enabled = false;
512        let mut gen = WorkforceGenerator::new(config, 42);
513        let metrics = gen.generate_diversity("C001", 1000, d("2025-01-01"));
514        assert!(metrics.is_empty());
515    }
516}