1use chrono::NaiveDate;
5use datasynth_config::schema::SocialConfig;
6use datasynth_core::models::{
7 DiversityDimension, GovernanceMetric, IncidentType, OrganizationLevel, PayEquityMetric,
8 SafetyIncident, SafetyMetric, WorkforceDiversityMetric,
9};
10use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
11use rand::prelude::*;
12use rand_chacha::ChaCha8Rng;
13use rust_decimal::Decimal;
14use rust_decimal_macros::dec;
15
16pub struct WorkforceGenerator {
18 rng: ChaCha8Rng,
19 uuid_factory: DeterministicUuidFactory,
20 config: SocialConfig,
21 counter: u64,
22}
23
24impl WorkforceGenerator {
25 pub fn new(seed: u64, config: SocialConfig) -> Self {
27 Self {
28 rng: ChaCha8Rng::seed_from_u64(seed),
29 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::Esg),
30 config,
31 counter: 0,
32 }
33 }
34
35 pub fn generate_diversity(
41 &mut self,
42 entity_id: &str,
43 total_headcount: u32,
44 period: NaiveDate,
45 ) -> Vec<WorkforceDiversityMetric> {
46 if !self.config.diversity.enabled || total_headcount == 0 {
47 return Vec::new();
48 }
49
50 let mut metrics = Vec::new();
51 let levels = [
52 OrganizationLevel::Corporate,
53 OrganizationLevel::Executive,
54 OrganizationLevel::Board,
55 ];
56
57 for dimension in &[
58 DiversityDimension::Gender,
59 DiversityDimension::Ethnicity,
60 DiversityDimension::Age,
61 ] {
62 let categories = self.categories_for(*dimension);
63 for level in &levels {
65 let level_hc = match level {
66 OrganizationLevel::Corporate => total_headcount,
67 OrganizationLevel::Executive => (total_headcount / 50).max(5),
68 OrganizationLevel::Board => 11,
69 _ => total_headcount,
70 };
71
72 let shares = self.random_shares(categories.len());
73 for (i, cat) in categories.iter().enumerate() {
74 self.counter += 1;
75 let headcount = (Decimal::from(level_hc)
76 * Decimal::from_f64_retain(shares[i]).unwrap_or(Decimal::ZERO))
77 .round_dp(0);
78 let hc = headcount.to_string().parse::<u32>().unwrap_or(0);
79
80 let percentage = if level_hc > 0 {
81 (Decimal::from(hc) / Decimal::from(level_hc)).round_dp(4)
82 } else {
83 Decimal::ZERO
84 };
85
86 metrics.push(WorkforceDiversityMetric {
87 id: format!("DV-{:06}", self.counter),
88 entity_id: entity_id.to_string(),
89 period,
90 dimension: *dimension,
91 level: *level,
92 category: cat.to_string(),
93 headcount: hc,
94 total_headcount: level_hc,
95 percentage,
96 });
97 }
98 }
99 }
100
101 metrics
102 }
103
104 fn categories_for(&self, dimension: DiversityDimension) -> Vec<&'static str> {
105 match dimension {
106 DiversityDimension::Gender => vec!["Male", "Female", "Non-Binary"],
107 DiversityDimension::Ethnicity => {
108 vec!["White", "Asian", "Hispanic", "Black", "Other"]
109 }
110 DiversityDimension::Age => {
111 vec!["Under 30", "30-50", "Over 50"]
112 }
113 DiversityDimension::Disability => vec!["No Disability", "Has Disability"],
114 DiversityDimension::VeteranStatus => vec!["Non-Veteran", "Veteran"],
115 }
116 }
117
118 fn random_shares(&mut self, count: usize) -> Vec<f64> {
120 let mut raw: Vec<f64> = (0..count).map(|_| self.rng.gen::<f64>()).collect();
121 let total: f64 = raw.iter().sum();
122 if total > 0.0 {
123 for v in &mut raw {
124 *v /= total;
125 }
126 }
127 raw
128 }
129
130 pub fn generate_pay_equity(
134 &mut self,
135 entity_id: &str,
136 period: NaiveDate,
137 ) -> Vec<PayEquityMetric> {
138 if !self.config.pay_equity.enabled {
139 return Vec::new();
140 }
141
142 let comparisons = [
143 (DiversityDimension::Gender, "Male", "Female"),
144 (DiversityDimension::Ethnicity, "White", "Asian"),
145 (DiversityDimension::Ethnicity, "White", "Hispanic"),
146 (DiversityDimension::Ethnicity, "White", "Black"),
147 ];
148
149 comparisons
150 .iter()
151 .map(|(dim, ref_group, cmp_group)| {
152 self.counter += 1;
153 let ref_salary: f64 = self.rng.gen_range(70_000.0..120_000.0);
154 let gap_factor: f64 = self.rng.gen_range(0.85..1.05);
156 let cmp_salary = ref_salary * gap_factor;
157
158 let ref_dec = Decimal::from_f64_retain(ref_salary)
159 .unwrap_or(dec!(90000))
160 .round_dp(2);
161 let cmp_dec = Decimal::from_f64_retain(cmp_salary)
162 .unwrap_or(dec!(85000))
163 .round_dp(2);
164 let ratio = if ref_dec.is_zero() {
165 dec!(1.00)
166 } else {
167 (cmp_dec / ref_dec).round_dp(4)
168 };
169
170 PayEquityMetric {
171 id: format!("PE-{:06}", self.counter),
172 entity_id: entity_id.to_string(),
173 period,
174 dimension: *dim,
175 reference_group: ref_group.to_string(),
176 comparison_group: cmp_group.to_string(),
177 reference_median_salary: ref_dec,
178 comparison_median_salary: cmp_dec,
179 pay_gap_ratio: ratio,
180 sample_size: self.rng.gen_range(50..500),
181 }
182 })
183 .collect()
184 }
185
186 pub fn generate_safety_incidents(
190 &mut self,
191 entity_id: &str,
192 facility_count: u32,
193 start_date: NaiveDate,
194 end_date: NaiveDate,
195 ) -> Vec<SafetyIncident> {
196 if !self.config.safety.enabled {
197 return Vec::new();
198 }
199
200 let total_incidents = self.config.safety.incident_count;
201 let period_days = (end_date - start_date).num_days().max(1);
202
203 (0..total_incidents)
204 .map(|_| {
205 self.counter += 1;
206 let day_offset = self.rng.gen_range(0..period_days);
207 let date = start_date + chrono::Duration::days(day_offset);
208 let fac = self.rng.gen_range(1..=facility_count.max(1));
209
210 let incident_type = self.pick_incident_type();
211 let days_away = match incident_type {
212 IncidentType::Fatality => 0,
213 IncidentType::NearMiss | IncidentType::PropertyDamage => 0,
214 IncidentType::Injury => self.rng.gen_range(0..30u32),
215 IncidentType::Illness => self.rng.gen_range(1..15u32),
216 };
217 let is_recordable = !matches!(incident_type, IncidentType::NearMiss);
218
219 let description = match incident_type {
220 IncidentType::Injury => "Workplace injury incident".to_string(),
221 IncidentType::Illness => "Occupational illness reported".to_string(),
222 IncidentType::NearMiss => "Near miss event documented".to_string(),
223 IncidentType::Fatality => "Fatal workplace incident".to_string(),
224 IncidentType::PropertyDamage => "Property damage incident".to_string(),
225 };
226
227 SafetyIncident {
228 id: format!("SI-{:06}", self.counter),
229 entity_id: entity_id.to_string(),
230 facility_id: format!("FAC-{:03}", fac),
231 date,
232 incident_type,
233 days_away,
234 is_recordable,
235 description,
236 }
237 })
238 .collect()
239 }
240
241 fn pick_incident_type(&mut self) -> IncidentType {
242 let roll: f64 = self.rng.gen::<f64>();
243 if roll < 0.35 {
244 IncidentType::NearMiss
245 } else if roll < 0.65 {
246 IncidentType::Injury
247 } else if roll < 0.80 {
248 IncidentType::Illness
249 } else if roll < 0.95 {
250 IncidentType::PropertyDamage
251 } else {
252 IncidentType::Fatality
253 }
254 }
255
256 pub fn compute_safety_metrics(
258 &mut self,
259 entity_id: &str,
260 incidents: &[SafetyIncident],
261 total_hours_worked: u64,
262 period: NaiveDate,
263 ) -> SafetyMetric {
264 self.counter += 1;
265
266 let recordable = incidents.iter().filter(|i| i.is_recordable).count() as u32;
267 let lost_time = incidents.iter().filter(|i| i.days_away > 0).count() as u32;
268 let days_away: u32 = incidents.iter().map(|i| i.days_away).sum();
269 let near_misses = incidents
270 .iter()
271 .filter(|i| i.incident_type == IncidentType::NearMiss)
272 .count() as u32;
273 let fatalities = incidents
274 .iter()
275 .filter(|i| i.incident_type == IncidentType::Fatality)
276 .count() as u32;
277
278 let hours_dec = Decimal::from(total_hours_worked);
279 let base = dec!(200000);
280
281 let trir = if total_hours_worked > 0 {
282 (Decimal::from(recordable) * base / hours_dec).round_dp(4)
283 } else {
284 Decimal::ZERO
285 };
286 let ltir = if total_hours_worked > 0 {
287 (Decimal::from(lost_time) * base / hours_dec).round_dp(4)
288 } else {
289 Decimal::ZERO
290 };
291 let dart_rate = if total_hours_worked > 0 {
292 (Decimal::from(days_away) * base / hours_dec).round_dp(4)
293 } else {
294 Decimal::ZERO
295 };
296
297 SafetyMetric {
298 id: format!("SM-{:06}", self.counter),
299 entity_id: entity_id.to_string(),
300 period,
301 total_hours_worked,
302 recordable_incidents: recordable,
303 lost_time_incidents: lost_time,
304 days_away,
305 near_misses,
306 fatalities,
307 trir,
308 ltir,
309 dart_rate,
310 }
311 }
312}
313
314pub struct GovernanceGenerator {
316 rng: ChaCha8Rng,
317 counter: u64,
318 board_size: u32,
319 independence_target: f64,
320}
321
322impl GovernanceGenerator {
323 pub fn new(seed: u64, board_size: u32, independence_target: f64) -> Self {
325 Self {
326 rng: ChaCha8Rng::seed_from_u64(seed),
327 counter: 0,
328 board_size: board_size.max(3),
329 independence_target,
330 }
331 }
332
333 pub fn generate(&mut self, entity_id: &str, period: NaiveDate) -> GovernanceMetric {
335 self.counter += 1;
336
337 let ind_frac: f64 = self.rng.gen_range(
339 (self.independence_target - 0.10).max(0.0)..(self.independence_target + 0.10).min(1.0),
340 );
341 let independent = (self.board_size as f64 * ind_frac).round() as u32;
342 let independent = independent.min(self.board_size);
343
344 let fem_frac: f64 = self.rng.gen_range(0.20..0.40);
346 let female = (self.board_size as f64 * fem_frac).round() as u32;
347 let female = female.min(self.board_size);
348
349 let independence_ratio = if self.board_size > 0 {
350 (Decimal::from(independent) / Decimal::from(self.board_size)).round_dp(4)
351 } else {
352 Decimal::ZERO
353 };
354 let gender_ratio = if self.board_size > 0 {
355 (Decimal::from(female) / Decimal::from(self.board_size)).round_dp(4)
356 } else {
357 Decimal::ZERO
358 };
359
360 let ethics_pct: f64 = self.rng.gen_range(0.85..0.99);
361 let whistleblower: u32 = self.rng.gen_range(0..5);
362 let anti_corruption: u32 = if self.rng.gen::<f64>() < 0.10 { 1 } else { 0 };
363
364 GovernanceMetric {
365 id: format!("GV-{:06}", self.counter),
366 entity_id: entity_id.to_string(),
367 period,
368 board_size: self.board_size,
369 independent_directors: independent,
370 female_directors: female,
371 board_independence_ratio: independence_ratio,
372 board_gender_diversity_ratio: gender_ratio,
373 ethics_training_completion_pct: (ethics_pct * 100.0).round() / 100.0,
374 whistleblower_reports: whistleblower,
375 anti_corruption_violations: anti_corruption,
376 }
377 }
378}
379
380#[cfg(test)]
381#[allow(clippy::unwrap_used)]
382mod tests {
383 use super::*;
384
385 fn d(s: &str) -> NaiveDate {
386 NaiveDate::parse_from_str(s, "%Y-%m-%d").unwrap()
387 }
388
389 #[test]
390 fn test_diversity_percentages_sum_to_one() {
391 let config = SocialConfig::default();
392 let mut gen = WorkforceGenerator::new(42, config);
393 let metrics = gen.generate_diversity("C001", 1000, d("2025-01-01"));
394
395 assert!(!metrics.is_empty());
396
397 let mut groups: std::collections::HashMap<
399 (String, String),
400 Vec<&WorkforceDiversityMetric>,
401 > = std::collections::HashMap::new();
402 for m in &metrics {
403 let key = (format!("{:?}", m.dimension), format!("{:?}", m.level));
404 groups.entry(key).or_default().push(m);
405 }
406
407 for (key, group) in &groups {
408 let total_hc: u32 = group.iter().map(|m| m.headcount).sum();
409 let expected = group[0].total_headcount;
410 assert!(
412 total_hc.abs_diff(expected) <= 1,
413 "Group {:?}: headcount sum {} != expected {}",
414 key,
415 total_hc,
416 expected
417 );
418 }
419 }
420
421 #[test]
422 fn test_pay_equity_ratios() {
423 let config = SocialConfig::default();
424 let mut gen = WorkforceGenerator::new(42, config);
425 let metrics = gen.generate_pay_equity("C001", d("2025-01-01"));
426
427 assert_eq!(metrics.len(), 4, "Should have 4 comparison pairs");
428 for m in &metrics {
429 assert!(m.pay_gap_ratio > dec!(0.80) && m.pay_gap_ratio < dec!(1.10));
430 assert!(m.reference_median_salary > Decimal::ZERO);
431 assert!(m.comparison_median_salary > Decimal::ZERO);
432 assert!(m.sample_size > 0);
433 }
434 }
435
436 #[test]
437 fn test_safety_incidents() {
438 let config = SocialConfig {
439 safety: datasynth_config::schema::SafetySchemaConfig {
440 enabled: true,
441 target_trir: 2.5,
442 incident_count: 30,
443 },
444 ..Default::default()
445 };
446 let mut gen = WorkforceGenerator::new(42, config);
447 let incidents = gen.generate_safety_incidents("C001", 3, d("2025-01-01"), d("2025-12-31"));
448
449 assert_eq!(incidents.len(), 30);
450
451 let recordable = incidents.iter().filter(|i| i.is_recordable).count();
452 let near_miss = incidents
453 .iter()
454 .filter(|i| i.incident_type == IncidentType::NearMiss)
455 .count();
456 assert_eq!(recordable + near_miss, 30);
458 }
459
460 #[test]
461 fn test_safety_metric_trir_computation() {
462 let config = SocialConfig::default();
463 let mut gen = WorkforceGenerator::new(42, config);
464
465 let incidents = vec![
466 SafetyIncident {
467 id: "SI-001".into(),
468 entity_id: "C001".into(),
469 facility_id: "FAC-001".into(),
470 date: d("2025-03-15"),
471 incident_type: IncidentType::Injury,
472 days_away: 5,
473 is_recordable: true,
474 description: "Test".into(),
475 },
476 SafetyIncident {
477 id: "SI-002".into(),
478 entity_id: "C001".into(),
479 facility_id: "FAC-001".into(),
480 date: d("2025-06-20"),
481 incident_type: IncidentType::NearMiss,
482 days_away: 0,
483 is_recordable: false,
484 description: "Test".into(),
485 },
486 ];
487
488 let metric = gen.compute_safety_metrics("C001", &incidents, 500_000, d("2025-01-01"));
489
490 assert_eq!(metric.recordable_incidents, 1);
491 assert_eq!(metric.near_misses, 1);
492 assert_eq!(metric.lost_time_incidents, 1);
493 assert_eq!(metric.days_away, 5);
494 assert_eq!(metric.trir, dec!(0.4000));
496 assert_eq!(metric.computed_trir(), dec!(0.4000));
497 }
498
499 #[test]
500 fn test_governance_generation() {
501 let mut gen = GovernanceGenerator::new(42, 11, 0.67);
502 let metric = gen.generate("C001", d("2025-01-01"));
503
504 assert_eq!(metric.board_size, 11);
505 assert!(metric.independent_directors <= 11);
506 assert!(metric.female_directors <= 11);
507 assert!(metric.board_independence_ratio > Decimal::ZERO);
508 assert!(metric.ethics_training_completion_pct >= 0.85);
509 }
510
511 #[test]
512 fn test_disabled_diversity() {
513 let mut config = SocialConfig::default();
514 config.diversity.enabled = false;
515 let mut gen = WorkforceGenerator::new(42, config);
516 let metrics = gen.generate_diversity("C001", 1000, d("2025-01-01"));
517 assert!(metrics.is_empty());
518 }
519}