Skip to main content

datasynth_core/llm/
nl_config.rs

1//! Natural language to YAML configuration generator.
2//!
3//! Takes a free-text description of desired synthetic data (e.g., "Generate 1 year of
4//! retail data for a medium US company with fraud detection") and produces a valid
5//! `GeneratorConfig` YAML string.
6
7use super::provider::{LlmProvider, LlmRequest};
8use crate::error::SynthError;
9
10/// Structured representation of user intent extracted from natural language.
11#[derive(Debug, Clone, Default)]
12pub struct ConfigIntent {
13    /// Target industry (e.g., "retail", "manufacturing", "financial_services").
14    pub industry: Option<String>,
15    /// Country code (e.g., "US", "DE", "GB").
16    pub country: Option<String>,
17    /// Company size: "small", "medium", or "large".
18    pub company_size: Option<String>,
19    /// Duration in months.
20    pub period_months: Option<u32>,
21    /// Requested feature flags (e.g., "fraud", "audit", "banking", "controls").
22    pub features: Vec<String>,
23}
24
25/// Generates YAML configuration from natural language descriptions.
26///
27/// The generator uses a two-phase approach:
28/// 1. Parse the natural language description into a structured [`ConfigIntent`].
29/// 2. Map the intent to a YAML configuration string using preset templates.
30pub struct NlConfigGenerator;
31
32impl NlConfigGenerator {
33    /// Generate a YAML configuration from a natural language description.
34    ///
35    /// Uses the provided LLM provider to help parse the description, with
36    /// keyword-based fallback parsing for reliability.
37    ///
38    /// # Errors
39    ///
40    /// Returns `SynthError::GenerationError` if the description cannot be parsed
41    /// or the resulting configuration is invalid.
42    pub fn generate(description: &str, provider: &dyn LlmProvider) -> Result<String, SynthError> {
43        if description.trim().is_empty() {
44            return Err(SynthError::generation(
45                "Natural language description cannot be empty",
46            ));
47        }
48
49        let intent = Self::parse_intent(description, provider)?;
50        Self::intent_to_yaml(&intent)
51    }
52
53    /// Parse a natural language description into a structured [`ConfigIntent`].
54    ///
55    /// Attempts to use the LLM provider first, then falls back to keyword-based
56    /// extraction for reliability.
57    pub fn parse_intent(
58        description: &str,
59        provider: &dyn LlmProvider,
60    ) -> Result<ConfigIntent, SynthError> {
61        // Try LLM-based parsing first
62        let llm_intent = Self::parse_with_llm(description, provider);
63
64        // Always run keyword-based parsing as fallback/supplement
65        let keyword_intent = Self::parse_with_keywords(description);
66
67        // Merge: prefer LLM results where available, fall back to keywords
68        match llm_intent {
69            Ok(llm) => Ok(Self::merge_intents(llm, keyword_intent)),
70            Err(e) => {
71                tracing::warn!(
72                    "LLM-based config parsing failed, falling back to keyword parsing: {}",
73                    e
74                );
75                Ok(keyword_intent)
76            }
77        }
78    }
79
80    /// Map a [`ConfigIntent`] to a YAML configuration string.
81    pub fn intent_to_yaml(intent: &ConfigIntent) -> Result<String, SynthError> {
82        let industry = intent.industry.as_deref().unwrap_or("manufacturing");
83        let country = intent.country.as_deref().unwrap_or("US");
84        let complexity = intent.company_size.as_deref().unwrap_or("medium");
85        let period_months = intent.period_months.unwrap_or(12);
86
87        // Validate inputs
88        if !(1..=120).contains(&period_months) {
89            return Err(SynthError::generation(format!(
90                "Period months must be between 1 and 120, got {period_months}"
91            )));
92        }
93
94        let valid_complexities = ["small", "medium", "large"];
95        if !valid_complexities.contains(&complexity) {
96            return Err(SynthError::generation(format!(
97                "Invalid company size '{complexity}', must be one of: small, medium, large"
98            )));
99        }
100
101        let currency = Self::country_to_currency(country);
102        let company_name = Self::industry_company_name(industry);
103
104        let mut yaml = String::with_capacity(2048);
105
106        // Global settings
107        yaml.push_str(&format!(
108            "global:\n  industry: {industry}\n  start_date: \"2024-01-01\"\n  period_months: {period_months}\n  seed: 42\n\n"
109        ));
110
111        // Companies
112        yaml.push_str(&format!(
113            "companies:\n  - code: \"C001\"\n    name: \"{company_name}\"\n    currency: \"{currency}\"\n    country: \"{country}\"\n\n"
114        ));
115
116        // Chart of accounts
117        yaml.push_str(&format!(
118            "chart_of_accounts:\n  complexity: {complexity}\n\n"
119        ));
120
121        // Transactions
122        let tx_count = Self::complexity_to_tx_count(complexity);
123        yaml.push_str(&format!(
124            "transactions:\n  count: {tx_count}\n  anomaly_rate: 0.02\n\n"
125        ));
126
127        // Output
128        yaml.push_str("output:\n  format: csv\n  compression: false\n\n");
129
130        // Feature-specific sections
131        for feature in &intent.features {
132            match feature.as_str() {
133                "fraud" => {
134                    yaml.push_str(
135                        "fraud:\n  enabled: true\n  types:\n    - fictitious_transaction\n    - duplicate_payment\n    - split_transaction\n  injection_rate: 0.03\n\n",
136                    );
137                }
138                "audit" => {
139                    yaml.push_str(
140                        "audit_standards:\n  enabled: true\n  isa_compliance:\n    enabled: true\n    compliance_level: standard\n    framework: isa\n  analytical_procedures:\n    enabled: true\n    procedures_per_account: 3\n  confirmations:\n    enabled: true\n    positive_response_rate: 0.85\n  sox:\n    enabled: true\n    materiality_threshold: 10000.0\n\n",
141                    );
142                }
143                "banking" => {
144                    yaml.push_str(
145                        "banking:\n  enabled: true\n  customer_count: 100\n  account_types:\n    - checking\n    - savings\n    - loan\n  kyc_enabled: true\n  aml_enabled: true\n\n",
146                    );
147                }
148                "controls" => {
149                    yaml.push_str(
150                        "internal_controls:\n  enabled: true\n  coso_enabled: true\n  include_entity_level_controls: true\n  target_maturity_level: \"managed\"\n  exception_rate: 0.02\n  sod_violation_rate: 0.01\n\n",
151                    );
152                }
153                "process_mining" => {
154                    yaml.push_str(
155                        "business_processes:\n  enabled: true\n  ocel_export: true\n  p2p:\n    enabled: true\n  o2c:\n    enabled: true\n\n",
156                    );
157                }
158                "intercompany" => {
159                    yaml.push_str(
160                        "intercompany:\n  enabled: true\n  matching_tolerance: 0.01\n  elimination_enabled: true\n\n",
161                    );
162                }
163                "distributions" => {
164                    yaml.push_str(&format!(
165                        "distributions:\n  enabled: true\n  industry_profile: {industry}\n  amounts:\n    enabled: true\n    distribution_type: lognormal\n    benford_compliance: true\n\n"
166                    ));
167                }
168                other => {
169                    tracing::warn!(
170                        "Unknown NL config feature '{}' ignored. Valid features: fraud, audit, banking, controls, process_mining, intercompany, distributions",
171                        other
172                    );
173                }
174            }
175        }
176
177        Ok(yaml)
178    }
179
180    /// Attempt LLM-based parsing of the description.
181    fn parse_with_llm(
182        description: &str,
183        provider: &dyn LlmProvider,
184    ) -> Result<ConfigIntent, SynthError> {
185        let system_prompt = "You are a configuration parser. Extract structured fields from a natural language description of desired synthetic data generation. Return ONLY a JSON object with these fields: industry (string or null), country (string or null), company_size (string or null), period_months (number or null), features (array of strings). Valid industries: retail, manufacturing, financial_services, healthcare, technology. Valid sizes: small, medium, large. Valid features: fraud, audit, banking, controls, process_mining, intercompany, distributions.";
186
187        let request = LlmRequest::new(description)
188            .with_system(system_prompt.to_string())
189            .with_temperature(0.1)
190            .with_max_tokens(512);
191
192        let response = provider.complete(&request)?;
193        Self::parse_llm_response(&response.content)
194    }
195
196    /// Parse the LLM response JSON into a ConfigIntent.
197    fn parse_llm_response(content: &str) -> Result<ConfigIntent, SynthError> {
198        // Try to find JSON in the response
199        let json_str = Self::extract_json(content)
200            .ok_or_else(|| SynthError::generation("No JSON found in LLM response"))?;
201
202        let value: serde_json::Value = serde_json::from_str(json_str)
203            .map_err(|e| SynthError::generation(format!("Failed to parse LLM JSON: {e}")))?;
204
205        let industry = value
206            .get("industry")
207            .and_then(|v| v.as_str())
208            .map(String::from);
209        let country = value
210            .get("country")
211            .and_then(|v| v.as_str())
212            .map(String::from);
213        let company_size = value
214            .get("company_size")
215            .and_then(|v| v.as_str())
216            .map(String::from);
217        let period_months = value
218            .get("period_months")
219            .and_then(serde_json::Value::as_u64)
220            .map(|v| v as u32);
221        let features = value
222            .get("features")
223            .and_then(|v| v.as_array())
224            .map(|arr| {
225                arr.iter()
226                    .filter_map(|v| v.as_str().map(String::from))
227                    .collect()
228            })
229            .unwrap_or_default();
230
231        Ok(ConfigIntent {
232            industry,
233            country,
234            company_size,
235            period_months,
236            features,
237        })
238    }
239
240    /// Extract a JSON object substring from potentially noisy LLM output.
241    fn extract_json(content: &str) -> Option<&str> {
242        // Find the first '{' and matching '}'
243        let start = content.find('{')?;
244        let mut depth = 0i32;
245        for (i, ch) in content[start..].char_indices() {
246            match ch {
247                '{' => depth += 1,
248                '}' => {
249                    depth -= 1;
250                    if depth == 0 {
251                        return Some(&content[start..start + i + 1]);
252                    }
253                }
254                _ => {}
255            }
256        }
257        None
258    }
259
260    /// Keyword-based parsing as a reliable fallback.
261    fn parse_with_keywords(description: &str) -> ConfigIntent {
262        let lower = description.to_lowercase();
263
264        let industry = Self::extract_industry(&lower);
265        let country = Self::extract_country(&lower);
266        let company_size = Self::extract_size(&lower);
267        let period_months = Self::extract_period(&lower);
268        let features = Self::extract_features(&lower);
269
270        ConfigIntent {
271            industry,
272            country,
273            company_size,
274            period_months,
275            features,
276        }
277    }
278
279    /// Extract industry from lowercased text.
280    ///
281    /// Uses a scoring approach: each industry gets points for keyword matches,
282    /// and the highest-scoring industry wins. This avoids order-dependent
283    /// issues where "banking" in a feature context incorrectly triggers
284    /// "financial_services" over "technology".
285    fn extract_industry(text: &str) -> Option<String> {
286        let patterns: &[(&[&str], &str)] = &[
287            (
288                &["retail", "store", "shop", "e-commerce", "ecommerce"],
289                "retail",
290            ),
291            (
292                &["manufactur", "factory", "production", "assembly"],
293                "manufacturing",
294            ),
295            (
296                &[
297                    "financial",
298                    "finance",
299                    "insurance",
300                    "fintech",
301                    "investment firm",
302                ],
303                "financial_services",
304            ),
305            (
306                &["health", "hospital", "medical", "pharma", "clinic"],
307                "healthcare",
308            ),
309            (
310                &["tech", "software", "saas", "startup", "digital"],
311                "technology",
312            ),
313        ];
314
315        let mut best: Option<(&str, usize)> = None;
316        for (keywords, industry) in patterns {
317            let count = keywords.iter().filter(|kw| text.contains(*kw)).count();
318            if count > 0 && (best.is_none() || count > best.expect("checked is_some").1) {
319                best = Some((industry, count));
320            }
321        }
322        best.map(|(industry, _)| industry.to_string())
323    }
324
325    /// Extract country from lowercased text.
326    fn extract_country(text: &str) -> Option<String> {
327        // Check full country names first (most reliable), then short codes.
328        // Short codes like "in", "de", "us" can clash with English words,
329        // so we only use unambiguous short codes.
330        let name_patterns = [
331            (&["united states", "u.s.", "america"][..], "US"),
332            (&["germany", "german"][..], "DE"),
333            (&["united kingdom", "british", "england"][..], "GB"),
334            (&["china", "chinese"][..], "CN"),
335            (&["japan", "japanese"][..], "JP"),
336            (&["india", "indian"][..], "IN"),
337            (&["brazil", "brazilian"][..], "BR"),
338            (&["mexico", "mexican"][..], "MX"),
339            (&["australia", "australian"][..], "AU"),
340            (&["singapore", "singaporean"][..], "SG"),
341            (&["korea", "korean"][..], "KR"),
342            (&["france", "french"][..], "FR"),
343            (&["canada", "canadian"][..], "CA"),
344        ];
345
346        for (keywords, code) in &name_patterns {
347            if keywords.iter().any(|kw| text.contains(kw)) {
348                return Some(code.to_string());
349            }
350        }
351
352        // Fall back to short codes (padded with spaces).
353        // Excluded: "in" (India - clashes with preposition "in"),
354        //           "de" (Germany - clashes with various uses).
355        let padded = format!(" {text} ");
356        let safe_codes = [
357            (" us ", "US"),
358            (" uk ", "GB"),
359            (" gb ", "GB"),
360            (" cn ", "CN"),
361            (" jp ", "JP"),
362            (" br ", "BR"),
363            (" mx ", "MX"),
364            (" au ", "AU"),
365            (" sg ", "SG"),
366            (" kr ", "KR"),
367            (" fr ", "FR"),
368            (" ca ", "CA"),
369        ];
370
371        for (code_pattern, code) in &safe_codes {
372            if padded.contains(code_pattern) {
373                return Some(code.to_string());
374            }
375        }
376
377        None
378    }
379
380    /// Extract company size from lowercased text.
381    fn extract_size(text: &str) -> Option<String> {
382        if text.contains("small") || text.contains("startup") || text.contains("tiny") {
383            Some("small".to_string())
384        } else if text.contains("large")
385            || text.contains("enterprise")
386            || text.contains("big")
387            || text.contains("multinational")
388            || text.contains("fortune 500")
389        {
390            Some("large".to_string())
391        } else if text.contains("medium")
392            || text.contains("mid-size")
393            || text.contains("midsize")
394            || text.contains("mid size")
395        {
396            Some("medium".to_string())
397        } else {
398            None
399        }
400    }
401
402    /// Extract period in months from lowercased text.
403    fn extract_period(text: &str) -> Option<u32> {
404        // Match patterns like "1 year", "2 years", "6 months", "18 months"
405        // Also handle "one year", "two years", etc.
406        let word_numbers = [
407            ("one", 1u32),
408            ("two", 2),
409            ("three", 3),
410            ("four", 4),
411            ("five", 5),
412            ("six", 6),
413            ("twelve", 12),
414            ("eighteen", 18),
415            ("twenty-four", 24),
416        ];
417
418        // Try "N year(s)" pattern
419        for (word, num) in &word_numbers {
420            if text.contains(&format!("{word} year")) {
421                return Some(num * 12);
422            }
423            if text.contains(&format!("{word} month")) {
424                return Some(*num);
425            }
426        }
427
428        // Try numeric patterns: "N year(s)", "N month(s)"
429        let tokens: Vec<&str> = text.split_whitespace().collect();
430        for window in tokens.windows(2) {
431            if let Ok(num) = window[0].parse::<u32>() {
432                if window[1].starts_with("year") {
433                    return Some(num * 12);
434                }
435                if window[1].starts_with("month") {
436                    return Some(num);
437                }
438            }
439        }
440
441        None
442    }
443
444    /// Extract feature flags from lowercased text.
445    fn extract_features(text: &str) -> Vec<String> {
446        let mut features = Vec::new();
447
448        let feature_patterns = [
449            (&["fraud", "fraudulent", "suspicious"][..], "fraud"),
450            (&["audit", "auditing", "assurance"][..], "audit"),
451            (&["banking", "bank account", "kyc", "aml"][..], "banking"),
452            (
453                &["control", "sox", "sod", "segregation of duties", "coso"][..],
454                "controls",
455            ),
456            (
457                &["process mining", "ocel", "event log"][..],
458                "process_mining",
459            ),
460            (
461                &["intercompany", "inter-company", "consolidation"][..],
462                "intercompany",
463            ),
464            (
465                &["distribution", "benford", "statistical"][..],
466                "distributions",
467            ),
468        ];
469
470        for (keywords, feature) in &feature_patterns {
471            if keywords.iter().any(|kw| text.contains(kw)) {
472                features.push(feature.to_string());
473            }
474        }
475
476        features
477    }
478
479    /// Merge two ConfigIntents, preferring the primary where available.
480    fn merge_intents(primary: ConfigIntent, fallback: ConfigIntent) -> ConfigIntent {
481        ConfigIntent {
482            industry: primary.industry.or(fallback.industry),
483            country: primary.country.or(fallback.country),
484            company_size: primary.company_size.or(fallback.company_size),
485            period_months: primary.period_months.or(fallback.period_months),
486            features: if primary.features.is_empty() {
487                fallback.features
488            } else {
489                primary.features
490            },
491        }
492    }
493
494    /// Map country code to default currency.
495    fn country_to_currency(country: &str) -> &'static str {
496        match country {
497            "US" | "CA" => "USD",
498            "DE" | "FR" => "EUR",
499            "GB" => "GBP",
500            "CN" => "CNY",
501            "JP" => "JPY",
502            "IN" => "INR",
503            "BR" => "BRL",
504            "MX" => "MXN",
505            "AU" => "AUD",
506            "SG" => "SGD",
507            "KR" => "KRW",
508            _ => "USD",
509        }
510    }
511
512    /// Generate a company name based on industry.
513    fn industry_company_name(industry: &str) -> &'static str {
514        match industry {
515            "retail" => "Retail Corp",
516            "manufacturing" => "Manufacturing Industries Inc",
517            "financial_services" => "Financial Services Group",
518            "healthcare" => "HealthCare Solutions",
519            "technology" => "TechCorp Solutions",
520            _ => "DataSynth Corp",
521        }
522    }
523
524    /// Map complexity to an appropriate transaction count.
525    fn complexity_to_tx_count(complexity: &str) -> u32 {
526        match complexity {
527            "small" => 1000,
528            "medium" => 5000,
529            "large" => 25000,
530            _ => 5000,
531        }
532    }
533}
534
535#[cfg(test)]
536#[allow(clippy::unwrap_used)]
537mod tests {
538    use super::*;
539    use crate::llm::mock_provider::MockLlmProvider;
540
541    #[test]
542    fn test_parse_retail_description() {
543        let provider = MockLlmProvider::new(42);
544        let intent = NlConfigGenerator::parse_intent(
545            "Generate 1 year of retail data for a medium US company",
546            &provider,
547        )
548        .expect("should parse successfully");
549
550        assert_eq!(intent.industry, Some("retail".to_string()));
551        assert_eq!(intent.country, Some("US".to_string()));
552        assert_eq!(intent.company_size, Some("medium".to_string()));
553        assert_eq!(intent.period_months, Some(12));
554    }
555
556    #[test]
557    fn test_parse_manufacturing_with_fraud() {
558        let provider = MockLlmProvider::new(42);
559        let intent = NlConfigGenerator::parse_intent(
560            "Create 6 months of manufacturing data for a large German company with fraud detection",
561            &provider,
562        )
563        .expect("should parse successfully");
564
565        assert_eq!(intent.industry, Some("manufacturing".to_string()));
566        assert_eq!(intent.country, Some("DE".to_string()));
567        assert_eq!(intent.company_size, Some("large".to_string()));
568        assert_eq!(intent.period_months, Some(6));
569        assert!(intent.features.contains(&"fraud".to_string()));
570    }
571
572    #[test]
573    fn test_parse_financial_services_with_audit() {
574        let provider = MockLlmProvider::new(42);
575        let intent = NlConfigGenerator::parse_intent(
576            "I need 2 years of financial services data for audit testing with SOX controls",
577            &provider,
578        )
579        .expect("should parse successfully");
580
581        assert_eq!(intent.industry, Some("financial_services".to_string()));
582        assert_eq!(intent.period_months, Some(24));
583        assert!(intent.features.contains(&"audit".to_string()));
584        assert!(intent.features.contains(&"controls".to_string()));
585    }
586
587    #[test]
588    fn test_parse_healthcare_small() {
589        let provider = MockLlmProvider::new(42);
590        let intent = NlConfigGenerator::parse_intent(
591            "Small healthcare company in Japan, 3 months of data",
592            &provider,
593        )
594        .expect("should parse successfully");
595
596        assert_eq!(intent.industry, Some("healthcare".to_string()));
597        assert_eq!(intent.country, Some("JP".to_string()));
598        assert_eq!(intent.company_size, Some("small".to_string()));
599        assert_eq!(intent.period_months, Some(3));
600    }
601
602    #[test]
603    fn test_parse_technology_with_banking() {
604        let provider = MockLlmProvider::new(42);
605        let intent = NlConfigGenerator::parse_intent(
606            "Generate data for a technology startup in Singapore with banking and KYC",
607            &provider,
608        )
609        .expect("should parse successfully");
610
611        assert_eq!(intent.industry, Some("technology".to_string()));
612        assert_eq!(intent.country, Some("SG".to_string()));
613        assert_eq!(intent.company_size, Some("small".to_string()));
614        assert!(intent.features.contains(&"banking".to_string()));
615    }
616
617    #[test]
618    fn test_parse_word_numbers() {
619        let provider = MockLlmProvider::new(42);
620        let intent =
621            NlConfigGenerator::parse_intent("Generate two years of retail data", &provider)
622                .expect("should parse successfully");
623
624        assert_eq!(intent.period_months, Some(24));
625    }
626
627    #[test]
628    fn test_parse_multiple_features() {
629        let provider = MockLlmProvider::new(42);
630        let intent = NlConfigGenerator::parse_intent(
631            "Manufacturing data with fraud detection, audit trail, process mining, and intercompany consolidation",
632            &provider,
633        )
634        .expect("should parse successfully");
635
636        assert_eq!(intent.industry, Some("manufacturing".to_string()));
637        assert!(intent.features.contains(&"fraud".to_string()));
638        assert!(intent.features.contains(&"audit".to_string()));
639        assert!(intent.features.contains(&"process_mining".to_string()));
640        assert!(intent.features.contains(&"intercompany".to_string()));
641    }
642
643    #[test]
644    fn test_intent_to_yaml_basic() {
645        let intent = ConfigIntent {
646            industry: Some("retail".to_string()),
647            country: Some("US".to_string()),
648            company_size: Some("medium".to_string()),
649            period_months: Some(12),
650            features: vec![],
651        };
652
653        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
654
655        assert!(yaml.contains("industry: retail"));
656        assert!(yaml.contains("period_months: 12"));
657        assert!(yaml.contains("currency: \"USD\""));
658        assert!(yaml.contains("country: \"US\""));
659        assert!(yaml.contains("complexity: medium"));
660        assert!(yaml.contains("count: 5000"));
661    }
662
663    #[test]
664    fn test_intent_to_yaml_with_features() {
665        let intent = ConfigIntent {
666            industry: Some("manufacturing".to_string()),
667            country: Some("DE".to_string()),
668            company_size: Some("large".to_string()),
669            period_months: Some(24),
670            features: vec![
671                "fraud".to_string(),
672                "audit".to_string(),
673                "controls".to_string(),
674            ],
675        };
676
677        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
678
679        assert!(yaml.contains("industry: manufacturing"));
680        assert!(yaml.contains("currency: \"EUR\""));
681        assert!(yaml.contains("complexity: large"));
682        assert!(yaml.contains("count: 25000"));
683        assert!(yaml.contains("fraud:"));
684        assert!(yaml.contains("audit_standards:"));
685        assert!(yaml.contains("internal_controls:"));
686    }
687
688    #[test]
689    fn test_intent_to_yaml_defaults() {
690        let intent = ConfigIntent::default();
691
692        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
693
694        // Should use defaults
695        assert!(yaml.contains("industry: manufacturing"));
696        assert!(yaml.contains("period_months: 12"));
697        assert!(yaml.contains("complexity: medium"));
698    }
699
700    #[test]
701    fn test_intent_to_yaml_invalid_period() {
702        let intent = ConfigIntent {
703            period_months: Some(0),
704            ..ConfigIntent::default()
705        };
706
707        let result = NlConfigGenerator::intent_to_yaml(&intent);
708        assert!(result.is_err());
709
710        let intent = ConfigIntent {
711            period_months: Some(121),
712            ..ConfigIntent::default()
713        };
714
715        let result = NlConfigGenerator::intent_to_yaml(&intent);
716        assert!(result.is_err());
717    }
718
719    #[test]
720    fn test_generate_end_to_end() {
721        let provider = MockLlmProvider::new(42);
722        let yaml = NlConfigGenerator::generate(
723            "Generate 1 year of retail data for a medium US company with fraud detection",
724            &provider,
725        )
726        .expect("should generate YAML");
727
728        assert!(yaml.contains("industry: retail"));
729        assert!(yaml.contains("period_months: 12"));
730        assert!(yaml.contains("currency: \"USD\""));
731        assert!(yaml.contains("fraud:"));
732        assert!(yaml.contains("complexity: medium"));
733    }
734
735    #[test]
736    fn test_generate_empty_description() {
737        let provider = MockLlmProvider::new(42);
738        let result = NlConfigGenerator::generate("", &provider);
739        assert!(result.is_err());
740
741        let result = NlConfigGenerator::generate("   ", &provider);
742        assert!(result.is_err());
743    }
744
745    #[test]
746    fn test_extract_json_from_response() {
747        let content = r#"Here is the parsed output: {"industry": "retail", "country": "US"} done"#;
748        let json = NlConfigGenerator::extract_json(content);
749        assert!(json.is_some());
750        assert_eq!(
751            json.expect("json should be present"),
752            r#"{"industry": "retail", "country": "US"}"#
753        );
754    }
755
756    #[test]
757    fn test_extract_json_nested() {
758        let content = r#"{"industry": "retail", "features": ["fraud", "audit"]}"#;
759        let json = NlConfigGenerator::extract_json(content);
760        assert!(json.is_some());
761    }
762
763    #[test]
764    fn test_extract_json_missing() {
765        let content = "No JSON here at all";
766        let json = NlConfigGenerator::extract_json(content);
767        assert!(json.is_none());
768    }
769
770    #[test]
771    fn test_parse_llm_response_valid() {
772        let content = r#"{"industry": "retail", "country": "US", "company_size": "medium", "period_months": 12, "features": ["fraud"]}"#;
773        let intent =
774            NlConfigGenerator::parse_llm_response(content).expect("should parse valid JSON");
775
776        assert_eq!(intent.industry, Some("retail".to_string()));
777        assert_eq!(intent.country, Some("US".to_string()));
778        assert_eq!(intent.company_size, Some("medium".to_string()));
779        assert_eq!(intent.period_months, Some(12));
780        assert_eq!(intent.features, vec!["fraud".to_string()]);
781    }
782
783    #[test]
784    fn test_parse_llm_response_partial() {
785        let content = r#"{"industry": "retail"}"#;
786        let intent =
787            NlConfigGenerator::parse_llm_response(content).expect("should parse partial JSON");
788
789        assert_eq!(intent.industry, Some("retail".to_string()));
790        assert_eq!(intent.country, None);
791        assert!(intent.features.is_empty());
792    }
793
794    #[test]
795    fn test_country_to_currency_mapping() {
796        assert_eq!(NlConfigGenerator::country_to_currency("US"), "USD");
797        assert_eq!(NlConfigGenerator::country_to_currency("DE"), "EUR");
798        assert_eq!(NlConfigGenerator::country_to_currency("GB"), "GBP");
799        assert_eq!(NlConfigGenerator::country_to_currency("JP"), "JPY");
800        assert_eq!(NlConfigGenerator::country_to_currency("CN"), "CNY");
801        assert_eq!(NlConfigGenerator::country_to_currency("BR"), "BRL");
802        assert_eq!(NlConfigGenerator::country_to_currency("XX"), "USD"); // Unknown defaults to USD
803    }
804
805    #[test]
806    fn test_merge_intents() {
807        let primary = ConfigIntent {
808            industry: Some("retail".to_string()),
809            country: None,
810            company_size: None,
811            period_months: Some(12),
812            features: vec![],
813        };
814        let fallback = ConfigIntent {
815            industry: Some("manufacturing".to_string()),
816            country: Some("DE".to_string()),
817            company_size: Some("large".to_string()),
818            period_months: Some(6),
819            features: vec!["fraud".to_string()],
820        };
821
822        let merged = NlConfigGenerator::merge_intents(primary, fallback);
823        assert_eq!(merged.industry, Some("retail".to_string())); // primary wins
824        assert_eq!(merged.country, Some("DE".to_string())); // fallback fills gap
825        assert_eq!(merged.company_size, Some("large".to_string())); // fallback fills gap
826        assert_eq!(merged.period_months, Some(12)); // primary wins
827        assert_eq!(merged.features, vec!["fraud".to_string()]); // fallback since primary empty
828    }
829
830    #[test]
831    fn test_parse_uk_country() {
832        let provider = MockLlmProvider::new(42);
833        let intent = NlConfigGenerator::parse_intent(
834            "Generate data for a UK manufacturing company",
835            &provider,
836        )
837        .expect("should parse successfully");
838
839        assert_eq!(intent.country, Some("GB".to_string()));
840    }
841
842    #[test]
843    fn test_intent_to_yaml_banking_feature() {
844        let intent = ConfigIntent {
845            industry: Some("financial_services".to_string()),
846            country: Some("US".to_string()),
847            company_size: Some("large".to_string()),
848            period_months: Some(12),
849            features: vec!["banking".to_string()],
850        };
851
852        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
853
854        assert!(yaml.contains("banking:"));
855        assert!(yaml.contains("kyc_enabled: true"));
856        assert!(yaml.contains("aml_enabled: true"));
857    }
858
859    #[test]
860    fn test_intent_to_yaml_process_mining_feature() {
861        let intent = ConfigIntent {
862            features: vec!["process_mining".to_string()],
863            ..ConfigIntent::default()
864        };
865
866        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
867
868        assert!(yaml.contains("business_processes:"));
869        assert!(yaml.contains("ocel_export: true"));
870    }
871
872    #[test]
873    fn test_intent_to_yaml_distributions_feature() {
874        let intent = ConfigIntent {
875            industry: Some("retail".to_string()),
876            features: vec!["distributions".to_string()],
877            ..ConfigIntent::default()
878        };
879
880        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
881
882        assert!(yaml.contains("distributions:"));
883        assert!(yaml.contains("industry_profile: retail"));
884        assert!(yaml.contains("benford_compliance: true"));
885    }
886}