Skip to main content

datasynth_core/llm/
nl_config.rs

1//! Natural language to YAML configuration generator.
2//!
3//! Takes a free-text description of desired synthetic data (e.g., "Generate 1 year of
4//! retail data for a medium US company with fraud detection") and produces a valid
5//! `GeneratorConfig` YAML string.
6
7use super::provider::{LlmProvider, LlmRequest};
8use crate::error::SynthError;
9
10/// Structured representation of user intent extracted from natural language.
11#[derive(Debug, Clone, Default)]
12pub struct ConfigIntent {
13    /// Target industry (e.g., "retail", "manufacturing", "financial_services").
14    pub industry: Option<String>,
15    /// Country code (e.g., "US", "DE", "GB").
16    pub country: Option<String>,
17    /// Company size: "small", "medium", or "large".
18    pub company_size: Option<String>,
19    /// Duration in months.
20    pub period_months: Option<u32>,
21    /// Requested feature flags (e.g., "fraud", "audit", "banking", "controls").
22    pub features: Vec<String>,
23}
24
25/// Generates YAML configuration from natural language descriptions.
26///
27/// The generator uses a two-phase approach:
28/// 1. Parse the natural language description into a structured [`ConfigIntent`].
29/// 2. Map the intent to a YAML configuration string using preset templates.
30pub struct NlConfigGenerator;
31
32impl NlConfigGenerator {
33    /// Generate a YAML configuration from a natural language description.
34    ///
35    /// Uses the provided LLM provider to help parse the description, with
36    /// keyword-based fallback parsing for reliability.
37    ///
38    /// # Errors
39    ///
40    /// Returns `SynthError::GenerationError` if the description cannot be parsed
41    /// or the resulting configuration is invalid.
42    pub fn generate(description: &str, provider: &dyn LlmProvider) -> Result<String, SynthError> {
43        if description.trim().is_empty() {
44            return Err(SynthError::generation(
45                "Natural language description cannot be empty",
46            ));
47        }
48
49        let intent = Self::parse_intent(description, provider)?;
50        Self::intent_to_yaml(&intent)
51    }
52
53    /// Parse a natural language description into a structured [`ConfigIntent`].
54    ///
55    /// Attempts to use the LLM provider first, then falls back to keyword-based
56    /// extraction for reliability.
57    pub fn parse_intent(
58        description: &str,
59        provider: &dyn LlmProvider,
60    ) -> Result<ConfigIntent, SynthError> {
61        // Try LLM-based parsing first
62        let llm_intent = Self::parse_with_llm(description, provider);
63
64        // Always run keyword-based parsing as fallback/supplement
65        let keyword_intent = Self::parse_with_keywords(description);
66
67        // Merge: prefer LLM results where available, fall back to keywords
68        match llm_intent {
69            Ok(llm) => Ok(Self::merge_intents(llm, keyword_intent)),
70            Err(e) => {
71                tracing::warn!(
72                    "LLM-based config parsing failed, falling back to keyword parsing: {}",
73                    e
74                );
75                Ok(keyword_intent)
76            }
77        }
78    }
79
80    /// Map a [`ConfigIntent`] to a YAML configuration string.
81    pub fn intent_to_yaml(intent: &ConfigIntent) -> Result<String, SynthError> {
82        let industry = intent.industry.as_deref().unwrap_or("manufacturing");
83        let country = intent.country.as_deref().unwrap_or("US");
84        let complexity = intent.company_size.as_deref().unwrap_or("medium");
85        let period_months = intent.period_months.unwrap_or(12);
86
87        // Validate inputs
88        if !(1..=120).contains(&period_months) {
89            return Err(SynthError::generation(format!(
90                "Period months must be between 1 and 120, got {}",
91                period_months
92            )));
93        }
94
95        let valid_complexities = ["small", "medium", "large"];
96        if !valid_complexities.contains(&complexity) {
97            return Err(SynthError::generation(format!(
98                "Invalid company size '{}', must be one of: small, medium, large",
99                complexity
100            )));
101        }
102
103        let currency = Self::country_to_currency(country);
104        let company_name = Self::industry_company_name(industry);
105
106        let mut yaml = String::with_capacity(2048);
107
108        // Global settings
109        yaml.push_str(&format!(
110            "global:\n  industry: {}\n  start_date: \"2024-01-01\"\n  period_months: {}\n  seed: 42\n\n",
111            industry, period_months
112        ));
113
114        // Companies
115        yaml.push_str(&format!(
116            "companies:\n  - code: \"C001\"\n    name: \"{}\"\n    currency: \"{}\"\n    country: \"{}\"\n\n",
117            company_name, currency, country
118        ));
119
120        // Chart of accounts
121        yaml.push_str(&format!(
122            "chart_of_accounts:\n  complexity: {}\n\n",
123            complexity
124        ));
125
126        // Transactions
127        let tx_count = Self::complexity_to_tx_count(complexity);
128        yaml.push_str(&format!(
129            "transactions:\n  count: {}\n  anomaly_rate: 0.02\n\n",
130            tx_count
131        ));
132
133        // Output
134        yaml.push_str("output:\n  format: csv\n  compression: false\n\n");
135
136        // Feature-specific sections
137        for feature in &intent.features {
138            match feature.as_str() {
139                "fraud" => {
140                    yaml.push_str(
141                        "fraud:\n  enabled: true\n  types:\n    - fictitious_transaction\n    - duplicate_payment\n    - split_transaction\n  injection_rate: 0.03\n\n",
142                    );
143                }
144                "audit" => {
145                    yaml.push_str(
146                        "audit_standards:\n  enabled: true\n  isa_compliance:\n    enabled: true\n    compliance_level: standard\n    framework: isa\n  analytical_procedures:\n    enabled: true\n    procedures_per_account: 3\n  confirmations:\n    enabled: true\n    positive_response_rate: 0.85\n  sox:\n    enabled: true\n    materiality_threshold: 10000.0\n\n",
147                    );
148                }
149                "banking" => {
150                    yaml.push_str(
151                        "banking:\n  enabled: true\n  customer_count: 100\n  account_types:\n    - checking\n    - savings\n    - loan\n  kyc_enabled: true\n  aml_enabled: true\n\n",
152                    );
153                }
154                "controls" => {
155                    yaml.push_str(
156                        "internal_controls:\n  enabled: true\n  coso_enabled: true\n  include_entity_level_controls: true\n  target_maturity_level: \"managed\"\n  exception_rate: 0.02\n  sod_violation_rate: 0.01\n\n",
157                    );
158                }
159                "process_mining" => {
160                    yaml.push_str(
161                        "business_processes:\n  enabled: true\n  ocel_export: true\n  p2p:\n    enabled: true\n  o2c:\n    enabled: true\n\n",
162                    );
163                }
164                "intercompany" => {
165                    yaml.push_str(
166                        "intercompany:\n  enabled: true\n  matching_tolerance: 0.01\n  elimination_enabled: true\n\n",
167                    );
168                }
169                "distributions" => {
170                    yaml.push_str(&format!(
171                        "distributions:\n  enabled: true\n  industry_profile: {}\n  amounts:\n    enabled: true\n    distribution_type: lognormal\n    benford_compliance: true\n\n",
172                        industry
173                    ));
174                }
175                other => {
176                    tracing::warn!(
177                        "Unknown NL config feature '{}' ignored. Valid features: fraud, audit, banking, controls, process_mining, intercompany, distributions",
178                        other
179                    );
180                }
181            }
182        }
183
184        Ok(yaml)
185    }
186
187    /// Attempt LLM-based parsing of the description.
188    fn parse_with_llm(
189        description: &str,
190        provider: &dyn LlmProvider,
191    ) -> Result<ConfigIntent, SynthError> {
192        let system_prompt = "You are a configuration parser. Extract structured fields from a natural language description of desired synthetic data generation. Return ONLY a JSON object with these fields: industry (string or null), country (string or null), company_size (string or null), period_months (number or null), features (array of strings). Valid industries: retail, manufacturing, financial_services, healthcare, technology. Valid sizes: small, medium, large. Valid features: fraud, audit, banking, controls, process_mining, intercompany, distributions.";
193
194        let request = LlmRequest::new(description)
195            .with_system(system_prompt.to_string())
196            .with_temperature(0.1)
197            .with_max_tokens(512);
198
199        let response = provider.complete(&request)?;
200        Self::parse_llm_response(&response.content)
201    }
202
203    /// Parse the LLM response JSON into a ConfigIntent.
204    fn parse_llm_response(content: &str) -> Result<ConfigIntent, SynthError> {
205        // Try to find JSON in the response
206        let json_str = Self::extract_json(content)
207            .ok_or_else(|| SynthError::generation("No JSON found in LLM response"))?;
208
209        let value: serde_json::Value = serde_json::from_str(json_str)
210            .map_err(|e| SynthError::generation(format!("Failed to parse LLM JSON: {}", e)))?;
211
212        let industry = value
213            .get("industry")
214            .and_then(|v| v.as_str())
215            .map(String::from);
216        let country = value
217            .get("country")
218            .and_then(|v| v.as_str())
219            .map(String::from);
220        let company_size = value
221            .get("company_size")
222            .and_then(|v| v.as_str())
223            .map(String::from);
224        let period_months = value
225            .get("period_months")
226            .and_then(|v| v.as_u64())
227            .map(|v| v as u32);
228        let features = value
229            .get("features")
230            .and_then(|v| v.as_array())
231            .map(|arr| {
232                arr.iter()
233                    .filter_map(|v| v.as_str().map(String::from))
234                    .collect()
235            })
236            .unwrap_or_default();
237
238        Ok(ConfigIntent {
239            industry,
240            country,
241            company_size,
242            period_months,
243            features,
244        })
245    }
246
247    /// Extract a JSON object substring from potentially noisy LLM output.
248    fn extract_json(content: &str) -> Option<&str> {
249        // Find the first '{' and matching '}'
250        let start = content.find('{')?;
251        let mut depth = 0i32;
252        for (i, ch) in content[start..].char_indices() {
253            match ch {
254                '{' => depth += 1,
255                '}' => {
256                    depth -= 1;
257                    if depth == 0 {
258                        return Some(&content[start..start + i + 1]);
259                    }
260                }
261                _ => {}
262            }
263        }
264        None
265    }
266
267    /// Keyword-based parsing as a reliable fallback.
268    fn parse_with_keywords(description: &str) -> ConfigIntent {
269        let lower = description.to_lowercase();
270
271        let industry = Self::extract_industry(&lower);
272        let country = Self::extract_country(&lower);
273        let company_size = Self::extract_size(&lower);
274        let period_months = Self::extract_period(&lower);
275        let features = Self::extract_features(&lower);
276
277        ConfigIntent {
278            industry,
279            country,
280            company_size,
281            period_months,
282            features,
283        }
284    }
285
286    /// Extract industry from lowercased text.
287    ///
288    /// Uses a scoring approach: each industry gets points for keyword matches,
289    /// and the highest-scoring industry wins. This avoids order-dependent
290    /// issues where "banking" in a feature context incorrectly triggers
291    /// "financial_services" over "technology".
292    fn extract_industry(text: &str) -> Option<String> {
293        let patterns: &[(&[&str], &str)] = &[
294            (
295                &["retail", "store", "shop", "e-commerce", "ecommerce"],
296                "retail",
297            ),
298            (
299                &["manufactur", "factory", "production", "assembly"],
300                "manufacturing",
301            ),
302            (
303                &[
304                    "financial",
305                    "finance",
306                    "insurance",
307                    "fintech",
308                    "investment firm",
309                ],
310                "financial_services",
311            ),
312            (
313                &["health", "hospital", "medical", "pharma", "clinic"],
314                "healthcare",
315            ),
316            (
317                &["tech", "software", "saas", "startup", "digital"],
318                "technology",
319            ),
320        ];
321
322        let mut best: Option<(&str, usize)> = None;
323        for (keywords, industry) in patterns {
324            let count = keywords.iter().filter(|kw| text.contains(*kw)).count();
325            if count > 0 && (best.is_none() || count > best.expect("checked is_some").1) {
326                best = Some((industry, count));
327            }
328        }
329        best.map(|(industry, _)| industry.to_string())
330    }
331
332    /// Extract country from lowercased text.
333    fn extract_country(text: &str) -> Option<String> {
334        // Check full country names first (most reliable), then short codes.
335        // Short codes like "in", "de", "us" can clash with English words,
336        // so we only use unambiguous short codes.
337        let name_patterns = [
338            (&["united states", "u.s.", "america"][..], "US"),
339            (&["germany", "german"][..], "DE"),
340            (&["united kingdom", "british", "england"][..], "GB"),
341            (&["china", "chinese"][..], "CN"),
342            (&["japan", "japanese"][..], "JP"),
343            (&["india", "indian"][..], "IN"),
344            (&["brazil", "brazilian"][..], "BR"),
345            (&["mexico", "mexican"][..], "MX"),
346            (&["australia", "australian"][..], "AU"),
347            (&["singapore", "singaporean"][..], "SG"),
348            (&["korea", "korean"][..], "KR"),
349            (&["france", "french"][..], "FR"),
350            (&["canada", "canadian"][..], "CA"),
351        ];
352
353        for (keywords, code) in &name_patterns {
354            if keywords.iter().any(|kw| text.contains(kw)) {
355                return Some(code.to_string());
356            }
357        }
358
359        // Fall back to short codes (padded with spaces).
360        // Excluded: "in" (India - clashes with preposition "in"),
361        //           "de" (Germany - clashes with various uses).
362        let padded = format!(" {} ", text);
363        let safe_codes = [
364            (" us ", "US"),
365            (" uk ", "GB"),
366            (" gb ", "GB"),
367            (" cn ", "CN"),
368            (" jp ", "JP"),
369            (" br ", "BR"),
370            (" mx ", "MX"),
371            (" au ", "AU"),
372            (" sg ", "SG"),
373            (" kr ", "KR"),
374            (" fr ", "FR"),
375            (" ca ", "CA"),
376        ];
377
378        for (code_pattern, code) in &safe_codes {
379            if padded.contains(code_pattern) {
380                return Some(code.to_string());
381            }
382        }
383
384        None
385    }
386
387    /// Extract company size from lowercased text.
388    fn extract_size(text: &str) -> Option<String> {
389        if text.contains("small") || text.contains("startup") || text.contains("tiny") {
390            Some("small".to_string())
391        } else if text.contains("large")
392            || text.contains("enterprise")
393            || text.contains("big")
394            || text.contains("multinational")
395            || text.contains("fortune 500")
396        {
397            Some("large".to_string())
398        } else if text.contains("medium")
399            || text.contains("mid-size")
400            || text.contains("midsize")
401            || text.contains("mid size")
402        {
403            Some("medium".to_string())
404        } else {
405            None
406        }
407    }
408
409    /// Extract period in months from lowercased text.
410    fn extract_period(text: &str) -> Option<u32> {
411        // Match patterns like "1 year", "2 years", "6 months", "18 months"
412        // Also handle "one year", "two years", etc.
413        let word_numbers = [
414            ("one", 1u32),
415            ("two", 2),
416            ("three", 3),
417            ("four", 4),
418            ("five", 5),
419            ("six", 6),
420            ("twelve", 12),
421            ("eighteen", 18),
422            ("twenty-four", 24),
423        ];
424
425        // Try "N year(s)" pattern
426        for (word, num) in &word_numbers {
427            if text.contains(&format!("{} year", word)) {
428                return Some(num * 12);
429            }
430            if text.contains(&format!("{} month", word)) {
431                return Some(*num);
432            }
433        }
434
435        // Try numeric patterns: "N year(s)", "N month(s)"
436        let tokens: Vec<&str> = text.split_whitespace().collect();
437        for window in tokens.windows(2) {
438            if let Ok(num) = window[0].parse::<u32>() {
439                if window[1].starts_with("year") {
440                    return Some(num * 12);
441                }
442                if window[1].starts_with("month") {
443                    return Some(num);
444                }
445            }
446        }
447
448        None
449    }
450
451    /// Extract feature flags from lowercased text.
452    fn extract_features(text: &str) -> Vec<String> {
453        let mut features = Vec::new();
454
455        let feature_patterns = [
456            (&["fraud", "fraudulent", "suspicious"][..], "fraud"),
457            (&["audit", "auditing", "assurance"][..], "audit"),
458            (&["banking", "bank account", "kyc", "aml"][..], "banking"),
459            (
460                &["control", "sox", "sod", "segregation of duties", "coso"][..],
461                "controls",
462            ),
463            (
464                &["process mining", "ocel", "event log"][..],
465                "process_mining",
466            ),
467            (
468                &["intercompany", "inter-company", "consolidation"][..],
469                "intercompany",
470            ),
471            (
472                &["distribution", "benford", "statistical"][..],
473                "distributions",
474            ),
475        ];
476
477        for (keywords, feature) in &feature_patterns {
478            if keywords.iter().any(|kw| text.contains(kw)) {
479                features.push(feature.to_string());
480            }
481        }
482
483        features
484    }
485
486    /// Merge two ConfigIntents, preferring the primary where available.
487    fn merge_intents(primary: ConfigIntent, fallback: ConfigIntent) -> ConfigIntent {
488        ConfigIntent {
489            industry: primary.industry.or(fallback.industry),
490            country: primary.country.or(fallback.country),
491            company_size: primary.company_size.or(fallback.company_size),
492            period_months: primary.period_months.or(fallback.period_months),
493            features: if primary.features.is_empty() {
494                fallback.features
495            } else {
496                primary.features
497            },
498        }
499    }
500
501    /// Map country code to default currency.
502    fn country_to_currency(country: &str) -> &'static str {
503        match country {
504            "US" | "CA" => "USD",
505            "DE" | "FR" => "EUR",
506            "GB" => "GBP",
507            "CN" => "CNY",
508            "JP" => "JPY",
509            "IN" => "INR",
510            "BR" => "BRL",
511            "MX" => "MXN",
512            "AU" => "AUD",
513            "SG" => "SGD",
514            "KR" => "KRW",
515            _ => "USD",
516        }
517    }
518
519    /// Generate a company name based on industry.
520    fn industry_company_name(industry: &str) -> &'static str {
521        match industry {
522            "retail" => "Retail Corp",
523            "manufacturing" => "Manufacturing Industries Inc",
524            "financial_services" => "Financial Services Group",
525            "healthcare" => "HealthCare Solutions",
526            "technology" => "TechCorp Solutions",
527            _ => "DataSynth Corp",
528        }
529    }
530
531    /// Map complexity to an appropriate transaction count.
532    fn complexity_to_tx_count(complexity: &str) -> u32 {
533        match complexity {
534            "small" => 1000,
535            "medium" => 5000,
536            "large" => 25000,
537            _ => 5000,
538        }
539    }
540}
541
542#[cfg(test)]
543#[allow(clippy::unwrap_used)]
544mod tests {
545    use super::*;
546    use crate::llm::mock_provider::MockLlmProvider;
547
548    #[test]
549    fn test_parse_retail_description() {
550        let provider = MockLlmProvider::new(42);
551        let intent = NlConfigGenerator::parse_intent(
552            "Generate 1 year of retail data for a medium US company",
553            &provider,
554        )
555        .expect("should parse successfully");
556
557        assert_eq!(intent.industry, Some("retail".to_string()));
558        assert_eq!(intent.country, Some("US".to_string()));
559        assert_eq!(intent.company_size, Some("medium".to_string()));
560        assert_eq!(intent.period_months, Some(12));
561    }
562
563    #[test]
564    fn test_parse_manufacturing_with_fraud() {
565        let provider = MockLlmProvider::new(42);
566        let intent = NlConfigGenerator::parse_intent(
567            "Create 6 months of manufacturing data for a large German company with fraud detection",
568            &provider,
569        )
570        .expect("should parse successfully");
571
572        assert_eq!(intent.industry, Some("manufacturing".to_string()));
573        assert_eq!(intent.country, Some("DE".to_string()));
574        assert_eq!(intent.company_size, Some("large".to_string()));
575        assert_eq!(intent.period_months, Some(6));
576        assert!(intent.features.contains(&"fraud".to_string()));
577    }
578
579    #[test]
580    fn test_parse_financial_services_with_audit() {
581        let provider = MockLlmProvider::new(42);
582        let intent = NlConfigGenerator::parse_intent(
583            "I need 2 years of financial services data for audit testing with SOX controls",
584            &provider,
585        )
586        .expect("should parse successfully");
587
588        assert_eq!(intent.industry, Some("financial_services".to_string()));
589        assert_eq!(intent.period_months, Some(24));
590        assert!(intent.features.contains(&"audit".to_string()));
591        assert!(intent.features.contains(&"controls".to_string()));
592    }
593
594    #[test]
595    fn test_parse_healthcare_small() {
596        let provider = MockLlmProvider::new(42);
597        let intent = NlConfigGenerator::parse_intent(
598            "Small healthcare company in Japan, 3 months of data",
599            &provider,
600        )
601        .expect("should parse successfully");
602
603        assert_eq!(intent.industry, Some("healthcare".to_string()));
604        assert_eq!(intent.country, Some("JP".to_string()));
605        assert_eq!(intent.company_size, Some("small".to_string()));
606        assert_eq!(intent.period_months, Some(3));
607    }
608
609    #[test]
610    fn test_parse_technology_with_banking() {
611        let provider = MockLlmProvider::new(42);
612        let intent = NlConfigGenerator::parse_intent(
613            "Generate data for a technology startup in Singapore with banking and KYC",
614            &provider,
615        )
616        .expect("should parse successfully");
617
618        assert_eq!(intent.industry, Some("technology".to_string()));
619        assert_eq!(intent.country, Some("SG".to_string()));
620        assert_eq!(intent.company_size, Some("small".to_string()));
621        assert!(intent.features.contains(&"banking".to_string()));
622    }
623
624    #[test]
625    fn test_parse_word_numbers() {
626        let provider = MockLlmProvider::new(42);
627        let intent =
628            NlConfigGenerator::parse_intent("Generate two years of retail data", &provider)
629                .expect("should parse successfully");
630
631        assert_eq!(intent.period_months, Some(24));
632    }
633
634    #[test]
635    fn test_parse_multiple_features() {
636        let provider = MockLlmProvider::new(42);
637        let intent = NlConfigGenerator::parse_intent(
638            "Manufacturing data with fraud detection, audit trail, process mining, and intercompany consolidation",
639            &provider,
640        )
641        .expect("should parse successfully");
642
643        assert_eq!(intent.industry, Some("manufacturing".to_string()));
644        assert!(intent.features.contains(&"fraud".to_string()));
645        assert!(intent.features.contains(&"audit".to_string()));
646        assert!(intent.features.contains(&"process_mining".to_string()));
647        assert!(intent.features.contains(&"intercompany".to_string()));
648    }
649
650    #[test]
651    fn test_intent_to_yaml_basic() {
652        let intent = ConfigIntent {
653            industry: Some("retail".to_string()),
654            country: Some("US".to_string()),
655            company_size: Some("medium".to_string()),
656            period_months: Some(12),
657            features: vec![],
658        };
659
660        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
661
662        assert!(yaml.contains("industry: retail"));
663        assert!(yaml.contains("period_months: 12"));
664        assert!(yaml.contains("currency: \"USD\""));
665        assert!(yaml.contains("country: \"US\""));
666        assert!(yaml.contains("complexity: medium"));
667        assert!(yaml.contains("count: 5000"));
668    }
669
670    #[test]
671    fn test_intent_to_yaml_with_features() {
672        let intent = ConfigIntent {
673            industry: Some("manufacturing".to_string()),
674            country: Some("DE".to_string()),
675            company_size: Some("large".to_string()),
676            period_months: Some(24),
677            features: vec![
678                "fraud".to_string(),
679                "audit".to_string(),
680                "controls".to_string(),
681            ],
682        };
683
684        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
685
686        assert!(yaml.contains("industry: manufacturing"));
687        assert!(yaml.contains("currency: \"EUR\""));
688        assert!(yaml.contains("complexity: large"));
689        assert!(yaml.contains("count: 25000"));
690        assert!(yaml.contains("fraud:"));
691        assert!(yaml.contains("audit_standards:"));
692        assert!(yaml.contains("internal_controls:"));
693    }
694
695    #[test]
696    fn test_intent_to_yaml_defaults() {
697        let intent = ConfigIntent::default();
698
699        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
700
701        // Should use defaults
702        assert!(yaml.contains("industry: manufacturing"));
703        assert!(yaml.contains("period_months: 12"));
704        assert!(yaml.contains("complexity: medium"));
705    }
706
707    #[test]
708    fn test_intent_to_yaml_invalid_period() {
709        let intent = ConfigIntent {
710            period_months: Some(0),
711            ..ConfigIntent::default()
712        };
713
714        let result = NlConfigGenerator::intent_to_yaml(&intent);
715        assert!(result.is_err());
716
717        let intent = ConfigIntent {
718            period_months: Some(121),
719            ..ConfigIntent::default()
720        };
721
722        let result = NlConfigGenerator::intent_to_yaml(&intent);
723        assert!(result.is_err());
724    }
725
726    #[test]
727    fn test_generate_end_to_end() {
728        let provider = MockLlmProvider::new(42);
729        let yaml = NlConfigGenerator::generate(
730            "Generate 1 year of retail data for a medium US company with fraud detection",
731            &provider,
732        )
733        .expect("should generate YAML");
734
735        assert!(yaml.contains("industry: retail"));
736        assert!(yaml.contains("period_months: 12"));
737        assert!(yaml.contains("currency: \"USD\""));
738        assert!(yaml.contains("fraud:"));
739        assert!(yaml.contains("complexity: medium"));
740    }
741
742    #[test]
743    fn test_generate_empty_description() {
744        let provider = MockLlmProvider::new(42);
745        let result = NlConfigGenerator::generate("", &provider);
746        assert!(result.is_err());
747
748        let result = NlConfigGenerator::generate("   ", &provider);
749        assert!(result.is_err());
750    }
751
752    #[test]
753    fn test_extract_json_from_response() {
754        let content = r#"Here is the parsed output: {"industry": "retail", "country": "US"} done"#;
755        let json = NlConfigGenerator::extract_json(content);
756        assert!(json.is_some());
757        assert_eq!(
758            json.expect("json should be present"),
759            r#"{"industry": "retail", "country": "US"}"#
760        );
761    }
762
763    #[test]
764    fn test_extract_json_nested() {
765        let content = r#"{"industry": "retail", "features": ["fraud", "audit"]}"#;
766        let json = NlConfigGenerator::extract_json(content);
767        assert!(json.is_some());
768    }
769
770    #[test]
771    fn test_extract_json_missing() {
772        let content = "No JSON here at all";
773        let json = NlConfigGenerator::extract_json(content);
774        assert!(json.is_none());
775    }
776
777    #[test]
778    fn test_parse_llm_response_valid() {
779        let content = r#"{"industry": "retail", "country": "US", "company_size": "medium", "period_months": 12, "features": ["fraud"]}"#;
780        let intent =
781            NlConfigGenerator::parse_llm_response(content).expect("should parse valid JSON");
782
783        assert_eq!(intent.industry, Some("retail".to_string()));
784        assert_eq!(intent.country, Some("US".to_string()));
785        assert_eq!(intent.company_size, Some("medium".to_string()));
786        assert_eq!(intent.period_months, Some(12));
787        assert_eq!(intent.features, vec!["fraud".to_string()]);
788    }
789
790    #[test]
791    fn test_parse_llm_response_partial() {
792        let content = r#"{"industry": "retail"}"#;
793        let intent =
794            NlConfigGenerator::parse_llm_response(content).expect("should parse partial JSON");
795
796        assert_eq!(intent.industry, Some("retail".to_string()));
797        assert_eq!(intent.country, None);
798        assert!(intent.features.is_empty());
799    }
800
801    #[test]
802    fn test_country_to_currency_mapping() {
803        assert_eq!(NlConfigGenerator::country_to_currency("US"), "USD");
804        assert_eq!(NlConfigGenerator::country_to_currency("DE"), "EUR");
805        assert_eq!(NlConfigGenerator::country_to_currency("GB"), "GBP");
806        assert_eq!(NlConfigGenerator::country_to_currency("JP"), "JPY");
807        assert_eq!(NlConfigGenerator::country_to_currency("CN"), "CNY");
808        assert_eq!(NlConfigGenerator::country_to_currency("BR"), "BRL");
809        assert_eq!(NlConfigGenerator::country_to_currency("XX"), "USD"); // Unknown defaults to USD
810    }
811
812    #[test]
813    fn test_merge_intents() {
814        let primary = ConfigIntent {
815            industry: Some("retail".to_string()),
816            country: None,
817            company_size: None,
818            period_months: Some(12),
819            features: vec![],
820        };
821        let fallback = ConfigIntent {
822            industry: Some("manufacturing".to_string()),
823            country: Some("DE".to_string()),
824            company_size: Some("large".to_string()),
825            period_months: Some(6),
826            features: vec!["fraud".to_string()],
827        };
828
829        let merged = NlConfigGenerator::merge_intents(primary, fallback);
830        assert_eq!(merged.industry, Some("retail".to_string())); // primary wins
831        assert_eq!(merged.country, Some("DE".to_string())); // fallback fills gap
832        assert_eq!(merged.company_size, Some("large".to_string())); // fallback fills gap
833        assert_eq!(merged.period_months, Some(12)); // primary wins
834        assert_eq!(merged.features, vec!["fraud".to_string()]); // fallback since primary empty
835    }
836
837    #[test]
838    fn test_parse_uk_country() {
839        let provider = MockLlmProvider::new(42);
840        let intent = NlConfigGenerator::parse_intent(
841            "Generate data for a UK manufacturing company",
842            &provider,
843        )
844        .expect("should parse successfully");
845
846        assert_eq!(intent.country, Some("GB".to_string()));
847    }
848
849    #[test]
850    fn test_intent_to_yaml_banking_feature() {
851        let intent = ConfigIntent {
852            industry: Some("financial_services".to_string()),
853            country: Some("US".to_string()),
854            company_size: Some("large".to_string()),
855            period_months: Some(12),
856            features: vec!["banking".to_string()],
857        };
858
859        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
860
861        assert!(yaml.contains("banking:"));
862        assert!(yaml.contains("kyc_enabled: true"));
863        assert!(yaml.contains("aml_enabled: true"));
864    }
865
866    #[test]
867    fn test_intent_to_yaml_process_mining_feature() {
868        let intent = ConfigIntent {
869            features: vec!["process_mining".to_string()],
870            ..ConfigIntent::default()
871        };
872
873        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
874
875        assert!(yaml.contains("business_processes:"));
876        assert!(yaml.contains("ocel_export: true"));
877    }
878
879    #[test]
880    fn test_intent_to_yaml_distributions_feature() {
881        let intent = ConfigIntent {
882            industry: Some("retail".to_string()),
883            features: vec!["distributions".to_string()],
884            ..ConfigIntent::default()
885        };
886
887        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
888
889        assert!(yaml.contains("distributions:"));
890        assert!(yaml.contains("industry_profile: retail"));
891        assert!(yaml.contains("benford_compliance: true"));
892    }
893}