Skip to main content

datasynth_runtime/
orchestrator.rs

1//! Generation orchestrator for coordinating data generation.
2
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5use std::time::Duration;
6
7use datasynth_config::schema::GeneratorConfig;
8use datasynth_core::error::{SynthError, SynthResult};
9use datasynth_core::models::*;
10use datasynth_generators::{ChartOfAccountsGenerator, JournalEntryGenerator};
11use indicatif::{ProgressBar, ProgressStyle};
12
13/// Result of a generation run.
14pub struct GenerationResult {
15    /// Generated chart of accounts
16    pub chart_of_accounts: ChartOfAccounts,
17    /// Generated journal entries
18    pub journal_entries: Vec<JournalEntry>,
19    /// Statistics about the generation
20    pub statistics: GenerationStatistics,
21}
22
23/// Statistics about a generation run.
24#[derive(Debug, Clone)]
25pub struct GenerationStatistics {
26    /// Total journal entries generated
27    pub total_entries: u64,
28    /// Total line items generated
29    pub total_line_items: u64,
30    /// Number of accounts in CoA
31    pub accounts_count: usize,
32    /// Number of companies
33    pub companies_count: usize,
34    /// Period in months
35    pub period_months: u32,
36}
37
38/// Main orchestrator for generation.
39pub struct GenerationOrchestrator {
40    config: GeneratorConfig,
41    coa: Option<Arc<ChartOfAccounts>>,
42    /// Optional pause flag for external control (e.g., signal handlers).
43    pause_flag: Option<Arc<AtomicBool>>,
44}
45
46impl GenerationOrchestrator {
47    /// Create a new orchestrator.
48    pub fn new(config: GeneratorConfig) -> SynthResult<Self> {
49        // Validate config
50        datasynth_config::validate_config(&config)?;
51
52        Ok(Self {
53            config,
54            coa: None,
55            pause_flag: None,
56        })
57    }
58
59    /// Set a pause flag that can be controlled externally (e.g., by a signal handler).
60    /// When the flag is true, generation will pause until it becomes false.
61    pub fn with_pause_flag(mut self, flag: Arc<AtomicBool>) -> Self {
62        self.pause_flag = Some(flag);
63        self
64    }
65
66    /// Check if generation is currently paused.
67    fn is_paused(&self) -> bool {
68        self.pause_flag
69            .as_ref()
70            .map(|f| f.load(Ordering::Relaxed))
71            .unwrap_or(false)
72    }
73
74    /// Wait while paused, checking periodically.
75    fn wait_while_paused(&self, pb: &ProgressBar) {
76        let was_paused = self.is_paused();
77        if was_paused {
78            pb.set_message("PAUSED - send SIGUSR1 to resume");
79        }
80
81        while self.is_paused() {
82            std::thread::sleep(Duration::from_millis(100));
83        }
84
85        if was_paused {
86            pb.set_message("");
87        }
88    }
89
90    /// Generate the chart of accounts.
91    pub fn generate_coa(&mut self) -> SynthResult<Arc<ChartOfAccounts>> {
92        let seed = self.config.global.seed.unwrap_or_else(rand::random);
93        let mut gen = ChartOfAccountsGenerator::new(
94            self.config.chart_of_accounts.complexity,
95            self.config.global.industry,
96            seed,
97        );
98
99        let coa = Arc::new(gen.generate());
100        self.coa = Some(Arc::clone(&coa));
101        Ok(coa)
102    }
103
104    /// Calculate total transactions to generate.
105    pub fn calculate_total_transactions(&self) -> u64 {
106        let months = self.config.global.period_months as f64;
107
108        self.config
109            .companies
110            .iter()
111            .map(|c| {
112                let annual = c.annual_transaction_volume.count() as f64;
113                let weighted = annual * c.volume_weight;
114                (weighted * months / 12.0) as u64
115            })
116            .sum()
117    }
118
119    /// Run the generation.
120    pub fn generate(&mut self) -> SynthResult<GenerationResult> {
121        // Generate CoA if not already done
122        let coa = match &self.coa {
123            Some(c) => Arc::clone(c),
124            None => self.generate_coa()?,
125        };
126
127        let total = self.calculate_total_transactions();
128        let seed = self.config.global.seed.unwrap_or_else(rand::random);
129
130        // Parse dates
131        let start_date =
132            chrono::NaiveDate::parse_from_str(&self.config.global.start_date, "%Y-%m-%d")
133                .map_err(|e| SynthError::config(format!("Invalid start_date: {}", e)))?;
134
135        let end_date = start_date + chrono::Months::new(self.config.global.period_months);
136
137        // Create progress bar
138        let pb = ProgressBar::new(total);
139        pb.set_style(
140            ProgressStyle::default_bar()
141                .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({per_sec})")
142                .expect("Progress bar template is a compile-time constant and should always be valid")
143                .progress_chars("#>-"),
144        );
145
146        // Get company codes
147        let company_codes: Vec<String> = self
148            .config
149            .companies
150            .iter()
151            .map(|c| c.code.clone())
152            .collect();
153
154        // Generate entries with fraud config
155        let mut generator = JournalEntryGenerator::new_with_params(
156            self.config.transactions.clone(),
157            Arc::clone(&coa),
158            company_codes,
159            start_date,
160            end_date,
161            seed,
162        )
163        .with_fraud_config(self.config.fraud.clone());
164
165        let mut entries = Vec::with_capacity(total as usize);
166        let mut total_lines = 0u64;
167
168        for _ in 0..total {
169            // Check and wait if paused
170            self.wait_while_paused(&pb);
171
172            let entry = generator.generate();
173            total_lines += entry.line_count() as u64;
174            entries.push(entry);
175            pb.inc(1);
176        }
177
178        pb.finish_with_message("Generation complete");
179
180        Ok(GenerationResult {
181            chart_of_accounts: (*coa).clone(),
182            journal_entries: entries,
183            statistics: GenerationStatistics {
184                total_entries: total,
185                total_line_items: total_lines,
186                accounts_count: coa.account_count(),
187                companies_count: self.config.companies.len(),
188                period_months: self.config.global.period_months,
189            },
190        })
191    }
192}