1use std::fs;
8use std::path::{Path, PathBuf};
9
10use datasynth_config::GeneratorConfig;
11use datasynth_core::models::generation_session::{
12 add_months, advance_seed, BalanceState, DocumentIdState, EntityCounts, GenerationPeriod,
13 PeriodLog, SessionState,
14};
15use datasynth_core::SynthError;
16
17use crate::enhanced_orchestrator::{EnhancedOrchestrator, PhaseConfig};
18
19type SynthResult<T> = Result<T, SynthError>;
20
21#[derive(Debug, Clone)]
23pub enum OutputMode {
24 Batch(PathBuf),
26 MultiPeriod(PathBuf),
28}
29
30#[derive(Debug)]
32pub struct PeriodResult {
33 pub period: GenerationPeriod,
35 pub output_path: PathBuf,
37 pub journal_entry_count: usize,
39 pub document_count: usize,
41 pub anomaly_count: usize,
43 pub duration_secs: f64,
45}
46
47#[derive(Debug)]
53pub struct GenerationSession {
54 config: GeneratorConfig,
55 state: SessionState,
56 periods: Vec<GenerationPeriod>,
57 output_mode: OutputMode,
58 phase_config: PhaseConfig,
59}
60
61impl GenerationSession {
62 pub fn new(config: GeneratorConfig, output_path: PathBuf) -> SynthResult<Self> {
68 let start_date = chrono::NaiveDate::parse_from_str(&config.global.start_date, "%Y-%m-%d")
69 .map_err(|e| SynthError::generation(format!("Invalid start_date: {e}")))?;
70
71 let total_months = config.global.period_months;
72 let fy_months = config.global.fiscal_year_months.unwrap_or(total_months);
73 let periods = GenerationPeriod::compute_periods(start_date, total_months, fy_months);
74
75 let output_mode = if periods.len() > 1 {
76 OutputMode::MultiPeriod(output_path)
77 } else {
78 OutputMode::Batch(output_path)
79 };
80
81 let seed = config.global.seed.unwrap_or(42);
82 let config_hash = Self::compute_config_hash(&config);
83
84 let state = SessionState {
85 rng_seed: seed,
86 period_cursor: 0,
87 balance_state: BalanceState::default(),
88 document_id_state: DocumentIdState::default(),
89 entity_counts: EntityCounts::default(),
90 generation_log: Vec::new(),
91 config_hash,
92 };
93
94 Ok(Self {
95 config,
96 state,
97 periods,
98 output_mode,
99 phase_config: PhaseConfig::default(),
100 })
101 }
102
103 pub fn resume(path: &Path, config: GeneratorConfig) -> SynthResult<Self> {
108 let data = fs::read_to_string(path)
109 .map_err(|e| SynthError::generation(format!("Failed to read .dss: {e}")))?;
110 let state: SessionState = serde_json::from_str(&data)
111 .map_err(|e| SynthError::generation(format!("Failed to parse .dss: {e}")))?;
112
113 let current_hash = Self::compute_config_hash(&config);
114 if state.config_hash != current_hash {
115 return Err(SynthError::generation(
116 "Config has changed since last checkpoint. Cannot resume with different config."
117 .to_string(),
118 ));
119 }
120
121 let start_date = chrono::NaiveDate::parse_from_str(&config.global.start_date, "%Y-%m-%d")
122 .map_err(|e| SynthError::generation(format!("Invalid start_date: {e}")))?;
123
124 let total_months = config.global.period_months;
125 let fy_months = config.global.fiscal_year_months.unwrap_or(total_months);
126 let periods = GenerationPeriod::compute_periods(start_date, total_months, fy_months);
127
128 let output_dir = path.parent().unwrap_or(Path::new(".")).to_path_buf();
129 let output_mode = if periods.len() > 1 {
130 OutputMode::MultiPeriod(output_dir)
131 } else {
132 OutputMode::Batch(output_dir)
133 };
134
135 Ok(Self {
136 config,
137 state,
138 periods,
139 output_mode,
140 phase_config: PhaseConfig::default(),
141 })
142 }
143
144 pub fn save(&self, path: &Path) -> SynthResult<()> {
146 let data = serde_json::to_string_pretty(&self.state)
147 .map_err(|e| SynthError::generation(format!("Failed to serialize state: {e}")))?;
148 fs::write(path, data)
149 .map_err(|e| SynthError::generation(format!("Failed to write .dss: {e}")))?;
150 Ok(())
151 }
152
153 pub fn generate_next_period(&mut self) -> SynthResult<Option<PeriodResult>> {
157 if self.state.period_cursor >= self.periods.len() {
158 return Ok(None);
159 }
160
161 let period = self.periods[self.state.period_cursor].clone();
162 let start = std::time::Instant::now();
163
164 let period_seed = advance_seed(self.state.rng_seed, period.index);
165
166 let mut period_config = self.config.clone();
167 period_config.global.start_date = period.start_date.format("%Y-%m-%d").to_string();
168 period_config.global.period_months = period.months;
169 period_config.global.seed = Some(period_seed);
170
171 let output_path = match &self.output_mode {
172 OutputMode::Batch(p) => p.clone(),
173 OutputMode::MultiPeriod(p) => p.join(&period.label),
174 };
175
176 fs::create_dir_all(&output_path)
177 .map_err(|e| SynthError::generation(format!("Failed to create output dir: {e}")))?;
178
179 let orchestrator = EnhancedOrchestrator::new(period_config, self.phase_config.clone())?;
180 let mut orchestrator = orchestrator.with_output_path(&output_path);
181 let result = orchestrator.generate()?;
182
183 let duration = start.elapsed().as_secs_f64();
184
185 let je_count = result.journal_entries.len();
187
188 let doc_count = result.document_flows.purchase_orders.len()
190 + result.document_flows.sales_orders.len()
191 + result.document_flows.goods_receipts.len()
192 + result.document_flows.vendor_invoices.len()
193 + result.document_flows.customer_invoices.len()
194 + result.document_flows.deliveries.len()
195 + result.document_flows.payments.len();
196
197 let anomaly_count = result.anomaly_labels.labels.len();
199
200 self.state.generation_log.push(PeriodLog {
201 period_label: period.label.clone(),
202 journal_entries: je_count,
203 documents: doc_count,
204 anomalies: anomaly_count,
205 duration_secs: duration,
206 });
207
208 self.state.period_cursor += 1;
209
210 Ok(Some(PeriodResult {
211 period,
212 output_path,
213 journal_entry_count: je_count,
214 document_count: doc_count,
215 anomaly_count,
216 duration_secs: duration,
217 }))
218 }
219
220 pub fn generate_all(&mut self) -> SynthResult<Vec<PeriodResult>> {
222 let mut results = Vec::new();
223 while let Some(result) = self.generate_next_period()? {
224 results.push(result);
225 }
226 Ok(results)
227 }
228
229 pub fn generate_delta(&mut self, additional_months: u32) -> SynthResult<Vec<PeriodResult>> {
231 let last_end = if let Some(last_period) = self.periods.last() {
232 add_months(last_period.end_date, 1)
233 } else {
234 chrono::NaiveDate::parse_from_str(&self.config.global.start_date, "%Y-%m-%d")
235 .map_err(|e| SynthError::generation(format!("Invalid start_date: {e}")))?
236 };
237
238 let fy_months = self
239 .config
240 .global
241 .fiscal_year_months
242 .unwrap_or(self.config.global.period_months);
243 let new_periods = GenerationPeriod::compute_periods(last_end, additional_months, fy_months);
244
245 let base_index = self.periods.len();
246 let new_periods: Vec<GenerationPeriod> = new_periods
247 .into_iter()
248 .enumerate()
249 .map(|(i, mut p)| {
250 p.index = base_index + i;
251 p
252 })
253 .collect();
254
255 self.periods.extend(new_periods);
256 self.generate_all()
257 }
258
259 pub fn state(&self) -> &SessionState {
261 &self.state
262 }
263
264 pub fn periods(&self) -> &[GenerationPeriod] {
266 &self.periods
267 }
268
269 pub fn remaining_periods(&self) -> usize {
271 self.periods.len().saturating_sub(self.state.period_cursor)
272 }
273
274 fn compute_config_hash(config: &GeneratorConfig) -> String {
276 use std::hash::{Hash, Hasher};
277 let json = serde_json::to_string(config).unwrap_or_default();
278 let mut hasher = std::collections::hash_map::DefaultHasher::new();
279 json.hash(&mut hasher);
280 format!("{:016x}", hasher.finish())
281 }
282}
283
284#[cfg(test)]
285#[allow(clippy::unwrap_used)]
286mod tests {
287 use super::*;
288
289 fn minimal_config() -> GeneratorConfig {
290 serde_yaml::from_str(
291 r#"
292global:
293 seed: 42
294 industry: retail
295 start_date: "2024-01-01"
296 period_months: 12
297companies:
298 - code: "C001"
299 name: "Test Corp"
300 currency: "USD"
301 country: "US"
302 annual_transaction_volume: ten_k
303chart_of_accounts:
304 complexity: small
305output:
306 output_directory: "./output"
307"#,
308 )
309 .expect("minimal config should parse")
310 }
311
312 #[test]
313 fn test_session_new_single_period() {
314 let config = minimal_config();
315 let session =
316 GenerationSession::new(config, PathBuf::from("/tmp/test_session_single")).unwrap();
317 assert_eq!(session.periods().len(), 1);
318 assert_eq!(session.remaining_periods(), 1);
319 }
320
321 #[test]
322 fn test_session_new_multi_period() {
323 let mut config = minimal_config();
324 config.global.period_months = 36;
325 config.global.fiscal_year_months = Some(12);
326 let session =
327 GenerationSession::new(config, PathBuf::from("/tmp/test_session_multi")).unwrap();
328 assert_eq!(session.periods().len(), 3);
329 assert_eq!(session.remaining_periods(), 3);
330 }
331
332 #[test]
333 fn test_session_save_and_resume() {
334 let config = minimal_config();
335 let session =
336 GenerationSession::new(config.clone(), PathBuf::from("/tmp/test_session_save"))
337 .unwrap();
338 let tmp = std::env::temp_dir().join("test_gen_session.dss");
339 session.save(&tmp).unwrap();
340 let resumed = GenerationSession::resume(&tmp, config).unwrap();
341 assert_eq!(resumed.state().period_cursor, 0);
342 assert_eq!(resumed.state().rng_seed, 42);
343 let _ = fs::remove_file(&tmp);
344 }
345
346 #[test]
347 fn test_session_resume_config_mismatch() {
348 let config = minimal_config();
349 let session =
350 GenerationSession::new(config.clone(), PathBuf::from("/tmp/test_session_mismatch"))
351 .unwrap();
352 let tmp = std::env::temp_dir().join("test_gen_session_mismatch.dss");
353 session.save(&tmp).unwrap();
354 let mut different = config;
355 different.global.seed = Some(999);
356 let result = GenerationSession::resume(&tmp, different);
357 assert!(result.is_err());
358 let err_msg = result.unwrap_err().to_string();
359 assert!(
360 err_msg.contains("Config has changed"),
361 "Expected config drift error, got: {}",
362 err_msg
363 );
364 let _ = fs::remove_file(&tmp);
365 }
366
367 #[test]
368 fn test_session_remaining_periods() {
369 let config = minimal_config();
370 let session =
371 GenerationSession::new(config, PathBuf::from("/tmp/test_session_remaining")).unwrap();
372 assert_eq!(session.remaining_periods(), 1);
373 }
374
375 #[test]
376 fn test_session_config_hash_deterministic() {
377 let config = minimal_config();
378 let h1 = GenerationSession::compute_config_hash(&config);
379 let h2 = GenerationSession::compute_config_hash(&config);
380 assert_eq!(h1, h2);
381 }
382
383 #[test]
384 fn test_session_config_hash_changes_on_mutation() {
385 let config = minimal_config();
386 let h1 = GenerationSession::compute_config_hash(&config);
387 let mut modified = config;
388 modified.global.seed = Some(999);
389 let h2 = GenerationSession::compute_config_hash(&modified);
390 assert_ne!(h1, h2);
391 }
392
393 #[test]
394 fn test_session_output_mode_batch_for_single_period() {
395 let config = minimal_config();
396 let session =
397 GenerationSession::new(config, PathBuf::from("/tmp/test_batch_mode")).unwrap();
398 assert!(matches!(session.output_mode, OutputMode::Batch(_)));
399 }
400
401 #[test]
402 fn test_session_output_mode_multi_for_multiple_periods() {
403 let mut config = minimal_config();
404 config.global.period_months = 24;
405 config.global.fiscal_year_months = Some(12);
406 let session =
407 GenerationSession::new(config, PathBuf::from("/tmp/test_multi_mode")).unwrap();
408 assert!(matches!(session.output_mode, OutputMode::MultiPeriod(_)));
409 }
410}