1use std::fs;
8use std::path::{Path, PathBuf};
9
10use datasynth_config::GeneratorConfig;
11use datasynth_core::models::generation_session::{
12 add_months, advance_seed, BalanceState, DocumentIdState, EntityCounts, GenerationPeriod,
13 PeriodLog, SessionState,
14};
15use datasynth_core::SynthError;
16
17use crate::enhanced_orchestrator::{EnhancedOrchestrator, PhaseConfig};
18
19type SynthResult<T> = Result<T, SynthError>;
20
21#[derive(Debug, Clone)]
23pub enum OutputMode {
24 Batch(PathBuf),
26 MultiPeriod(PathBuf),
28}
29
30#[derive(Debug)]
32pub struct PeriodResult {
33 pub period: GenerationPeriod,
35 pub output_path: PathBuf,
37 pub journal_entry_count: usize,
39 pub document_count: usize,
41 pub anomaly_count: usize,
43 pub duration_secs: f64,
45}
46
47#[derive(Debug)]
53pub struct GenerationSession {
54 config: GeneratorConfig,
55 state: SessionState,
56 periods: Vec<GenerationPeriod>,
57 output_mode: OutputMode,
58 phase_config: PhaseConfig,
59}
60
61impl GenerationSession {
62 pub fn new(config: GeneratorConfig, output_path: PathBuf) -> SynthResult<Self> {
68 let start_date =
69 chrono::NaiveDate::parse_from_str(&config.global.start_date, "%Y-%m-%d")
70 .map_err(|e| SynthError::generation(format!("Invalid start_date: {}", e)))?;
71
72 let total_months = config.global.period_months;
73 let fy_months = config.global.fiscal_year_months.unwrap_or(total_months);
74 let periods = GenerationPeriod::compute_periods(start_date, total_months, fy_months);
75
76 let output_mode = if periods.len() > 1 {
77 OutputMode::MultiPeriod(output_path)
78 } else {
79 OutputMode::Batch(output_path)
80 };
81
82 let seed = config.global.seed.unwrap_or(42);
83 let config_hash = Self::compute_config_hash(&config);
84
85 let state = SessionState {
86 rng_seed: seed,
87 period_cursor: 0,
88 balance_state: BalanceState::default(),
89 document_id_state: DocumentIdState::default(),
90 entity_counts: EntityCounts::default(),
91 generation_log: Vec::new(),
92 config_hash,
93 };
94
95 Ok(Self {
96 config,
97 state,
98 periods,
99 output_mode,
100 phase_config: PhaseConfig::default(),
101 })
102 }
103
104 pub fn resume(path: &Path, config: GeneratorConfig) -> SynthResult<Self> {
109 let data = fs::read_to_string(path)
110 .map_err(|e| SynthError::generation(format!("Failed to read .dss: {}", e)))?;
111 let state: SessionState = serde_json::from_str(&data)
112 .map_err(|e| SynthError::generation(format!("Failed to parse .dss: {}", e)))?;
113
114 let current_hash = Self::compute_config_hash(&config);
115 if state.config_hash != current_hash {
116 return Err(SynthError::generation(
117 "Config has changed since last checkpoint. Cannot resume with different config."
118 .to_string(),
119 ));
120 }
121
122 let start_date =
123 chrono::NaiveDate::parse_from_str(&config.global.start_date, "%Y-%m-%d")
124 .map_err(|e| SynthError::generation(format!("Invalid start_date: {}", e)))?;
125
126 let total_months = config.global.period_months;
127 let fy_months = config.global.fiscal_year_months.unwrap_or(total_months);
128 let periods = GenerationPeriod::compute_periods(start_date, total_months, fy_months);
129
130 let output_dir = path.parent().unwrap_or(Path::new(".")).to_path_buf();
131 let output_mode = if periods.len() > 1 {
132 OutputMode::MultiPeriod(output_dir)
133 } else {
134 OutputMode::Batch(output_dir)
135 };
136
137 Ok(Self {
138 config,
139 state,
140 periods,
141 output_mode,
142 phase_config: PhaseConfig::default(),
143 })
144 }
145
146 pub fn save(&self, path: &Path) -> SynthResult<()> {
148 let data = serde_json::to_string_pretty(&self.state)
149 .map_err(|e| SynthError::generation(format!("Failed to serialize state: {}", e)))?;
150 fs::write(path, data)
151 .map_err(|e| SynthError::generation(format!("Failed to write .dss: {}", e)))?;
152 Ok(())
153 }
154
155 pub fn generate_next_period(&mut self) -> SynthResult<Option<PeriodResult>> {
159 if self.state.period_cursor >= self.periods.len() {
160 return Ok(None);
161 }
162
163 let period = self.periods[self.state.period_cursor].clone();
164 let start = std::time::Instant::now();
165
166 let period_seed = advance_seed(self.state.rng_seed, period.index);
167
168 let mut period_config = self.config.clone();
169 period_config.global.start_date = period.start_date.format("%Y-%m-%d").to_string();
170 period_config.global.period_months = period.months;
171 period_config.global.seed = Some(period_seed);
172
173 let output_path = match &self.output_mode {
174 OutputMode::Batch(p) => p.clone(),
175 OutputMode::MultiPeriod(p) => p.join(&period.label),
176 };
177
178 fs::create_dir_all(&output_path)
179 .map_err(|e| SynthError::generation(format!("Failed to create output dir: {}", e)))?;
180
181 let orchestrator = EnhancedOrchestrator::new(period_config, self.phase_config.clone())?;
182 let mut orchestrator = orchestrator.with_output_path(&output_path);
183 let result = orchestrator.generate()?;
184
185 let duration = start.elapsed().as_secs_f64();
186
187 let je_count = result.journal_entries.len();
189
190 let doc_count = result.document_flows.purchase_orders.len()
192 + result.document_flows.sales_orders.len()
193 + result.document_flows.goods_receipts.len()
194 + result.document_flows.vendor_invoices.len()
195 + result.document_flows.customer_invoices.len()
196 + result.document_flows.deliveries.len()
197 + result.document_flows.payments.len();
198
199 let anomaly_count = result.anomaly_labels.labels.len();
201
202 self.state.generation_log.push(PeriodLog {
203 period_label: period.label.clone(),
204 journal_entries: je_count,
205 documents: doc_count,
206 anomalies: anomaly_count,
207 duration_secs: duration,
208 });
209
210 self.state.period_cursor += 1;
211
212 Ok(Some(PeriodResult {
213 period,
214 output_path,
215 journal_entry_count: je_count,
216 document_count: doc_count,
217 anomaly_count,
218 duration_secs: duration,
219 }))
220 }
221
222 pub fn generate_all(&mut self) -> SynthResult<Vec<PeriodResult>> {
224 let mut results = Vec::new();
225 while let Some(result) = self.generate_next_period()? {
226 results.push(result);
227 }
228 Ok(results)
229 }
230
231 pub fn generate_delta(&mut self, additional_months: u32) -> SynthResult<Vec<PeriodResult>> {
233 let last_end = if let Some(last_period) = self.periods.last() {
234 add_months(last_period.end_date, 1)
235 } else {
236 chrono::NaiveDate::parse_from_str(&self.config.global.start_date, "%Y-%m-%d")
237 .map_err(|e| SynthError::generation(format!("Invalid start_date: {}", e)))?
238 };
239
240 let fy_months = self
241 .config
242 .global
243 .fiscal_year_months
244 .unwrap_or(self.config.global.period_months);
245 let new_periods = GenerationPeriod::compute_periods(last_end, additional_months, fy_months);
246
247 let base_index = self.periods.len();
248 let new_periods: Vec<GenerationPeriod> = new_periods
249 .into_iter()
250 .enumerate()
251 .map(|(i, mut p)| {
252 p.index = base_index + i;
253 p
254 })
255 .collect();
256
257 self.periods.extend(new_periods);
258 self.generate_all()
259 }
260
261 pub fn state(&self) -> &SessionState {
263 &self.state
264 }
265
266 pub fn periods(&self) -> &[GenerationPeriod] {
268 &self.periods
269 }
270
271 pub fn remaining_periods(&self) -> usize {
273 self.periods.len().saturating_sub(self.state.period_cursor)
274 }
275
276 fn compute_config_hash(config: &GeneratorConfig) -> String {
278 use std::hash::{Hash, Hasher};
279 let json = serde_json::to_string(config).unwrap_or_default();
280 let mut hasher = std::collections::hash_map::DefaultHasher::new();
281 json.hash(&mut hasher);
282 format!("{:016x}", hasher.finish())
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 fn minimal_config() -> GeneratorConfig {
291 serde_yaml::from_str(
292 r#"
293global:
294 seed: 42
295 industry: retail
296 start_date: "2024-01-01"
297 period_months: 12
298companies:
299 - code: "C001"
300 name: "Test Corp"
301 currency: "USD"
302 country: "US"
303 annual_transaction_volume: ten_k
304chart_of_accounts:
305 complexity: small
306output:
307 output_directory: "./output"
308"#,
309 )
310 .expect("minimal config should parse")
311 }
312
313 #[test]
314 fn test_session_new_single_period() {
315 let config = minimal_config();
316 let session =
317 GenerationSession::new(config, PathBuf::from("/tmp/test_session_single")).unwrap();
318 assert_eq!(session.periods().len(), 1);
319 assert_eq!(session.remaining_periods(), 1);
320 }
321
322 #[test]
323 fn test_session_new_multi_period() {
324 let mut config = minimal_config();
325 config.global.period_months = 36;
326 config.global.fiscal_year_months = Some(12);
327 let session =
328 GenerationSession::new(config, PathBuf::from("/tmp/test_session_multi")).unwrap();
329 assert_eq!(session.periods().len(), 3);
330 assert_eq!(session.remaining_periods(), 3);
331 }
332
333 #[test]
334 fn test_session_save_and_resume() {
335 let config = minimal_config();
336 let session =
337 GenerationSession::new(config.clone(), PathBuf::from("/tmp/test_session_save"))
338 .unwrap();
339 let tmp = std::env::temp_dir().join("test_gen_session.dss");
340 session.save(&tmp).unwrap();
341 let resumed = GenerationSession::resume(&tmp, config).unwrap();
342 assert_eq!(resumed.state().period_cursor, 0);
343 assert_eq!(resumed.state().rng_seed, 42);
344 let _ = fs::remove_file(&tmp);
345 }
346
347 #[test]
348 fn test_session_resume_config_mismatch() {
349 let config = minimal_config();
350 let session =
351 GenerationSession::new(config.clone(), PathBuf::from("/tmp/test_session_mismatch"))
352 .unwrap();
353 let tmp = std::env::temp_dir().join("test_gen_session_mismatch.dss");
354 session.save(&tmp).unwrap();
355 let mut different = config;
356 different.global.seed = Some(999);
357 let result = GenerationSession::resume(&tmp, different);
358 assert!(result.is_err());
359 let err_msg = result.unwrap_err().to_string();
360 assert!(
361 err_msg.contains("Config has changed"),
362 "Expected config drift error, got: {}",
363 err_msg
364 );
365 let _ = fs::remove_file(&tmp);
366 }
367
368 #[test]
369 fn test_session_remaining_periods() {
370 let config = minimal_config();
371 let session =
372 GenerationSession::new(config, PathBuf::from("/tmp/test_session_remaining")).unwrap();
373 assert_eq!(session.remaining_periods(), 1);
374 }
375
376 #[test]
377 fn test_session_config_hash_deterministic() {
378 let config = minimal_config();
379 let h1 = GenerationSession::compute_config_hash(&config);
380 let h2 = GenerationSession::compute_config_hash(&config);
381 assert_eq!(h1, h2);
382 }
383
384 #[test]
385 fn test_session_config_hash_changes_on_mutation() {
386 let config = minimal_config();
387 let h1 = GenerationSession::compute_config_hash(&config);
388 let mut modified = config;
389 modified.global.seed = Some(999);
390 let h2 = GenerationSession::compute_config_hash(&modified);
391 assert_ne!(h1, h2);
392 }
393
394 #[test]
395 fn test_session_output_mode_batch_for_single_period() {
396 let config = minimal_config();
397 let session =
398 GenerationSession::new(config, PathBuf::from("/tmp/test_batch_mode")).unwrap();
399 assert!(matches!(session.output_mode, OutputMode::Batch(_)));
400 }
401
402 #[test]
403 fn test_session_output_mode_multi_for_multiple_periods() {
404 let mut config = minimal_config();
405 config.global.period_months = 24;
406 config.global.fiscal_year_months = Some(12);
407 let session =
408 GenerationSession::new(config, PathBuf::from("/tmp/test_multi_mode")).unwrap();
409 assert!(matches!(session.output_mode, OutputMode::MultiPeriod(_)));
410 }
411}