Skip to main content

datasynth_runtime/
orchestrator.rs

1//! Generation orchestrator for coordinating data generation.
2
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5use std::time::Duration;
6
7use datasynth_config::schema::GeneratorConfig;
8use datasynth_core::error::{SynthError, SynthResult};
9use datasynth_core::models::*;
10use datasynth_core::traits::{Generator, ParallelGenerator};
11use datasynth_generators::{ChartOfAccountsGenerator, JournalEntryGenerator};
12use indicatif::{ProgressBar, ProgressStyle};
13use rayon::prelude::*;
14
15/// Result of a generation run.
16pub struct GenerationResult {
17    /// Generated chart of accounts
18    pub chart_of_accounts: ChartOfAccounts,
19    /// Generated journal entries
20    pub journal_entries: Vec<JournalEntry>,
21    /// Statistics about the generation
22    pub statistics: GenerationStatistics,
23}
24
25/// Statistics about a generation run.
26#[derive(Debug, Clone)]
27pub struct GenerationStatistics {
28    /// Total journal entries generated
29    pub total_entries: u64,
30    /// Total line items generated
31    pub total_line_items: u64,
32    /// Number of accounts in CoA
33    pub accounts_count: usize,
34    /// Number of companies
35    pub companies_count: usize,
36    /// Period in months
37    pub period_months: u32,
38}
39
40/// Main orchestrator for generation.
41pub struct GenerationOrchestrator {
42    config: GeneratorConfig,
43    coa: Option<Arc<ChartOfAccounts>>,
44    /// Optional pause flag for external control (e.g., signal handlers).
45    pause_flag: Option<Arc<AtomicBool>>,
46}
47
48impl GenerationOrchestrator {
49    /// Create a new orchestrator.
50    pub fn new(config: GeneratorConfig) -> SynthResult<Self> {
51        // Validate config
52        datasynth_config::validate_config(&config)?;
53
54        Ok(Self {
55            config,
56            coa: None,
57            pause_flag: None,
58        })
59    }
60
61    /// Set a pause flag that can be controlled externally (e.g., by a signal handler).
62    /// When the flag is true, generation will pause until it becomes false.
63    pub fn with_pause_flag(mut self, flag: Arc<AtomicBool>) -> Self {
64        self.pause_flag = Some(flag);
65        self
66    }
67
68    /// Check if generation is currently paused.
69    fn is_paused(&self) -> bool {
70        self.pause_flag
71            .as_ref()
72            .map(|f| f.load(Ordering::Relaxed))
73            .unwrap_or(false)
74    }
75
76    /// Wait while paused, checking periodically.
77    fn wait_while_paused(&self, pb: &ProgressBar) {
78        let was_paused = self.is_paused();
79        if was_paused {
80            pb.set_message("PAUSED - send SIGUSR1 to resume");
81        }
82
83        while self.is_paused() {
84            std::thread::sleep(Duration::from_millis(100));
85        }
86
87        if was_paused {
88            pb.set_message("");
89        }
90    }
91
92    /// Generate the chart of accounts.
93    pub fn generate_coa(&mut self) -> SynthResult<Arc<ChartOfAccounts>> {
94        let seed = self.config.global.seed.unwrap_or_else(rand::random);
95        let mut gen = ChartOfAccountsGenerator::new(
96            self.config.chart_of_accounts.complexity,
97            self.config.global.industry,
98            seed,
99        );
100
101        let coa = Arc::new(gen.generate());
102        self.coa = Some(Arc::clone(&coa));
103        Ok(coa)
104    }
105
106    /// Calculate total transactions to generate.
107    pub fn calculate_total_transactions(&self) -> u64 {
108        let months = self.config.global.period_months as f64;
109
110        self.config
111            .companies
112            .iter()
113            .map(|c| {
114                let annual = c.annual_transaction_volume.count() as f64;
115                let weighted = annual * c.volume_weight;
116                (weighted * months / 12.0) as u64
117            })
118            .sum()
119    }
120
121    /// Run the generation.
122    pub fn generate(&mut self) -> SynthResult<GenerationResult> {
123        // Generate CoA if not already done
124        let coa = match &self.coa {
125            Some(c) => Arc::clone(c),
126            None => self.generate_coa()?,
127        };
128
129        let total = self.calculate_total_transactions();
130        let seed = self.config.global.seed.unwrap_or_else(rand::random);
131
132        // Parse dates
133        let start_date =
134            chrono::NaiveDate::parse_from_str(&self.config.global.start_date, "%Y-%m-%d")
135                .map_err(|e| SynthError::config(format!("Invalid start_date: {}", e)))?;
136
137        let end_date = start_date + chrono::Months::new(self.config.global.period_months);
138
139        // Create progress bar
140        let pb = ProgressBar::new(total);
141        pb.set_style(
142            ProgressStyle::default_bar()
143                .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({per_sec})")
144                .expect("Progress bar template is a compile-time constant and should always be valid")
145                .progress_chars("#>-"),
146        );
147
148        // Get company codes
149        let company_codes: Vec<String> = self
150            .config
151            .companies
152            .iter()
153            .map(|c| c.code.clone())
154            .collect();
155
156        // Generate entries with fraud config
157        let mut generator = JournalEntryGenerator::new_with_params(
158            self.config.transactions.clone(),
159            Arc::clone(&coa),
160            company_codes,
161            start_date,
162            end_date,
163            seed,
164        )
165        .with_fraud_config(self.config.fraud.clone());
166
167        // Parallel generation: split across available cores for large datasets
168        let num_threads = num_cpus::get().max(1).min(total as usize).max(1);
169
170        let entries = if total >= 10_000 && num_threads > 1 {
171            let sub_generators = generator.split(num_threads);
172            let entries_per_thread = total as usize / num_threads;
173            let remainder = total as usize % num_threads;
174
175            let batches: Vec<Vec<JournalEntry>> = sub_generators
176                .into_par_iter()
177                .enumerate()
178                .map(|(i, mut gen)| {
179                    let count = entries_per_thread + if i < remainder { 1 } else { 0 };
180                    gen.generate_batch(count)
181                })
182                .collect();
183
184            let entries = JournalEntryGenerator::merge_results(batches);
185            pb.inc(total);
186            entries
187        } else {
188            let mut entries = Vec::with_capacity(total as usize);
189            for _ in 0..total {
190                self.wait_while_paused(&pb);
191                let entry = generator.generate();
192                entries.push(entry);
193                pb.inc(1);
194            }
195            entries
196        };
197
198        let total_lines: u64 = entries.iter().map(|e| e.line_count() as u64).sum();
199
200        pb.finish_with_message("Generation complete");
201
202        Ok(GenerationResult {
203            chart_of_accounts: (*coa).clone(),
204            journal_entries: entries,
205            statistics: GenerationStatistics {
206                total_entries: total,
207                total_line_items: total_lines,
208                accounts_count: coa.account_count(),
209                companies_count: self.config.companies.len(),
210                period_months: self.config.global.period_months,
211            },
212        })
213    }
214}