1use super::provider::{LlmProvider, LlmRequest};
8use crate::error::SynthError;
9
10#[derive(Debug, Clone, Default)]
12pub struct ConfigIntent {
13 pub industry: Option<String>,
15 pub country: Option<String>,
17 pub company_size: Option<String>,
19 pub period_months: Option<u32>,
21 pub features: Vec<String>,
23}
24
25pub struct NlConfigGenerator;
31
32impl NlConfigGenerator {
33 pub fn generate(description: &str, provider: &dyn LlmProvider) -> Result<String, SynthError> {
43 if description.trim().is_empty() {
44 return Err(SynthError::generation(
45 "Natural language description cannot be empty",
46 ));
47 }
48
49 let intent = Self::parse_intent(description, provider)?;
50 Self::intent_to_yaml(&intent)
51 }
52
53 pub fn parse_intent(
58 description: &str,
59 provider: &dyn LlmProvider,
60 ) -> Result<ConfigIntent, SynthError> {
61 let llm_intent = Self::parse_with_llm(description, provider);
63
64 let keyword_intent = Self::parse_with_keywords(description);
66
67 match llm_intent {
69 Ok(llm) => Ok(Self::merge_intents(llm, keyword_intent)),
70 Err(e) => {
71 tracing::warn!(
72 "LLM-based config parsing failed, falling back to keyword parsing: {}",
73 e
74 );
75 Ok(keyword_intent)
76 }
77 }
78 }
79
80 pub fn generate_full(
87 description: &str,
88 provider: &dyn LlmProvider,
89 ) -> Result<String, SynthError> {
90 if description.trim().is_empty() {
91 return Err(SynthError::generation(
92 "Natural language description cannot be empty",
93 ));
94 }
95
96 let system = Self::full_schema_system_prompt();
97 let request = LlmRequest::new(description.to_string())
98 .with_system(system)
99 .with_temperature(0.2)
100 .with_max_tokens(4096);
101
102 match provider.complete(&request) {
103 Ok(response) => {
104 let yaml_text = Self::extract_yaml(&response.content);
105 if let Ok(value) = serde_yaml::from_str::<serde_yaml::Value>(&yaml_text) {
107 if let Some(map) = value.as_mapping() {
108 let known_keys = [
109 "global",
110 "companies",
111 "chart_of_accounts",
112 "transactions",
113 "output",
114 "fraud",
115 "audit_standards",
116 "banking",
117 "internal_controls",
118 "distributions",
119 "temporal_patterns",
120 "document_flows",
121 "intercompany",
122 "master_data",
123 "business_processes",
124 "hr",
125 "manufacturing",
126 "tax",
127 "treasury",
128 "esg",
129 "project_accounting",
130 "diffusion",
131 "llm",
132 "causal",
133 ];
134 let has_known = map
135 .keys()
136 .any(|k| k.as_str().map(|s| known_keys.contains(&s)).unwrap_or(false));
137 if has_known {
138 return Ok(yaml_text);
139 }
140 }
141 }
142 tracing::warn!(
144 "LLM full-config response did not contain valid DataSynth YAML; falling back to template"
145 );
146 Self::generate(description, provider)
147 }
148 Err(e) => {
149 tracing::warn!("LLM full-config generation failed: {e}; falling back to template");
150 Self::generate(description, provider)
151 }
152 }
153 }
154
155 pub fn extract_yaml(content: &str) -> String {
157 let trimmed = content.trim();
158
159 if let Some(start) = trimmed.find("```yaml") {
161 let after = &trimmed[start + 7..];
162 if let Some(end) = after.find("```") {
163 return after[..end].trim().to_string();
164 }
165 }
166
167 if let Some(start) = trimmed.find("```") {
169 let after = &trimmed[start + 3..];
170 if let Some(end) = after.find("```") {
171 return after[..end].trim().to_string();
172 }
173 }
174
175 trimmed.to_string()
177 }
178
179 pub fn full_schema_system_prompt() -> String {
183 concat!(
184 "You are a DataSynth configuration generator. Given a natural language description, ",
185 "produce a complete, valid DataSynth YAML configuration.\n\n",
186 "Top-level sections (all optional — include only what is relevant):\n\n",
187 "global:\n",
188 " industry: <retail|manufacturing|financial_services|healthcare|technology>\n",
189 " start_date: \"YYYY-MM-DD\"\n",
190 " period_months: <1-120>\n",
191 " seed: <integer>\n\n",
192 "companies:\n",
193 " - code: \"C001\"\n",
194 " name: \"...\"\n",
195 " currency: \"USD\"\n",
196 " country: \"US\"\n\n",
197 "chart_of_accounts:\n",
198 " complexity: <small|medium|large>\n\n",
199 "transactions:\n",
200 " count: <number>\n",
201 " anomaly_rate: <0.0-1.0>\n\n",
202 "output:\n",
203 " format: <csv|json|parquet>\n",
204 " compression: <true|false>\n\n",
205 "fraud:\n",
206 " enabled: true\n",
207 " types: [fictitious_transaction, duplicate_payment, split_transaction, ...]\n",
208 " injection_rate: <0.0-1.0>\n\n",
209 "internal_controls:\n",
210 " enabled: true\n",
211 " coso_enabled: true\n",
212 " target_maturity_level: <ad_hoc|repeatable|defined|managed|optimized>\n\n",
213 "distributions:\n",
214 " enabled: true\n",
215 " industry_profile: <industry>\n",
216 " amounts: { enabled: true, distribution_type: lognormal, benford_compliance: true }\n\n",
217 "temporal_patterns:\n",
218 " enabled: true\n",
219 " business_days: { enabled: true }\n",
220 " period_end: { model: exponential }\n\n",
221 "banking:\n",
222 " enabled: true\n",
223 " customer_count: <number>\n",
224 " kyc_enabled: true\n",
225 " aml_enabled: true\n\n",
226 "audit_standards:\n",
227 " enabled: true\n",
228 " isa_compliance: { enabled: true, compliance_level: standard }\n",
229 " sox: { enabled: true }\n\n",
230 "intercompany:\n",
231 " enabled: true\n\n",
232 "document_flows:\n",
233 " p2p: { enabled: true }\n",
234 " o2c: { enabled: true }\n\n",
235 "master_data:\n",
236 " vendors: { count: <number> }\n",
237 " customers: { count: <number> }\n\n",
238 "hr:\n",
239 " enabled: true\n",
240 " payroll: { enabled: true }\n\n",
241 "manufacturing:\n",
242 " enabled: true\n\n",
243 "tax:\n",
244 " enabled: true\n\n",
245 "treasury:\n",
246 " enabled: true\n\n",
247 "esg:\n",
248 " enabled: true\n\n",
249 "project_accounting:\n",
250 " enabled: true\n\n",
251 "diffusion:\n",
252 " enabled: true\n",
253 " backend: <statistical|neural|hybrid>\n\n",
254 "Return ONLY the YAML configuration (optionally inside ```yaml fences), no other text.\n"
255 ).to_string()
256 }
257
258 pub fn intent_to_yaml(intent: &ConfigIntent) -> Result<String, SynthError> {
260 let industry = intent.industry.as_deref().unwrap_or("manufacturing");
261 let country = intent.country.as_deref().unwrap_or("US");
262 let complexity = intent.company_size.as_deref().unwrap_or("medium");
263 let period_months = intent.period_months.unwrap_or(12);
264
265 if !(1..=120).contains(&period_months) {
267 return Err(SynthError::generation(format!(
268 "Period months must be between 1 and 120, got {period_months}"
269 )));
270 }
271
272 let valid_complexities = ["small", "medium", "large"];
273 if !valid_complexities.contains(&complexity) {
274 return Err(SynthError::generation(format!(
275 "Invalid company size '{complexity}', must be one of: small, medium, large"
276 )));
277 }
278
279 let currency = Self::country_to_currency(country);
280 let company_name = Self::industry_company_name(industry);
281
282 let mut yaml = String::with_capacity(2048);
283
284 yaml.push_str(&format!(
286 "global:\n industry: {industry}\n start_date: \"2024-01-01\"\n period_months: {period_months}\n seed: 42\n\n"
287 ));
288
289 yaml.push_str(&format!(
291 "companies:\n - code: \"C001\"\n name: \"{company_name}\"\n currency: \"{currency}\"\n country: \"{country}\"\n\n"
292 ));
293
294 yaml.push_str(&format!(
296 "chart_of_accounts:\n complexity: {complexity}\n\n"
297 ));
298
299 let tx_count = Self::complexity_to_tx_count(complexity);
301 yaml.push_str(&format!(
302 "transactions:\n count: {tx_count}\n anomaly_rate: 0.02\n\n"
303 ));
304
305 yaml.push_str("output:\n format: csv\n compression: false\n\n");
307
308 for feature in &intent.features {
310 match feature.as_str() {
311 "fraud" => {
312 yaml.push_str(
313 "fraud:\n enabled: true\n types:\n - fictitious_transaction\n - duplicate_payment\n - split_transaction\n injection_rate: 0.03\n\n",
314 );
315 }
316 "audit" => {
317 yaml.push_str(
318 "audit_standards:\n enabled: true\n isa_compliance:\n enabled: true\n compliance_level: standard\n framework: isa\n analytical_procedures:\n enabled: true\n procedures_per_account: 3\n confirmations:\n enabled: true\n positive_response_rate: 0.85\n sox:\n enabled: true\n materiality_threshold: 10000.0\n\n",
319 );
320 }
321 "banking" => {
322 yaml.push_str(
323 "banking:\n enabled: true\n customer_count: 100\n account_types:\n - checking\n - savings\n - loan\n kyc_enabled: true\n aml_enabled: true\n\n",
324 );
325 }
326 "controls" => {
327 yaml.push_str(
328 "internal_controls:\n enabled: true\n coso_enabled: true\n include_entity_level_controls: true\n target_maturity_level: \"managed\"\n exception_rate: 0.02\n sod_violation_rate: 0.01\n\n",
329 );
330 }
331 "process_mining" => {
332 yaml.push_str(
333 "business_processes:\n enabled: true\n ocel_export: true\n p2p:\n enabled: true\n o2c:\n enabled: true\n\n",
334 );
335 }
336 "intercompany" => {
337 yaml.push_str(
338 "intercompany:\n enabled: true\n matching_tolerance: 0.01\n elimination_enabled: true\n\n",
339 );
340 }
341 "distributions" => {
342 yaml.push_str(&format!(
343 "distributions:\n enabled: true\n industry_profile: {industry}\n amounts:\n enabled: true\n distribution_type: lognormal\n benford_compliance: true\n\n"
344 ));
345 }
346 other => {
347 tracing::warn!(
348 "Unknown NL config feature '{}' ignored. Valid features: fraud, audit, banking, controls, process_mining, intercompany, distributions",
349 other
350 );
351 }
352 }
353 }
354
355 Ok(yaml)
356 }
357
358 fn parse_with_llm(
360 description: &str,
361 provider: &dyn LlmProvider,
362 ) -> Result<ConfigIntent, SynthError> {
363 let system_prompt = "You are a configuration parser. Extract structured fields from a natural language description of desired synthetic data generation. Return ONLY a JSON object with these fields: industry (string or null), country (string or null), company_size (string or null), period_months (number or null), features (array of strings). Valid industries: retail, manufacturing, financial_services, healthcare, technology. Valid sizes: small, medium, large. Valid features: fraud, audit, banking, controls, process_mining, intercompany, distributions.";
364
365 let request = LlmRequest::new(description)
366 .with_system(system_prompt.to_string())
367 .with_temperature(0.1)
368 .with_max_tokens(512);
369
370 let response = provider.complete(&request)?;
371 Self::parse_llm_response(&response.content)
372 }
373
374 fn parse_llm_response(content: &str) -> Result<ConfigIntent, SynthError> {
376 let json_str = Self::extract_json(content)
378 .ok_or_else(|| SynthError::generation("No JSON found in LLM response"))?;
379
380 let value: serde_json::Value = serde_json::from_str(json_str)
381 .map_err(|e| SynthError::generation(format!("Failed to parse LLM JSON: {e}")))?;
382
383 let industry = value
384 .get("industry")
385 .and_then(|v| v.as_str())
386 .map(String::from);
387 let country = value
388 .get("country")
389 .and_then(|v| v.as_str())
390 .map(String::from);
391 let company_size = value
392 .get("company_size")
393 .and_then(|v| v.as_str())
394 .map(String::from);
395 let period_months = value
396 .get("period_months")
397 .and_then(serde_json::Value::as_u64)
398 .map(|v| v as u32);
399 let features = value
400 .get("features")
401 .and_then(|v| v.as_array())
402 .map(|arr| {
403 arr.iter()
404 .filter_map(|v| v.as_str().map(String::from))
405 .collect()
406 })
407 .unwrap_or_default();
408
409 Ok(ConfigIntent {
410 industry,
411 country,
412 company_size,
413 period_months,
414 features,
415 })
416 }
417
418 fn extract_json(content: &str) -> Option<&str> {
420 super::json_utils::extract_json_object(content)
421 }
422
423 fn parse_with_keywords(description: &str) -> ConfigIntent {
425 let lower = description.to_lowercase();
426
427 let industry = Self::extract_industry(&lower);
428 let country = Self::extract_country(&lower);
429 let company_size = Self::extract_size(&lower);
430 let period_months = Self::extract_period(&lower);
431 let features = Self::extract_features(&lower);
432
433 ConfigIntent {
434 industry,
435 country,
436 company_size,
437 period_months,
438 features,
439 }
440 }
441
442 fn extract_industry(text: &str) -> Option<String> {
449 let patterns: &[(&[&str], &str)] = &[
450 (
451 &["retail", "store", "shop", "e-commerce", "ecommerce"],
452 "retail",
453 ),
454 (
455 &["manufactur", "factory", "production", "assembly"],
456 "manufacturing",
457 ),
458 (
459 &[
460 "financial",
461 "finance",
462 "insurance",
463 "fintech",
464 "investment firm",
465 ],
466 "financial_services",
467 ),
468 (
469 &["health", "hospital", "medical", "pharma", "clinic"],
470 "healthcare",
471 ),
472 (
473 &["tech", "software", "saas", "startup", "digital"],
474 "technology",
475 ),
476 ];
477
478 let mut best: Option<(&str, usize)> = None;
479 for (keywords, industry) in patterns {
480 let count = keywords.iter().filter(|kw| text.contains(*kw)).count();
481 if count > 0 && (best.is_none() || count > best.expect("checked is_some").1) {
482 best = Some((industry, count));
483 }
484 }
485 best.map(|(industry, _)| industry.to_string())
486 }
487
488 fn extract_country(text: &str) -> Option<String> {
490 let name_patterns = [
494 (&["united states", "u.s.", "america"][..], "US"),
495 (&["germany", "german"][..], "DE"),
496 (&["united kingdom", "british", "england"][..], "GB"),
497 (&["china", "chinese"][..], "CN"),
498 (&["japan", "japanese"][..], "JP"),
499 (&["india", "indian"][..], "IN"),
500 (&["brazil", "brazilian"][..], "BR"),
501 (&["mexico", "mexican"][..], "MX"),
502 (&["australia", "australian"][..], "AU"),
503 (&["singapore", "singaporean"][..], "SG"),
504 (&["korea", "korean"][..], "KR"),
505 (&["france", "french"][..], "FR"),
506 (&["canada", "canadian"][..], "CA"),
507 ];
508
509 for (keywords, code) in &name_patterns {
510 if keywords.iter().any(|kw| text.contains(kw)) {
511 return Some(code.to_string());
512 }
513 }
514
515 let padded = format!(" {text} ");
519 let safe_codes = [
520 (" us ", "US"),
521 (" uk ", "GB"),
522 (" gb ", "GB"),
523 (" cn ", "CN"),
524 (" jp ", "JP"),
525 (" br ", "BR"),
526 (" mx ", "MX"),
527 (" au ", "AU"),
528 (" sg ", "SG"),
529 (" kr ", "KR"),
530 (" fr ", "FR"),
531 (" ca ", "CA"),
532 ];
533
534 for (code_pattern, code) in &safe_codes {
535 if padded.contains(code_pattern) {
536 return Some(code.to_string());
537 }
538 }
539
540 None
541 }
542
543 fn extract_size(text: &str) -> Option<String> {
545 if text.contains("small") || text.contains("startup") || text.contains("tiny") {
546 Some("small".to_string())
547 } else if text.contains("large")
548 || text.contains("enterprise")
549 || text.contains("big")
550 || text.contains("multinational")
551 || text.contains("fortune 500")
552 {
553 Some("large".to_string())
554 } else if text.contains("medium")
555 || text.contains("mid-size")
556 || text.contains("midsize")
557 || text.contains("mid size")
558 {
559 Some("medium".to_string())
560 } else {
561 None
562 }
563 }
564
565 fn extract_period(text: &str) -> Option<u32> {
567 let word_numbers = [
570 ("one", 1u32),
571 ("two", 2),
572 ("three", 3),
573 ("four", 4),
574 ("five", 5),
575 ("six", 6),
576 ("twelve", 12),
577 ("eighteen", 18),
578 ("twenty-four", 24),
579 ];
580
581 for (word, num) in &word_numbers {
583 if text.contains(&format!("{word} year")) {
584 return Some(num * 12);
585 }
586 if text.contains(&format!("{word} month")) {
587 return Some(*num);
588 }
589 }
590
591 let tokens: Vec<&str> = text.split_whitespace().collect();
593 for window in tokens.windows(2) {
594 if let Ok(num) = window[0].parse::<u32>() {
595 if window[1].starts_with("year") {
596 return Some(num * 12);
597 }
598 if window[1].starts_with("month") {
599 return Some(num);
600 }
601 }
602 }
603
604 None
605 }
606
607 fn extract_features(text: &str) -> Vec<String> {
609 let mut features = Vec::new();
610
611 let feature_patterns = [
612 (&["fraud", "fraudulent", "suspicious"][..], "fraud"),
613 (&["audit", "auditing", "assurance"][..], "audit"),
614 (&["banking", "bank account", "kyc", "aml"][..], "banking"),
615 (
616 &["control", "sox", "sod", "segregation of duties", "coso"][..],
617 "controls",
618 ),
619 (
620 &["process mining", "ocel", "event log"][..],
621 "process_mining",
622 ),
623 (
624 &["intercompany", "inter-company", "consolidation"][..],
625 "intercompany",
626 ),
627 (
628 &["distribution", "benford", "statistical"][..],
629 "distributions",
630 ),
631 ];
632
633 for (keywords, feature) in &feature_patterns {
634 if keywords.iter().any(|kw| text.contains(kw)) {
635 features.push(feature.to_string());
636 }
637 }
638
639 features
640 }
641
642 fn merge_intents(primary: ConfigIntent, fallback: ConfigIntent) -> ConfigIntent {
644 ConfigIntent {
645 industry: primary.industry.or(fallback.industry),
646 country: primary.country.or(fallback.country),
647 company_size: primary.company_size.or(fallback.company_size),
648 period_months: primary.period_months.or(fallback.period_months),
649 features: if primary.features.is_empty() {
650 fallback.features
651 } else {
652 primary.features
653 },
654 }
655 }
656
657 fn country_to_currency(country: &str) -> &'static str {
659 match country {
660 "US" | "CA" => "USD",
661 "DE" | "FR" => "EUR",
662 "GB" => "GBP",
663 "CN" => "CNY",
664 "JP" => "JPY",
665 "IN" => "INR",
666 "BR" => "BRL",
667 "MX" => "MXN",
668 "AU" => "AUD",
669 "SG" => "SGD",
670 "KR" => "KRW",
671 _ => "USD",
672 }
673 }
674
675 fn industry_company_name(industry: &str) -> &'static str {
677 match industry {
678 "retail" => "Retail Corp",
679 "manufacturing" => "Manufacturing Industries Inc",
680 "financial_services" => "Financial Services Group",
681 "healthcare" => "HealthCare Solutions",
682 "technology" => "TechCorp Solutions",
683 _ => "DataSynth Corp",
684 }
685 }
686
687 fn complexity_to_tx_count(complexity: &str) -> u32 {
689 match complexity {
690 "small" => 1000,
691 "medium" => 5000,
692 "large" => 25000,
693 _ => 5000,
694 }
695 }
696}
697
698#[cfg(test)]
699mod tests {
700 use super::*;
701 use crate::llm::mock_provider::MockLlmProvider;
702
703 #[test]
704 fn test_parse_retail_description() {
705 let provider = MockLlmProvider::new(42);
706 let intent = NlConfigGenerator::parse_intent(
707 "Generate 1 year of retail data for a medium US company",
708 &provider,
709 )
710 .expect("should parse successfully");
711
712 assert_eq!(intent.industry, Some("retail".to_string()));
713 assert_eq!(intent.country, Some("US".to_string()));
714 assert_eq!(intent.company_size, Some("medium".to_string()));
715 assert_eq!(intent.period_months, Some(12));
716 }
717
718 #[test]
719 fn test_parse_manufacturing_with_fraud() {
720 let provider = MockLlmProvider::new(42);
721 let intent = NlConfigGenerator::parse_intent(
722 "Create 6 months of manufacturing data for a large German company with fraud detection",
723 &provider,
724 )
725 .expect("should parse successfully");
726
727 assert_eq!(intent.industry, Some("manufacturing".to_string()));
728 assert_eq!(intent.country, Some("DE".to_string()));
729 assert_eq!(intent.company_size, Some("large".to_string()));
730 assert_eq!(intent.period_months, Some(6));
731 assert!(intent.features.contains(&"fraud".to_string()));
732 }
733
734 #[test]
735 fn test_parse_financial_services_with_audit() {
736 let provider = MockLlmProvider::new(42);
737 let intent = NlConfigGenerator::parse_intent(
738 "I need 2 years of financial services data for audit testing with SOX controls",
739 &provider,
740 )
741 .expect("should parse successfully");
742
743 assert_eq!(intent.industry, Some("financial_services".to_string()));
744 assert_eq!(intent.period_months, Some(24));
745 assert!(intent.features.contains(&"audit".to_string()));
746 assert!(intent.features.contains(&"controls".to_string()));
747 }
748
749 #[test]
750 fn test_parse_healthcare_small() {
751 let provider = MockLlmProvider::new(42);
752 let intent = NlConfigGenerator::parse_intent(
753 "Small healthcare company in Japan, 3 months of data",
754 &provider,
755 )
756 .expect("should parse successfully");
757
758 assert_eq!(intent.industry, Some("healthcare".to_string()));
759 assert_eq!(intent.country, Some("JP".to_string()));
760 assert_eq!(intent.company_size, Some("small".to_string()));
761 assert_eq!(intent.period_months, Some(3));
762 }
763
764 #[test]
765 fn test_parse_technology_with_banking() {
766 let provider = MockLlmProvider::new(42);
767 let intent = NlConfigGenerator::parse_intent(
768 "Generate data for a technology startup in Singapore with banking and KYC",
769 &provider,
770 )
771 .expect("should parse successfully");
772
773 assert_eq!(intent.industry, Some("technology".to_string()));
774 assert_eq!(intent.country, Some("SG".to_string()));
775 assert_eq!(intent.company_size, Some("small".to_string()));
776 assert!(intent.features.contains(&"banking".to_string()));
777 }
778
779 #[test]
780 fn test_parse_word_numbers() {
781 let provider = MockLlmProvider::new(42);
782 let intent =
783 NlConfigGenerator::parse_intent("Generate two years of retail data", &provider)
784 .expect("should parse successfully");
785
786 assert_eq!(intent.period_months, Some(24));
787 }
788
789 #[test]
790 fn test_parse_multiple_features() {
791 let provider = MockLlmProvider::new(42);
792 let intent = NlConfigGenerator::parse_intent(
793 "Manufacturing data with fraud detection, audit trail, process mining, and intercompany consolidation",
794 &provider,
795 )
796 .expect("should parse successfully");
797
798 assert_eq!(intent.industry, Some("manufacturing".to_string()));
799 assert!(intent.features.contains(&"fraud".to_string()));
800 assert!(intent.features.contains(&"audit".to_string()));
801 assert!(intent.features.contains(&"process_mining".to_string()));
802 assert!(intent.features.contains(&"intercompany".to_string()));
803 }
804
805 #[test]
806 fn test_intent_to_yaml_basic() {
807 let intent = ConfigIntent {
808 industry: Some("retail".to_string()),
809 country: Some("US".to_string()),
810 company_size: Some("medium".to_string()),
811 period_months: Some(12),
812 features: vec![],
813 };
814
815 let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
816
817 assert!(yaml.contains("industry: retail"));
818 assert!(yaml.contains("period_months: 12"));
819 assert!(yaml.contains("currency: \"USD\""));
820 assert!(yaml.contains("country: \"US\""));
821 assert!(yaml.contains("complexity: medium"));
822 assert!(yaml.contains("count: 5000"));
823 }
824
825 #[test]
826 fn test_intent_to_yaml_with_features() {
827 let intent = ConfigIntent {
828 industry: Some("manufacturing".to_string()),
829 country: Some("DE".to_string()),
830 company_size: Some("large".to_string()),
831 period_months: Some(24),
832 features: vec![
833 "fraud".to_string(),
834 "audit".to_string(),
835 "controls".to_string(),
836 ],
837 };
838
839 let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
840
841 assert!(yaml.contains("industry: manufacturing"));
842 assert!(yaml.contains("currency: \"EUR\""));
843 assert!(yaml.contains("complexity: large"));
844 assert!(yaml.contains("count: 25000"));
845 assert!(yaml.contains("fraud:"));
846 assert!(yaml.contains("audit_standards:"));
847 assert!(yaml.contains("internal_controls:"));
848 }
849
850 #[test]
851 fn test_intent_to_yaml_defaults() {
852 let intent = ConfigIntent::default();
853
854 let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
855
856 assert!(yaml.contains("industry: manufacturing"));
858 assert!(yaml.contains("period_months: 12"));
859 assert!(yaml.contains("complexity: medium"));
860 }
861
862 #[test]
863 fn test_intent_to_yaml_invalid_period() {
864 let intent = ConfigIntent {
865 period_months: Some(0),
866 ..ConfigIntent::default()
867 };
868
869 let result = NlConfigGenerator::intent_to_yaml(&intent);
870 assert!(result.is_err());
871
872 let intent = ConfigIntent {
873 period_months: Some(121),
874 ..ConfigIntent::default()
875 };
876
877 let result = NlConfigGenerator::intent_to_yaml(&intent);
878 assert!(result.is_err());
879 }
880
881 #[test]
882 fn test_generate_end_to_end() {
883 let provider = MockLlmProvider::new(42);
884 let yaml = NlConfigGenerator::generate(
885 "Generate 1 year of retail data for a medium US company with fraud detection",
886 &provider,
887 )
888 .expect("should generate YAML");
889
890 assert!(yaml.contains("industry: retail"));
891 assert!(yaml.contains("period_months: 12"));
892 assert!(yaml.contains("currency: \"USD\""));
893 assert!(yaml.contains("fraud:"));
894 assert!(yaml.contains("complexity: medium"));
895 }
896
897 #[test]
898 fn test_generate_empty_description() {
899 let provider = MockLlmProvider::new(42);
900 let result = NlConfigGenerator::generate("", &provider);
901 assert!(result.is_err());
902
903 let result = NlConfigGenerator::generate(" ", &provider);
904 assert!(result.is_err());
905 }
906
907 #[test]
908 fn test_extract_json_from_response() {
909 let content = r#"Here is the parsed output: {"industry": "retail", "country": "US"} done"#;
910 let json = NlConfigGenerator::extract_json(content);
911 assert!(json.is_some());
912 assert_eq!(
913 json.expect("json should be present"),
914 r#"{"industry": "retail", "country": "US"}"#
915 );
916 }
917
918 #[test]
919 fn test_extract_json_nested() {
920 let content = r#"{"industry": "retail", "features": ["fraud", "audit"]}"#;
921 let json = NlConfigGenerator::extract_json(content);
922 assert!(json.is_some());
923 }
924
925 #[test]
926 fn test_extract_json_missing() {
927 let content = "No JSON here at all";
928 let json = NlConfigGenerator::extract_json(content);
929 assert!(json.is_none());
930 }
931
932 #[test]
933 fn test_parse_llm_response_valid() {
934 let content = r#"{"industry": "retail", "country": "US", "company_size": "medium", "period_months": 12, "features": ["fraud"]}"#;
935 let intent =
936 NlConfigGenerator::parse_llm_response(content).expect("should parse valid JSON");
937
938 assert_eq!(intent.industry, Some("retail".to_string()));
939 assert_eq!(intent.country, Some("US".to_string()));
940 assert_eq!(intent.company_size, Some("medium".to_string()));
941 assert_eq!(intent.period_months, Some(12));
942 assert_eq!(intent.features, vec!["fraud".to_string()]);
943 }
944
945 #[test]
946 fn test_parse_llm_response_partial() {
947 let content = r#"{"industry": "retail"}"#;
948 let intent =
949 NlConfigGenerator::parse_llm_response(content).expect("should parse partial JSON");
950
951 assert_eq!(intent.industry, Some("retail".to_string()));
952 assert_eq!(intent.country, None);
953 assert!(intent.features.is_empty());
954 }
955
956 #[test]
957 fn test_country_to_currency_mapping() {
958 assert_eq!(NlConfigGenerator::country_to_currency("US"), "USD");
959 assert_eq!(NlConfigGenerator::country_to_currency("DE"), "EUR");
960 assert_eq!(NlConfigGenerator::country_to_currency("GB"), "GBP");
961 assert_eq!(NlConfigGenerator::country_to_currency("JP"), "JPY");
962 assert_eq!(NlConfigGenerator::country_to_currency("CN"), "CNY");
963 assert_eq!(NlConfigGenerator::country_to_currency("BR"), "BRL");
964 assert_eq!(NlConfigGenerator::country_to_currency("XX"), "USD"); }
966
967 #[test]
968 fn test_merge_intents() {
969 let primary = ConfigIntent {
970 industry: Some("retail".to_string()),
971 country: None,
972 company_size: None,
973 period_months: Some(12),
974 features: vec![],
975 };
976 let fallback = ConfigIntent {
977 industry: Some("manufacturing".to_string()),
978 country: Some("DE".to_string()),
979 company_size: Some("large".to_string()),
980 period_months: Some(6),
981 features: vec!["fraud".to_string()],
982 };
983
984 let merged = NlConfigGenerator::merge_intents(primary, fallback);
985 assert_eq!(merged.industry, Some("retail".to_string())); assert_eq!(merged.country, Some("DE".to_string())); assert_eq!(merged.company_size, Some("large".to_string())); assert_eq!(merged.period_months, Some(12)); assert_eq!(merged.features, vec!["fraud".to_string()]); }
991
992 #[test]
993 fn test_parse_uk_country() {
994 let provider = MockLlmProvider::new(42);
995 let intent = NlConfigGenerator::parse_intent(
996 "Generate data for a UK manufacturing company",
997 &provider,
998 )
999 .expect("should parse successfully");
1000
1001 assert_eq!(intent.country, Some("GB".to_string()));
1002 }
1003
1004 #[test]
1005 fn test_intent_to_yaml_banking_feature() {
1006 let intent = ConfigIntent {
1007 industry: Some("financial_services".to_string()),
1008 country: Some("US".to_string()),
1009 company_size: Some("large".to_string()),
1010 period_months: Some(12),
1011 features: vec!["banking".to_string()],
1012 };
1013
1014 let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
1015
1016 assert!(yaml.contains("banking:"));
1017 assert!(yaml.contains("kyc_enabled: true"));
1018 assert!(yaml.contains("aml_enabled: true"));
1019 }
1020
1021 #[test]
1022 fn test_intent_to_yaml_process_mining_feature() {
1023 let intent = ConfigIntent {
1024 features: vec!["process_mining".to_string()],
1025 ..ConfigIntent::default()
1026 };
1027
1028 let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
1029
1030 assert!(yaml.contains("business_processes:"));
1031 assert!(yaml.contains("ocel_export: true"));
1032 }
1033
1034 #[test]
1035 fn test_intent_to_yaml_distributions_feature() {
1036 let intent = ConfigIntent {
1037 industry: Some("retail".to_string()),
1038 features: vec!["distributions".to_string()],
1039 ..ConfigIntent::default()
1040 };
1041
1042 let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
1043
1044 assert!(yaml.contains("distributions:"));
1045 assert!(yaml.contains("industry_profile: retail"));
1046 assert!(yaml.contains("benford_compliance: true"));
1047 }
1048
1049 #[test]
1050 fn test_extract_yaml_from_fenced_block() {
1051 let content = "Here is the config:\n```yaml\nglobal:\n industry: retail\n```\nDone.";
1052 let yaml = NlConfigGenerator::extract_yaml(content);
1053 assert!(yaml.contains("global:"));
1054 assert!(yaml.contains("industry: retail"));
1055 assert!(!yaml.contains("```"));
1056 }
1057
1058 #[test]
1059 fn test_extract_yaml_plain_fences() {
1060 let content = "```\nglobal:\n seed: 42\n```";
1061 let yaml = NlConfigGenerator::extract_yaml(content);
1062 assert!(yaml.contains("global:"));
1063 assert!(yaml.contains("seed: 42"));
1064 assert!(!yaml.contains("```"));
1065 }
1066
1067 #[test]
1068 fn test_extract_yaml_no_fences() {
1069 let content = "global:\n industry: manufacturing\n";
1070 let yaml = NlConfigGenerator::extract_yaml(content);
1071 assert!(yaml.contains("global:"));
1072 assert!(yaml.contains("industry: manufacturing"));
1073 }
1074
1075 #[test]
1076 fn test_generate_full_falls_back_to_template() {
1077 let provider = MockLlmProvider::new(42);
1080 let yaml = NlConfigGenerator::generate_full(
1081 "Generate 1 year of retail data for a medium US company",
1082 &provider,
1083 )
1084 .expect("should fall back to template-based generation");
1085
1086 assert!(yaml.contains("industry: retail"));
1087 assert!(yaml.contains("period_months: 12"));
1088 }
1089
1090 #[test]
1091 fn test_generate_full_empty_description() {
1092 let provider = MockLlmProvider::new(42);
1093 let result = NlConfigGenerator::generate_full("", &provider);
1094 assert!(result.is_err());
1095
1096 let result = NlConfigGenerator::generate_full(" ", &provider);
1097 assert!(result.is_err());
1098 }
1099
1100 #[test]
1101 fn test_full_schema_system_prompt_covers_key_sections() {
1102 let prompt = NlConfigGenerator::full_schema_system_prompt();
1103 assert!(prompt.contains("global:"));
1104 assert!(prompt.contains("companies:"));
1105 assert!(prompt.contains("chart_of_accounts:"));
1106 assert!(prompt.contains("transactions:"));
1107 assert!(prompt.contains("fraud:"));
1108 assert!(prompt.contains("banking:"));
1109 assert!(prompt.contains("distributions:"));
1110 assert!(prompt.contains("diffusion:"));
1111 }
1112}