Skip to main content

datasynth_runtime/
generation_session.rs

1//! Multi-period generation session with checkpoint/resume support.
2//!
3//! [`GenerationSession`] wraps [`EnhancedOrchestrator`] and drives it through
4//! a sequence of [`GenerationPeriod`]s, persisting state to `.dss` files so
5//! that long runs can be resumed after interruption.
6
7use std::fs;
8use std::path::{Path, PathBuf};
9
10use datasynth_config::GeneratorConfig;
11use datasynth_core::models::generation_session::{
12    add_months, advance_seed, BalanceState, DocumentIdState, EntityCounts, GenerationPeriod,
13    PeriodLog, SessionState,
14};
15use datasynth_core::SynthError;
16
17use crate::enhanced_orchestrator::{EnhancedOrchestrator, PhaseConfig};
18
19type SynthResult<T> = Result<T, SynthError>;
20
21/// Controls how period output directories are laid out.
22#[derive(Debug, Clone)]
23pub enum OutputMode {
24    /// Single output directory (one period).
25    Batch(PathBuf),
26    /// One sub-directory per period under a root directory.
27    MultiPeriod(PathBuf),
28}
29
30/// Summary of a single completed period generation.
31#[derive(Debug)]
32pub struct PeriodResult {
33    /// The period that was generated.
34    pub period: GenerationPeriod,
35    /// Filesystem path where this period's output was written.
36    pub output_path: PathBuf,
37    /// Number of journal entries generated in this period.
38    pub journal_entry_count: usize,
39    /// Number of document flow records generated in this period.
40    pub document_count: usize,
41    /// Number of anomalies injected in this period.
42    pub anomaly_count: usize,
43    /// Wall-clock duration for generating this period (seconds).
44    pub duration_secs: f64,
45}
46
47/// A multi-period generation session with checkpoint/resume support.
48///
49/// The session decomposes the total requested time span into fiscal-year-aligned
50/// periods and generates each one sequentially, carrying forward balance and ID
51/// state between periods.
52#[derive(Debug)]
53pub struct GenerationSession {
54    config: GeneratorConfig,
55    state: SessionState,
56    periods: Vec<GenerationPeriod>,
57    output_mode: OutputMode,
58    phase_config: PhaseConfig,
59}
60
61impl GenerationSession {
62    /// Create a new session from a config and output path.
63    ///
64    /// The total time span is decomposed into fiscal-year-aligned periods
65    /// based on `config.global.fiscal_year_months` (defaults to `period_months`
66    /// if not set, yielding a single period).
67    pub fn new(config: GeneratorConfig, output_path: PathBuf) -> SynthResult<Self> {
68        let start_date =
69            chrono::NaiveDate::parse_from_str(&config.global.start_date, "%Y-%m-%d")
70                .map_err(|e| SynthError::generation(format!("Invalid start_date: {}", e)))?;
71
72        let total_months = config.global.period_months;
73        let fy_months = config.global.fiscal_year_months.unwrap_or(total_months);
74        let periods = GenerationPeriod::compute_periods(start_date, total_months, fy_months);
75
76        let output_mode = if periods.len() > 1 {
77            OutputMode::MultiPeriod(output_path)
78        } else {
79            OutputMode::Batch(output_path)
80        };
81
82        let seed = config.global.seed.unwrap_or(42);
83        let config_hash = Self::compute_config_hash(&config);
84
85        let state = SessionState {
86            rng_seed: seed,
87            period_cursor: 0,
88            balance_state: BalanceState::default(),
89            document_id_state: DocumentIdState::default(),
90            entity_counts: EntityCounts::default(),
91            generation_log: Vec::new(),
92            config_hash,
93        };
94
95        Ok(Self {
96            config,
97            state,
98            periods,
99            output_mode,
100            phase_config: PhaseConfig::default(),
101        })
102    }
103
104    /// Resume a session from a `.dss` checkpoint file.
105    ///
106    /// The config hash is verified against the checkpoint to ensure the config
107    /// has not changed since the session was last saved.
108    pub fn resume(path: &Path, config: GeneratorConfig) -> SynthResult<Self> {
109        let data = fs::read_to_string(path)
110            .map_err(|e| SynthError::generation(format!("Failed to read .dss: {}", e)))?;
111        let state: SessionState = serde_json::from_str(&data)
112            .map_err(|e| SynthError::generation(format!("Failed to parse .dss: {}", e)))?;
113
114        let current_hash = Self::compute_config_hash(&config);
115        if state.config_hash != current_hash {
116            return Err(SynthError::generation(
117                "Config has changed since last checkpoint. Cannot resume with different config."
118                    .to_string(),
119            ));
120        }
121
122        let start_date =
123            chrono::NaiveDate::parse_from_str(&config.global.start_date, "%Y-%m-%d")
124                .map_err(|e| SynthError::generation(format!("Invalid start_date: {}", e)))?;
125
126        let total_months = config.global.period_months;
127        let fy_months = config.global.fiscal_year_months.unwrap_or(total_months);
128        let periods = GenerationPeriod::compute_periods(start_date, total_months, fy_months);
129
130        let output_dir = path.parent().unwrap_or(Path::new(".")).to_path_buf();
131        let output_mode = if periods.len() > 1 {
132            OutputMode::MultiPeriod(output_dir)
133        } else {
134            OutputMode::Batch(output_dir)
135        };
136
137        Ok(Self {
138            config,
139            state,
140            periods,
141            output_mode,
142            phase_config: PhaseConfig::default(),
143        })
144    }
145
146    /// Persist the current session state to a `.dss` file.
147    pub fn save(&self, path: &Path) -> SynthResult<()> {
148        let data = serde_json::to_string_pretty(&self.state)
149            .map_err(|e| SynthError::generation(format!("Failed to serialize state: {}", e)))?;
150        fs::write(path, data)
151            .map_err(|e| SynthError::generation(format!("Failed to write .dss: {}", e)))?;
152        Ok(())
153    }
154
155    /// Generate the next period in the sequence.
156    ///
157    /// Returns `Ok(None)` if all periods have been generated.
158    pub fn generate_next_period(&mut self) -> SynthResult<Option<PeriodResult>> {
159        if self.state.period_cursor >= self.periods.len() {
160            return Ok(None);
161        }
162
163        let period = self.periods[self.state.period_cursor].clone();
164        let start = std::time::Instant::now();
165
166        let period_seed = advance_seed(self.state.rng_seed, period.index);
167
168        let mut period_config = self.config.clone();
169        period_config.global.start_date = period.start_date.format("%Y-%m-%d").to_string();
170        period_config.global.period_months = period.months;
171        period_config.global.seed = Some(period_seed);
172
173        let output_path = match &self.output_mode {
174            OutputMode::Batch(p) => p.clone(),
175            OutputMode::MultiPeriod(p) => p.join(&period.label),
176        };
177
178        fs::create_dir_all(&output_path)
179            .map_err(|e| SynthError::generation(format!("Failed to create output dir: {}", e)))?;
180
181        let orchestrator = EnhancedOrchestrator::new(period_config, self.phase_config.clone())?;
182        let mut orchestrator = orchestrator.with_output_path(&output_path);
183        let result = orchestrator.generate()?;
184
185        let duration = start.elapsed().as_secs_f64();
186
187        // Count journal entries from the result vec
188        let je_count = result.journal_entries.len();
189
190        // Count documents from the document_flows snapshot
191        let doc_count = result.document_flows.purchase_orders.len()
192            + result.document_flows.sales_orders.len()
193            + result.document_flows.goods_receipts.len()
194            + result.document_flows.vendor_invoices.len()
195            + result.document_flows.customer_invoices.len()
196            + result.document_flows.deliveries.len()
197            + result.document_flows.payments.len();
198
199        // Count anomalies from anomaly_labels
200        let anomaly_count = result.anomaly_labels.labels.len();
201
202        self.state.generation_log.push(PeriodLog {
203            period_label: period.label.clone(),
204            journal_entries: je_count,
205            documents: doc_count,
206            anomalies: anomaly_count,
207            duration_secs: duration,
208        });
209
210        self.state.period_cursor += 1;
211
212        Ok(Some(PeriodResult {
213            period,
214            output_path,
215            journal_entry_count: je_count,
216            document_count: doc_count,
217            anomaly_count,
218            duration_secs: duration,
219        }))
220    }
221
222    /// Generate all remaining periods in the sequence.
223    pub fn generate_all(&mut self) -> SynthResult<Vec<PeriodResult>> {
224        let mut results = Vec::new();
225        while let Some(result) = self.generate_next_period()? {
226            results.push(result);
227        }
228        Ok(results)
229    }
230
231    /// Extend the session with additional months and generate them.
232    pub fn generate_delta(&mut self, additional_months: u32) -> SynthResult<Vec<PeriodResult>> {
233        let last_end = if let Some(last_period) = self.periods.last() {
234            add_months(last_period.end_date, 1)
235        } else {
236            chrono::NaiveDate::parse_from_str(&self.config.global.start_date, "%Y-%m-%d")
237                .map_err(|e| SynthError::generation(format!("Invalid start_date: {}", e)))?
238        };
239
240        let fy_months = self
241            .config
242            .global
243            .fiscal_year_months
244            .unwrap_or(self.config.global.period_months);
245        let new_periods = GenerationPeriod::compute_periods(last_end, additional_months, fy_months);
246
247        let base_index = self.periods.len();
248        let new_periods: Vec<GenerationPeriod> = new_periods
249            .into_iter()
250            .enumerate()
251            .map(|(i, mut p)| {
252                p.index = base_index + i;
253                p
254            })
255            .collect();
256
257        self.periods.extend(new_periods);
258        self.generate_all()
259    }
260
261    /// Read-only access to the session state.
262    pub fn state(&self) -> &SessionState {
263        &self.state
264    }
265
266    /// Read-only access to the period list.
267    pub fn periods(&self) -> &[GenerationPeriod] {
268        &self.periods
269    }
270
271    /// Number of periods that have not yet been generated.
272    pub fn remaining_periods(&self) -> usize {
273        self.periods.len().saturating_sub(self.state.period_cursor)
274    }
275
276    /// Compute a hash of the config for drift detection.
277    fn compute_config_hash(config: &GeneratorConfig) -> String {
278        use std::hash::{Hash, Hasher};
279        let json = serde_json::to_string(config).unwrap_or_default();
280        let mut hasher = std::collections::hash_map::DefaultHasher::new();
281        json.hash(&mut hasher);
282        format!("{:016x}", hasher.finish())
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    fn minimal_config() -> GeneratorConfig {
291        serde_yaml::from_str(
292            r#"
293global:
294  seed: 42
295  industry: retail
296  start_date: "2024-01-01"
297  period_months: 12
298companies:
299  - code: "C001"
300    name: "Test Corp"
301    currency: "USD"
302    country: "US"
303    annual_transaction_volume: ten_k
304chart_of_accounts:
305  complexity: small
306output:
307  output_directory: "./output"
308"#,
309        )
310        .expect("minimal config should parse")
311    }
312
313    #[test]
314    fn test_session_new_single_period() {
315        let config = minimal_config();
316        let session =
317            GenerationSession::new(config, PathBuf::from("/tmp/test_session_single")).unwrap();
318        assert_eq!(session.periods().len(), 1);
319        assert_eq!(session.remaining_periods(), 1);
320    }
321
322    #[test]
323    fn test_session_new_multi_period() {
324        let mut config = minimal_config();
325        config.global.period_months = 36;
326        config.global.fiscal_year_months = Some(12);
327        let session =
328            GenerationSession::new(config, PathBuf::from("/tmp/test_session_multi")).unwrap();
329        assert_eq!(session.periods().len(), 3);
330        assert_eq!(session.remaining_periods(), 3);
331    }
332
333    #[test]
334    fn test_session_save_and_resume() {
335        let config = minimal_config();
336        let session =
337            GenerationSession::new(config.clone(), PathBuf::from("/tmp/test_session_save"))
338                .unwrap();
339        let tmp = std::env::temp_dir().join("test_gen_session.dss");
340        session.save(&tmp).unwrap();
341        let resumed = GenerationSession::resume(&tmp, config).unwrap();
342        assert_eq!(resumed.state().period_cursor, 0);
343        assert_eq!(resumed.state().rng_seed, 42);
344        let _ = fs::remove_file(&tmp);
345    }
346
347    #[test]
348    fn test_session_resume_config_mismatch() {
349        let config = minimal_config();
350        let session =
351            GenerationSession::new(config.clone(), PathBuf::from("/tmp/test_session_mismatch"))
352                .unwrap();
353        let tmp = std::env::temp_dir().join("test_gen_session_mismatch.dss");
354        session.save(&tmp).unwrap();
355        let mut different = config;
356        different.global.seed = Some(999);
357        let result = GenerationSession::resume(&tmp, different);
358        assert!(result.is_err());
359        let err_msg = result.unwrap_err().to_string();
360        assert!(
361            err_msg.contains("Config has changed"),
362            "Expected config drift error, got: {}",
363            err_msg
364        );
365        let _ = fs::remove_file(&tmp);
366    }
367
368    #[test]
369    fn test_session_remaining_periods() {
370        let config = minimal_config();
371        let session =
372            GenerationSession::new(config, PathBuf::from("/tmp/test_session_remaining")).unwrap();
373        assert_eq!(session.remaining_periods(), 1);
374    }
375
376    #[test]
377    fn test_session_config_hash_deterministic() {
378        let config = minimal_config();
379        let h1 = GenerationSession::compute_config_hash(&config);
380        let h2 = GenerationSession::compute_config_hash(&config);
381        assert_eq!(h1, h2);
382    }
383
384    #[test]
385    fn test_session_config_hash_changes_on_mutation() {
386        let config = minimal_config();
387        let h1 = GenerationSession::compute_config_hash(&config);
388        let mut modified = config;
389        modified.global.seed = Some(999);
390        let h2 = GenerationSession::compute_config_hash(&modified);
391        assert_ne!(h1, h2);
392    }
393
394    #[test]
395    fn test_session_output_mode_batch_for_single_period() {
396        let config = minimal_config();
397        let session =
398            GenerationSession::new(config, PathBuf::from("/tmp/test_batch_mode")).unwrap();
399        assert!(matches!(session.output_mode, OutputMode::Batch(_)));
400    }
401
402    #[test]
403    fn test_session_output_mode_multi_for_multiple_periods() {
404        let mut config = minimal_config();
405        config.global.period_months = 24;
406        config.global.fiscal_year_months = Some(12);
407        let session =
408            GenerationSession::new(config, PathBuf::from("/tmp/test_multi_mode")).unwrap();
409        assert!(matches!(session.output_mode, OutputMode::MultiPeriod(_)));
410    }
411}