1use super::provider::{LlmProvider, LlmRequest};
8use crate::error::SynthError;
9
10#[derive(Debug, Clone, Default)]
12pub struct ConfigIntent {
13 pub industry: Option<String>,
15 pub country: Option<String>,
17 pub company_size: Option<String>,
19 pub period_months: Option<u32>,
21 pub features: Vec<String>,
23}
24
25pub struct NlConfigGenerator;
31
32impl NlConfigGenerator {
33 pub fn generate(description: &str, provider: &dyn LlmProvider) -> Result<String, SynthError> {
43 if description.trim().is_empty() {
44 return Err(SynthError::generation(
45 "Natural language description cannot be empty",
46 ));
47 }
48
49 let intent = Self::parse_intent(description, provider)?;
50 Self::intent_to_yaml(&intent)
51 }
52
53 pub fn parse_intent(
58 description: &str,
59 provider: &dyn LlmProvider,
60 ) -> Result<ConfigIntent, SynthError> {
61 let llm_intent = Self::parse_with_llm(description, provider);
63
64 let keyword_intent = Self::parse_with_keywords(description);
66
67 match llm_intent {
69 Ok(llm) => Ok(Self::merge_intents(llm, keyword_intent)),
70 Err(e) => {
71 tracing::warn!(
72 "LLM-based config parsing failed, falling back to keyword parsing: {}",
73 e
74 );
75 Ok(keyword_intent)
76 }
77 }
78 }
79
80 pub fn generate_full(
87 description: &str,
88 provider: &dyn LlmProvider,
89 ) -> Result<String, SynthError> {
90 if description.trim().is_empty() {
91 return Err(SynthError::generation(
92 "Natural language description cannot be empty",
93 ));
94 }
95
96 let system = Self::full_schema_system_prompt();
97 let request = LlmRequest::new(description.to_string())
98 .with_system(system)
99 .with_temperature(0.2)
100 .with_max_tokens(4096);
101
102 match provider.complete(&request) {
103 Ok(response) => {
104 let yaml_text = Self::extract_yaml(&response.content);
105 if let Ok(value) = serde_yaml::from_str::<serde_yaml::Value>(&yaml_text) {
107 if let Some(map) = value.as_mapping() {
108 let known_keys = [
109 "global",
110 "companies",
111 "chart_of_accounts",
112 "transactions",
113 "output",
114 "fraud",
115 "audit_standards",
116 "banking",
117 "internal_controls",
118 "distributions",
119 "temporal_patterns",
120 "document_flows",
121 "intercompany",
122 "master_data",
123 "business_processes",
124 "hr",
125 "manufacturing",
126 "tax",
127 "treasury",
128 "esg",
129 "project_accounting",
130 "diffusion",
131 "llm",
132 "causal",
133 ];
134 let has_known = map
135 .keys()
136 .any(|k| k.as_str().map(|s| known_keys.contains(&s)).unwrap_or(false));
137 if has_known {
138 return Ok(yaml_text);
139 }
140 }
141 }
142 tracing::warn!(
144 "LLM full-config response did not contain valid DataSynth YAML; falling back to template"
145 );
146 Self::generate(description, provider)
147 }
148 Err(e) => {
149 tracing::warn!("LLM full-config generation failed: {e}; falling back to template");
150 Self::generate(description, provider)
151 }
152 }
153 }
154
155 pub fn extract_yaml(content: &str) -> String {
157 let trimmed = content.trim();
158
159 if let Some(start) = trimmed.find("```yaml") {
161 let after = &trimmed[start + 7..];
162 if let Some(end) = after.find("```") {
163 return after[..end].trim().to_string();
164 }
165 }
166
167 if let Some(start) = trimmed.find("```") {
169 let after = &trimmed[start + 3..];
170 if let Some(end) = after.find("```") {
171 return after[..end].trim().to_string();
172 }
173 }
174
175 trimmed.to_string()
177 }
178
179 pub fn full_schema_system_prompt() -> String {
183 concat!(
184 "You are a DataSynth configuration generator. Given a natural language description, ",
185 "produce a complete, valid DataSynth YAML configuration.\n\n",
186 "Top-level sections (all optional — include only what is relevant):\n\n",
187 "global:\n",
188 " industry: <retail|manufacturing|financial_services|healthcare|technology>\n",
189 " start_date: \"YYYY-MM-DD\"\n",
190 " period_months: <1-120>\n",
191 " seed: <integer>\n\n",
192 "companies:\n",
193 " - code: \"C001\"\n",
194 " name: \"...\"\n",
195 " currency: \"USD\"\n",
196 " country: \"US\"\n\n",
197 "chart_of_accounts:\n",
198 " complexity: <small|medium|large>\n\n",
199 "transactions:\n",
200 " count: <number>\n",
201 " anomaly_rate: <0.0-1.0>\n\n",
202 "output:\n",
203 " format: <csv|json|parquet>\n",
204 " compression: <true|false>\n\n",
205 "fraud:\n",
206 " enabled: true\n",
207 " types: [fictitious_transaction, duplicate_payment, split_transaction, ...]\n",
208 " injection_rate: <0.0-1.0>\n\n",
209 "internal_controls:\n",
210 " enabled: true\n",
211 " coso_enabled: true\n",
212 " target_maturity_level: <ad_hoc|repeatable|defined|managed|optimized>\n\n",
213 "distributions:\n",
214 " enabled: true\n",
215 " industry_profile: <industry>\n",
216 " amounts: { enabled: true, distribution_type: lognormal, benford_compliance: true }\n\n",
217 "temporal_patterns:\n",
218 " enabled: true\n",
219 " business_days: { enabled: true }\n",
220 " period_end: { model: exponential }\n\n",
221 "banking:\n",
222 " enabled: true\n",
223 " customer_count: <number>\n",
224 " kyc_enabled: true\n",
225 " aml_enabled: true\n\n",
226 "audit_standards:\n",
227 " enabled: true\n",
228 " isa_compliance: { enabled: true, compliance_level: standard }\n",
229 " sox: { enabled: true }\n\n",
230 "intercompany:\n",
231 " enabled: true\n\n",
232 "document_flows:\n",
233 " p2p: { enabled: true }\n",
234 " o2c: { enabled: true }\n\n",
235 "master_data:\n",
236 " vendors: { count: <number> }\n",
237 " customers: { count: <number> }\n\n",
238 "hr:\n",
239 " enabled: true\n",
240 " payroll: { enabled: true }\n\n",
241 "manufacturing:\n",
242 " enabled: true\n\n",
243 "tax:\n",
244 " enabled: true\n\n",
245 "treasury:\n",
246 " enabled: true\n\n",
247 "esg:\n",
248 " enabled: true\n\n",
249 "project_accounting:\n",
250 " enabled: true\n\n",
251 "diffusion:\n",
252 " enabled: true\n",
253 " backend: <statistical|neural|hybrid>\n\n",
254 "Return ONLY the YAML configuration (optionally inside ```yaml fences), no other text.\n"
255 ).to_string()
256 }
257
258 pub fn intent_to_yaml(intent: &ConfigIntent) -> Result<String, SynthError> {
260 let industry = intent.industry.as_deref().unwrap_or("manufacturing");
261 let country = intent.country.as_deref().unwrap_or("US");
262 let complexity = intent.company_size.as_deref().unwrap_or("medium");
263 let period_months = intent.period_months.unwrap_or(12);
264
265 if !(1..=120).contains(&period_months) {
267 return Err(SynthError::generation(format!(
268 "Period months must be between 1 and 120, got {period_months}"
269 )));
270 }
271
272 let valid_complexities = ["small", "medium", "large"];
273 if !valid_complexities.contains(&complexity) {
274 return Err(SynthError::generation(format!(
275 "Invalid company size '{complexity}', must be one of: small, medium, large"
276 )));
277 }
278
279 let currency = Self::country_to_currency(country);
280 let company_name = Self::industry_company_name(industry);
281
282 let mut yaml = String::with_capacity(2048);
283
284 yaml.push_str(&format!(
286 "global:\n industry: {industry}\n start_date: \"2024-01-01\"\n period_months: {period_months}\n seed: 42\n\n"
287 ));
288
289 yaml.push_str(&format!(
291 "companies:\n - code: \"C001\"\n name: \"{company_name}\"\n currency: \"{currency}\"\n country: \"{country}\"\n\n"
292 ));
293
294 yaml.push_str(&format!(
296 "chart_of_accounts:\n complexity: {complexity}\n\n"
297 ));
298
299 let tx_count = Self::complexity_to_tx_count(complexity);
301 yaml.push_str(&format!(
302 "transactions:\n count: {tx_count}\n anomaly_rate: 0.02\n\n"
303 ));
304
305 yaml.push_str("output:\n format: csv\n compression: false\n\n");
307
308 for feature in &intent.features {
310 match feature.as_str() {
311 "fraud" => {
312 yaml.push_str(
313 "fraud:\n enabled: true\n types:\n - fictitious_transaction\n - duplicate_payment\n - split_transaction\n injection_rate: 0.03\n\n",
314 );
315 }
316 "audit" => {
317 yaml.push_str(
318 "audit_standards:\n enabled: true\n isa_compliance:\n enabled: true\n compliance_level: standard\n framework: isa\n analytical_procedures:\n enabled: true\n procedures_per_account: 3\n confirmations:\n enabled: true\n positive_response_rate: 0.85\n sox:\n enabled: true\n materiality_threshold: 10000.0\n\n",
319 );
320 }
321 "banking" => {
322 yaml.push_str(
323 "banking:\n enabled: true\n customer_count: 100\n account_types:\n - checking\n - savings\n - loan\n kyc_enabled: true\n aml_enabled: true\n\n",
324 );
325 }
326 "controls" => {
327 yaml.push_str(
328 "internal_controls:\n enabled: true\n coso_enabled: true\n include_entity_level_controls: true\n target_maturity_level: \"managed\"\n exception_rate: 0.02\n sod_violation_rate: 0.01\n\n",
329 );
330 }
331 "process_mining" => {
332 yaml.push_str(
333 "business_processes:\n enabled: true\n ocel_export: true\n p2p:\n enabled: true\n o2c:\n enabled: true\n\n",
334 );
335 }
336 "intercompany" => {
337 yaml.push_str(
338 "intercompany:\n enabled: true\n matching_tolerance: 0.01\n elimination_enabled: true\n\n",
339 );
340 }
341 "distributions" => {
342 yaml.push_str(&format!(
343 "distributions:\n enabled: true\n industry_profile: {industry}\n amounts:\n enabled: true\n distribution_type: lognormal\n benford_compliance: true\n\n"
344 ));
345 }
346 other => {
347 tracing::warn!(
348 "Unknown NL config feature '{}' ignored. Valid features: fraud, audit, banking, controls, process_mining, intercompany, distributions",
349 other
350 );
351 }
352 }
353 }
354
355 Ok(yaml)
356 }
357
358 fn parse_with_llm(
360 description: &str,
361 provider: &dyn LlmProvider,
362 ) -> Result<ConfigIntent, SynthError> {
363 let system_prompt = "You are a configuration parser. Extract structured fields from a natural language description of desired synthetic data generation. Return ONLY a JSON object with these fields: industry (string or null), country (string or null), company_size (string or null), period_months (number or null), features (array of strings). Valid industries: retail, manufacturing, financial_services, healthcare, technology. Valid sizes: small, medium, large. Valid features: fraud, audit, banking, controls, process_mining, intercompany, distributions.";
364
365 let request = LlmRequest::new(description)
366 .with_system(system_prompt.to_string())
367 .with_temperature(0.1)
368 .with_max_tokens(512);
369
370 let response = provider.complete(&request)?;
371 Self::parse_llm_response(&response.content)
372 }
373
374 fn parse_llm_response(content: &str) -> Result<ConfigIntent, SynthError> {
376 let json_str = Self::extract_json(content)
378 .ok_or_else(|| SynthError::generation("No JSON found in LLM response"))?;
379
380 let value: serde_json::Value = serde_json::from_str(json_str)
381 .map_err(|e| SynthError::generation(format!("Failed to parse LLM JSON: {e}")))?;
382
383 let industry = value
384 .get("industry")
385 .and_then(|v| v.as_str())
386 .map(String::from);
387 let country = value
388 .get("country")
389 .and_then(|v| v.as_str())
390 .map(String::from);
391 let company_size = value
392 .get("company_size")
393 .and_then(|v| v.as_str())
394 .map(String::from);
395 let period_months = value
396 .get("period_months")
397 .and_then(serde_json::Value::as_u64)
398 .map(|v| v as u32);
399 let features = value
400 .get("features")
401 .and_then(|v| v.as_array())
402 .map(|arr| {
403 arr.iter()
404 .filter_map(|v| v.as_str().map(String::from))
405 .collect()
406 })
407 .unwrap_or_default();
408
409 Ok(ConfigIntent {
410 industry,
411 country,
412 company_size,
413 period_months,
414 features,
415 })
416 }
417
418 fn extract_json(content: &str) -> Option<&str> {
420 super::json_utils::extract_json_object(content)
421 }
422
423 fn parse_with_keywords(description: &str) -> ConfigIntent {
425 let lower = description.to_lowercase();
426
427 let industry = Self::extract_industry(&lower);
428 let country = Self::extract_country(&lower);
429 let company_size = Self::extract_size(&lower);
430 let period_months = Self::extract_period(&lower);
431 let features = Self::extract_features(&lower);
432
433 ConfigIntent {
434 industry,
435 country,
436 company_size,
437 period_months,
438 features,
439 }
440 }
441
442 fn extract_industry(text: &str) -> Option<String> {
449 let patterns: &[(&[&str], &str)] = &[
450 (
451 &["retail", "store", "shop", "e-commerce", "ecommerce"],
452 "retail",
453 ),
454 (
455 &["manufactur", "factory", "production", "assembly"],
456 "manufacturing",
457 ),
458 (
459 &[
460 "financial",
461 "finance",
462 "insurance",
463 "fintech",
464 "investment firm",
465 ],
466 "financial_services",
467 ),
468 (
469 &["health", "hospital", "medical", "pharma", "clinic"],
470 "healthcare",
471 ),
472 (
473 &["tech", "software", "saas", "startup", "digital"],
474 "technology",
475 ),
476 ];
477
478 let mut best: Option<(&str, usize)> = None;
479 for (keywords, industry) in patterns {
480 let count = keywords.iter().filter(|kw| text.contains(*kw)).count();
481 if count > 0 && (best.is_none() || count > best.expect("checked is_some").1) {
482 best = Some((industry, count));
483 }
484 }
485 best.map(|(industry, _)| industry.to_string())
486 }
487
488 fn extract_country(text: &str) -> Option<String> {
490 let name_patterns = [
494 (&["united states", "u.s.", "america"][..], "US"),
495 (&["germany", "german"][..], "DE"),
496 (&["united kingdom", "british", "england"][..], "GB"),
497 (&["china", "chinese"][..], "CN"),
498 (&["japan", "japanese"][..], "JP"),
499 (&["india", "indian"][..], "IN"),
500 (&["brazil", "brazilian"][..], "BR"),
501 (&["mexico", "mexican"][..], "MX"),
502 (&["australia", "australian"][..], "AU"),
503 (&["singapore", "singaporean"][..], "SG"),
504 (&["korea", "korean"][..], "KR"),
505 (&["france", "french"][..], "FR"),
506 (&["canada", "canadian"][..], "CA"),
507 ];
508
509 for (keywords, code) in &name_patterns {
510 if keywords.iter().any(|kw| text.contains(kw)) {
511 return Some(code.to_string());
512 }
513 }
514
515 let padded = format!(" {text} ");
519 let safe_codes = [
520 (" us ", "US"),
521 (" uk ", "GB"),
522 (" gb ", "GB"),
523 (" cn ", "CN"),
524 (" jp ", "JP"),
525 (" br ", "BR"),
526 (" mx ", "MX"),
527 (" au ", "AU"),
528 (" sg ", "SG"),
529 (" kr ", "KR"),
530 (" fr ", "FR"),
531 (" ca ", "CA"),
532 ];
533
534 for (code_pattern, code) in &safe_codes {
535 if padded.contains(code_pattern) {
536 return Some(code.to_string());
537 }
538 }
539
540 None
541 }
542
543 fn extract_size(text: &str) -> Option<String> {
545 if text.contains("small") || text.contains("startup") || text.contains("tiny") {
546 Some("small".to_string())
547 } else if text.contains("large")
548 || text.contains("enterprise")
549 || text.contains("big")
550 || text.contains("multinational")
551 || text.contains("fortune 500")
552 {
553 Some("large".to_string())
554 } else if text.contains("medium")
555 || text.contains("mid-size")
556 || text.contains("midsize")
557 || text.contains("mid size")
558 {
559 Some("medium".to_string())
560 } else {
561 None
562 }
563 }
564
565 fn extract_period(text: &str) -> Option<u32> {
567 let word_numbers = [
570 ("one", 1u32),
571 ("two", 2),
572 ("three", 3),
573 ("four", 4),
574 ("five", 5),
575 ("six", 6),
576 ("twelve", 12),
577 ("eighteen", 18),
578 ("twenty-four", 24),
579 ];
580
581 for (word, num) in &word_numbers {
583 if text.contains(&format!("{word} year")) {
584 return Some(num * 12);
585 }
586 if text.contains(&format!("{word} month")) {
587 return Some(*num);
588 }
589 }
590
591 let tokens: Vec<&str> = text.split_whitespace().collect();
593 for window in tokens.windows(2) {
594 if let Ok(num) = window[0].parse::<u32>() {
595 if window[1].starts_with("year") {
596 return Some(num * 12);
597 }
598 if window[1].starts_with("month") {
599 return Some(num);
600 }
601 }
602 }
603
604 None
605 }
606
607 fn extract_features(text: &str) -> Vec<String> {
609 let mut features = Vec::new();
610
611 let feature_patterns = [
612 (&["fraud", "fraudulent", "suspicious"][..], "fraud"),
613 (&["audit", "auditing", "assurance"][..], "audit"),
614 (&["banking", "bank account", "kyc", "aml"][..], "banking"),
615 (
616 &["control", "sox", "sod", "segregation of duties", "coso"][..],
617 "controls",
618 ),
619 (
620 &["process mining", "ocel", "event log"][..],
621 "process_mining",
622 ),
623 (
624 &["intercompany", "inter-company", "consolidation"][..],
625 "intercompany",
626 ),
627 (
628 &["distribution", "benford", "statistical"][..],
629 "distributions",
630 ),
631 ];
632
633 for (keywords, feature) in &feature_patterns {
634 if keywords.iter().any(|kw| text.contains(kw)) {
635 features.push(feature.to_string());
636 }
637 }
638
639 features
640 }
641
642 fn merge_intents(primary: ConfigIntent, fallback: ConfigIntent) -> ConfigIntent {
644 ConfigIntent {
645 industry: primary.industry.or(fallback.industry),
646 country: primary.country.or(fallback.country),
647 company_size: primary.company_size.or(fallback.company_size),
648 period_months: primary.period_months.or(fallback.period_months),
649 features: if primary.features.is_empty() {
650 fallback.features
651 } else {
652 primary.features
653 },
654 }
655 }
656
657 fn country_to_currency(country: &str) -> &'static str {
659 match country {
660 "US" | "CA" => "USD",
661 "DE" | "FR" => "EUR",
662 "GB" => "GBP",
663 "CN" => "CNY",
664 "JP" => "JPY",
665 "IN" => "INR",
666 "BR" => "BRL",
667 "MX" => "MXN",
668 "AU" => "AUD",
669 "SG" => "SGD",
670 "KR" => "KRW",
671 _ => "USD",
672 }
673 }
674
675 fn industry_company_name(industry: &str) -> &'static str {
677 match industry {
678 "retail" => "Retail Corp",
679 "manufacturing" => "Manufacturing Industries Inc",
680 "financial_services" => "Financial Services Group",
681 "healthcare" => "HealthCare Solutions",
682 "technology" => "TechCorp Solutions",
683 _ => "DataSynth Corp",
684 }
685 }
686
687 fn complexity_to_tx_count(complexity: &str) -> u32 {
689 match complexity {
690 "small" => 1000,
691 "medium" => 5000,
692 "large" => 25000,
693 _ => 5000,
694 }
695 }
696}
697
698#[cfg(test)]
699#[allow(clippy::unwrap_used)]
700mod tests {
701 use super::*;
702 use crate::llm::mock_provider::MockLlmProvider;
703
704 #[test]
705 fn test_parse_retail_description() {
706 let provider = MockLlmProvider::new(42);
707 let intent = NlConfigGenerator::parse_intent(
708 "Generate 1 year of retail data for a medium US company",
709 &provider,
710 )
711 .expect("should parse successfully");
712
713 assert_eq!(intent.industry, Some("retail".to_string()));
714 assert_eq!(intent.country, Some("US".to_string()));
715 assert_eq!(intent.company_size, Some("medium".to_string()));
716 assert_eq!(intent.period_months, Some(12));
717 }
718
719 #[test]
720 fn test_parse_manufacturing_with_fraud() {
721 let provider = MockLlmProvider::new(42);
722 let intent = NlConfigGenerator::parse_intent(
723 "Create 6 months of manufacturing data for a large German company with fraud detection",
724 &provider,
725 )
726 .expect("should parse successfully");
727
728 assert_eq!(intent.industry, Some("manufacturing".to_string()));
729 assert_eq!(intent.country, Some("DE".to_string()));
730 assert_eq!(intent.company_size, Some("large".to_string()));
731 assert_eq!(intent.period_months, Some(6));
732 assert!(intent.features.contains(&"fraud".to_string()));
733 }
734
735 #[test]
736 fn test_parse_financial_services_with_audit() {
737 let provider = MockLlmProvider::new(42);
738 let intent = NlConfigGenerator::parse_intent(
739 "I need 2 years of financial services data for audit testing with SOX controls",
740 &provider,
741 )
742 .expect("should parse successfully");
743
744 assert_eq!(intent.industry, Some("financial_services".to_string()));
745 assert_eq!(intent.period_months, Some(24));
746 assert!(intent.features.contains(&"audit".to_string()));
747 assert!(intent.features.contains(&"controls".to_string()));
748 }
749
750 #[test]
751 fn test_parse_healthcare_small() {
752 let provider = MockLlmProvider::new(42);
753 let intent = NlConfigGenerator::parse_intent(
754 "Small healthcare company in Japan, 3 months of data",
755 &provider,
756 )
757 .expect("should parse successfully");
758
759 assert_eq!(intent.industry, Some("healthcare".to_string()));
760 assert_eq!(intent.country, Some("JP".to_string()));
761 assert_eq!(intent.company_size, Some("small".to_string()));
762 assert_eq!(intent.period_months, Some(3));
763 }
764
765 #[test]
766 fn test_parse_technology_with_banking() {
767 let provider = MockLlmProvider::new(42);
768 let intent = NlConfigGenerator::parse_intent(
769 "Generate data for a technology startup in Singapore with banking and KYC",
770 &provider,
771 )
772 .expect("should parse successfully");
773
774 assert_eq!(intent.industry, Some("technology".to_string()));
775 assert_eq!(intent.country, Some("SG".to_string()));
776 assert_eq!(intent.company_size, Some("small".to_string()));
777 assert!(intent.features.contains(&"banking".to_string()));
778 }
779
780 #[test]
781 fn test_parse_word_numbers() {
782 let provider = MockLlmProvider::new(42);
783 let intent =
784 NlConfigGenerator::parse_intent("Generate two years of retail data", &provider)
785 .expect("should parse successfully");
786
787 assert_eq!(intent.period_months, Some(24));
788 }
789
790 #[test]
791 fn test_parse_multiple_features() {
792 let provider = MockLlmProvider::new(42);
793 let intent = NlConfigGenerator::parse_intent(
794 "Manufacturing data with fraud detection, audit trail, process mining, and intercompany consolidation",
795 &provider,
796 )
797 .expect("should parse successfully");
798
799 assert_eq!(intent.industry, Some("manufacturing".to_string()));
800 assert!(intent.features.contains(&"fraud".to_string()));
801 assert!(intent.features.contains(&"audit".to_string()));
802 assert!(intent.features.contains(&"process_mining".to_string()));
803 assert!(intent.features.contains(&"intercompany".to_string()));
804 }
805
806 #[test]
807 fn test_intent_to_yaml_basic() {
808 let intent = ConfigIntent {
809 industry: Some("retail".to_string()),
810 country: Some("US".to_string()),
811 company_size: Some("medium".to_string()),
812 period_months: Some(12),
813 features: vec![],
814 };
815
816 let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
817
818 assert!(yaml.contains("industry: retail"));
819 assert!(yaml.contains("period_months: 12"));
820 assert!(yaml.contains("currency: \"USD\""));
821 assert!(yaml.contains("country: \"US\""));
822 assert!(yaml.contains("complexity: medium"));
823 assert!(yaml.contains("count: 5000"));
824 }
825
826 #[test]
827 fn test_intent_to_yaml_with_features() {
828 let intent = ConfigIntent {
829 industry: Some("manufacturing".to_string()),
830 country: Some("DE".to_string()),
831 company_size: Some("large".to_string()),
832 period_months: Some(24),
833 features: vec![
834 "fraud".to_string(),
835 "audit".to_string(),
836 "controls".to_string(),
837 ],
838 };
839
840 let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
841
842 assert!(yaml.contains("industry: manufacturing"));
843 assert!(yaml.contains("currency: \"EUR\""));
844 assert!(yaml.contains("complexity: large"));
845 assert!(yaml.contains("count: 25000"));
846 assert!(yaml.contains("fraud:"));
847 assert!(yaml.contains("audit_standards:"));
848 assert!(yaml.contains("internal_controls:"));
849 }
850
851 #[test]
852 fn test_intent_to_yaml_defaults() {
853 let intent = ConfigIntent::default();
854
855 let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
856
857 assert!(yaml.contains("industry: manufacturing"));
859 assert!(yaml.contains("period_months: 12"));
860 assert!(yaml.contains("complexity: medium"));
861 }
862
863 #[test]
864 fn test_intent_to_yaml_invalid_period() {
865 let intent = ConfigIntent {
866 period_months: Some(0),
867 ..ConfigIntent::default()
868 };
869
870 let result = NlConfigGenerator::intent_to_yaml(&intent);
871 assert!(result.is_err());
872
873 let intent = ConfigIntent {
874 period_months: Some(121),
875 ..ConfigIntent::default()
876 };
877
878 let result = NlConfigGenerator::intent_to_yaml(&intent);
879 assert!(result.is_err());
880 }
881
882 #[test]
883 fn test_generate_end_to_end() {
884 let provider = MockLlmProvider::new(42);
885 let yaml = NlConfigGenerator::generate(
886 "Generate 1 year of retail data for a medium US company with fraud detection",
887 &provider,
888 )
889 .expect("should generate YAML");
890
891 assert!(yaml.contains("industry: retail"));
892 assert!(yaml.contains("period_months: 12"));
893 assert!(yaml.contains("currency: \"USD\""));
894 assert!(yaml.contains("fraud:"));
895 assert!(yaml.contains("complexity: medium"));
896 }
897
898 #[test]
899 fn test_generate_empty_description() {
900 let provider = MockLlmProvider::new(42);
901 let result = NlConfigGenerator::generate("", &provider);
902 assert!(result.is_err());
903
904 let result = NlConfigGenerator::generate(" ", &provider);
905 assert!(result.is_err());
906 }
907
908 #[test]
909 fn test_extract_json_from_response() {
910 let content = r#"Here is the parsed output: {"industry": "retail", "country": "US"} done"#;
911 let json = NlConfigGenerator::extract_json(content);
912 assert!(json.is_some());
913 assert_eq!(
914 json.expect("json should be present"),
915 r#"{"industry": "retail", "country": "US"}"#
916 );
917 }
918
919 #[test]
920 fn test_extract_json_nested() {
921 let content = r#"{"industry": "retail", "features": ["fraud", "audit"]}"#;
922 let json = NlConfigGenerator::extract_json(content);
923 assert!(json.is_some());
924 }
925
926 #[test]
927 fn test_extract_json_missing() {
928 let content = "No JSON here at all";
929 let json = NlConfigGenerator::extract_json(content);
930 assert!(json.is_none());
931 }
932
933 #[test]
934 fn test_parse_llm_response_valid() {
935 let content = r#"{"industry": "retail", "country": "US", "company_size": "medium", "period_months": 12, "features": ["fraud"]}"#;
936 let intent =
937 NlConfigGenerator::parse_llm_response(content).expect("should parse valid JSON");
938
939 assert_eq!(intent.industry, Some("retail".to_string()));
940 assert_eq!(intent.country, Some("US".to_string()));
941 assert_eq!(intent.company_size, Some("medium".to_string()));
942 assert_eq!(intent.period_months, Some(12));
943 assert_eq!(intent.features, vec!["fraud".to_string()]);
944 }
945
946 #[test]
947 fn test_parse_llm_response_partial() {
948 let content = r#"{"industry": "retail"}"#;
949 let intent =
950 NlConfigGenerator::parse_llm_response(content).expect("should parse partial JSON");
951
952 assert_eq!(intent.industry, Some("retail".to_string()));
953 assert_eq!(intent.country, None);
954 assert!(intent.features.is_empty());
955 }
956
957 #[test]
958 fn test_country_to_currency_mapping() {
959 assert_eq!(NlConfigGenerator::country_to_currency("US"), "USD");
960 assert_eq!(NlConfigGenerator::country_to_currency("DE"), "EUR");
961 assert_eq!(NlConfigGenerator::country_to_currency("GB"), "GBP");
962 assert_eq!(NlConfigGenerator::country_to_currency("JP"), "JPY");
963 assert_eq!(NlConfigGenerator::country_to_currency("CN"), "CNY");
964 assert_eq!(NlConfigGenerator::country_to_currency("BR"), "BRL");
965 assert_eq!(NlConfigGenerator::country_to_currency("XX"), "USD"); }
967
968 #[test]
969 fn test_merge_intents() {
970 let primary = ConfigIntent {
971 industry: Some("retail".to_string()),
972 country: None,
973 company_size: None,
974 period_months: Some(12),
975 features: vec![],
976 };
977 let fallback = ConfigIntent {
978 industry: Some("manufacturing".to_string()),
979 country: Some("DE".to_string()),
980 company_size: Some("large".to_string()),
981 period_months: Some(6),
982 features: vec!["fraud".to_string()],
983 };
984
985 let merged = NlConfigGenerator::merge_intents(primary, fallback);
986 assert_eq!(merged.industry, Some("retail".to_string())); assert_eq!(merged.country, Some("DE".to_string())); assert_eq!(merged.company_size, Some("large".to_string())); assert_eq!(merged.period_months, Some(12)); assert_eq!(merged.features, vec!["fraud".to_string()]); }
992
993 #[test]
994 fn test_parse_uk_country() {
995 let provider = MockLlmProvider::new(42);
996 let intent = NlConfigGenerator::parse_intent(
997 "Generate data for a UK manufacturing company",
998 &provider,
999 )
1000 .expect("should parse successfully");
1001
1002 assert_eq!(intent.country, Some("GB".to_string()));
1003 }
1004
1005 #[test]
1006 fn test_intent_to_yaml_banking_feature() {
1007 let intent = ConfigIntent {
1008 industry: Some("financial_services".to_string()),
1009 country: Some("US".to_string()),
1010 company_size: Some("large".to_string()),
1011 period_months: Some(12),
1012 features: vec!["banking".to_string()],
1013 };
1014
1015 let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
1016
1017 assert!(yaml.contains("banking:"));
1018 assert!(yaml.contains("kyc_enabled: true"));
1019 assert!(yaml.contains("aml_enabled: true"));
1020 }
1021
1022 #[test]
1023 fn test_intent_to_yaml_process_mining_feature() {
1024 let intent = ConfigIntent {
1025 features: vec!["process_mining".to_string()],
1026 ..ConfigIntent::default()
1027 };
1028
1029 let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
1030
1031 assert!(yaml.contains("business_processes:"));
1032 assert!(yaml.contains("ocel_export: true"));
1033 }
1034
1035 #[test]
1036 fn test_intent_to_yaml_distributions_feature() {
1037 let intent = ConfigIntent {
1038 industry: Some("retail".to_string()),
1039 features: vec!["distributions".to_string()],
1040 ..ConfigIntent::default()
1041 };
1042
1043 let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
1044
1045 assert!(yaml.contains("distributions:"));
1046 assert!(yaml.contains("industry_profile: retail"));
1047 assert!(yaml.contains("benford_compliance: true"));
1048 }
1049
1050 #[test]
1051 fn test_extract_yaml_from_fenced_block() {
1052 let content = "Here is the config:\n```yaml\nglobal:\n industry: retail\n```\nDone.";
1053 let yaml = NlConfigGenerator::extract_yaml(content);
1054 assert!(yaml.contains("global:"));
1055 assert!(yaml.contains("industry: retail"));
1056 assert!(!yaml.contains("```"));
1057 }
1058
1059 #[test]
1060 fn test_extract_yaml_plain_fences() {
1061 let content = "```\nglobal:\n seed: 42\n```";
1062 let yaml = NlConfigGenerator::extract_yaml(content);
1063 assert!(yaml.contains("global:"));
1064 assert!(yaml.contains("seed: 42"));
1065 assert!(!yaml.contains("```"));
1066 }
1067
1068 #[test]
1069 fn test_extract_yaml_no_fences() {
1070 let content = "global:\n industry: manufacturing\n";
1071 let yaml = NlConfigGenerator::extract_yaml(content);
1072 assert!(yaml.contains("global:"));
1073 assert!(yaml.contains("industry: manufacturing"));
1074 }
1075
1076 #[test]
1077 fn test_generate_full_falls_back_to_template() {
1078 let provider = MockLlmProvider::new(42);
1081 let yaml = NlConfigGenerator::generate_full(
1082 "Generate 1 year of retail data for a medium US company",
1083 &provider,
1084 )
1085 .expect("should fall back to template-based generation");
1086
1087 assert!(yaml.contains("industry: retail"));
1088 assert!(yaml.contains("period_months: 12"));
1089 }
1090
1091 #[test]
1092 fn test_generate_full_empty_description() {
1093 let provider = MockLlmProvider::new(42);
1094 let result = NlConfigGenerator::generate_full("", &provider);
1095 assert!(result.is_err());
1096
1097 let result = NlConfigGenerator::generate_full(" ", &provider);
1098 assert!(result.is_err());
1099 }
1100
1101 #[test]
1102 fn test_full_schema_system_prompt_covers_key_sections() {
1103 let prompt = NlConfigGenerator::full_schema_system_prompt();
1104 assert!(prompt.contains("global:"));
1105 assert!(prompt.contains("companies:"));
1106 assert!(prompt.contains("chart_of_accounts:"));
1107 assert!(prompt.contains("transactions:"));
1108 assert!(prompt.contains("fraud:"));
1109 assert!(prompt.contains("banking:"));
1110 assert!(prompt.contains("distributions:"));
1111 assert!(prompt.contains("diffusion:"));
1112 }
1113}