Skip to main content

datasynth_generators/esg/
workforce_generator.rs

1//! Workforce ESG generator — derives diversity metrics, pay equity ratios,
2//! safety incidents, and aggregate safety metrics from employee data.
3use chrono::NaiveDate;
4use datasynth_config::schema::SocialConfig;
5use datasynth_core::models::{
6    DiversityDimension, GovernanceMetric, IncidentType, OrganizationLevel, PayEquityMetric,
7    SafetyIncident, SafetyMetric, WorkforceDiversityMetric,
8};
9use datasynth_core::utils::seeded_rng;
10use rand::prelude::*;
11use rand_chacha::ChaCha8Rng;
12use rust_decimal::Decimal;
13use rust_decimal_macros::dec;
14
15/// Generates workforce diversity, pay equity, and safety metrics.
16pub struct WorkforceGenerator {
17    rng: ChaCha8Rng,
18    config: SocialConfig,
19    counter: u64,
20}
21
22impl WorkforceGenerator {
23    /// Create a new workforce generator.
24    pub fn new(config: SocialConfig, seed: u64) -> Self {
25        Self {
26            rng: seeded_rng(seed, 0),
27            config,
28            counter: 0,
29        }
30    }
31
32    // ----- Diversity -----
33
34    /// Generate workforce diversity metrics for a reporting period.
35    ///
36    /// Produces one record per (dimension × category × level) combination.
37    pub fn generate_diversity(
38        &mut self,
39        entity_id: &str,
40        total_headcount: u32,
41        period: NaiveDate,
42    ) -> Vec<WorkforceDiversityMetric> {
43        if !self.config.diversity.enabled || total_headcount == 0 {
44            return Vec::new();
45        }
46
47        let mut metrics = Vec::new();
48        let levels = [
49            OrganizationLevel::Corporate,
50            OrganizationLevel::Executive,
51            OrganizationLevel::Board,
52        ];
53
54        for dimension in &[
55            DiversityDimension::Gender,
56            DiversityDimension::Ethnicity,
57            DiversityDimension::Age,
58        ] {
59            let categories = self.categories_for(*dimension);
60            // Distribute headcount per level
61            for level in &levels {
62                let level_hc = match level {
63                    OrganizationLevel::Corporate => total_headcount,
64                    OrganizationLevel::Executive => (total_headcount / 50).max(5),
65                    OrganizationLevel::Board => 11,
66                    _ => total_headcount,
67                };
68
69                let shares = self.random_shares(categories.len());
70                for (i, cat) in categories.iter().enumerate() {
71                    self.counter += 1;
72                    let headcount = (Decimal::from(level_hc)
73                        * Decimal::from_f64_retain(shares[i]).unwrap_or(Decimal::ZERO))
74                    .round_dp(0);
75                    let hc = headcount.to_string().parse::<u32>().unwrap_or(0);
76
77                    let percentage = if level_hc > 0 {
78                        (Decimal::from(hc) / Decimal::from(level_hc)).round_dp(4)
79                    } else {
80                        Decimal::ZERO
81                    };
82
83                    metrics.push(WorkforceDiversityMetric {
84                        id: format!("DV-{:06}", self.counter),
85                        entity_id: entity_id.to_string(),
86                        period,
87                        dimension: *dimension,
88                        level: *level,
89                        category: cat.to_string(),
90                        headcount: hc,
91                        total_headcount: level_hc,
92                        percentage,
93                    });
94                }
95            }
96        }
97
98        metrics
99    }
100
101    fn categories_for(&self, dimension: DiversityDimension) -> Vec<&'static str> {
102        match dimension {
103            DiversityDimension::Gender => vec!["Male", "Female", "Non-Binary"],
104            DiversityDimension::Ethnicity => {
105                vec!["White", "Asian", "Hispanic", "Black", "Other"]
106            }
107            DiversityDimension::Age => {
108                vec!["Under 30", "30-50", "Over 50"]
109            }
110            DiversityDimension::Disability => vec!["No Disability", "Has Disability"],
111            DiversityDimension::VeteranStatus => vec!["Non-Veteran", "Veteran"],
112        }
113    }
114
115    /// Generate random shares that sum to 1.0.
116    fn random_shares(&mut self, count: usize) -> Vec<f64> {
117        let mut raw: Vec<f64> = (0..count).map(|_| self.rng.random::<f64>()).collect();
118        let total: f64 = raw.iter().sum();
119        if total > 0.0 {
120            for v in &mut raw {
121                *v /= total;
122            }
123        }
124        raw
125    }
126
127    // ----- Pay Equity -----
128
129    /// Generate pay equity metrics for common group comparisons.
130    pub fn generate_pay_equity(
131        &mut self,
132        entity_id: &str,
133        period: NaiveDate,
134    ) -> Vec<PayEquityMetric> {
135        if !self.config.pay_equity.enabled {
136            return Vec::new();
137        }
138
139        let comparisons = [
140            (DiversityDimension::Gender, "Male", "Female"),
141            (DiversityDimension::Ethnicity, "White", "Asian"),
142            (DiversityDimension::Ethnicity, "White", "Hispanic"),
143            (DiversityDimension::Ethnicity, "White", "Black"),
144        ];
145
146        comparisons
147            .iter()
148            .map(|(dim, ref_group, cmp_group)| {
149                self.counter += 1;
150                let ref_salary: f64 = self.rng.random_range(70_000.0..120_000.0);
151                // Pay gap: comparison group earns 85-105% of reference
152                let gap_factor: f64 = self.rng.random_range(0.85..1.05);
153                let cmp_salary = ref_salary * gap_factor;
154
155                let ref_dec = Decimal::from_f64_retain(ref_salary)
156                    .unwrap_or(dec!(90000))
157                    .round_dp(2);
158                let cmp_dec = Decimal::from_f64_retain(cmp_salary)
159                    .unwrap_or(dec!(85000))
160                    .round_dp(2);
161                let ratio = if ref_dec.is_zero() {
162                    dec!(1.00)
163                } else {
164                    (cmp_dec / ref_dec).round_dp(4)
165                };
166
167                PayEquityMetric {
168                    id: format!("PE-{:06}", self.counter),
169                    entity_id: entity_id.to_string(),
170                    period,
171                    dimension: *dim,
172                    reference_group: ref_group.to_string(),
173                    comparison_group: cmp_group.to_string(),
174                    reference_median_salary: ref_dec,
175                    comparison_median_salary: cmp_dec,
176                    pay_gap_ratio: ratio,
177                    sample_size: self.rng.random_range(50..500),
178                }
179            })
180            .collect()
181    }
182
183    // ----- Safety -----
184
185    /// Generate safety incidents for a period.
186    pub fn generate_safety_incidents(
187        &mut self,
188        entity_id: &str,
189        facility_count: u32,
190        start_date: NaiveDate,
191        end_date: NaiveDate,
192    ) -> Vec<SafetyIncident> {
193        if !self.config.safety.enabled {
194            return Vec::new();
195        }
196
197        let total_incidents = self.config.safety.incident_count;
198        let period_days = (end_date - start_date).num_days().max(1);
199
200        (0..total_incidents)
201            .map(|_| {
202                self.counter += 1;
203                let day_offset = self.rng.random_range(0..period_days);
204                let date = start_date + chrono::Duration::days(day_offset);
205                let fac = self.rng.random_range(1..=facility_count.max(1));
206
207                let incident_type = self.pick_incident_type();
208                let days_away = match incident_type {
209                    IncidentType::Fatality => 0,
210                    IncidentType::NearMiss | IncidentType::PropertyDamage => 0,
211                    IncidentType::Injury => self.rng.random_range(0..30u32),
212                    IncidentType::Illness => self.rng.random_range(1..15u32),
213                };
214                let is_recordable = !matches!(incident_type, IncidentType::NearMiss);
215
216                let description = match incident_type {
217                    IncidentType::Injury => "Workplace injury incident".to_string(),
218                    IncidentType::Illness => "Occupational illness reported".to_string(),
219                    IncidentType::NearMiss => "Near miss event documented".to_string(),
220                    IncidentType::Fatality => "Fatal workplace incident".to_string(),
221                    IncidentType::PropertyDamage => "Property damage incident".to_string(),
222                };
223
224                SafetyIncident {
225                    id: format!("SI-{:06}", self.counter),
226                    entity_id: entity_id.to_string(),
227                    facility_id: format!("FAC-{:03}", fac),
228                    date,
229                    incident_type,
230                    days_away,
231                    is_recordable,
232                    description,
233                }
234            })
235            .collect()
236    }
237
238    fn pick_incident_type(&mut self) -> IncidentType {
239        let roll: f64 = self.rng.random::<f64>();
240        if roll < 0.35 {
241            IncidentType::NearMiss
242        } else if roll < 0.65 {
243            IncidentType::Injury
244        } else if roll < 0.80 {
245            IncidentType::Illness
246        } else if roll < 0.95 {
247            IncidentType::PropertyDamage
248        } else {
249            IncidentType::Fatality
250        }
251    }
252
253    /// Compute aggregate safety metrics from incidents.
254    pub fn compute_safety_metrics(
255        &mut self,
256        entity_id: &str,
257        incidents: &[SafetyIncident],
258        total_hours_worked: u64,
259        period: NaiveDate,
260    ) -> SafetyMetric {
261        self.counter += 1;
262
263        let recordable = incidents.iter().filter(|i| i.is_recordable).count() as u32;
264        let lost_time = incidents.iter().filter(|i| i.days_away > 0).count() as u32;
265        let days_away: u32 = incidents.iter().map(|i| i.days_away).sum();
266        let near_misses = incidents
267            .iter()
268            .filter(|i| i.incident_type == IncidentType::NearMiss)
269            .count() as u32;
270        let fatalities = incidents
271            .iter()
272            .filter(|i| i.incident_type == IncidentType::Fatality)
273            .count() as u32;
274
275        let hours_dec = Decimal::from(total_hours_worked);
276        let base = dec!(200000);
277
278        let trir = if total_hours_worked > 0 {
279            (Decimal::from(recordable) * base / hours_dec).round_dp(4)
280        } else {
281            Decimal::ZERO
282        };
283        let ltir = if total_hours_worked > 0 {
284            (Decimal::from(lost_time) * base / hours_dec).round_dp(4)
285        } else {
286            Decimal::ZERO
287        };
288        let dart_rate = if total_hours_worked > 0 {
289            (Decimal::from(days_away) * base / hours_dec).round_dp(4)
290        } else {
291            Decimal::ZERO
292        };
293
294        SafetyMetric {
295            id: format!("SM-{:06}", self.counter),
296            entity_id: entity_id.to_string(),
297            period,
298            total_hours_worked,
299            recordable_incidents: recordable,
300            lost_time_incidents: lost_time,
301            days_away,
302            near_misses,
303            fatalities,
304            trir,
305            ltir,
306            dart_rate,
307        }
308    }
309}
310
311/// Generates [`GovernanceMetric`] records.
312pub struct GovernanceGenerator {
313    rng: ChaCha8Rng,
314    counter: u64,
315    board_size: u32,
316    independence_target: f64,
317}
318
319impl GovernanceGenerator {
320    /// Create a new governance generator.
321    pub fn new(seed: u64, board_size: u32, independence_target: f64) -> Self {
322        Self {
323            rng: seeded_rng(seed, 0),
324            counter: 0,
325            board_size: board_size.max(3),
326            independence_target,
327        }
328    }
329
330    /// Generate a governance metric for a period.
331    pub fn generate(&mut self, entity_id: &str, period: NaiveDate) -> GovernanceMetric {
332        self.counter += 1;
333
334        // Independent directors: aim near target with some noise
335        let ind_frac: f64 = self.rng.random_range(
336            (self.independence_target - 0.10).max(0.0)..(self.independence_target + 0.10).min(1.0),
337        );
338        let independent = (self.board_size as f64 * ind_frac).round() as u32;
339        let independent = independent.min(self.board_size);
340
341        // Female directors: 20-40% range
342        let fem_frac: f64 = self.rng.random_range(0.20..0.40);
343        let female = (self.board_size as f64 * fem_frac).round() as u32;
344        let female = female.min(self.board_size);
345
346        let independence_ratio = if self.board_size > 0 {
347            (Decimal::from(independent) / Decimal::from(self.board_size)).round_dp(4)
348        } else {
349            Decimal::ZERO
350        };
351        let gender_ratio = if self.board_size > 0 {
352            (Decimal::from(female) / Decimal::from(self.board_size)).round_dp(4)
353        } else {
354            Decimal::ZERO
355        };
356
357        let ethics_pct: f64 = self.rng.random_range(0.85..0.99);
358        let whistleblower: u32 = self.rng.random_range(0..5);
359        let anti_corruption: u32 = if self.rng.random::<f64>() < 0.10 {
360            1
361        } else {
362            0
363        };
364
365        GovernanceMetric {
366            id: format!("GV-{:06}", self.counter),
367            entity_id: entity_id.to_string(),
368            period,
369            board_size: self.board_size,
370            independent_directors: independent,
371            female_directors: female,
372            board_independence_ratio: independence_ratio,
373            board_gender_diversity_ratio: gender_ratio,
374            ethics_training_completion_pct: (ethics_pct * 100.0).round() / 100.0,
375            whistleblower_reports: whistleblower,
376            anti_corruption_violations: anti_corruption,
377        }
378    }
379}
380
381#[cfg(test)]
382#[allow(clippy::unwrap_used)]
383mod tests {
384    use super::*;
385
386    fn d(s: &str) -> NaiveDate {
387        NaiveDate::parse_from_str(s, "%Y-%m-%d").unwrap()
388    }
389
390    #[test]
391    fn test_diversity_percentages_sum_to_one() {
392        let config = SocialConfig::default();
393        let mut gen = WorkforceGenerator::new(config, 42);
394        let metrics = gen.generate_diversity("C001", 1000, d("2025-01-01"));
395
396        assert!(!metrics.is_empty());
397
398        // Group by (dimension, level) and check percentages sum ≈ 1.0
399        let mut groups: std::collections::HashMap<
400            (String, String),
401            Vec<&WorkforceDiversityMetric>,
402        > = std::collections::HashMap::new();
403        for m in &metrics {
404            let key = (format!("{:?}", m.dimension), format!("{:?}", m.level));
405            groups.entry(key).or_default().push(m);
406        }
407
408        for (key, group) in &groups {
409            let total_hc: u32 = group.iter().map(|m| m.headcount).sum();
410            let expected = group[0].total_headcount;
411            // Allow rounding tolerance of ±1
412            assert!(
413                total_hc.abs_diff(expected) <= 1,
414                "Group {:?}: headcount sum {} != expected {}",
415                key,
416                total_hc,
417                expected
418            );
419        }
420    }
421
422    #[test]
423    fn test_pay_equity_ratios() {
424        let config = SocialConfig::default();
425        let mut gen = WorkforceGenerator::new(config, 42);
426        let metrics = gen.generate_pay_equity("C001", d("2025-01-01"));
427
428        assert_eq!(metrics.len(), 4, "Should have 4 comparison pairs");
429        for m in &metrics {
430            assert!(m.pay_gap_ratio > dec!(0.80) && m.pay_gap_ratio < dec!(1.10));
431            assert!(m.reference_median_salary > Decimal::ZERO);
432            assert!(m.comparison_median_salary > Decimal::ZERO);
433            assert!(m.sample_size > 0);
434        }
435    }
436
437    #[test]
438    fn test_safety_incidents() {
439        let config = SocialConfig {
440            safety: datasynth_config::schema::SafetySchemaConfig {
441                enabled: true,
442                target_trir: 2.5,
443                incident_count: 30,
444            },
445            ..Default::default()
446        };
447        let mut gen = WorkforceGenerator::new(config, 42);
448        let incidents = gen.generate_safety_incidents("C001", 3, d("2025-01-01"), d("2025-12-31"));
449
450        assert_eq!(incidents.len(), 30);
451
452        let recordable = incidents.iter().filter(|i| i.is_recordable).count();
453        let near_miss = incidents
454            .iter()
455            .filter(|i| i.incident_type == IncidentType::NearMiss)
456            .count();
457        // Near misses are not recordable
458        assert_eq!(recordable + near_miss, 30);
459    }
460
461    #[test]
462    fn test_safety_metric_trir_computation() {
463        let config = SocialConfig::default();
464        let mut gen = WorkforceGenerator::new(config, 42);
465
466        let incidents = vec![
467            SafetyIncident {
468                id: "SI-001".into(),
469                entity_id: "C001".into(),
470                facility_id: "FAC-001".into(),
471                date: d("2025-03-15"),
472                incident_type: IncidentType::Injury,
473                days_away: 5,
474                is_recordable: true,
475                description: "Test".into(),
476            },
477            SafetyIncident {
478                id: "SI-002".into(),
479                entity_id: "C001".into(),
480                facility_id: "FAC-001".into(),
481                date: d("2025-06-20"),
482                incident_type: IncidentType::NearMiss,
483                days_away: 0,
484                is_recordable: false,
485                description: "Test".into(),
486            },
487        ];
488
489        let metric = gen.compute_safety_metrics("C001", &incidents, 500_000, d("2025-01-01"));
490
491        assert_eq!(metric.recordable_incidents, 1);
492        assert_eq!(metric.near_misses, 1);
493        assert_eq!(metric.lost_time_incidents, 1);
494        assert_eq!(metric.days_away, 5);
495        // TRIR = 1 × 200,000 / 500,000 = 0.4
496        assert_eq!(metric.trir, dec!(0.4000));
497        assert_eq!(metric.computed_trir(), dec!(0.4000));
498    }
499
500    #[test]
501    fn test_governance_generation() {
502        let mut gen = GovernanceGenerator::new(42, 11, 0.67);
503        let metric = gen.generate("C001", d("2025-01-01"));
504
505        assert_eq!(metric.board_size, 11);
506        assert!(metric.independent_directors <= 11);
507        assert!(metric.female_directors <= 11);
508        assert!(metric.board_independence_ratio > Decimal::ZERO);
509        assert!(metric.ethics_training_completion_pct >= 0.85);
510    }
511
512    #[test]
513    fn test_disabled_diversity() {
514        let mut config = SocialConfig::default();
515        config.diversity.enabled = false;
516        let mut gen = WorkforceGenerator::new(config, 42);
517        let metrics = gen.generate_diversity("C001", 1000, d("2025-01-01"));
518        assert!(metrics.is_empty());
519    }
520}