Skip to main content

datasynth_runtime/
generation_session.rs

1//! Multi-period generation session with checkpoint/resume support.
2//!
3//! [`GenerationSession`] wraps [`EnhancedOrchestrator`] and drives it through
4//! a sequence of [`GenerationPeriod`]s, persisting state to `.dss` files so
5//! that long runs can be resumed after interruption.
6
7use std::fs;
8use std::path::{Path, PathBuf};
9
10use datasynth_config::GeneratorConfig;
11use datasynth_core::models::generation_session::{
12    add_months, advance_seed, BalanceState, DocumentIdState, EntityCounts, GenerationPeriod,
13    PeriodLog, SessionState,
14};
15use datasynth_core::SynthError;
16
17use crate::enhanced_orchestrator::{EnhancedOrchestrator, PhaseConfig};
18
19type SynthResult<T> = Result<T, SynthError>;
20
21/// Controls how period output directories are laid out.
22#[derive(Debug, Clone)]
23pub enum OutputMode {
24    /// Single output directory (one period).
25    Batch(PathBuf),
26    /// One sub-directory per period under a root directory.
27    MultiPeriod(PathBuf),
28}
29
30/// Summary of a single completed period generation.
31#[derive(Debug)]
32pub struct PeriodResult {
33    /// The period that was generated.
34    pub period: GenerationPeriod,
35    /// Filesystem path where this period's output was written.
36    pub output_path: PathBuf,
37    /// Number of journal entries generated in this period.
38    pub journal_entry_count: usize,
39    /// Number of document flow records generated in this period.
40    pub document_count: usize,
41    /// Number of anomalies injected in this period.
42    pub anomaly_count: usize,
43    /// Wall-clock duration for generating this period (seconds).
44    pub duration_secs: f64,
45}
46
47/// A multi-period generation session with checkpoint/resume support.
48///
49/// The session decomposes the total requested time span into fiscal-year-aligned
50/// periods and generates each one sequentially, carrying forward balance and ID
51/// state between periods.
52#[derive(Debug)]
53pub struct GenerationSession {
54    config: GeneratorConfig,
55    state: SessionState,
56    periods: Vec<GenerationPeriod>,
57    output_mode: OutputMode,
58    phase_config: PhaseConfig,
59}
60
61impl GenerationSession {
62    /// Create a new session from a config and output path.
63    ///
64    /// The total time span is decomposed into fiscal-year-aligned periods
65    /// based on `config.global.fiscal_year_months` (defaults to `period_months`
66    /// if not set, yielding a single period).
67    pub fn new(config: GeneratorConfig, output_path: PathBuf) -> SynthResult<Self> {
68        let start_date = chrono::NaiveDate::parse_from_str(&config.global.start_date, "%Y-%m-%d")
69            .map_err(|e| SynthError::generation(format!("Invalid start_date: {e}")))?;
70
71        let total_months = config.global.period_months;
72        let fy_months = config.global.fiscal_year_months.unwrap_or(total_months);
73        let periods = GenerationPeriod::compute_periods(start_date, total_months, fy_months);
74
75        let output_mode = if periods.len() > 1 {
76            OutputMode::MultiPeriod(output_path)
77        } else {
78            OutputMode::Batch(output_path)
79        };
80
81        let seed = config.global.seed.unwrap_or(42);
82        let config_hash = Self::compute_config_hash(&config);
83
84        let state = SessionState {
85            rng_seed: seed,
86            period_cursor: 0,
87            balance_state: BalanceState::default(),
88            document_id_state: DocumentIdState::default(),
89            entity_counts: EntityCounts::default(),
90            generation_log: Vec::new(),
91            config_hash,
92        };
93
94        Ok(Self {
95            config,
96            state,
97            periods,
98            output_mode,
99            phase_config: PhaseConfig::default(),
100        })
101    }
102
103    /// Resume a session from a `.dss` checkpoint file.
104    ///
105    /// The config hash is verified against the checkpoint to ensure the config
106    /// has not changed since the session was last saved.
107    pub fn resume(path: &Path, config: GeneratorConfig) -> SynthResult<Self> {
108        let data = fs::read_to_string(path)
109            .map_err(|e| SynthError::generation(format!("Failed to read .dss: {e}")))?;
110        let state: SessionState = serde_json::from_str(&data)
111            .map_err(|e| SynthError::generation(format!("Failed to parse .dss: {e}")))?;
112
113        let current_hash = Self::compute_config_hash(&config);
114        if state.config_hash != current_hash {
115            return Err(SynthError::generation(
116                "Config has changed since last checkpoint. Cannot resume with different config."
117                    .to_string(),
118            ));
119        }
120
121        let start_date = chrono::NaiveDate::parse_from_str(&config.global.start_date, "%Y-%m-%d")
122            .map_err(|e| SynthError::generation(format!("Invalid start_date: {e}")))?;
123
124        let total_months = config.global.period_months;
125        let fy_months = config.global.fiscal_year_months.unwrap_or(total_months);
126        let periods = GenerationPeriod::compute_periods(start_date, total_months, fy_months);
127
128        let output_dir = path.parent().unwrap_or(Path::new(".")).to_path_buf();
129        let output_mode = if periods.len() > 1 {
130            OutputMode::MultiPeriod(output_dir)
131        } else {
132            OutputMode::Batch(output_dir)
133        };
134
135        Ok(Self {
136            config,
137            state,
138            periods,
139            output_mode,
140            phase_config: PhaseConfig::default(),
141        })
142    }
143
144    /// Persist the current session state to a `.dss` file.
145    pub fn save(&self, path: &Path) -> SynthResult<()> {
146        let data = serde_json::to_string_pretty(&self.state)
147            .map_err(|e| SynthError::generation(format!("Failed to serialize state: {e}")))?;
148        fs::write(path, data)
149            .map_err(|e| SynthError::generation(format!("Failed to write .dss: {e}")))?;
150        Ok(())
151    }
152
153    /// Generate the next period in the sequence.
154    ///
155    /// Returns `Ok(None)` if all periods have been generated.
156    pub fn generate_next_period(&mut self) -> SynthResult<Option<PeriodResult>> {
157        if self.state.period_cursor >= self.periods.len() {
158            return Ok(None);
159        }
160
161        let period = self.periods[self.state.period_cursor].clone();
162        let start = std::time::Instant::now();
163
164        let period_seed = advance_seed(self.state.rng_seed, period.index);
165
166        let mut period_config = self.config.clone();
167        period_config.global.start_date = period.start_date.format("%Y-%m-%d").to_string();
168        period_config.global.period_months = period.months;
169        period_config.global.seed = Some(period_seed);
170
171        let output_path = match &self.output_mode {
172            OutputMode::Batch(p) => p.clone(),
173            OutputMode::MultiPeriod(p) => p.join(&period.label),
174        };
175
176        fs::create_dir_all(&output_path)
177            .map_err(|e| SynthError::generation(format!("Failed to create output dir: {e}")))?;
178
179        let orchestrator = EnhancedOrchestrator::new(period_config, self.phase_config.clone())?;
180        let mut orchestrator = orchestrator.with_output_path(&output_path);
181        let result = orchestrator.generate()?;
182
183        let duration = start.elapsed().as_secs_f64();
184
185        // Count journal entries from the result vec
186        let je_count = result.journal_entries.len();
187
188        // Count documents from the document_flows snapshot
189        let doc_count = result.document_flows.purchase_orders.len()
190            + result.document_flows.sales_orders.len()
191            + result.document_flows.goods_receipts.len()
192            + result.document_flows.vendor_invoices.len()
193            + result.document_flows.customer_invoices.len()
194            + result.document_flows.deliveries.len()
195            + result.document_flows.payments.len();
196
197        // Count anomalies from anomaly_labels
198        let anomaly_count = result.anomaly_labels.labels.len();
199
200        // ---------------------------------------------------------------
201        // Balance carry-forward: aggregate closing GL balances from JEs
202        // so the next period starts from this period's closing position.
203        // ---------------------------------------------------------------
204        {
205            use std::collections::HashMap;
206
207            // Build net balance per GL account (debit positive, credit negative).
208            let mut gl_net: HashMap<String, f64> = HashMap::new();
209            for je in &result.journal_entries {
210                for line in &je.lines {
211                    let account = line.gl_account.clone();
212                    let delta = f64::try_from(line.debit_amount).unwrap_or(0.0)
213                        - f64::try_from(line.credit_amount).unwrap_or(0.0);
214                    *gl_net.entry(account).or_insert(0.0) += delta;
215                }
216            }
217
218            // Carry forward as opening balances for the next period.
219            // We merge into any existing carry-forward from prior periods.
220            for (account, delta) in gl_net {
221                *self
222                    .state
223                    .balance_state
224                    .gl_balances
225                    .entry(account)
226                    .or_insert(0.0) += delta;
227            }
228
229            // Derive aggregate subledger totals from the balance map.
230            // AR is represented by account 1100, AP by account 2000 (sign convention:
231            // positive = debit balance for AR, positive credit balance treated as
232            // positive AP by flipping sign).
233            self.state.balance_state.ar_total = self
234                .state
235                .balance_state
236                .gl_balances
237                .get("1100")
238                .copied()
239                .unwrap_or(0.0)
240                .max(0.0);
241            self.state.balance_state.ap_total = (-self
242                .state
243                .balance_state
244                .gl_balances
245                .get("2000")
246                .copied()
247                .unwrap_or(0.0))
248            .max(0.0);
249
250            // Retained earnings: sum of all income statement accounts (4xxx–8xxx range).
251            // Positive retained earnings arise when revenues (credit-normal) exceed expenses.
252            let retained: f64 = self
253                .state
254                .balance_state
255                .gl_balances
256                .iter()
257                .filter_map(|(acct, &bal)| {
258                    acct.parse::<u32>()
259                        .ok()
260                        .filter(|&n| (4000..=8999).contains(&n))
261                        .map(|_| -bal) // credit-normal income accounts are negative in debit-net map
262                })
263                .sum();
264            self.state.balance_state.retained_earnings += retained;
265
266            // Advance document ID counters so each period's IDs are globally unique.
267            self.state.document_id_state.next_je_number += je_count as u64;
268            self.state.document_id_state.next_po_number +=
269                result.document_flows.purchase_orders.len() as u64;
270            self.state.document_id_state.next_so_number +=
271                result.document_flows.sales_orders.len() as u64;
272            self.state.document_id_state.next_invoice_number +=
273                (result.document_flows.vendor_invoices.len()
274                    + result.document_flows.customer_invoices.len()) as u64;
275            self.state.document_id_state.next_payment_number +=
276                result.document_flows.payments.len() as u64;
277            self.state.document_id_state.next_gr_number +=
278                result.document_flows.goods_receipts.len() as u64;
279        }
280
281        self.state.generation_log.push(PeriodLog {
282            period_label: period.label.clone(),
283            journal_entries: je_count,
284            documents: doc_count,
285            anomalies: anomaly_count,
286            duration_secs: duration,
287        });
288
289        self.state.period_cursor += 1;
290
291        Ok(Some(PeriodResult {
292            period,
293            output_path,
294            journal_entry_count: je_count,
295            document_count: doc_count,
296            anomaly_count,
297            duration_secs: duration,
298        }))
299    }
300
301    /// Generate all remaining periods in the sequence.
302    pub fn generate_all(&mut self) -> SynthResult<Vec<PeriodResult>> {
303        let mut results = Vec::new();
304        while let Some(result) = self.generate_next_period()? {
305            results.push(result);
306        }
307        Ok(results)
308    }
309
310    /// Extend the session with additional months and generate them.
311    pub fn generate_delta(&mut self, additional_months: u32) -> SynthResult<Vec<PeriodResult>> {
312        let last_end = if let Some(last_period) = self.periods.last() {
313            add_months(last_period.end_date, 1)
314        } else {
315            chrono::NaiveDate::parse_from_str(&self.config.global.start_date, "%Y-%m-%d")
316                .map_err(|e| SynthError::generation(format!("Invalid start_date: {e}")))?
317        };
318
319        let fy_months = self
320            .config
321            .global
322            .fiscal_year_months
323            .unwrap_or(self.config.global.period_months);
324        let new_periods = GenerationPeriod::compute_periods(last_end, additional_months, fy_months);
325
326        let base_index = self.periods.len();
327        let new_periods: Vec<GenerationPeriod> = new_periods
328            .into_iter()
329            .enumerate()
330            .map(|(i, mut p)| {
331                p.index = base_index + i;
332                p
333            })
334            .collect();
335
336        self.periods.extend(new_periods);
337        self.generate_all()
338    }
339
340    /// Read-only access to the session state.
341    pub fn state(&self) -> &SessionState {
342        &self.state
343    }
344
345    /// Read-only access to the period list.
346    pub fn periods(&self) -> &[GenerationPeriod] {
347        &self.periods
348    }
349
350    /// Number of periods that have not yet been generated.
351    pub fn remaining_periods(&self) -> usize {
352        self.periods.len().saturating_sub(self.state.period_cursor)
353    }
354
355    /// Compute a hash of the config for drift detection.
356    fn compute_config_hash(config: &GeneratorConfig) -> String {
357        use std::hash::{Hash, Hasher};
358        let json = serde_json::to_string(config).unwrap_or_default();
359        let mut hasher = std::collections::hash_map::DefaultHasher::new();
360        json.hash(&mut hasher);
361        format!("{:016x}", hasher.finish())
362    }
363}
364
365#[cfg(test)]
366#[allow(clippy::unwrap_used)]
367mod tests {
368    use super::*;
369
370    fn minimal_config() -> GeneratorConfig {
371        serde_yaml::from_str(
372            r#"
373global:
374  seed: 42
375  industry: retail
376  start_date: "2024-01-01"
377  period_months: 12
378companies:
379  - code: "C001"
380    name: "Test Corp"
381    currency: "USD"
382    country: "US"
383    annual_transaction_volume: ten_k
384chart_of_accounts:
385  complexity: small
386output:
387  output_directory: "./output"
388"#,
389        )
390        .expect("minimal config should parse")
391    }
392
393    #[test]
394    fn test_session_new_single_period() {
395        let config = minimal_config();
396        let session =
397            GenerationSession::new(config, PathBuf::from("/tmp/test_session_single")).unwrap();
398        assert_eq!(session.periods().len(), 1);
399        assert_eq!(session.remaining_periods(), 1);
400    }
401
402    #[test]
403    fn test_session_new_multi_period() {
404        let mut config = minimal_config();
405        config.global.period_months = 36;
406        config.global.fiscal_year_months = Some(12);
407        let session =
408            GenerationSession::new(config, PathBuf::from("/tmp/test_session_multi")).unwrap();
409        assert_eq!(session.periods().len(), 3);
410        assert_eq!(session.remaining_periods(), 3);
411    }
412
413    #[test]
414    fn test_session_save_and_resume() {
415        let config = minimal_config();
416        let session =
417            GenerationSession::new(config.clone(), PathBuf::from("/tmp/test_session_save"))
418                .unwrap();
419        let tmp = std::env::temp_dir().join("test_gen_session.dss");
420        session.save(&tmp).unwrap();
421        let resumed = GenerationSession::resume(&tmp, config).unwrap();
422        assert_eq!(resumed.state().period_cursor, 0);
423        assert_eq!(resumed.state().rng_seed, 42);
424        let _ = fs::remove_file(&tmp);
425    }
426
427    #[test]
428    fn test_session_resume_config_mismatch() {
429        let config = minimal_config();
430        let session =
431            GenerationSession::new(config.clone(), PathBuf::from("/tmp/test_session_mismatch"))
432                .unwrap();
433        let tmp = std::env::temp_dir().join("test_gen_session_mismatch.dss");
434        session.save(&tmp).unwrap();
435        let mut different = config;
436        different.global.seed = Some(999);
437        let result = GenerationSession::resume(&tmp, different);
438        assert!(result.is_err());
439        let err_msg = result.unwrap_err().to_string();
440        assert!(
441            err_msg.contains("Config has changed"),
442            "Expected config drift error, got: {}",
443            err_msg
444        );
445        let _ = fs::remove_file(&tmp);
446    }
447
448    #[test]
449    fn test_session_remaining_periods() {
450        let config = minimal_config();
451        let session =
452            GenerationSession::new(config, PathBuf::from("/tmp/test_session_remaining")).unwrap();
453        assert_eq!(session.remaining_periods(), 1);
454    }
455
456    #[test]
457    fn test_session_config_hash_deterministic() {
458        let config = minimal_config();
459        let h1 = GenerationSession::compute_config_hash(&config);
460        let h2 = GenerationSession::compute_config_hash(&config);
461        assert_eq!(h1, h2);
462    }
463
464    #[test]
465    fn test_session_config_hash_changes_on_mutation() {
466        let config = minimal_config();
467        let h1 = GenerationSession::compute_config_hash(&config);
468        let mut modified = config;
469        modified.global.seed = Some(999);
470        let h2 = GenerationSession::compute_config_hash(&modified);
471        assert_ne!(h1, h2);
472    }
473
474    #[test]
475    fn test_session_output_mode_batch_for_single_period() {
476        let config = minimal_config();
477        let session =
478            GenerationSession::new(config, PathBuf::from("/tmp/test_batch_mode")).unwrap();
479        assert!(matches!(session.output_mode, OutputMode::Batch(_)));
480    }
481
482    #[test]
483    fn test_session_output_mode_multi_for_multiple_periods() {
484        let mut config = minimal_config();
485        config.global.period_months = 24;
486        config.global.fiscal_year_months = Some(12);
487        let session =
488            GenerationSession::new(config, PathBuf::from("/tmp/test_multi_mode")).unwrap();
489        assert!(matches!(session.output_mode, OutputMode::MultiPeriod(_)));
490    }
491}