Skip to main content

datasynth_core/llm/
nl_config.rs

1//! Natural language to YAML configuration generator.
2//!
3//! Takes a free-text description of desired synthetic data (e.g., "Generate 1 year of
4//! retail data for a medium US company with fraud detection") and produces a valid
5//! `GeneratorConfig` YAML string.
6
7use super::provider::{LlmProvider, LlmRequest};
8use crate::error::SynthError;
9
10/// Structured representation of user intent extracted from natural language.
11#[derive(Debug, Clone, Default)]
12pub struct ConfigIntent {
13    /// Target industry (e.g., "retail", "manufacturing", "financial_services").
14    pub industry: Option<String>,
15    /// Country code (e.g., "US", "DE", "GB").
16    pub country: Option<String>,
17    /// Company size: "small", "medium", or "large".
18    pub company_size: Option<String>,
19    /// Duration in months.
20    pub period_months: Option<u32>,
21    /// Requested feature flags (e.g., "fraud", "audit", "banking", "controls").
22    pub features: Vec<String>,
23}
24
25/// Generates YAML configuration from natural language descriptions.
26///
27/// The generator uses a two-phase approach:
28/// 1. Parse the natural language description into a structured [`ConfigIntent`].
29/// 2. Map the intent to a YAML configuration string using preset templates.
30pub struct NlConfigGenerator;
31
32impl NlConfigGenerator {
33    /// Generate a YAML configuration from a natural language description.
34    ///
35    /// Uses the provided LLM provider to help parse the description, with
36    /// keyword-based fallback parsing for reliability.
37    ///
38    /// # Errors
39    ///
40    /// Returns `SynthError::GenerationError` if the description cannot be parsed
41    /// or the resulting configuration is invalid.
42    pub fn generate(description: &str, provider: &dyn LlmProvider) -> Result<String, SynthError> {
43        if description.trim().is_empty() {
44            return Err(SynthError::generation(
45                "Natural language description cannot be empty",
46            ));
47        }
48
49        let intent = Self::parse_intent(description, provider)?;
50        Self::intent_to_yaml(&intent)
51    }
52
53    /// Parse a natural language description into a structured [`ConfigIntent`].
54    ///
55    /// Attempts to use the LLM provider first, then falls back to keyword-based
56    /// extraction for reliability.
57    pub fn parse_intent(
58        description: &str,
59        provider: &dyn LlmProvider,
60    ) -> Result<ConfigIntent, SynthError> {
61        // Try LLM-based parsing first
62        let llm_intent = Self::parse_with_llm(description, provider);
63
64        // Always run keyword-based parsing as fallback/supplement
65        let keyword_intent = Self::parse_with_keywords(description);
66
67        // Merge: prefer LLM results where available, fall back to keywords
68        match llm_intent {
69            Ok(llm) => Ok(Self::merge_intents(llm, keyword_intent)),
70            Err(_) => Ok(keyword_intent),
71        }
72    }
73
74    /// Map a [`ConfigIntent`] to a YAML configuration string.
75    pub fn intent_to_yaml(intent: &ConfigIntent) -> Result<String, SynthError> {
76        let industry = intent.industry.as_deref().unwrap_or("manufacturing");
77        let country = intent.country.as_deref().unwrap_or("US");
78        let complexity = intent.company_size.as_deref().unwrap_or("medium");
79        let period_months = intent.period_months.unwrap_or(12);
80
81        // Validate inputs
82        if !(1..=120).contains(&period_months) {
83            return Err(SynthError::generation(format!(
84                "Period months must be between 1 and 120, got {}",
85                period_months
86            )));
87        }
88
89        let valid_complexities = ["small", "medium", "large"];
90        if !valid_complexities.contains(&complexity) {
91            return Err(SynthError::generation(format!(
92                "Invalid company size '{}', must be one of: small, medium, large",
93                complexity
94            )));
95        }
96
97        let currency = Self::country_to_currency(country);
98        let company_name = Self::industry_company_name(industry);
99
100        let mut yaml = String::with_capacity(2048);
101
102        // Global settings
103        yaml.push_str(&format!(
104            "global:\n  industry: {}\n  start_date: \"2024-01-01\"\n  period_months: {}\n  seed: 42\n\n",
105            industry, period_months
106        ));
107
108        // Companies
109        yaml.push_str(&format!(
110            "companies:\n  - code: \"C001\"\n    name: \"{}\"\n    currency: \"{}\"\n    country: \"{}\"\n\n",
111            company_name, currency, country
112        ));
113
114        // Chart of accounts
115        yaml.push_str(&format!(
116            "chart_of_accounts:\n  complexity: {}\n\n",
117            complexity
118        ));
119
120        // Transactions
121        let tx_count = Self::complexity_to_tx_count(complexity);
122        yaml.push_str(&format!(
123            "transactions:\n  count: {}\n  anomaly_rate: 0.02\n\n",
124            tx_count
125        ));
126
127        // Output
128        yaml.push_str("output:\n  format: csv\n  compression: false\n\n");
129
130        // Feature-specific sections
131        for feature in &intent.features {
132            match feature.as_str() {
133                "fraud" => {
134                    yaml.push_str(
135                        "fraud:\n  enabled: true\n  types:\n    - fictitious_transaction\n    - duplicate_payment\n    - split_transaction\n  injection_rate: 0.03\n\n",
136                    );
137                }
138                "audit" => {
139                    yaml.push_str(
140                        "audit_standards:\n  enabled: true\n  isa_compliance:\n    enabled: true\n    compliance_level: standard\n    framework: isa\n  analytical_procedures:\n    enabled: true\n    procedures_per_account: 3\n  confirmations:\n    enabled: true\n    positive_response_rate: 0.85\n  sox:\n    enabled: true\n    materiality_threshold: 10000.0\n\n",
141                    );
142                }
143                "banking" => {
144                    yaml.push_str(
145                        "banking:\n  enabled: true\n  customer_count: 100\n  account_types:\n    - checking\n    - savings\n    - loan\n  kyc_enabled: true\n  aml_enabled: true\n\n",
146                    );
147                }
148                "controls" => {
149                    yaml.push_str(
150                        "internal_controls:\n  enabled: true\n  coso_enabled: true\n  include_entity_level_controls: true\n  target_maturity_level: \"managed\"\n  exception_rate: 0.02\n  sod_violation_rate: 0.01\n\n",
151                    );
152                }
153                "process_mining" => {
154                    yaml.push_str(
155                        "business_processes:\n  enabled: true\n  ocel_export: true\n  p2p:\n    enabled: true\n  o2c:\n    enabled: true\n\n",
156                    );
157                }
158                "intercompany" => {
159                    yaml.push_str(
160                        "intercompany:\n  enabled: true\n  matching_tolerance: 0.01\n  elimination_enabled: true\n\n",
161                    );
162                }
163                "distributions" => {
164                    yaml.push_str(&format!(
165                        "distributions:\n  enabled: true\n  industry_profile: {}\n  amounts:\n    enabled: true\n    distribution_type: lognormal\n    benford_compliance: true\n\n",
166                        industry
167                    ));
168                }
169                _ => {} // Unknown features are silently ignored
170            }
171        }
172
173        Ok(yaml)
174    }
175
176    /// Attempt LLM-based parsing of the description.
177    fn parse_with_llm(
178        description: &str,
179        provider: &dyn LlmProvider,
180    ) -> Result<ConfigIntent, SynthError> {
181        let system_prompt = "You are a configuration parser. Extract structured fields from a natural language description of desired synthetic data generation. Return ONLY a JSON object with these fields: industry (string or null), country (string or null), company_size (string or null), period_months (number or null), features (array of strings). Valid industries: retail, manufacturing, financial_services, healthcare, technology. Valid sizes: small, medium, large. Valid features: fraud, audit, banking, controls, process_mining, intercompany, distributions.";
182
183        let request = LlmRequest::new(description)
184            .with_system(system_prompt.to_string())
185            .with_temperature(0.1)
186            .with_max_tokens(512);
187
188        let response = provider.complete(&request)?;
189        Self::parse_llm_response(&response.content)
190    }
191
192    /// Parse the LLM response JSON into a ConfigIntent.
193    fn parse_llm_response(content: &str) -> Result<ConfigIntent, SynthError> {
194        // Try to find JSON in the response
195        let json_str = Self::extract_json(content)
196            .ok_or_else(|| SynthError::generation("No JSON found in LLM response"))?;
197
198        let value: serde_json::Value = serde_json::from_str(json_str)
199            .map_err(|e| SynthError::generation(format!("Failed to parse LLM JSON: {}", e)))?;
200
201        let industry = value
202            .get("industry")
203            .and_then(|v| v.as_str())
204            .map(String::from);
205        let country = value
206            .get("country")
207            .and_then(|v| v.as_str())
208            .map(String::from);
209        let company_size = value
210            .get("company_size")
211            .and_then(|v| v.as_str())
212            .map(String::from);
213        let period_months = value
214            .get("period_months")
215            .and_then(|v| v.as_u64())
216            .map(|v| v as u32);
217        let features = value
218            .get("features")
219            .and_then(|v| v.as_array())
220            .map(|arr| {
221                arr.iter()
222                    .filter_map(|v| v.as_str().map(String::from))
223                    .collect()
224            })
225            .unwrap_or_default();
226
227        Ok(ConfigIntent {
228            industry,
229            country,
230            company_size,
231            period_months,
232            features,
233        })
234    }
235
236    /// Extract a JSON object substring from potentially noisy LLM output.
237    fn extract_json(content: &str) -> Option<&str> {
238        // Find the first '{' and matching '}'
239        let start = content.find('{')?;
240        let mut depth = 0i32;
241        for (i, ch) in content[start..].char_indices() {
242            match ch {
243                '{' => depth += 1,
244                '}' => {
245                    depth -= 1;
246                    if depth == 0 {
247                        return Some(&content[start..start + i + 1]);
248                    }
249                }
250                _ => {}
251            }
252        }
253        None
254    }
255
256    /// Keyword-based parsing as a reliable fallback.
257    fn parse_with_keywords(description: &str) -> ConfigIntent {
258        let lower = description.to_lowercase();
259
260        let industry = Self::extract_industry(&lower);
261        let country = Self::extract_country(&lower);
262        let company_size = Self::extract_size(&lower);
263        let period_months = Self::extract_period(&lower);
264        let features = Self::extract_features(&lower);
265
266        ConfigIntent {
267            industry,
268            country,
269            company_size,
270            period_months,
271            features,
272        }
273    }
274
275    /// Extract industry from lowercased text.
276    ///
277    /// Uses a scoring approach: each industry gets points for keyword matches,
278    /// and the highest-scoring industry wins. This avoids order-dependent
279    /// issues where "banking" in a feature context incorrectly triggers
280    /// "financial_services" over "technology".
281    fn extract_industry(text: &str) -> Option<String> {
282        let patterns: &[(&[&str], &str)] = &[
283            (
284                &["retail", "store", "shop", "e-commerce", "ecommerce"],
285                "retail",
286            ),
287            (
288                &["manufactur", "factory", "production", "assembly"],
289                "manufacturing",
290            ),
291            (
292                &[
293                    "financial",
294                    "finance",
295                    "insurance",
296                    "fintech",
297                    "investment firm",
298                ],
299                "financial_services",
300            ),
301            (
302                &["health", "hospital", "medical", "pharma", "clinic"],
303                "healthcare",
304            ),
305            (
306                &["tech", "software", "saas", "startup", "digital"],
307                "technology",
308            ),
309        ];
310
311        let mut best: Option<(&str, usize)> = None;
312        for (keywords, industry) in patterns {
313            let count = keywords.iter().filter(|kw| text.contains(*kw)).count();
314            if count > 0 && (best.is_none() || count > best.expect("checked is_some").1) {
315                best = Some((industry, count));
316            }
317        }
318        best.map(|(industry, _)| industry.to_string())
319    }
320
321    /// Extract country from lowercased text.
322    fn extract_country(text: &str) -> Option<String> {
323        // Check full country names first (most reliable), then short codes.
324        // Short codes like "in", "de", "us" can clash with English words,
325        // so we only use unambiguous short codes.
326        let name_patterns = [
327            (&["united states", "u.s.", "america"][..], "US"),
328            (&["germany", "german"][..], "DE"),
329            (&["united kingdom", "british", "england"][..], "GB"),
330            (&["china", "chinese"][..], "CN"),
331            (&["japan", "japanese"][..], "JP"),
332            (&["india", "indian"][..], "IN"),
333            (&["brazil", "brazilian"][..], "BR"),
334            (&["mexico", "mexican"][..], "MX"),
335            (&["australia", "australian"][..], "AU"),
336            (&["singapore", "singaporean"][..], "SG"),
337            (&["korea", "korean"][..], "KR"),
338            (&["france", "french"][..], "FR"),
339            (&["canada", "canadian"][..], "CA"),
340        ];
341
342        for (keywords, code) in &name_patterns {
343            if keywords.iter().any(|kw| text.contains(kw)) {
344                return Some(code.to_string());
345            }
346        }
347
348        // Fall back to short codes (padded with spaces).
349        // Excluded: "in" (India - clashes with preposition "in"),
350        //           "de" (Germany - clashes with various uses).
351        let padded = format!(" {} ", text);
352        let safe_codes = [
353            (" us ", "US"),
354            (" uk ", "GB"),
355            (" gb ", "GB"),
356            (" cn ", "CN"),
357            (" jp ", "JP"),
358            (" br ", "BR"),
359            (" mx ", "MX"),
360            (" au ", "AU"),
361            (" sg ", "SG"),
362            (" kr ", "KR"),
363            (" fr ", "FR"),
364            (" ca ", "CA"),
365        ];
366
367        for (code_pattern, code) in &safe_codes {
368            if padded.contains(code_pattern) {
369                return Some(code.to_string());
370            }
371        }
372
373        None
374    }
375
376    /// Extract company size from lowercased text.
377    fn extract_size(text: &str) -> Option<String> {
378        if text.contains("small") || text.contains("startup") || text.contains("tiny") {
379            Some("small".to_string())
380        } else if text.contains("large")
381            || text.contains("enterprise")
382            || text.contains("big")
383            || text.contains("multinational")
384            || text.contains("fortune 500")
385        {
386            Some("large".to_string())
387        } else if text.contains("medium")
388            || text.contains("mid-size")
389            || text.contains("midsize")
390            || text.contains("mid size")
391        {
392            Some("medium".to_string())
393        } else {
394            None
395        }
396    }
397
398    /// Extract period in months from lowercased text.
399    fn extract_period(text: &str) -> Option<u32> {
400        // Match patterns like "1 year", "2 years", "6 months", "18 months"
401        // Also handle "one year", "two years", etc.
402        let word_numbers = [
403            ("one", 1u32),
404            ("two", 2),
405            ("three", 3),
406            ("four", 4),
407            ("five", 5),
408            ("six", 6),
409            ("twelve", 12),
410            ("eighteen", 18),
411            ("twenty-four", 24),
412        ];
413
414        // Try "N year(s)" pattern
415        for (word, num) in &word_numbers {
416            if text.contains(&format!("{} year", word)) {
417                return Some(num * 12);
418            }
419            if text.contains(&format!("{} month", word)) {
420                return Some(*num);
421            }
422        }
423
424        // Try numeric patterns: "N year(s)", "N month(s)"
425        let tokens: Vec<&str> = text.split_whitespace().collect();
426        for window in tokens.windows(2) {
427            if let Ok(num) = window[0].parse::<u32>() {
428                if window[1].starts_with("year") {
429                    return Some(num * 12);
430                }
431                if window[1].starts_with("month") {
432                    return Some(num);
433                }
434            }
435        }
436
437        None
438    }
439
440    /// Extract feature flags from lowercased text.
441    fn extract_features(text: &str) -> Vec<String> {
442        let mut features = Vec::new();
443
444        let feature_patterns = [
445            (&["fraud", "fraudulent", "suspicious"][..], "fraud"),
446            (&["audit", "auditing", "assurance"][..], "audit"),
447            (&["banking", "bank account", "kyc", "aml"][..], "banking"),
448            (
449                &["control", "sox", "sod", "segregation of duties", "coso"][..],
450                "controls",
451            ),
452            (
453                &["process mining", "ocel", "event log"][..],
454                "process_mining",
455            ),
456            (
457                &["intercompany", "inter-company", "consolidation"][..],
458                "intercompany",
459            ),
460            (
461                &["distribution", "benford", "statistical"][..],
462                "distributions",
463            ),
464        ];
465
466        for (keywords, feature) in &feature_patterns {
467            if keywords.iter().any(|kw| text.contains(kw)) {
468                features.push(feature.to_string());
469            }
470        }
471
472        features
473    }
474
475    /// Merge two ConfigIntents, preferring the primary where available.
476    fn merge_intents(primary: ConfigIntent, fallback: ConfigIntent) -> ConfigIntent {
477        ConfigIntent {
478            industry: primary.industry.or(fallback.industry),
479            country: primary.country.or(fallback.country),
480            company_size: primary.company_size.or(fallback.company_size),
481            period_months: primary.period_months.or(fallback.period_months),
482            features: if primary.features.is_empty() {
483                fallback.features
484            } else {
485                primary.features
486            },
487        }
488    }
489
490    /// Map country code to default currency.
491    fn country_to_currency(country: &str) -> &'static str {
492        match country {
493            "US" | "CA" => "USD",
494            "DE" | "FR" => "EUR",
495            "GB" => "GBP",
496            "CN" => "CNY",
497            "JP" => "JPY",
498            "IN" => "INR",
499            "BR" => "BRL",
500            "MX" => "MXN",
501            "AU" => "AUD",
502            "SG" => "SGD",
503            "KR" => "KRW",
504            _ => "USD",
505        }
506    }
507
508    /// Generate a company name based on industry.
509    fn industry_company_name(industry: &str) -> &'static str {
510        match industry {
511            "retail" => "Retail Corp",
512            "manufacturing" => "Manufacturing Industries Inc",
513            "financial_services" => "Financial Services Group",
514            "healthcare" => "HealthCare Solutions",
515            "technology" => "TechCorp Solutions",
516            _ => "DataSynth Corp",
517        }
518    }
519
520    /// Map complexity to an appropriate transaction count.
521    fn complexity_to_tx_count(complexity: &str) -> u32 {
522        match complexity {
523            "small" => 1000,
524            "medium" => 5000,
525            "large" => 25000,
526            _ => 5000,
527        }
528    }
529}
530
531#[cfg(test)]
532#[allow(clippy::unwrap_used)]
533mod tests {
534    use super::*;
535    use crate::llm::mock_provider::MockLlmProvider;
536
537    #[test]
538    fn test_parse_retail_description() {
539        let provider = MockLlmProvider::new(42);
540        let intent = NlConfigGenerator::parse_intent(
541            "Generate 1 year of retail data for a medium US company",
542            &provider,
543        )
544        .expect("should parse successfully");
545
546        assert_eq!(intent.industry, Some("retail".to_string()));
547        assert_eq!(intent.country, Some("US".to_string()));
548        assert_eq!(intent.company_size, Some("medium".to_string()));
549        assert_eq!(intent.period_months, Some(12));
550    }
551
552    #[test]
553    fn test_parse_manufacturing_with_fraud() {
554        let provider = MockLlmProvider::new(42);
555        let intent = NlConfigGenerator::parse_intent(
556            "Create 6 months of manufacturing data for a large German company with fraud detection",
557            &provider,
558        )
559        .expect("should parse successfully");
560
561        assert_eq!(intent.industry, Some("manufacturing".to_string()));
562        assert_eq!(intent.country, Some("DE".to_string()));
563        assert_eq!(intent.company_size, Some("large".to_string()));
564        assert_eq!(intent.period_months, Some(6));
565        assert!(intent.features.contains(&"fraud".to_string()));
566    }
567
568    #[test]
569    fn test_parse_financial_services_with_audit() {
570        let provider = MockLlmProvider::new(42);
571        let intent = NlConfigGenerator::parse_intent(
572            "I need 2 years of financial services data for audit testing with SOX controls",
573            &provider,
574        )
575        .expect("should parse successfully");
576
577        assert_eq!(intent.industry, Some("financial_services".to_string()));
578        assert_eq!(intent.period_months, Some(24));
579        assert!(intent.features.contains(&"audit".to_string()));
580        assert!(intent.features.contains(&"controls".to_string()));
581    }
582
583    #[test]
584    fn test_parse_healthcare_small() {
585        let provider = MockLlmProvider::new(42);
586        let intent = NlConfigGenerator::parse_intent(
587            "Small healthcare company in Japan, 3 months of data",
588            &provider,
589        )
590        .expect("should parse successfully");
591
592        assert_eq!(intent.industry, Some("healthcare".to_string()));
593        assert_eq!(intent.country, Some("JP".to_string()));
594        assert_eq!(intent.company_size, Some("small".to_string()));
595        assert_eq!(intent.period_months, Some(3));
596    }
597
598    #[test]
599    fn test_parse_technology_with_banking() {
600        let provider = MockLlmProvider::new(42);
601        let intent = NlConfigGenerator::parse_intent(
602            "Generate data for a technology startup in Singapore with banking and KYC",
603            &provider,
604        )
605        .expect("should parse successfully");
606
607        assert_eq!(intent.industry, Some("technology".to_string()));
608        assert_eq!(intent.country, Some("SG".to_string()));
609        assert_eq!(intent.company_size, Some("small".to_string()));
610        assert!(intent.features.contains(&"banking".to_string()));
611    }
612
613    #[test]
614    fn test_parse_word_numbers() {
615        let provider = MockLlmProvider::new(42);
616        let intent =
617            NlConfigGenerator::parse_intent("Generate two years of retail data", &provider)
618                .expect("should parse successfully");
619
620        assert_eq!(intent.period_months, Some(24));
621    }
622
623    #[test]
624    fn test_parse_multiple_features() {
625        let provider = MockLlmProvider::new(42);
626        let intent = NlConfigGenerator::parse_intent(
627            "Manufacturing data with fraud detection, audit trail, process mining, and intercompany consolidation",
628            &provider,
629        )
630        .expect("should parse successfully");
631
632        assert_eq!(intent.industry, Some("manufacturing".to_string()));
633        assert!(intent.features.contains(&"fraud".to_string()));
634        assert!(intent.features.contains(&"audit".to_string()));
635        assert!(intent.features.contains(&"process_mining".to_string()));
636        assert!(intent.features.contains(&"intercompany".to_string()));
637    }
638
639    #[test]
640    fn test_intent_to_yaml_basic() {
641        let intent = ConfigIntent {
642            industry: Some("retail".to_string()),
643            country: Some("US".to_string()),
644            company_size: Some("medium".to_string()),
645            period_months: Some(12),
646            features: vec![],
647        };
648
649        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
650
651        assert!(yaml.contains("industry: retail"));
652        assert!(yaml.contains("period_months: 12"));
653        assert!(yaml.contains("currency: \"USD\""));
654        assert!(yaml.contains("country: \"US\""));
655        assert!(yaml.contains("complexity: medium"));
656        assert!(yaml.contains("count: 5000"));
657    }
658
659    #[test]
660    fn test_intent_to_yaml_with_features() {
661        let intent = ConfigIntent {
662            industry: Some("manufacturing".to_string()),
663            country: Some("DE".to_string()),
664            company_size: Some("large".to_string()),
665            period_months: Some(24),
666            features: vec![
667                "fraud".to_string(),
668                "audit".to_string(),
669                "controls".to_string(),
670            ],
671        };
672
673        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
674
675        assert!(yaml.contains("industry: manufacturing"));
676        assert!(yaml.contains("currency: \"EUR\""));
677        assert!(yaml.contains("complexity: large"));
678        assert!(yaml.contains("count: 25000"));
679        assert!(yaml.contains("fraud:"));
680        assert!(yaml.contains("audit_standards:"));
681        assert!(yaml.contains("internal_controls:"));
682    }
683
684    #[test]
685    fn test_intent_to_yaml_defaults() {
686        let intent = ConfigIntent::default();
687
688        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
689
690        // Should use defaults
691        assert!(yaml.contains("industry: manufacturing"));
692        assert!(yaml.contains("period_months: 12"));
693        assert!(yaml.contains("complexity: medium"));
694    }
695
696    #[test]
697    fn test_intent_to_yaml_invalid_period() {
698        let intent = ConfigIntent {
699            period_months: Some(0),
700            ..ConfigIntent::default()
701        };
702
703        let result = NlConfigGenerator::intent_to_yaml(&intent);
704        assert!(result.is_err());
705
706        let intent = ConfigIntent {
707            period_months: Some(121),
708            ..ConfigIntent::default()
709        };
710
711        let result = NlConfigGenerator::intent_to_yaml(&intent);
712        assert!(result.is_err());
713    }
714
715    #[test]
716    fn test_generate_end_to_end() {
717        let provider = MockLlmProvider::new(42);
718        let yaml = NlConfigGenerator::generate(
719            "Generate 1 year of retail data for a medium US company with fraud detection",
720            &provider,
721        )
722        .expect("should generate YAML");
723
724        assert!(yaml.contains("industry: retail"));
725        assert!(yaml.contains("period_months: 12"));
726        assert!(yaml.contains("currency: \"USD\""));
727        assert!(yaml.contains("fraud:"));
728        assert!(yaml.contains("complexity: medium"));
729    }
730
731    #[test]
732    fn test_generate_empty_description() {
733        let provider = MockLlmProvider::new(42);
734        let result = NlConfigGenerator::generate("", &provider);
735        assert!(result.is_err());
736
737        let result = NlConfigGenerator::generate("   ", &provider);
738        assert!(result.is_err());
739    }
740
741    #[test]
742    fn test_extract_json_from_response() {
743        let content = r#"Here is the parsed output: {"industry": "retail", "country": "US"} done"#;
744        let json = NlConfigGenerator::extract_json(content);
745        assert!(json.is_some());
746        assert_eq!(
747            json.expect("json should be present"),
748            r#"{"industry": "retail", "country": "US"}"#
749        );
750    }
751
752    #[test]
753    fn test_extract_json_nested() {
754        let content = r#"{"industry": "retail", "features": ["fraud", "audit"]}"#;
755        let json = NlConfigGenerator::extract_json(content);
756        assert!(json.is_some());
757    }
758
759    #[test]
760    fn test_extract_json_missing() {
761        let content = "No JSON here at all";
762        let json = NlConfigGenerator::extract_json(content);
763        assert!(json.is_none());
764    }
765
766    #[test]
767    fn test_parse_llm_response_valid() {
768        let content = r#"{"industry": "retail", "country": "US", "company_size": "medium", "period_months": 12, "features": ["fraud"]}"#;
769        let intent =
770            NlConfigGenerator::parse_llm_response(content).expect("should parse valid JSON");
771
772        assert_eq!(intent.industry, Some("retail".to_string()));
773        assert_eq!(intent.country, Some("US".to_string()));
774        assert_eq!(intent.company_size, Some("medium".to_string()));
775        assert_eq!(intent.period_months, Some(12));
776        assert_eq!(intent.features, vec!["fraud".to_string()]);
777    }
778
779    #[test]
780    fn test_parse_llm_response_partial() {
781        let content = r#"{"industry": "retail"}"#;
782        let intent =
783            NlConfigGenerator::parse_llm_response(content).expect("should parse partial JSON");
784
785        assert_eq!(intent.industry, Some("retail".to_string()));
786        assert_eq!(intent.country, None);
787        assert!(intent.features.is_empty());
788    }
789
790    #[test]
791    fn test_country_to_currency_mapping() {
792        assert_eq!(NlConfigGenerator::country_to_currency("US"), "USD");
793        assert_eq!(NlConfigGenerator::country_to_currency("DE"), "EUR");
794        assert_eq!(NlConfigGenerator::country_to_currency("GB"), "GBP");
795        assert_eq!(NlConfigGenerator::country_to_currency("JP"), "JPY");
796        assert_eq!(NlConfigGenerator::country_to_currency("CN"), "CNY");
797        assert_eq!(NlConfigGenerator::country_to_currency("BR"), "BRL");
798        assert_eq!(NlConfigGenerator::country_to_currency("XX"), "USD"); // Unknown defaults to USD
799    }
800
801    #[test]
802    fn test_merge_intents() {
803        let primary = ConfigIntent {
804            industry: Some("retail".to_string()),
805            country: None,
806            company_size: None,
807            period_months: Some(12),
808            features: vec![],
809        };
810        let fallback = ConfigIntent {
811            industry: Some("manufacturing".to_string()),
812            country: Some("DE".to_string()),
813            company_size: Some("large".to_string()),
814            period_months: Some(6),
815            features: vec!["fraud".to_string()],
816        };
817
818        let merged = NlConfigGenerator::merge_intents(primary, fallback);
819        assert_eq!(merged.industry, Some("retail".to_string())); // primary wins
820        assert_eq!(merged.country, Some("DE".to_string())); // fallback fills gap
821        assert_eq!(merged.company_size, Some("large".to_string())); // fallback fills gap
822        assert_eq!(merged.period_months, Some(12)); // primary wins
823        assert_eq!(merged.features, vec!["fraud".to_string()]); // fallback since primary empty
824    }
825
826    #[test]
827    fn test_parse_uk_country() {
828        let provider = MockLlmProvider::new(42);
829        let intent = NlConfigGenerator::parse_intent(
830            "Generate data for a UK manufacturing company",
831            &provider,
832        )
833        .expect("should parse successfully");
834
835        assert_eq!(intent.country, Some("GB".to_string()));
836    }
837
838    #[test]
839    fn test_intent_to_yaml_banking_feature() {
840        let intent = ConfigIntent {
841            industry: Some("financial_services".to_string()),
842            country: Some("US".to_string()),
843            company_size: Some("large".to_string()),
844            period_months: Some(12),
845            features: vec!["banking".to_string()],
846        };
847
848        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
849
850        assert!(yaml.contains("banking:"));
851        assert!(yaml.contains("kyc_enabled: true"));
852        assert!(yaml.contains("aml_enabled: true"));
853    }
854
855    #[test]
856    fn test_intent_to_yaml_process_mining_feature() {
857        let intent = ConfigIntent {
858            features: vec!["process_mining".to_string()],
859            ..ConfigIntent::default()
860        };
861
862        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
863
864        assert!(yaml.contains("business_processes:"));
865        assert!(yaml.contains("ocel_export: true"));
866    }
867
868    #[test]
869    fn test_intent_to_yaml_distributions_feature() {
870        let intent = ConfigIntent {
871            industry: Some("retail".to_string()),
872            features: vec!["distributions".to_string()],
873            ..ConfigIntent::default()
874        };
875
876        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
877
878        assert!(yaml.contains("distributions:"));
879        assert!(yaml.contains("industry_profile: retail"));
880        assert!(yaml.contains("benford_compliance: true"));
881    }
882}