Skip to main content

datasynth_runtime/
generation_session.rs

1//! Multi-period generation session with checkpoint/resume support.
2//!
3//! [`GenerationSession`] wraps [`EnhancedOrchestrator`] and drives it through
4//! a sequence of [`GenerationPeriod`]s, persisting state to `.dss` files so
5//! that long runs can be resumed after interruption.
6
7use std::fs;
8use std::path::{Path, PathBuf};
9
10use datasynth_config::GeneratorConfig;
11use datasynth_core::models::generation_session::{
12    add_months, advance_seed, BalanceState, DocumentIdState, EntityCounts, GenerationPeriod,
13    PeriodLog, SessionState,
14};
15use datasynth_core::SynthError;
16
17use crate::enhanced_orchestrator::{EnhancedOrchestrator, PhaseConfig};
18
19type SynthResult<T> = Result<T, SynthError>;
20
21/// Controls how period output directories are laid out.
22#[derive(Debug, Clone)]
23pub enum OutputMode {
24    /// Single output directory (one period).
25    Batch(PathBuf),
26    /// One sub-directory per period under a root directory.
27    MultiPeriod(PathBuf),
28}
29
30/// Summary of a single completed period generation.
31#[derive(Debug)]
32pub struct PeriodResult {
33    /// The period that was generated.
34    pub period: GenerationPeriod,
35    /// Filesystem path where this period's output was written.
36    pub output_path: PathBuf,
37    /// Number of journal entries generated in this period.
38    pub journal_entry_count: usize,
39    /// Number of document flow records generated in this period.
40    pub document_count: usize,
41    /// Number of anomalies injected in this period.
42    pub anomaly_count: usize,
43    /// Wall-clock duration for generating this period (seconds).
44    pub duration_secs: f64,
45}
46
47/// A multi-period generation session with checkpoint/resume support.
48///
49/// The session decomposes the total requested time span into fiscal-year-aligned
50/// periods and generates each one sequentially, carrying forward balance and ID
51/// state between periods.
52#[derive(Debug)]
53pub struct GenerationSession {
54    config: GeneratorConfig,
55    state: SessionState,
56    periods: Vec<GenerationPeriod>,
57    output_mode: OutputMode,
58    phase_config: PhaseConfig,
59}
60
61impl GenerationSession {
62    /// Create a new session from a config and output path.
63    ///
64    /// The total time span is decomposed into fiscal-year-aligned periods
65    /// based on `config.global.fiscal_year_months` (defaults to `period_months`
66    /// if not set, yielding a single period).
67    pub fn new(config: GeneratorConfig, output_path: PathBuf) -> SynthResult<Self> {
68        let start_date = chrono::NaiveDate::parse_from_str(&config.global.start_date, "%Y-%m-%d")
69            .map_err(|e| SynthError::generation(format!("Invalid start_date: {e}")))?;
70
71        let total_months = config.global.period_months;
72        let fy_months = config.global.fiscal_year_months.unwrap_or(total_months);
73        let periods = GenerationPeriod::compute_periods(start_date, total_months, fy_months);
74
75        let output_mode = if periods.len() > 1 {
76            OutputMode::MultiPeriod(output_path)
77        } else {
78            OutputMode::Batch(output_path)
79        };
80
81        let seed = config.global.seed.unwrap_or(42);
82        let config_hash = Self::compute_config_hash(&config);
83
84        let state = SessionState {
85            rng_seed: seed,
86            period_cursor: 0,
87            balance_state: BalanceState::default(),
88            document_id_state: DocumentIdState::default(),
89            entity_counts: EntityCounts::default(),
90            generation_log: Vec::new(),
91            config_hash,
92        };
93
94        Ok(Self {
95            config,
96            state,
97            periods,
98            output_mode,
99            phase_config: PhaseConfig::default(),
100        })
101    }
102
103    /// Resume a session from a `.dss` checkpoint file.
104    ///
105    /// The config hash is verified against the checkpoint to ensure the config
106    /// has not changed since the session was last saved.
107    pub fn resume(path: &Path, config: GeneratorConfig) -> SynthResult<Self> {
108        let data = fs::read_to_string(path)
109            .map_err(|e| SynthError::generation(format!("Failed to read .dss: {e}")))?;
110        let state: SessionState = serde_json::from_str(&data)
111            .map_err(|e| SynthError::generation(format!("Failed to parse .dss: {e}")))?;
112
113        let current_hash = Self::compute_config_hash(&config);
114        if state.config_hash != current_hash {
115            return Err(SynthError::generation(
116                "Config has changed since last checkpoint. Cannot resume with different config."
117                    .to_string(),
118            ));
119        }
120
121        let start_date = chrono::NaiveDate::parse_from_str(&config.global.start_date, "%Y-%m-%d")
122            .map_err(|e| SynthError::generation(format!("Invalid start_date: {e}")))?;
123
124        let total_months = config.global.period_months;
125        let fy_months = config.global.fiscal_year_months.unwrap_or(total_months);
126        let periods = GenerationPeriod::compute_periods(start_date, total_months, fy_months);
127
128        let output_dir = path.parent().unwrap_or(Path::new(".")).to_path_buf();
129        let output_mode = if periods.len() > 1 {
130            OutputMode::MultiPeriod(output_dir)
131        } else {
132            OutputMode::Batch(output_dir)
133        };
134
135        Ok(Self {
136            config,
137            state,
138            periods,
139            output_mode,
140            phase_config: PhaseConfig::default(),
141        })
142    }
143
144    /// Persist the current session state to a `.dss` file.
145    pub fn save(&self, path: &Path) -> SynthResult<()> {
146        let data = serde_json::to_string_pretty(&self.state)
147            .map_err(|e| SynthError::generation(format!("Failed to serialize state: {e}")))?;
148        fs::write(path, data)
149            .map_err(|e| SynthError::generation(format!("Failed to write .dss: {e}")))?;
150        Ok(())
151    }
152
153    /// Generate the next period in the sequence.
154    ///
155    /// Returns `Ok(None)` if all periods have been generated.
156    pub fn generate_next_period(&mut self) -> SynthResult<Option<PeriodResult>> {
157        if self.state.period_cursor >= self.periods.len() {
158            return Ok(None);
159        }
160
161        let period = self.periods[self.state.period_cursor].clone();
162        let start = std::time::Instant::now();
163
164        let period_seed = advance_seed(self.state.rng_seed, period.index);
165
166        let mut period_config = self.config.clone();
167        period_config.global.start_date = period.start_date.format("%Y-%m-%d").to_string();
168        period_config.global.period_months = period.months;
169        period_config.global.seed = Some(period_seed);
170
171        let output_path = match &self.output_mode {
172            OutputMode::Batch(p) => p.clone(),
173            OutputMode::MultiPeriod(p) => p.join(&period.label),
174        };
175
176        fs::create_dir_all(&output_path)
177            .map_err(|e| SynthError::generation(format!("Failed to create output dir: {e}")))?;
178
179        let orchestrator = EnhancedOrchestrator::new(period_config, self.phase_config.clone())?;
180        let mut orchestrator = orchestrator.with_output_path(&output_path);
181        let result = orchestrator.generate()?;
182
183        let duration = start.elapsed().as_secs_f64();
184
185        // Count journal entries from the result vec
186        let je_count = result.journal_entries.len();
187
188        // Count documents from the document_flows snapshot
189        let doc_count = result.document_flows.purchase_orders.len()
190            + result.document_flows.sales_orders.len()
191            + result.document_flows.goods_receipts.len()
192            + result.document_flows.vendor_invoices.len()
193            + result.document_flows.customer_invoices.len()
194            + result.document_flows.deliveries.len()
195            + result.document_flows.payments.len();
196
197        // Count anomalies from anomaly_labels
198        let anomaly_count = result.anomaly_labels.labels.len();
199
200        self.state.generation_log.push(PeriodLog {
201            period_label: period.label.clone(),
202            journal_entries: je_count,
203            documents: doc_count,
204            anomalies: anomaly_count,
205            duration_secs: duration,
206        });
207
208        self.state.period_cursor += 1;
209
210        Ok(Some(PeriodResult {
211            period,
212            output_path,
213            journal_entry_count: je_count,
214            document_count: doc_count,
215            anomaly_count,
216            duration_secs: duration,
217        }))
218    }
219
220    /// Generate all remaining periods in the sequence.
221    pub fn generate_all(&mut self) -> SynthResult<Vec<PeriodResult>> {
222        let mut results = Vec::new();
223        while let Some(result) = self.generate_next_period()? {
224            results.push(result);
225        }
226        Ok(results)
227    }
228
229    /// Extend the session with additional months and generate them.
230    pub fn generate_delta(&mut self, additional_months: u32) -> SynthResult<Vec<PeriodResult>> {
231        let last_end = if let Some(last_period) = self.periods.last() {
232            add_months(last_period.end_date, 1)
233        } else {
234            chrono::NaiveDate::parse_from_str(&self.config.global.start_date, "%Y-%m-%d")
235                .map_err(|e| SynthError::generation(format!("Invalid start_date: {e}")))?
236        };
237
238        let fy_months = self
239            .config
240            .global
241            .fiscal_year_months
242            .unwrap_or(self.config.global.period_months);
243        let new_periods = GenerationPeriod::compute_periods(last_end, additional_months, fy_months);
244
245        let base_index = self.periods.len();
246        let new_periods: Vec<GenerationPeriod> = new_periods
247            .into_iter()
248            .enumerate()
249            .map(|(i, mut p)| {
250                p.index = base_index + i;
251                p
252            })
253            .collect();
254
255        self.periods.extend(new_periods);
256        self.generate_all()
257    }
258
259    /// Read-only access to the session state.
260    pub fn state(&self) -> &SessionState {
261        &self.state
262    }
263
264    /// Read-only access to the period list.
265    pub fn periods(&self) -> &[GenerationPeriod] {
266        &self.periods
267    }
268
269    /// Number of periods that have not yet been generated.
270    pub fn remaining_periods(&self) -> usize {
271        self.periods.len().saturating_sub(self.state.period_cursor)
272    }
273
274    /// Compute a hash of the config for drift detection.
275    fn compute_config_hash(config: &GeneratorConfig) -> String {
276        use std::hash::{Hash, Hasher};
277        let json = serde_json::to_string(config).unwrap_or_default();
278        let mut hasher = std::collections::hash_map::DefaultHasher::new();
279        json.hash(&mut hasher);
280        format!("{:016x}", hasher.finish())
281    }
282}
283
284#[cfg(test)]
285#[allow(clippy::unwrap_used)]
286mod tests {
287    use super::*;
288
289    fn minimal_config() -> GeneratorConfig {
290        serde_yaml::from_str(
291            r#"
292global:
293  seed: 42
294  industry: retail
295  start_date: "2024-01-01"
296  period_months: 12
297companies:
298  - code: "C001"
299    name: "Test Corp"
300    currency: "USD"
301    country: "US"
302    annual_transaction_volume: ten_k
303chart_of_accounts:
304  complexity: small
305output:
306  output_directory: "./output"
307"#,
308        )
309        .expect("minimal config should parse")
310    }
311
312    #[test]
313    fn test_session_new_single_period() {
314        let config = minimal_config();
315        let session =
316            GenerationSession::new(config, PathBuf::from("/tmp/test_session_single")).unwrap();
317        assert_eq!(session.periods().len(), 1);
318        assert_eq!(session.remaining_periods(), 1);
319    }
320
321    #[test]
322    fn test_session_new_multi_period() {
323        let mut config = minimal_config();
324        config.global.period_months = 36;
325        config.global.fiscal_year_months = Some(12);
326        let session =
327            GenerationSession::new(config, PathBuf::from("/tmp/test_session_multi")).unwrap();
328        assert_eq!(session.periods().len(), 3);
329        assert_eq!(session.remaining_periods(), 3);
330    }
331
332    #[test]
333    fn test_session_save_and_resume() {
334        let config = minimal_config();
335        let session =
336            GenerationSession::new(config.clone(), PathBuf::from("/tmp/test_session_save"))
337                .unwrap();
338        let tmp = std::env::temp_dir().join("test_gen_session.dss");
339        session.save(&tmp).unwrap();
340        let resumed = GenerationSession::resume(&tmp, config).unwrap();
341        assert_eq!(resumed.state().period_cursor, 0);
342        assert_eq!(resumed.state().rng_seed, 42);
343        let _ = fs::remove_file(&tmp);
344    }
345
346    #[test]
347    fn test_session_resume_config_mismatch() {
348        let config = minimal_config();
349        let session =
350            GenerationSession::new(config.clone(), PathBuf::from("/tmp/test_session_mismatch"))
351                .unwrap();
352        let tmp = std::env::temp_dir().join("test_gen_session_mismatch.dss");
353        session.save(&tmp).unwrap();
354        let mut different = config;
355        different.global.seed = Some(999);
356        let result = GenerationSession::resume(&tmp, different);
357        assert!(result.is_err());
358        let err_msg = result.unwrap_err().to_string();
359        assert!(
360            err_msg.contains("Config has changed"),
361            "Expected config drift error, got: {}",
362            err_msg
363        );
364        let _ = fs::remove_file(&tmp);
365    }
366
367    #[test]
368    fn test_session_remaining_periods() {
369        let config = minimal_config();
370        let session =
371            GenerationSession::new(config, PathBuf::from("/tmp/test_session_remaining")).unwrap();
372        assert_eq!(session.remaining_periods(), 1);
373    }
374
375    #[test]
376    fn test_session_config_hash_deterministic() {
377        let config = minimal_config();
378        let h1 = GenerationSession::compute_config_hash(&config);
379        let h2 = GenerationSession::compute_config_hash(&config);
380        assert_eq!(h1, h2);
381    }
382
383    #[test]
384    fn test_session_config_hash_changes_on_mutation() {
385        let config = minimal_config();
386        let h1 = GenerationSession::compute_config_hash(&config);
387        let mut modified = config;
388        modified.global.seed = Some(999);
389        let h2 = GenerationSession::compute_config_hash(&modified);
390        assert_ne!(h1, h2);
391    }
392
393    #[test]
394    fn test_session_output_mode_batch_for_single_period() {
395        let config = minimal_config();
396        let session =
397            GenerationSession::new(config, PathBuf::from("/tmp/test_batch_mode")).unwrap();
398        assert!(matches!(session.output_mode, OutputMode::Batch(_)));
399    }
400
401    #[test]
402    fn test_session_output_mode_multi_for_multiple_periods() {
403        let mut config = minimal_config();
404        config.global.period_months = 24;
405        config.global.fiscal_year_months = Some(12);
406        let session =
407            GenerationSession::new(config, PathBuf::from("/tmp/test_multi_mode")).unwrap();
408        assert!(matches!(session.output_mode, OutputMode::MultiPeriod(_)));
409    }
410}