1use std::fs;
8use std::path::{Path, PathBuf};
9
10use datasynth_config::GeneratorConfig;
11use datasynth_core::models::generation_session::{
12 add_months, advance_seed, BalanceState, DocumentIdState, EntityCounts, GenerationPeriod,
13 PeriodLog, SessionState,
14};
15use datasynth_core::SynthError;
16
17use crate::enhanced_orchestrator::{EnhancedOrchestrator, PhaseConfig};
18
19type SynthResult<T> = Result<T, SynthError>;
20
21#[derive(Debug, Clone)]
23pub enum OutputMode {
24 Batch(PathBuf),
26 MultiPeriod(PathBuf),
28}
29
30#[derive(Debug)]
32pub struct PeriodResult {
33 pub period: GenerationPeriod,
35 pub output_path: PathBuf,
37 pub journal_entry_count: usize,
39 pub document_count: usize,
41 pub anomaly_count: usize,
43 pub duration_secs: f64,
45}
46
47#[derive(Debug)]
53pub struct GenerationSession {
54 config: GeneratorConfig,
55 state: SessionState,
56 periods: Vec<GenerationPeriod>,
57 output_mode: OutputMode,
58 phase_config: PhaseConfig,
59}
60
61impl GenerationSession {
62 pub fn new(config: GeneratorConfig, output_path: PathBuf) -> SynthResult<Self> {
68 let start_date = chrono::NaiveDate::parse_from_str(&config.global.start_date, "%Y-%m-%d")
69 .map_err(|e| SynthError::generation(format!("Invalid start_date: {e}")))?;
70
71 let total_months = config.global.period_months;
72 let fy_months = config.global.fiscal_year_months.unwrap_or(total_months);
73 let periods = GenerationPeriod::compute_periods(start_date, total_months, fy_months);
74
75 let output_mode = if periods.len() > 1 {
76 OutputMode::MultiPeriod(output_path)
77 } else {
78 OutputMode::Batch(output_path)
79 };
80
81 let seed = config.global.seed.unwrap_or(42);
82 let config_hash = Self::compute_config_hash(&config);
83
84 let state = SessionState {
85 rng_seed: seed,
86 period_cursor: 0,
87 balance_state: BalanceState::default(),
88 document_id_state: DocumentIdState::default(),
89 entity_counts: EntityCounts::default(),
90 generation_log: Vec::new(),
91 config_hash,
92 };
93
94 Ok(Self {
95 config,
96 state,
97 periods,
98 output_mode,
99 phase_config: PhaseConfig::default(),
100 })
101 }
102
103 pub fn resume(path: &Path, config: GeneratorConfig) -> SynthResult<Self> {
108 let data = fs::read_to_string(path)
109 .map_err(|e| SynthError::generation(format!("Failed to read .dss: {e}")))?;
110 let state: SessionState = serde_json::from_str(&data)
111 .map_err(|e| SynthError::generation(format!("Failed to parse .dss: {e}")))?;
112
113 let current_hash = Self::compute_config_hash(&config);
114 if state.config_hash != current_hash {
115 return Err(SynthError::generation(
116 "Config has changed since last checkpoint. Cannot resume with different config."
117 .to_string(),
118 ));
119 }
120
121 let start_date = chrono::NaiveDate::parse_from_str(&config.global.start_date, "%Y-%m-%d")
122 .map_err(|e| SynthError::generation(format!("Invalid start_date: {e}")))?;
123
124 let total_months = config.global.period_months;
125 let fy_months = config.global.fiscal_year_months.unwrap_or(total_months);
126 let periods = GenerationPeriod::compute_periods(start_date, total_months, fy_months);
127
128 let output_dir = path.parent().unwrap_or(Path::new(".")).to_path_buf();
129 let output_mode = if periods.len() > 1 {
130 OutputMode::MultiPeriod(output_dir)
131 } else {
132 OutputMode::Batch(output_dir)
133 };
134
135 Ok(Self {
136 config,
137 state,
138 periods,
139 output_mode,
140 phase_config: PhaseConfig::default(),
141 })
142 }
143
144 pub fn save(&self, path: &Path) -> SynthResult<()> {
146 let data = serde_json::to_string_pretty(&self.state)
147 .map_err(|e| SynthError::generation(format!("Failed to serialize state: {e}")))?;
148 fs::write(path, data)
149 .map_err(|e| SynthError::generation(format!("Failed to write .dss: {e}")))?;
150 Ok(())
151 }
152
153 pub fn generate_next_period(&mut self) -> SynthResult<Option<PeriodResult>> {
157 if self.state.period_cursor >= self.periods.len() {
158 return Ok(None);
159 }
160
161 let period = self.periods[self.state.period_cursor].clone();
162 let start = std::time::Instant::now();
163
164 let period_seed = advance_seed(self.state.rng_seed, period.index);
165
166 let mut period_config = self.config.clone();
167 period_config.global.start_date = period.start_date.format("%Y-%m-%d").to_string();
168 period_config.global.period_months = period.months;
169 period_config.global.seed = Some(period_seed);
170
171 let output_path = match &self.output_mode {
172 OutputMode::Batch(p) => p.clone(),
173 OutputMode::MultiPeriod(p) => p.join(&period.label),
174 };
175
176 fs::create_dir_all(&output_path)
177 .map_err(|e| SynthError::generation(format!("Failed to create output dir: {e}")))?;
178
179 let orchestrator = EnhancedOrchestrator::new(period_config, self.phase_config.clone())?;
180 let mut orchestrator = orchestrator.with_output_path(&output_path);
181 let result = orchestrator.generate()?;
182
183 let duration = start.elapsed().as_secs_f64();
184
185 let je_count = result.journal_entries.len();
187
188 let doc_count = result.document_flows.purchase_orders.len()
190 + result.document_flows.sales_orders.len()
191 + result.document_flows.goods_receipts.len()
192 + result.document_flows.vendor_invoices.len()
193 + result.document_flows.customer_invoices.len()
194 + result.document_flows.deliveries.len()
195 + result.document_flows.payments.len();
196
197 let anomaly_count = result.anomaly_labels.labels.len();
199
200 {
205 use std::collections::HashMap;
206
207 let mut gl_net: HashMap<String, f64> = HashMap::new();
209 for je in &result.journal_entries {
210 for line in &je.lines {
211 let account = line.gl_account.clone();
212 let delta = f64::try_from(line.debit_amount).unwrap_or(0.0)
213 - f64::try_from(line.credit_amount).unwrap_or(0.0);
214 *gl_net.entry(account).or_insert(0.0) += delta;
215 }
216 }
217
218 for (account, delta) in gl_net {
221 *self
222 .state
223 .balance_state
224 .gl_balances
225 .entry(account)
226 .or_insert(0.0) += delta;
227 }
228
229 self.state.balance_state.ar_total = self
234 .state
235 .balance_state
236 .gl_balances
237 .get("1100")
238 .copied()
239 .unwrap_or(0.0)
240 .max(0.0);
241 self.state.balance_state.ap_total = (-self
242 .state
243 .balance_state
244 .gl_balances
245 .get("2000")
246 .copied()
247 .unwrap_or(0.0))
248 .max(0.0);
249
250 let retained: f64 = self
253 .state
254 .balance_state
255 .gl_balances
256 .iter()
257 .filter_map(|(acct, &bal)| {
258 acct.parse::<u32>()
259 .ok()
260 .filter(|&n| (4000..=8999).contains(&n))
261 .map(|_| -bal) })
263 .sum();
264 self.state.balance_state.retained_earnings += retained;
265
266 self.state.document_id_state.next_je_number += je_count as u64;
268 self.state.document_id_state.next_po_number +=
269 result.document_flows.purchase_orders.len() as u64;
270 self.state.document_id_state.next_so_number +=
271 result.document_flows.sales_orders.len() as u64;
272 self.state.document_id_state.next_invoice_number +=
273 (result.document_flows.vendor_invoices.len()
274 + result.document_flows.customer_invoices.len()) as u64;
275 self.state.document_id_state.next_payment_number +=
276 result.document_flows.payments.len() as u64;
277 self.state.document_id_state.next_gr_number +=
278 result.document_flows.goods_receipts.len() as u64;
279 }
280
281 self.state.generation_log.push(PeriodLog {
282 period_label: period.label.clone(),
283 journal_entries: je_count,
284 documents: doc_count,
285 anomalies: anomaly_count,
286 duration_secs: duration,
287 });
288
289 self.state.period_cursor += 1;
290
291 Ok(Some(PeriodResult {
292 period,
293 output_path,
294 journal_entry_count: je_count,
295 document_count: doc_count,
296 anomaly_count,
297 duration_secs: duration,
298 }))
299 }
300
301 pub fn generate_all(&mut self) -> SynthResult<Vec<PeriodResult>> {
303 let mut results = Vec::new();
304 while let Some(result) = self.generate_next_period()? {
305 results.push(result);
306 }
307 Ok(results)
308 }
309
310 pub fn generate_delta(&mut self, additional_months: u32) -> SynthResult<Vec<PeriodResult>> {
312 let last_end = if let Some(last_period) = self.periods.last() {
313 add_months(last_period.end_date, 1)
314 } else {
315 chrono::NaiveDate::parse_from_str(&self.config.global.start_date, "%Y-%m-%d")
316 .map_err(|e| SynthError::generation(format!("Invalid start_date: {e}")))?
317 };
318
319 let fy_months = self
320 .config
321 .global
322 .fiscal_year_months
323 .unwrap_or(self.config.global.period_months);
324 let new_periods = GenerationPeriod::compute_periods(last_end, additional_months, fy_months);
325
326 let base_index = self.periods.len();
327 let new_periods: Vec<GenerationPeriod> = new_periods
328 .into_iter()
329 .enumerate()
330 .map(|(i, mut p)| {
331 p.index = base_index + i;
332 p
333 })
334 .collect();
335
336 self.periods.extend(new_periods);
337 self.generate_all()
338 }
339
340 pub fn state(&self) -> &SessionState {
342 &self.state
343 }
344
345 pub fn periods(&self) -> &[GenerationPeriod] {
347 &self.periods
348 }
349
350 pub fn remaining_periods(&self) -> usize {
352 self.periods.len().saturating_sub(self.state.period_cursor)
353 }
354
355 fn compute_config_hash(config: &GeneratorConfig) -> String {
357 use std::hash::{Hash, Hasher};
358 let json = serde_json::to_string(config).unwrap_or_default();
359 let mut hasher = std::collections::hash_map::DefaultHasher::new();
360 json.hash(&mut hasher);
361 format!("{:016x}", hasher.finish())
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 fn minimal_config() -> GeneratorConfig {
370 serde_yaml::from_str(
371 r#"
372global:
373 seed: 42
374 industry: retail
375 start_date: "2024-01-01"
376 period_months: 12
377companies:
378 - code: "C001"
379 name: "Test Corp"
380 currency: "USD"
381 country: "US"
382 annual_transaction_volume: ten_k
383chart_of_accounts:
384 complexity: small
385output:
386 output_directory: "./output"
387"#,
388 )
389 .expect("minimal config should parse")
390 }
391
392 #[test]
393 fn test_session_new_single_period() {
394 let config = minimal_config();
395 let session =
396 GenerationSession::new(config, PathBuf::from("/tmp/test_session_single")).unwrap();
397 assert_eq!(session.periods().len(), 1);
398 assert_eq!(session.remaining_periods(), 1);
399 }
400
401 #[test]
402 fn test_session_new_multi_period() {
403 let mut config = minimal_config();
404 config.global.period_months = 36;
405 config.global.fiscal_year_months = Some(12);
406 let session =
407 GenerationSession::new(config, PathBuf::from("/tmp/test_session_multi")).unwrap();
408 assert_eq!(session.periods().len(), 3);
409 assert_eq!(session.remaining_periods(), 3);
410 }
411
412 #[test]
413 fn test_session_save_and_resume() {
414 let config = minimal_config();
415 let session =
416 GenerationSession::new(config.clone(), PathBuf::from("/tmp/test_session_save"))
417 .unwrap();
418 let tmp = std::env::temp_dir().join("test_gen_session.dss");
419 session.save(&tmp).unwrap();
420 let resumed = GenerationSession::resume(&tmp, config).unwrap();
421 assert_eq!(resumed.state().period_cursor, 0);
422 assert_eq!(resumed.state().rng_seed, 42);
423 let _ = fs::remove_file(&tmp);
424 }
425
426 #[test]
427 fn test_session_resume_config_mismatch() {
428 let config = minimal_config();
429 let session =
430 GenerationSession::new(config.clone(), PathBuf::from("/tmp/test_session_mismatch"))
431 .unwrap();
432 let tmp = std::env::temp_dir().join("test_gen_session_mismatch.dss");
433 session.save(&tmp).unwrap();
434 let mut different = config;
435 different.global.seed = Some(999);
436 let result = GenerationSession::resume(&tmp, different);
437 assert!(result.is_err());
438 let err_msg = result.unwrap_err().to_string();
439 assert!(
440 err_msg.contains("Config has changed"),
441 "Expected config drift error, got: {}",
442 err_msg
443 );
444 let _ = fs::remove_file(&tmp);
445 }
446
447 #[test]
448 fn test_session_remaining_periods() {
449 let config = minimal_config();
450 let session =
451 GenerationSession::new(config, PathBuf::from("/tmp/test_session_remaining")).unwrap();
452 assert_eq!(session.remaining_periods(), 1);
453 }
454
455 #[test]
456 fn test_session_config_hash_deterministic() {
457 let config = minimal_config();
458 let h1 = GenerationSession::compute_config_hash(&config);
459 let h2 = GenerationSession::compute_config_hash(&config);
460 assert_eq!(h1, h2);
461 }
462
463 #[test]
464 fn test_session_config_hash_changes_on_mutation() {
465 let config = minimal_config();
466 let h1 = GenerationSession::compute_config_hash(&config);
467 let mut modified = config;
468 modified.global.seed = Some(999);
469 let h2 = GenerationSession::compute_config_hash(&modified);
470 assert_ne!(h1, h2);
471 }
472
473 #[test]
474 fn test_session_output_mode_batch_for_single_period() {
475 let config = minimal_config();
476 let session =
477 GenerationSession::new(config, PathBuf::from("/tmp/test_batch_mode")).unwrap();
478 assert!(matches!(session.output_mode, OutputMode::Batch(_)));
479 }
480
481 #[test]
482 fn test_session_output_mode_multi_for_multiple_periods() {
483 let mut config = minimal_config();
484 config.global.period_months = 24;
485 config.global.fiscal_year_months = Some(12);
486 let session =
487 GenerationSession::new(config, PathBuf::from("/tmp/test_multi_mode")).unwrap();
488 assert!(matches!(session.output_mode, OutputMode::MultiPeriod(_)));
489 }
490}