1use std::fs;
8use std::path::{Path, PathBuf};
9
10use datasynth_config::GeneratorConfig;
11use datasynth_core::models::generation_session::{
12 add_months, advance_seed, BalanceState, DocumentIdState, EntityCounts, GenerationPeriod,
13 PeriodLog, SessionState,
14};
15use datasynth_core::SynthError;
16
17use crate::enhanced_orchestrator::{EnhancedOrchestrator, PhaseConfig};
18
19type SynthResult<T> = Result<T, SynthError>;
20
21#[derive(Debug, Clone)]
23pub enum OutputMode {
24 Batch(PathBuf),
26 MultiPeriod(PathBuf),
28}
29
30#[derive(Debug)]
32pub struct PeriodResult {
33 pub period: GenerationPeriod,
35 pub output_path: PathBuf,
37 pub journal_entry_count: usize,
39 pub document_count: usize,
41 pub anomaly_count: usize,
43 pub duration_secs: f64,
45}
46
47#[derive(Debug)]
53pub struct GenerationSession {
54 config: GeneratorConfig,
55 state: SessionState,
56 periods: Vec<GenerationPeriod>,
57 output_mode: OutputMode,
58 phase_config: PhaseConfig,
59}
60
61impl GenerationSession {
62 pub fn new(config: GeneratorConfig, output_path: PathBuf) -> SynthResult<Self> {
68 let start_date = chrono::NaiveDate::parse_from_str(&config.global.start_date, "%Y-%m-%d")
69 .map_err(|e| SynthError::generation(format!("Invalid start_date: {e}")))?;
70
71 let total_months = config.global.period_months;
72 let fy_months = config.global.fiscal_year_months.unwrap_or(total_months);
73 let periods = GenerationPeriod::compute_periods(start_date, total_months, fy_months);
74
75 let output_mode = if periods.len() > 1 {
76 OutputMode::MultiPeriod(output_path)
77 } else {
78 OutputMode::Batch(output_path)
79 };
80
81 let seed = config.global.seed.unwrap_or(42);
82 let config_hash = Self::compute_config_hash(&config);
83
84 let state = SessionState {
85 rng_seed: seed,
86 period_cursor: 0,
87 balance_state: BalanceState::default(),
88 document_id_state: DocumentIdState::default(),
89 entity_counts: EntityCounts::default(),
90 generation_log: Vec::new(),
91 config_hash,
92 };
93
94 Ok(Self {
95 config,
96 state,
97 periods,
98 output_mode,
99 phase_config: PhaseConfig::default(),
100 })
101 }
102
103 pub fn resume(path: &Path, config: GeneratorConfig) -> SynthResult<Self> {
108 let data = fs::read_to_string(path)
109 .map_err(|e| SynthError::generation(format!("Failed to read .dss: {e}")))?;
110 let state: SessionState = serde_json::from_str(&data)
111 .map_err(|e| SynthError::generation(format!("Failed to parse .dss: {e}")))?;
112
113 let current_hash = Self::compute_config_hash(&config);
114 if state.config_hash != current_hash {
115 return Err(SynthError::generation(
116 "Config has changed since last checkpoint. Cannot resume with different config."
117 .to_string(),
118 ));
119 }
120
121 let start_date = chrono::NaiveDate::parse_from_str(&config.global.start_date, "%Y-%m-%d")
122 .map_err(|e| SynthError::generation(format!("Invalid start_date: {e}")))?;
123
124 let total_months = config.global.period_months;
125 let fy_months = config.global.fiscal_year_months.unwrap_or(total_months);
126 let periods = GenerationPeriod::compute_periods(start_date, total_months, fy_months);
127
128 let output_dir = path.parent().unwrap_or(Path::new(".")).to_path_buf();
129 let output_mode = if periods.len() > 1 {
130 OutputMode::MultiPeriod(output_dir)
131 } else {
132 OutputMode::Batch(output_dir)
133 };
134
135 Ok(Self {
136 config,
137 state,
138 periods,
139 output_mode,
140 phase_config: PhaseConfig::default(),
141 })
142 }
143
144 pub fn save(&self, path: &Path) -> SynthResult<()> {
146 let data = serde_json::to_string_pretty(&self.state)
147 .map_err(|e| SynthError::generation(format!("Failed to serialize state: {e}")))?;
148 fs::write(path, data)
149 .map_err(|e| SynthError::generation(format!("Failed to write .dss: {e}")))?;
150 Ok(())
151 }
152
153 pub fn generate_next_period(&mut self) -> SynthResult<Option<PeriodResult>> {
157 if self.state.period_cursor >= self.periods.len() {
158 return Ok(None);
159 }
160
161 let period = self.periods[self.state.period_cursor].clone();
162 let start = std::time::Instant::now();
163
164 let period_seed = advance_seed(self.state.rng_seed, period.index);
165
166 let mut period_config = self.config.clone();
167 period_config.global.start_date = period.start_date.format("%Y-%m-%d").to_string();
168 period_config.global.period_months = period.months;
169 period_config.global.seed = Some(period_seed);
170
171 let output_path = match &self.output_mode {
172 OutputMode::Batch(p) => p.clone(),
173 OutputMode::MultiPeriod(p) => p.join(&period.label),
174 };
175
176 fs::create_dir_all(&output_path)
177 .map_err(|e| SynthError::generation(format!("Failed to create output dir: {e}")))?;
178
179 let orchestrator = EnhancedOrchestrator::new(period_config, self.phase_config.clone())?;
180 let mut orchestrator = orchestrator.with_output_path(&output_path);
181 let result = orchestrator.generate()?;
182
183 let duration = start.elapsed().as_secs_f64();
184
185 let je_count = result.journal_entries.len();
187
188 let doc_count = result.document_flows.purchase_orders.len()
190 + result.document_flows.sales_orders.len()
191 + result.document_flows.goods_receipts.len()
192 + result.document_flows.vendor_invoices.len()
193 + result.document_flows.customer_invoices.len()
194 + result.document_flows.deliveries.len()
195 + result.document_flows.payments.len();
196
197 let anomaly_count = result.anomaly_labels.labels.len();
199
200 {
205 use std::collections::HashMap;
206
207 let mut gl_net: HashMap<String, f64> = HashMap::new();
209 for je in &result.journal_entries {
210 for line in &je.lines {
211 let account = line.gl_account.clone();
212 let delta = f64::try_from(line.debit_amount).unwrap_or(0.0)
213 - f64::try_from(line.credit_amount).unwrap_or(0.0);
214 *gl_net.entry(account).or_insert(0.0) += delta;
215 }
216 }
217
218 for (account, delta) in gl_net {
221 *self
222 .state
223 .balance_state
224 .gl_balances
225 .entry(account)
226 .or_insert(0.0) += delta;
227 }
228
229 self.state.balance_state.ar_total = self
234 .state
235 .balance_state
236 .gl_balances
237 .get("1100")
238 .copied()
239 .unwrap_or(0.0)
240 .max(0.0);
241 self.state.balance_state.ap_total = (-self
242 .state
243 .balance_state
244 .gl_balances
245 .get("2000")
246 .copied()
247 .unwrap_or(0.0))
248 .max(0.0);
249
250 let retained: f64 = self
253 .state
254 .balance_state
255 .gl_balances
256 .iter()
257 .filter_map(|(acct, &bal)| {
258 acct.parse::<u32>()
259 .ok()
260 .filter(|&n| (4000..=8999).contains(&n))
261 .map(|_| -bal) })
263 .sum();
264 self.state.balance_state.retained_earnings += retained;
265
266 self.state.document_id_state.next_je_number += je_count as u64;
268 self.state.document_id_state.next_po_number +=
269 result.document_flows.purchase_orders.len() as u64;
270 self.state.document_id_state.next_so_number +=
271 result.document_flows.sales_orders.len() as u64;
272 self.state.document_id_state.next_invoice_number +=
273 (result.document_flows.vendor_invoices.len()
274 + result.document_flows.customer_invoices.len()) as u64;
275 self.state.document_id_state.next_payment_number +=
276 result.document_flows.payments.len() as u64;
277 self.state.document_id_state.next_gr_number +=
278 result.document_flows.goods_receipts.len() as u64;
279 }
280
281 self.state.generation_log.push(PeriodLog {
282 period_label: period.label.clone(),
283 journal_entries: je_count,
284 documents: doc_count,
285 anomalies: anomaly_count,
286 duration_secs: duration,
287 });
288
289 self.state.period_cursor += 1;
290
291 Ok(Some(PeriodResult {
292 period,
293 output_path,
294 journal_entry_count: je_count,
295 document_count: doc_count,
296 anomaly_count,
297 duration_secs: duration,
298 }))
299 }
300
301 pub fn generate_all(&mut self) -> SynthResult<Vec<PeriodResult>> {
303 let mut results = Vec::new();
304 while let Some(result) = self.generate_next_period()? {
305 results.push(result);
306 }
307 Ok(results)
308 }
309
310 pub fn generate_delta(&mut self, additional_months: u32) -> SynthResult<Vec<PeriodResult>> {
312 let last_end = if let Some(last_period) = self.periods.last() {
313 add_months(last_period.end_date, 1)
314 } else {
315 chrono::NaiveDate::parse_from_str(&self.config.global.start_date, "%Y-%m-%d")
316 .map_err(|e| SynthError::generation(format!("Invalid start_date: {e}")))?
317 };
318
319 let fy_months = self
320 .config
321 .global
322 .fiscal_year_months
323 .unwrap_or(self.config.global.period_months);
324 let new_periods = GenerationPeriod::compute_periods(last_end, additional_months, fy_months);
325
326 let base_index = self.periods.len();
327 let new_periods: Vec<GenerationPeriod> = new_periods
328 .into_iter()
329 .enumerate()
330 .map(|(i, mut p)| {
331 p.index = base_index + i;
332 p
333 })
334 .collect();
335
336 self.periods.extend(new_periods);
337 self.generate_all()
338 }
339
340 pub fn state(&self) -> &SessionState {
342 &self.state
343 }
344
345 pub fn periods(&self) -> &[GenerationPeriod] {
347 &self.periods
348 }
349
350 pub fn remaining_periods(&self) -> usize {
352 self.periods.len().saturating_sub(self.state.period_cursor)
353 }
354
355 fn compute_config_hash(config: &GeneratorConfig) -> String {
357 use std::hash::{Hash, Hasher};
358 let json = serde_json::to_string(config).unwrap_or_default();
359 let mut hasher = std::collections::hash_map::DefaultHasher::new();
360 json.hash(&mut hasher);
361 format!("{:016x}", hasher.finish())
362 }
363}
364
365#[cfg(test)]
366#[allow(clippy::unwrap_used)]
367mod tests {
368 use super::*;
369
370 fn minimal_config() -> GeneratorConfig {
371 serde_yaml::from_str(
372 r#"
373global:
374 seed: 42
375 industry: retail
376 start_date: "2024-01-01"
377 period_months: 12
378companies:
379 - code: "C001"
380 name: "Test Corp"
381 currency: "USD"
382 country: "US"
383 annual_transaction_volume: ten_k
384chart_of_accounts:
385 complexity: small
386output:
387 output_directory: "./output"
388"#,
389 )
390 .expect("minimal config should parse")
391 }
392
393 #[test]
394 fn test_session_new_single_period() {
395 let config = minimal_config();
396 let session =
397 GenerationSession::new(config, PathBuf::from("/tmp/test_session_single")).unwrap();
398 assert_eq!(session.periods().len(), 1);
399 assert_eq!(session.remaining_periods(), 1);
400 }
401
402 #[test]
403 fn test_session_new_multi_period() {
404 let mut config = minimal_config();
405 config.global.period_months = 36;
406 config.global.fiscal_year_months = Some(12);
407 let session =
408 GenerationSession::new(config, PathBuf::from("/tmp/test_session_multi")).unwrap();
409 assert_eq!(session.periods().len(), 3);
410 assert_eq!(session.remaining_periods(), 3);
411 }
412
413 #[test]
414 fn test_session_save_and_resume() {
415 let config = minimal_config();
416 let session =
417 GenerationSession::new(config.clone(), PathBuf::from("/tmp/test_session_save"))
418 .unwrap();
419 let tmp = std::env::temp_dir().join("test_gen_session.dss");
420 session.save(&tmp).unwrap();
421 let resumed = GenerationSession::resume(&tmp, config).unwrap();
422 assert_eq!(resumed.state().period_cursor, 0);
423 assert_eq!(resumed.state().rng_seed, 42);
424 let _ = fs::remove_file(&tmp);
425 }
426
427 #[test]
428 fn test_session_resume_config_mismatch() {
429 let config = minimal_config();
430 let session =
431 GenerationSession::new(config.clone(), PathBuf::from("/tmp/test_session_mismatch"))
432 .unwrap();
433 let tmp = std::env::temp_dir().join("test_gen_session_mismatch.dss");
434 session.save(&tmp).unwrap();
435 let mut different = config;
436 different.global.seed = Some(999);
437 let result = GenerationSession::resume(&tmp, different);
438 assert!(result.is_err());
439 let err_msg = result.unwrap_err().to_string();
440 assert!(
441 err_msg.contains("Config has changed"),
442 "Expected config drift error, got: {}",
443 err_msg
444 );
445 let _ = fs::remove_file(&tmp);
446 }
447
448 #[test]
449 fn test_session_remaining_periods() {
450 let config = minimal_config();
451 let session =
452 GenerationSession::new(config, PathBuf::from("/tmp/test_session_remaining")).unwrap();
453 assert_eq!(session.remaining_periods(), 1);
454 }
455
456 #[test]
457 fn test_session_config_hash_deterministic() {
458 let config = minimal_config();
459 let h1 = GenerationSession::compute_config_hash(&config);
460 let h2 = GenerationSession::compute_config_hash(&config);
461 assert_eq!(h1, h2);
462 }
463
464 #[test]
465 fn test_session_config_hash_changes_on_mutation() {
466 let config = minimal_config();
467 let h1 = GenerationSession::compute_config_hash(&config);
468 let mut modified = config;
469 modified.global.seed = Some(999);
470 let h2 = GenerationSession::compute_config_hash(&modified);
471 assert_ne!(h1, h2);
472 }
473
474 #[test]
475 fn test_session_output_mode_batch_for_single_period() {
476 let config = minimal_config();
477 let session =
478 GenerationSession::new(config, PathBuf::from("/tmp/test_batch_mode")).unwrap();
479 assert!(matches!(session.output_mode, OutputMode::Batch(_)));
480 }
481
482 #[test]
483 fn test_session_output_mode_multi_for_multiple_periods() {
484 let mut config = minimal_config();
485 config.global.period_months = 24;
486 config.global.fiscal_year_months = Some(12);
487 let session =
488 GenerationSession::new(config, PathBuf::from("/tmp/test_multi_mode")).unwrap();
489 assert!(matches!(session.output_mode, OutputMode::MultiPeriod(_)));
490 }
491}