datasynth_runtime/
orchestrator.rs1use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5use std::time::Duration;
6
7use datasynth_config::schema::GeneratorConfig;
8use datasynth_core::error::{SynthError, SynthResult};
9use datasynth_core::models::*;
10use datasynth_core::traits::{Generator, ParallelGenerator};
11use datasynth_generators::{ChartOfAccountsGenerator, JournalEntryGenerator};
12use indicatif::{ProgressBar, ProgressStyle};
13use rayon::prelude::*;
14
15pub struct GenerationResult {
17 pub chart_of_accounts: ChartOfAccounts,
19 pub journal_entries: Vec<JournalEntry>,
21 pub statistics: GenerationStatistics,
23}
24
25#[derive(Debug, Clone)]
27pub struct GenerationStatistics {
28 pub total_entries: u64,
30 pub total_line_items: u64,
32 pub accounts_count: usize,
34 pub companies_count: usize,
36 pub period_months: u32,
38}
39
40pub struct GenerationOrchestrator {
42 config: GeneratorConfig,
43 coa: Option<Arc<ChartOfAccounts>>,
44 pause_flag: Option<Arc<AtomicBool>>,
46}
47
48impl GenerationOrchestrator {
49 pub fn new(config: GeneratorConfig) -> SynthResult<Self> {
51 datasynth_config::validate_config(&config)?;
53
54 Ok(Self {
55 config,
56 coa: None,
57 pause_flag: None,
58 })
59 }
60
61 pub fn with_pause_flag(mut self, flag: Arc<AtomicBool>) -> Self {
64 self.pause_flag = Some(flag);
65 self
66 }
67
68 fn is_paused(&self) -> bool {
70 self.pause_flag
71 .as_ref()
72 .map(|f| f.load(Ordering::Relaxed))
73 .unwrap_or(false)
74 }
75
76 fn wait_while_paused(&self, pb: &ProgressBar) {
78 let was_paused = self.is_paused();
79 if was_paused {
80 pb.set_message("PAUSED - send SIGUSR1 to resume");
81 }
82
83 while self.is_paused() {
84 std::thread::sleep(Duration::from_millis(100));
85 }
86
87 if was_paused {
88 pb.set_message("");
89 }
90 }
91
92 pub fn generate_coa(&mut self) -> SynthResult<Arc<ChartOfAccounts>> {
94 let seed = self.config.global.seed.unwrap_or_else(rand::random);
95 let mut gen = ChartOfAccountsGenerator::new(
96 self.config.chart_of_accounts.complexity,
97 self.config.global.industry,
98 seed,
99 );
100
101 let coa = Arc::new(gen.generate());
102 self.coa = Some(Arc::clone(&coa));
103 Ok(coa)
104 }
105
106 pub fn calculate_total_transactions(&self) -> u64 {
108 let months = self.config.global.period_months as f64;
109
110 self.config
111 .companies
112 .iter()
113 .map(|c| {
114 let annual = c.annual_transaction_volume.count() as f64;
115 let weighted = annual * c.volume_weight;
116 (weighted * months / 12.0) as u64
117 })
118 .sum()
119 }
120
121 pub fn generate(&mut self) -> SynthResult<GenerationResult> {
123 let coa = match &self.coa {
125 Some(c) => Arc::clone(c),
126 None => self.generate_coa()?,
127 };
128
129 let total = self.calculate_total_transactions();
130 let seed = self.config.global.seed.unwrap_or_else(rand::random);
131
132 let start_date =
134 chrono::NaiveDate::parse_from_str(&self.config.global.start_date, "%Y-%m-%d")
135 .map_err(|e| SynthError::config(format!("Invalid start_date: {}", e)))?;
136
137 let end_date = start_date + chrono::Months::new(self.config.global.period_months);
138
139 let pb = ProgressBar::new(total);
141 pb.set_style(
142 ProgressStyle::default_bar()
143 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({per_sec})")
144 .expect("Progress bar template is a compile-time constant and should always be valid")
145 .progress_chars("#>-"),
146 );
147
148 let company_codes: Vec<String> = self
150 .config
151 .companies
152 .iter()
153 .map(|c| c.code.clone())
154 .collect();
155
156 let mut generator = JournalEntryGenerator::new_with_params(
158 self.config.transactions.clone(),
159 Arc::clone(&coa),
160 company_codes,
161 start_date,
162 end_date,
163 seed,
164 )
165 .with_fraud_config(self.config.fraud.clone());
166
167 let num_threads = num_cpus::get().max(1).min(total as usize).max(1);
169
170 let entries = if total >= 10_000 && num_threads > 1 {
171 let sub_generators = generator.split(num_threads);
172 let entries_per_thread = total as usize / num_threads;
173 let remainder = total as usize % num_threads;
174
175 let batches: Vec<Vec<JournalEntry>> = sub_generators
176 .into_par_iter()
177 .enumerate()
178 .map(|(i, mut gen)| {
179 let count = entries_per_thread + if i < remainder { 1 } else { 0 };
180 gen.generate_batch(count)
181 })
182 .collect();
183
184 let entries = JournalEntryGenerator::merge_results(batches);
185 pb.inc(total);
186 entries
187 } else {
188 let mut entries = Vec::with_capacity(total as usize);
189 for _ in 0..total {
190 self.wait_while_paused(&pb);
191 let entry = generator.generate();
192 entries.push(entry);
193 pb.inc(1);
194 }
195 entries
196 };
197
198 let total_lines: u64 = entries.iter().map(|e| e.line_count() as u64).sum();
199
200 pb.finish_with_message("Generation complete");
201
202 Ok(GenerationResult {
203 chart_of_accounts: (*coa).clone(),
204 journal_entries: entries,
205 statistics: GenerationStatistics {
206 total_entries: total,
207 total_line_items: total_lines,
208 accounts_count: coa.account_count(),
209 companies_count: self.config.companies.len(),
210 period_months: self.config.global.period_months,
211 },
212 })
213 }
214}