Skip to main content

datasynth_generators/esg/
workforce_generator.rs

1//! Workforce ESG generator — derives diversity metrics, pay equity ratios,
2//! safety incidents, and aggregate safety metrics from employee data.
3
4use chrono::NaiveDate;
5use datasynth_config::schema::SocialConfig;
6use datasynth_core::models::{
7    DiversityDimension, GovernanceMetric, IncidentType, OrganizationLevel, PayEquityMetric,
8    SafetyIncident, SafetyMetric, WorkforceDiversityMetric,
9};
10use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
11use rand::prelude::*;
12use rand_chacha::ChaCha8Rng;
13use rust_decimal::Decimal;
14use rust_decimal_macros::dec;
15
16/// Generates workforce diversity, pay equity, and safety metrics.
17pub struct WorkforceGenerator {
18    rng: ChaCha8Rng,
19    uuid_factory: DeterministicUuidFactory,
20    config: SocialConfig,
21    counter: u64,
22}
23
24impl WorkforceGenerator {
25    /// Create a new workforce generator.
26    pub fn new(seed: u64, config: SocialConfig) -> Self {
27        Self {
28            rng: ChaCha8Rng::seed_from_u64(seed),
29            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::Esg),
30            config,
31            counter: 0,
32        }
33    }
34
35    // ----- Diversity -----
36
37    /// Generate workforce diversity metrics for a reporting period.
38    ///
39    /// Produces one record per (dimension × category × level) combination.
40    pub fn generate_diversity(
41        &mut self,
42        entity_id: &str,
43        total_headcount: u32,
44        period: NaiveDate,
45    ) -> Vec<WorkforceDiversityMetric> {
46        if !self.config.diversity.enabled || total_headcount == 0 {
47            return Vec::new();
48        }
49
50        let mut metrics = Vec::new();
51        let levels = [
52            OrganizationLevel::Corporate,
53            OrganizationLevel::Executive,
54            OrganizationLevel::Board,
55        ];
56
57        for dimension in &[
58            DiversityDimension::Gender,
59            DiversityDimension::Ethnicity,
60            DiversityDimension::Age,
61        ] {
62            let categories = self.categories_for(*dimension);
63            // Distribute headcount per level
64            for level in &levels {
65                let level_hc = match level {
66                    OrganizationLevel::Corporate => total_headcount,
67                    OrganizationLevel::Executive => (total_headcount / 50).max(5),
68                    OrganizationLevel::Board => 11,
69                    _ => total_headcount,
70                };
71
72                let shares = self.random_shares(categories.len());
73                for (i, cat) in categories.iter().enumerate() {
74                    self.counter += 1;
75                    let headcount = (Decimal::from(level_hc)
76                        * Decimal::from_f64_retain(shares[i]).unwrap_or(Decimal::ZERO))
77                    .round_dp(0);
78                    let hc = headcount.to_string().parse::<u32>().unwrap_or(0);
79
80                    let percentage = if level_hc > 0 {
81                        (Decimal::from(hc) / Decimal::from(level_hc)).round_dp(4)
82                    } else {
83                        Decimal::ZERO
84                    };
85
86                    metrics.push(WorkforceDiversityMetric {
87                        id: format!("DV-{:06}", self.counter),
88                        entity_id: entity_id.to_string(),
89                        period,
90                        dimension: *dimension,
91                        level: *level,
92                        category: cat.to_string(),
93                        headcount: hc,
94                        total_headcount: level_hc,
95                        percentage,
96                    });
97                }
98            }
99        }
100
101        metrics
102    }
103
104    fn categories_for(&self, dimension: DiversityDimension) -> Vec<&'static str> {
105        match dimension {
106            DiversityDimension::Gender => vec!["Male", "Female", "Non-Binary"],
107            DiversityDimension::Ethnicity => {
108                vec!["White", "Asian", "Hispanic", "Black", "Other"]
109            }
110            DiversityDimension::Age => {
111                vec!["Under 30", "30-50", "Over 50"]
112            }
113            DiversityDimension::Disability => vec!["No Disability", "Has Disability"],
114            DiversityDimension::VeteranStatus => vec!["Non-Veteran", "Veteran"],
115        }
116    }
117
118    /// Generate random shares that sum to 1.0.
119    fn random_shares(&mut self, count: usize) -> Vec<f64> {
120        let mut raw: Vec<f64> = (0..count).map(|_| self.rng.gen::<f64>()).collect();
121        let total: f64 = raw.iter().sum();
122        if total > 0.0 {
123            for v in &mut raw {
124                *v /= total;
125            }
126        }
127        raw
128    }
129
130    // ----- Pay Equity -----
131
132    /// Generate pay equity metrics for common group comparisons.
133    pub fn generate_pay_equity(
134        &mut self,
135        entity_id: &str,
136        period: NaiveDate,
137    ) -> Vec<PayEquityMetric> {
138        if !self.config.pay_equity.enabled {
139            return Vec::new();
140        }
141
142        let comparisons = [
143            (DiversityDimension::Gender, "Male", "Female"),
144            (DiversityDimension::Ethnicity, "White", "Asian"),
145            (DiversityDimension::Ethnicity, "White", "Hispanic"),
146            (DiversityDimension::Ethnicity, "White", "Black"),
147        ];
148
149        comparisons
150            .iter()
151            .map(|(dim, ref_group, cmp_group)| {
152                self.counter += 1;
153                let ref_salary: f64 = self.rng.gen_range(70_000.0..120_000.0);
154                // Pay gap: comparison group earns 85-105% of reference
155                let gap_factor: f64 = self.rng.gen_range(0.85..1.05);
156                let cmp_salary = ref_salary * gap_factor;
157
158                let ref_dec = Decimal::from_f64_retain(ref_salary)
159                    .unwrap_or(dec!(90000))
160                    .round_dp(2);
161                let cmp_dec = Decimal::from_f64_retain(cmp_salary)
162                    .unwrap_or(dec!(85000))
163                    .round_dp(2);
164                let ratio = if ref_dec.is_zero() {
165                    dec!(1.00)
166                } else {
167                    (cmp_dec / ref_dec).round_dp(4)
168                };
169
170                PayEquityMetric {
171                    id: format!("PE-{:06}", self.counter),
172                    entity_id: entity_id.to_string(),
173                    period,
174                    dimension: *dim,
175                    reference_group: ref_group.to_string(),
176                    comparison_group: cmp_group.to_string(),
177                    reference_median_salary: ref_dec,
178                    comparison_median_salary: cmp_dec,
179                    pay_gap_ratio: ratio,
180                    sample_size: self.rng.gen_range(50..500),
181                }
182            })
183            .collect()
184    }
185
186    // ----- Safety -----
187
188    /// Generate safety incidents for a period.
189    pub fn generate_safety_incidents(
190        &mut self,
191        entity_id: &str,
192        facility_count: u32,
193        start_date: NaiveDate,
194        end_date: NaiveDate,
195    ) -> Vec<SafetyIncident> {
196        if !self.config.safety.enabled {
197            return Vec::new();
198        }
199
200        let total_incidents = self.config.safety.incident_count;
201        let period_days = (end_date - start_date).num_days().max(1);
202
203        (0..total_incidents)
204            .map(|_| {
205                self.counter += 1;
206                let day_offset = self.rng.gen_range(0..period_days);
207                let date = start_date + chrono::Duration::days(day_offset);
208                let fac = self.rng.gen_range(1..=facility_count.max(1));
209
210                let incident_type = self.pick_incident_type();
211                let days_away = match incident_type {
212                    IncidentType::Fatality => 0,
213                    IncidentType::NearMiss | IncidentType::PropertyDamage => 0,
214                    IncidentType::Injury => self.rng.gen_range(0..30u32),
215                    IncidentType::Illness => self.rng.gen_range(1..15u32),
216                };
217                let is_recordable = !matches!(incident_type, IncidentType::NearMiss);
218
219                let description = match incident_type {
220                    IncidentType::Injury => "Workplace injury incident".to_string(),
221                    IncidentType::Illness => "Occupational illness reported".to_string(),
222                    IncidentType::NearMiss => "Near miss event documented".to_string(),
223                    IncidentType::Fatality => "Fatal workplace incident".to_string(),
224                    IncidentType::PropertyDamage => "Property damage incident".to_string(),
225                };
226
227                SafetyIncident {
228                    id: format!("SI-{:06}", self.counter),
229                    entity_id: entity_id.to_string(),
230                    facility_id: format!("FAC-{:03}", fac),
231                    date,
232                    incident_type,
233                    days_away,
234                    is_recordable,
235                    description,
236                }
237            })
238            .collect()
239    }
240
241    fn pick_incident_type(&mut self) -> IncidentType {
242        let roll: f64 = self.rng.gen::<f64>();
243        if roll < 0.35 {
244            IncidentType::NearMiss
245        } else if roll < 0.65 {
246            IncidentType::Injury
247        } else if roll < 0.80 {
248            IncidentType::Illness
249        } else if roll < 0.95 {
250            IncidentType::PropertyDamage
251        } else {
252            IncidentType::Fatality
253        }
254    }
255
256    /// Compute aggregate safety metrics from incidents.
257    pub fn compute_safety_metrics(
258        &mut self,
259        entity_id: &str,
260        incidents: &[SafetyIncident],
261        total_hours_worked: u64,
262        period: NaiveDate,
263    ) -> SafetyMetric {
264        self.counter += 1;
265
266        let recordable = incidents.iter().filter(|i| i.is_recordable).count() as u32;
267        let lost_time = incidents.iter().filter(|i| i.days_away > 0).count() as u32;
268        let days_away: u32 = incidents.iter().map(|i| i.days_away).sum();
269        let near_misses = incidents
270            .iter()
271            .filter(|i| i.incident_type == IncidentType::NearMiss)
272            .count() as u32;
273        let fatalities = incidents
274            .iter()
275            .filter(|i| i.incident_type == IncidentType::Fatality)
276            .count() as u32;
277
278        let hours_dec = Decimal::from(total_hours_worked);
279        let base = dec!(200000);
280
281        let trir = if total_hours_worked > 0 {
282            (Decimal::from(recordable) * base / hours_dec).round_dp(4)
283        } else {
284            Decimal::ZERO
285        };
286        let ltir = if total_hours_worked > 0 {
287            (Decimal::from(lost_time) * base / hours_dec).round_dp(4)
288        } else {
289            Decimal::ZERO
290        };
291        let dart_rate = if total_hours_worked > 0 {
292            (Decimal::from(days_away) * base / hours_dec).round_dp(4)
293        } else {
294            Decimal::ZERO
295        };
296
297        SafetyMetric {
298            id: format!("SM-{:06}", self.counter),
299            entity_id: entity_id.to_string(),
300            period,
301            total_hours_worked,
302            recordable_incidents: recordable,
303            lost_time_incidents: lost_time,
304            days_away,
305            near_misses,
306            fatalities,
307            trir,
308            ltir,
309            dart_rate,
310        }
311    }
312}
313
314/// Generates [`GovernanceMetric`] records.
315pub struct GovernanceGenerator {
316    rng: ChaCha8Rng,
317    counter: u64,
318    board_size: u32,
319    independence_target: f64,
320}
321
322impl GovernanceGenerator {
323    /// Create a new governance generator.
324    pub fn new(seed: u64, board_size: u32, independence_target: f64) -> Self {
325        Self {
326            rng: ChaCha8Rng::seed_from_u64(seed),
327            counter: 0,
328            board_size: board_size.max(3),
329            independence_target,
330        }
331    }
332
333    /// Generate a governance metric for a period.
334    pub fn generate(&mut self, entity_id: &str, period: NaiveDate) -> GovernanceMetric {
335        self.counter += 1;
336
337        // Independent directors: aim near target with some noise
338        let ind_frac: f64 = self.rng.gen_range(
339            (self.independence_target - 0.10).max(0.0)..(self.independence_target + 0.10).min(1.0),
340        );
341        let independent = (self.board_size as f64 * ind_frac).round() as u32;
342        let independent = independent.min(self.board_size);
343
344        // Female directors: 20-40% range
345        let fem_frac: f64 = self.rng.gen_range(0.20..0.40);
346        let female = (self.board_size as f64 * fem_frac).round() as u32;
347        let female = female.min(self.board_size);
348
349        let independence_ratio = if self.board_size > 0 {
350            (Decimal::from(independent) / Decimal::from(self.board_size)).round_dp(4)
351        } else {
352            Decimal::ZERO
353        };
354        let gender_ratio = if self.board_size > 0 {
355            (Decimal::from(female) / Decimal::from(self.board_size)).round_dp(4)
356        } else {
357            Decimal::ZERO
358        };
359
360        let ethics_pct: f64 = self.rng.gen_range(0.85..0.99);
361        let whistleblower: u32 = self.rng.gen_range(0..5);
362        let anti_corruption: u32 = if self.rng.gen::<f64>() < 0.10 { 1 } else { 0 };
363
364        GovernanceMetric {
365            id: format!("GV-{:06}", self.counter),
366            entity_id: entity_id.to_string(),
367            period,
368            board_size: self.board_size,
369            independent_directors: independent,
370            female_directors: female,
371            board_independence_ratio: independence_ratio,
372            board_gender_diversity_ratio: gender_ratio,
373            ethics_training_completion_pct: (ethics_pct * 100.0).round() / 100.0,
374            whistleblower_reports: whistleblower,
375            anti_corruption_violations: anti_corruption,
376        }
377    }
378}
379
380#[cfg(test)]
381#[allow(clippy::unwrap_used)]
382mod tests {
383    use super::*;
384
385    fn d(s: &str) -> NaiveDate {
386        NaiveDate::parse_from_str(s, "%Y-%m-%d").unwrap()
387    }
388
389    #[test]
390    fn test_diversity_percentages_sum_to_one() {
391        let config = SocialConfig::default();
392        let mut gen = WorkforceGenerator::new(42, config);
393        let metrics = gen.generate_diversity("C001", 1000, d("2025-01-01"));
394
395        assert!(!metrics.is_empty());
396
397        // Group by (dimension, level) and check percentages sum ≈ 1.0
398        let mut groups: std::collections::HashMap<
399            (String, String),
400            Vec<&WorkforceDiversityMetric>,
401        > = std::collections::HashMap::new();
402        for m in &metrics {
403            let key = (format!("{:?}", m.dimension), format!("{:?}", m.level));
404            groups.entry(key).or_default().push(m);
405        }
406
407        for (key, group) in &groups {
408            let total_hc: u32 = group.iter().map(|m| m.headcount).sum();
409            let expected = group[0].total_headcount;
410            // Allow rounding tolerance of ±1
411            assert!(
412                total_hc.abs_diff(expected) <= 1,
413                "Group {:?}: headcount sum {} != expected {}",
414                key,
415                total_hc,
416                expected
417            );
418        }
419    }
420
421    #[test]
422    fn test_pay_equity_ratios() {
423        let config = SocialConfig::default();
424        let mut gen = WorkforceGenerator::new(42, config);
425        let metrics = gen.generate_pay_equity("C001", d("2025-01-01"));
426
427        assert_eq!(metrics.len(), 4, "Should have 4 comparison pairs");
428        for m in &metrics {
429            assert!(m.pay_gap_ratio > dec!(0.80) && m.pay_gap_ratio < dec!(1.10));
430            assert!(m.reference_median_salary > Decimal::ZERO);
431            assert!(m.comparison_median_salary > Decimal::ZERO);
432            assert!(m.sample_size > 0);
433        }
434    }
435
436    #[test]
437    fn test_safety_incidents() {
438        let config = SocialConfig {
439            safety: datasynth_config::schema::SafetySchemaConfig {
440                enabled: true,
441                target_trir: 2.5,
442                incident_count: 30,
443            },
444            ..Default::default()
445        };
446        let mut gen = WorkforceGenerator::new(42, config);
447        let incidents = gen.generate_safety_incidents("C001", 3, d("2025-01-01"), d("2025-12-31"));
448
449        assert_eq!(incidents.len(), 30);
450
451        let recordable = incidents.iter().filter(|i| i.is_recordable).count();
452        let near_miss = incidents
453            .iter()
454            .filter(|i| i.incident_type == IncidentType::NearMiss)
455            .count();
456        // Near misses are not recordable
457        assert_eq!(recordable + near_miss, 30);
458    }
459
460    #[test]
461    fn test_safety_metric_trir_computation() {
462        let config = SocialConfig::default();
463        let mut gen = WorkforceGenerator::new(42, config);
464
465        let incidents = vec![
466            SafetyIncident {
467                id: "SI-001".into(),
468                entity_id: "C001".into(),
469                facility_id: "FAC-001".into(),
470                date: d("2025-03-15"),
471                incident_type: IncidentType::Injury,
472                days_away: 5,
473                is_recordable: true,
474                description: "Test".into(),
475            },
476            SafetyIncident {
477                id: "SI-002".into(),
478                entity_id: "C001".into(),
479                facility_id: "FAC-001".into(),
480                date: d("2025-06-20"),
481                incident_type: IncidentType::NearMiss,
482                days_away: 0,
483                is_recordable: false,
484                description: "Test".into(),
485            },
486        ];
487
488        let metric = gen.compute_safety_metrics("C001", &incidents, 500_000, d("2025-01-01"));
489
490        assert_eq!(metric.recordable_incidents, 1);
491        assert_eq!(metric.near_misses, 1);
492        assert_eq!(metric.lost_time_incidents, 1);
493        assert_eq!(metric.days_away, 5);
494        // TRIR = 1 × 200,000 / 500,000 = 0.4
495        assert_eq!(metric.trir, dec!(0.4000));
496        assert_eq!(metric.computed_trir(), dec!(0.4000));
497    }
498
499    #[test]
500    fn test_governance_generation() {
501        let mut gen = GovernanceGenerator::new(42, 11, 0.67);
502        let metric = gen.generate("C001", d("2025-01-01"));
503
504        assert_eq!(metric.board_size, 11);
505        assert!(metric.independent_directors <= 11);
506        assert!(metric.female_directors <= 11);
507        assert!(metric.board_independence_ratio > Decimal::ZERO);
508        assert!(metric.ethics_training_completion_pct >= 0.85);
509    }
510
511    #[test]
512    fn test_disabled_diversity() {
513        let mut config = SocialConfig::default();
514        config.diversity.enabled = false;
515        let mut gen = WorkforceGenerator::new(42, config);
516        let metrics = gen.generate_diversity("C001", 1000, d("2025-01-01"));
517        assert!(metrics.is_empty());
518    }
519}