Skip to main content

datasynth_core/llm/
nl_config.rs

1//! Natural language to YAML configuration generator.
2//!
3//! Takes a free-text description of desired synthetic data (e.g., "Generate 1 year of
4//! retail data for a medium US company with fraud detection") and produces a valid
5//! `GeneratorConfig` YAML string.
6
7use super::provider::{LlmProvider, LlmRequest};
8use crate::error::SynthError;
9
10/// Structured representation of user intent extracted from natural language.
11#[derive(Debug, Clone, Default)]
12pub struct ConfigIntent {
13    /// Target industry (e.g., "retail", "manufacturing", "financial_services").
14    pub industry: Option<String>,
15    /// Country code (e.g., "US", "DE", "GB").
16    pub country: Option<String>,
17    /// Company size: "small", "medium", or "large".
18    pub company_size: Option<String>,
19    /// Duration in months.
20    pub period_months: Option<u32>,
21    /// Requested feature flags (e.g., "fraud", "audit", "banking", "controls").
22    pub features: Vec<String>,
23}
24
25/// Generates YAML configuration from natural language descriptions.
26///
27/// The generator uses a two-phase approach:
28/// 1. Parse the natural language description into a structured [`ConfigIntent`].
29/// 2. Map the intent to a YAML configuration string using preset templates.
30pub struct NlConfigGenerator;
31
32impl NlConfigGenerator {
33    /// Generate a YAML configuration from a natural language description.
34    ///
35    /// Uses the provided LLM provider to help parse the description, with
36    /// keyword-based fallback parsing for reliability.
37    ///
38    /// # Errors
39    ///
40    /// Returns `SynthError::GenerationError` if the description cannot be parsed
41    /// or the resulting configuration is invalid.
42    pub fn generate(description: &str, provider: &dyn LlmProvider) -> Result<String, SynthError> {
43        if description.trim().is_empty() {
44            return Err(SynthError::generation(
45                "Natural language description cannot be empty",
46            ));
47        }
48
49        let intent = Self::parse_intent(description, provider)?;
50        Self::intent_to_yaml(&intent)
51    }
52
53    /// Parse a natural language description into a structured [`ConfigIntent`].
54    ///
55    /// Attempts to use the LLM provider first, then falls back to keyword-based
56    /// extraction for reliability.
57    pub fn parse_intent(
58        description: &str,
59        provider: &dyn LlmProvider,
60    ) -> Result<ConfigIntent, SynthError> {
61        // Try LLM-based parsing first
62        let llm_intent = Self::parse_with_llm(description, provider);
63
64        // Always run keyword-based parsing as fallback/supplement
65        let keyword_intent = Self::parse_with_keywords(description);
66
67        // Merge: prefer LLM results where available, fall back to keywords
68        match llm_intent {
69            Ok(llm) => Ok(Self::merge_intents(llm, keyword_intent)),
70            Err(e) => {
71                tracing::warn!(
72                    "LLM-based config parsing failed, falling back to keyword parsing: {}",
73                    e
74                );
75                Ok(keyword_intent)
76            }
77        }
78    }
79
80    /// Generate a complete YAML configuration from a natural language description.
81    ///
82    /// Unlike [`generate`], which maps to a template via structured intent, this
83    /// method asks the LLM to produce the full YAML directly using the complete
84    /// DataSynth config schema as guidance.  Falls back to [`generate`] if the
85    /// LLM response is not valid YAML or does not contain expected top-level keys.
86    pub fn generate_full(
87        description: &str,
88        provider: &dyn LlmProvider,
89    ) -> Result<String, SynthError> {
90        if description.trim().is_empty() {
91            return Err(SynthError::generation(
92                "Natural language description cannot be empty",
93            ));
94        }
95
96        let system = Self::full_schema_system_prompt();
97        let request = LlmRequest::new(description.to_string())
98            .with_system(system)
99            .with_temperature(0.2)
100            .with_max_tokens(4096);
101
102        match provider.complete(&request) {
103            Ok(response) => {
104                let yaml_text = Self::extract_yaml(&response.content);
105                // Validate that it parses as a YAML mapping with at least one known key
106                if let Ok(value) = serde_yaml::from_str::<serde_yaml::Value>(&yaml_text) {
107                    if let Some(map) = value.as_mapping() {
108                        let known_keys = [
109                            "global",
110                            "companies",
111                            "chart_of_accounts",
112                            "transactions",
113                            "output",
114                            "fraud",
115                            "audit_standards",
116                            "banking",
117                            "internal_controls",
118                            "distributions",
119                            "temporal_patterns",
120                            "document_flows",
121                            "intercompany",
122                            "master_data",
123                            "business_processes",
124                            "hr",
125                            "manufacturing",
126                            "tax",
127                            "treasury",
128                            "esg",
129                            "project_accounting",
130                            "diffusion",
131                            "llm",
132                            "causal",
133                        ];
134                        let has_known = map
135                            .keys()
136                            .any(|k| k.as_str().map(|s| known_keys.contains(&s)).unwrap_or(false));
137                        if has_known {
138                            return Ok(yaml_text);
139                        }
140                    }
141                }
142                // Fallback to template-based generation
143                tracing::warn!(
144                    "LLM full-config response did not contain valid DataSynth YAML; falling back to template"
145                );
146                Self::generate(description, provider)
147            }
148            Err(e) => {
149                tracing::warn!("LLM full-config generation failed: {e}; falling back to template");
150                Self::generate(description, provider)
151            }
152        }
153    }
154
155    /// Extract YAML content from an LLM response, stripping ``` fences if present.
156    pub fn extract_yaml(content: &str) -> String {
157        let trimmed = content.trim();
158
159        // Try to extract from ```yaml ... ``` fenced block
160        if let Some(start) = trimmed.find("```yaml") {
161            let after = &trimmed[start + 7..];
162            if let Some(end) = after.find("```") {
163                return after[..end].trim().to_string();
164            }
165        }
166
167        // Try plain ``` ... ``` fenced block
168        if let Some(start) = trimmed.find("```") {
169            let after = &trimmed[start + 3..];
170            if let Some(end) = after.find("```") {
171                return after[..end].trim().to_string();
172            }
173        }
174
175        // No fences — return as-is
176        trimmed.to_string()
177    }
178
179    /// System prompt describing the full DataSynth configuration schema.
180    ///
181    /// Used by [`generate_full`] so the LLM can produce a complete config.
182    pub fn full_schema_system_prompt() -> String {
183        concat!(
184            "You are a DataSynth configuration generator. Given a natural language description, ",
185            "produce a complete, valid DataSynth YAML configuration.\n\n",
186            "Top-level sections (all optional — include only what is relevant):\n\n",
187            "global:\n",
188            "  industry: <retail|manufacturing|financial_services|healthcare|technology>\n",
189            "  start_date: \"YYYY-MM-DD\"\n",
190            "  period_months: <1-120>\n",
191            "  seed: <integer>\n\n",
192            "companies:\n",
193            "  - code: \"C001\"\n",
194            "    name: \"...\"\n",
195            "    currency: \"USD\"\n",
196            "    country: \"US\"\n\n",
197            "chart_of_accounts:\n",
198            "  complexity: <small|medium|large>\n\n",
199            "transactions:\n",
200            "  count: <number>\n",
201            "  anomaly_rate: <0.0-1.0>\n\n",
202            "output:\n",
203            "  format: <csv|json|parquet>\n",
204            "  compression: <true|false>\n\n",
205            "fraud:\n",
206            "  enabled: true\n",
207            "  types: [fictitious_transaction, duplicate_payment, split_transaction, ...]\n",
208            "  injection_rate: <0.0-1.0>\n\n",
209            "internal_controls:\n",
210            "  enabled: true\n",
211            "  coso_enabled: true\n",
212            "  target_maturity_level: <ad_hoc|repeatable|defined|managed|optimized>\n\n",
213            "distributions:\n",
214            "  enabled: true\n",
215            "  industry_profile: <industry>\n",
216            "  amounts: { enabled: true, distribution_type: lognormal, benford_compliance: true }\n\n",
217            "temporal_patterns:\n",
218            "  enabled: true\n",
219            "  business_days: { enabled: true }\n",
220            "  period_end: { model: exponential }\n\n",
221            "banking:\n",
222            "  enabled: true\n",
223            "  customer_count: <number>\n",
224            "  kyc_enabled: true\n",
225            "  aml_enabled: true\n\n",
226            "audit_standards:\n",
227            "  enabled: true\n",
228            "  isa_compliance: { enabled: true, compliance_level: standard }\n",
229            "  sox: { enabled: true }\n\n",
230            "intercompany:\n",
231            "  enabled: true\n\n",
232            "document_flows:\n",
233            "  p2p: { enabled: true }\n",
234            "  o2c: { enabled: true }\n\n",
235            "master_data:\n",
236            "  vendors: { count: <number> }\n",
237            "  customers: { count: <number> }\n\n",
238            "hr:\n",
239            "  enabled: true\n",
240            "  payroll: { enabled: true }\n\n",
241            "manufacturing:\n",
242            "  enabled: true\n\n",
243            "tax:\n",
244            "  enabled: true\n\n",
245            "treasury:\n",
246            "  enabled: true\n\n",
247            "esg:\n",
248            "  enabled: true\n\n",
249            "project_accounting:\n",
250            "  enabled: true\n\n",
251            "diffusion:\n",
252            "  enabled: true\n",
253            "  backend: <statistical|neural|hybrid>\n\n",
254            "Return ONLY the YAML configuration (optionally inside ```yaml fences), no other text.\n"
255        ).to_string()
256    }
257
258    /// Map a [`ConfigIntent`] to a YAML configuration string.
259    pub fn intent_to_yaml(intent: &ConfigIntent) -> Result<String, SynthError> {
260        let industry = intent.industry.as_deref().unwrap_or("manufacturing");
261        let country = intent.country.as_deref().unwrap_or("US");
262        let complexity = intent.company_size.as_deref().unwrap_or("medium");
263        let period_months = intent.period_months.unwrap_or(12);
264
265        // Validate inputs
266        if !(1..=120).contains(&period_months) {
267            return Err(SynthError::generation(format!(
268                "Period months must be between 1 and 120, got {period_months}"
269            )));
270        }
271
272        let valid_complexities = ["small", "medium", "large"];
273        if !valid_complexities.contains(&complexity) {
274            return Err(SynthError::generation(format!(
275                "Invalid company size '{complexity}', must be one of: small, medium, large"
276            )));
277        }
278
279        let currency = Self::country_to_currency(country);
280        let company_name = Self::industry_company_name(industry);
281
282        let mut yaml = String::with_capacity(2048);
283
284        // Global settings
285        yaml.push_str(&format!(
286            "global:\n  industry: {industry}\n  start_date: \"2024-01-01\"\n  period_months: {period_months}\n  seed: 42\n\n"
287        ));
288
289        // Companies
290        yaml.push_str(&format!(
291            "companies:\n  - code: \"C001\"\n    name: \"{company_name}\"\n    currency: \"{currency}\"\n    country: \"{country}\"\n\n"
292        ));
293
294        // Chart of accounts
295        yaml.push_str(&format!(
296            "chart_of_accounts:\n  complexity: {complexity}\n\n"
297        ));
298
299        // Transactions
300        let tx_count = Self::complexity_to_tx_count(complexity);
301        yaml.push_str(&format!(
302            "transactions:\n  count: {tx_count}\n  anomaly_rate: 0.02\n\n"
303        ));
304
305        // Output
306        yaml.push_str("output:\n  format: csv\n  compression: false\n\n");
307
308        // Feature-specific sections
309        for feature in &intent.features {
310            match feature.as_str() {
311                "fraud" => {
312                    yaml.push_str(
313                        "fraud:\n  enabled: true\n  types:\n    - fictitious_transaction\n    - duplicate_payment\n    - split_transaction\n  injection_rate: 0.03\n\n",
314                    );
315                }
316                "audit" => {
317                    yaml.push_str(
318                        "audit_standards:\n  enabled: true\n  isa_compliance:\n    enabled: true\n    compliance_level: standard\n    framework: isa\n  analytical_procedures:\n    enabled: true\n    procedures_per_account: 3\n  confirmations:\n    enabled: true\n    positive_response_rate: 0.85\n  sox:\n    enabled: true\n    materiality_threshold: 10000.0\n\n",
319                    );
320                }
321                "banking" => {
322                    yaml.push_str(
323                        "banking:\n  enabled: true\n  customer_count: 100\n  account_types:\n    - checking\n    - savings\n    - loan\n  kyc_enabled: true\n  aml_enabled: true\n\n",
324                    );
325                }
326                "controls" => {
327                    yaml.push_str(
328                        "internal_controls:\n  enabled: true\n  coso_enabled: true\n  include_entity_level_controls: true\n  target_maturity_level: \"managed\"\n  exception_rate: 0.02\n  sod_violation_rate: 0.01\n\n",
329                    );
330                }
331                "process_mining" => {
332                    yaml.push_str(
333                        "business_processes:\n  enabled: true\n  ocel_export: true\n  p2p:\n    enabled: true\n  o2c:\n    enabled: true\n\n",
334                    );
335                }
336                "intercompany" => {
337                    yaml.push_str(
338                        "intercompany:\n  enabled: true\n  matching_tolerance: 0.01\n  elimination_enabled: true\n\n",
339                    );
340                }
341                "distributions" => {
342                    yaml.push_str(&format!(
343                        "distributions:\n  enabled: true\n  industry_profile: {industry}\n  amounts:\n    enabled: true\n    distribution_type: lognormal\n    benford_compliance: true\n\n"
344                    ));
345                }
346                other => {
347                    tracing::warn!(
348                        "Unknown NL config feature '{}' ignored. Valid features: fraud, audit, banking, controls, process_mining, intercompany, distributions",
349                        other
350                    );
351                }
352            }
353        }
354
355        Ok(yaml)
356    }
357
358    /// Attempt LLM-based parsing of the description.
359    fn parse_with_llm(
360        description: &str,
361        provider: &dyn LlmProvider,
362    ) -> Result<ConfigIntent, SynthError> {
363        let system_prompt = "You are a configuration parser. Extract structured fields from a natural language description of desired synthetic data generation. Return ONLY a JSON object with these fields: industry (string or null), country (string or null), company_size (string or null), period_months (number or null), features (array of strings). Valid industries: retail, manufacturing, financial_services, healthcare, technology. Valid sizes: small, medium, large. Valid features: fraud, audit, banking, controls, process_mining, intercompany, distributions.";
364
365        let request = LlmRequest::new(description)
366            .with_system(system_prompt.to_string())
367            .with_temperature(0.1)
368            .with_max_tokens(512);
369
370        let response = provider.complete(&request)?;
371        Self::parse_llm_response(&response.content)
372    }
373
374    /// Parse the LLM response JSON into a ConfigIntent.
375    fn parse_llm_response(content: &str) -> Result<ConfigIntent, SynthError> {
376        // Try to find JSON in the response
377        let json_str = Self::extract_json(content)
378            .ok_or_else(|| SynthError::generation("No JSON found in LLM response"))?;
379
380        let value: serde_json::Value = serde_json::from_str(json_str)
381            .map_err(|e| SynthError::generation(format!("Failed to parse LLM JSON: {e}")))?;
382
383        let industry = value
384            .get("industry")
385            .and_then(|v| v.as_str())
386            .map(String::from);
387        let country = value
388            .get("country")
389            .and_then(|v| v.as_str())
390            .map(String::from);
391        let company_size = value
392            .get("company_size")
393            .and_then(|v| v.as_str())
394            .map(String::from);
395        let period_months = value
396            .get("period_months")
397            .and_then(serde_json::Value::as_u64)
398            .map(|v| v as u32);
399        let features = value
400            .get("features")
401            .and_then(|v| v.as_array())
402            .map(|arr| {
403                arr.iter()
404                    .filter_map(|v| v.as_str().map(String::from))
405                    .collect()
406            })
407            .unwrap_or_default();
408
409        Ok(ConfigIntent {
410            industry,
411            country,
412            company_size,
413            period_months,
414            features,
415        })
416    }
417
418    /// Extract a JSON object substring from potentially noisy LLM output.
419    fn extract_json(content: &str) -> Option<&str> {
420        super::json_utils::extract_json_object(content)
421    }
422
423    /// Keyword-based parsing as a reliable fallback.
424    fn parse_with_keywords(description: &str) -> ConfigIntent {
425        let lower = description.to_lowercase();
426
427        let industry = Self::extract_industry(&lower);
428        let country = Self::extract_country(&lower);
429        let company_size = Self::extract_size(&lower);
430        let period_months = Self::extract_period(&lower);
431        let features = Self::extract_features(&lower);
432
433        ConfigIntent {
434            industry,
435            country,
436            company_size,
437            period_months,
438            features,
439        }
440    }
441
442    /// Extract industry from lowercased text.
443    ///
444    /// Uses a scoring approach: each industry gets points for keyword matches,
445    /// and the highest-scoring industry wins. This avoids order-dependent
446    /// issues where "banking" in a feature context incorrectly triggers
447    /// "financial_services" over "technology".
448    fn extract_industry(text: &str) -> Option<String> {
449        let patterns: &[(&[&str], &str)] = &[
450            (
451                &["retail", "store", "shop", "e-commerce", "ecommerce"],
452                "retail",
453            ),
454            (
455                &["manufactur", "factory", "production", "assembly"],
456                "manufacturing",
457            ),
458            (
459                &[
460                    "financial",
461                    "finance",
462                    "insurance",
463                    "fintech",
464                    "investment firm",
465                ],
466                "financial_services",
467            ),
468            (
469                &["health", "hospital", "medical", "pharma", "clinic"],
470                "healthcare",
471            ),
472            (
473                &["tech", "software", "saas", "startup", "digital"],
474                "technology",
475            ),
476        ];
477
478        let mut best: Option<(&str, usize)> = None;
479        for (keywords, industry) in patterns {
480            let count = keywords.iter().filter(|kw| text.contains(*kw)).count();
481            if count > 0 && (best.is_none() || count > best.expect("checked is_some").1) {
482                best = Some((industry, count));
483            }
484        }
485        best.map(|(industry, _)| industry.to_string())
486    }
487
488    /// Extract country from lowercased text.
489    fn extract_country(text: &str) -> Option<String> {
490        // Check full country names first (most reliable), then short codes.
491        // Short codes like "in", "de", "us" can clash with English words,
492        // so we only use unambiguous short codes.
493        let name_patterns = [
494            (&["united states", "u.s.", "america"][..], "US"),
495            (&["germany", "german"][..], "DE"),
496            (&["united kingdom", "british", "england"][..], "GB"),
497            (&["china", "chinese"][..], "CN"),
498            (&["japan", "japanese"][..], "JP"),
499            (&["india", "indian"][..], "IN"),
500            (&["brazil", "brazilian"][..], "BR"),
501            (&["mexico", "mexican"][..], "MX"),
502            (&["australia", "australian"][..], "AU"),
503            (&["singapore", "singaporean"][..], "SG"),
504            (&["korea", "korean"][..], "KR"),
505            (&["france", "french"][..], "FR"),
506            (&["canada", "canadian"][..], "CA"),
507        ];
508
509        for (keywords, code) in &name_patterns {
510            if keywords.iter().any(|kw| text.contains(kw)) {
511                return Some(code.to_string());
512            }
513        }
514
515        // Fall back to short codes (padded with spaces).
516        // Excluded: "in" (India - clashes with preposition "in"),
517        //           "de" (Germany - clashes with various uses).
518        let padded = format!(" {text} ");
519        let safe_codes = [
520            (" us ", "US"),
521            (" uk ", "GB"),
522            (" gb ", "GB"),
523            (" cn ", "CN"),
524            (" jp ", "JP"),
525            (" br ", "BR"),
526            (" mx ", "MX"),
527            (" au ", "AU"),
528            (" sg ", "SG"),
529            (" kr ", "KR"),
530            (" fr ", "FR"),
531            (" ca ", "CA"),
532        ];
533
534        for (code_pattern, code) in &safe_codes {
535            if padded.contains(code_pattern) {
536                return Some(code.to_string());
537            }
538        }
539
540        None
541    }
542
543    /// Extract company size from lowercased text.
544    fn extract_size(text: &str) -> Option<String> {
545        if text.contains("small") || text.contains("startup") || text.contains("tiny") {
546            Some("small".to_string())
547        } else if text.contains("large")
548            || text.contains("enterprise")
549            || text.contains("big")
550            || text.contains("multinational")
551            || text.contains("fortune 500")
552        {
553            Some("large".to_string())
554        } else if text.contains("medium")
555            || text.contains("mid-size")
556            || text.contains("midsize")
557            || text.contains("mid size")
558        {
559            Some("medium".to_string())
560        } else {
561            None
562        }
563    }
564
565    /// Extract period in months from lowercased text.
566    fn extract_period(text: &str) -> Option<u32> {
567        // Match patterns like "1 year", "2 years", "6 months", "18 months"
568        // Also handle "one year", "two years", etc.
569        let word_numbers = [
570            ("one", 1u32),
571            ("two", 2),
572            ("three", 3),
573            ("four", 4),
574            ("five", 5),
575            ("six", 6),
576            ("twelve", 12),
577            ("eighteen", 18),
578            ("twenty-four", 24),
579        ];
580
581        // Try "N year(s)" pattern
582        for (word, num) in &word_numbers {
583            if text.contains(&format!("{word} year")) {
584                return Some(num * 12);
585            }
586            if text.contains(&format!("{word} month")) {
587                return Some(*num);
588            }
589        }
590
591        // Try numeric patterns: "N year(s)", "N month(s)"
592        let tokens: Vec<&str> = text.split_whitespace().collect();
593        for window in tokens.windows(2) {
594            if let Ok(num) = window[0].parse::<u32>() {
595                if window[1].starts_with("year") {
596                    return Some(num * 12);
597                }
598                if window[1].starts_with("month") {
599                    return Some(num);
600                }
601            }
602        }
603
604        None
605    }
606
607    /// Extract feature flags from lowercased text.
608    fn extract_features(text: &str) -> Vec<String> {
609        let mut features = Vec::new();
610
611        let feature_patterns = [
612            (&["fraud", "fraudulent", "suspicious"][..], "fraud"),
613            (&["audit", "auditing", "assurance"][..], "audit"),
614            (&["banking", "bank account", "kyc", "aml"][..], "banking"),
615            (
616                &["control", "sox", "sod", "segregation of duties", "coso"][..],
617                "controls",
618            ),
619            (
620                &["process mining", "ocel", "event log"][..],
621                "process_mining",
622            ),
623            (
624                &["intercompany", "inter-company", "consolidation"][..],
625                "intercompany",
626            ),
627            (
628                &["distribution", "benford", "statistical"][..],
629                "distributions",
630            ),
631        ];
632
633        for (keywords, feature) in &feature_patterns {
634            if keywords.iter().any(|kw| text.contains(kw)) {
635                features.push(feature.to_string());
636            }
637        }
638
639        features
640    }
641
642    /// Merge two ConfigIntents, preferring the primary where available.
643    fn merge_intents(primary: ConfigIntent, fallback: ConfigIntent) -> ConfigIntent {
644        ConfigIntent {
645            industry: primary.industry.or(fallback.industry),
646            country: primary.country.or(fallback.country),
647            company_size: primary.company_size.or(fallback.company_size),
648            period_months: primary.period_months.or(fallback.period_months),
649            features: if primary.features.is_empty() {
650                fallback.features
651            } else {
652                primary.features
653            },
654        }
655    }
656
657    /// Map country code to default currency.
658    fn country_to_currency(country: &str) -> &'static str {
659        match country {
660            "US" | "CA" => "USD",
661            "DE" | "FR" => "EUR",
662            "GB" => "GBP",
663            "CN" => "CNY",
664            "JP" => "JPY",
665            "IN" => "INR",
666            "BR" => "BRL",
667            "MX" => "MXN",
668            "AU" => "AUD",
669            "SG" => "SGD",
670            "KR" => "KRW",
671            _ => "USD",
672        }
673    }
674
675    /// Generate a company name based on industry.
676    fn industry_company_name(industry: &str) -> &'static str {
677        match industry {
678            "retail" => "Retail Corp",
679            "manufacturing" => "Manufacturing Industries Inc",
680            "financial_services" => "Financial Services Group",
681            "healthcare" => "HealthCare Solutions",
682            "technology" => "TechCorp Solutions",
683            _ => "DataSynth Corp",
684        }
685    }
686
687    /// Map complexity to an appropriate transaction count.
688    fn complexity_to_tx_count(complexity: &str) -> u32 {
689        match complexity {
690            "small" => 1000,
691            "medium" => 5000,
692            "large" => 25000,
693            _ => 5000,
694        }
695    }
696}
697
698#[cfg(test)]
699#[allow(clippy::unwrap_used)]
700mod tests {
701    use super::*;
702    use crate::llm::mock_provider::MockLlmProvider;
703
704    #[test]
705    fn test_parse_retail_description() {
706        let provider = MockLlmProvider::new(42);
707        let intent = NlConfigGenerator::parse_intent(
708            "Generate 1 year of retail data for a medium US company",
709            &provider,
710        )
711        .expect("should parse successfully");
712
713        assert_eq!(intent.industry, Some("retail".to_string()));
714        assert_eq!(intent.country, Some("US".to_string()));
715        assert_eq!(intent.company_size, Some("medium".to_string()));
716        assert_eq!(intent.period_months, Some(12));
717    }
718
719    #[test]
720    fn test_parse_manufacturing_with_fraud() {
721        let provider = MockLlmProvider::new(42);
722        let intent = NlConfigGenerator::parse_intent(
723            "Create 6 months of manufacturing data for a large German company with fraud detection",
724            &provider,
725        )
726        .expect("should parse successfully");
727
728        assert_eq!(intent.industry, Some("manufacturing".to_string()));
729        assert_eq!(intent.country, Some("DE".to_string()));
730        assert_eq!(intent.company_size, Some("large".to_string()));
731        assert_eq!(intent.period_months, Some(6));
732        assert!(intent.features.contains(&"fraud".to_string()));
733    }
734
735    #[test]
736    fn test_parse_financial_services_with_audit() {
737        let provider = MockLlmProvider::new(42);
738        let intent = NlConfigGenerator::parse_intent(
739            "I need 2 years of financial services data for audit testing with SOX controls",
740            &provider,
741        )
742        .expect("should parse successfully");
743
744        assert_eq!(intent.industry, Some("financial_services".to_string()));
745        assert_eq!(intent.period_months, Some(24));
746        assert!(intent.features.contains(&"audit".to_string()));
747        assert!(intent.features.contains(&"controls".to_string()));
748    }
749
750    #[test]
751    fn test_parse_healthcare_small() {
752        let provider = MockLlmProvider::new(42);
753        let intent = NlConfigGenerator::parse_intent(
754            "Small healthcare company in Japan, 3 months of data",
755            &provider,
756        )
757        .expect("should parse successfully");
758
759        assert_eq!(intent.industry, Some("healthcare".to_string()));
760        assert_eq!(intent.country, Some("JP".to_string()));
761        assert_eq!(intent.company_size, Some("small".to_string()));
762        assert_eq!(intent.period_months, Some(3));
763    }
764
765    #[test]
766    fn test_parse_technology_with_banking() {
767        let provider = MockLlmProvider::new(42);
768        let intent = NlConfigGenerator::parse_intent(
769            "Generate data for a technology startup in Singapore with banking and KYC",
770            &provider,
771        )
772        .expect("should parse successfully");
773
774        assert_eq!(intent.industry, Some("technology".to_string()));
775        assert_eq!(intent.country, Some("SG".to_string()));
776        assert_eq!(intent.company_size, Some("small".to_string()));
777        assert!(intent.features.contains(&"banking".to_string()));
778    }
779
780    #[test]
781    fn test_parse_word_numbers() {
782        let provider = MockLlmProvider::new(42);
783        let intent =
784            NlConfigGenerator::parse_intent("Generate two years of retail data", &provider)
785                .expect("should parse successfully");
786
787        assert_eq!(intent.period_months, Some(24));
788    }
789
790    #[test]
791    fn test_parse_multiple_features() {
792        let provider = MockLlmProvider::new(42);
793        let intent = NlConfigGenerator::parse_intent(
794            "Manufacturing data with fraud detection, audit trail, process mining, and intercompany consolidation",
795            &provider,
796        )
797        .expect("should parse successfully");
798
799        assert_eq!(intent.industry, Some("manufacturing".to_string()));
800        assert!(intent.features.contains(&"fraud".to_string()));
801        assert!(intent.features.contains(&"audit".to_string()));
802        assert!(intent.features.contains(&"process_mining".to_string()));
803        assert!(intent.features.contains(&"intercompany".to_string()));
804    }
805
806    #[test]
807    fn test_intent_to_yaml_basic() {
808        let intent = ConfigIntent {
809            industry: Some("retail".to_string()),
810            country: Some("US".to_string()),
811            company_size: Some("medium".to_string()),
812            period_months: Some(12),
813            features: vec![],
814        };
815
816        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
817
818        assert!(yaml.contains("industry: retail"));
819        assert!(yaml.contains("period_months: 12"));
820        assert!(yaml.contains("currency: \"USD\""));
821        assert!(yaml.contains("country: \"US\""));
822        assert!(yaml.contains("complexity: medium"));
823        assert!(yaml.contains("count: 5000"));
824    }
825
826    #[test]
827    fn test_intent_to_yaml_with_features() {
828        let intent = ConfigIntent {
829            industry: Some("manufacturing".to_string()),
830            country: Some("DE".to_string()),
831            company_size: Some("large".to_string()),
832            period_months: Some(24),
833            features: vec![
834                "fraud".to_string(),
835                "audit".to_string(),
836                "controls".to_string(),
837            ],
838        };
839
840        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
841
842        assert!(yaml.contains("industry: manufacturing"));
843        assert!(yaml.contains("currency: \"EUR\""));
844        assert!(yaml.contains("complexity: large"));
845        assert!(yaml.contains("count: 25000"));
846        assert!(yaml.contains("fraud:"));
847        assert!(yaml.contains("audit_standards:"));
848        assert!(yaml.contains("internal_controls:"));
849    }
850
851    #[test]
852    fn test_intent_to_yaml_defaults() {
853        let intent = ConfigIntent::default();
854
855        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
856
857        // Should use defaults
858        assert!(yaml.contains("industry: manufacturing"));
859        assert!(yaml.contains("period_months: 12"));
860        assert!(yaml.contains("complexity: medium"));
861    }
862
863    #[test]
864    fn test_intent_to_yaml_invalid_period() {
865        let intent = ConfigIntent {
866            period_months: Some(0),
867            ..ConfigIntent::default()
868        };
869
870        let result = NlConfigGenerator::intent_to_yaml(&intent);
871        assert!(result.is_err());
872
873        let intent = ConfigIntent {
874            period_months: Some(121),
875            ..ConfigIntent::default()
876        };
877
878        let result = NlConfigGenerator::intent_to_yaml(&intent);
879        assert!(result.is_err());
880    }
881
882    #[test]
883    fn test_generate_end_to_end() {
884        let provider = MockLlmProvider::new(42);
885        let yaml = NlConfigGenerator::generate(
886            "Generate 1 year of retail data for a medium US company with fraud detection",
887            &provider,
888        )
889        .expect("should generate YAML");
890
891        assert!(yaml.contains("industry: retail"));
892        assert!(yaml.contains("period_months: 12"));
893        assert!(yaml.contains("currency: \"USD\""));
894        assert!(yaml.contains("fraud:"));
895        assert!(yaml.contains("complexity: medium"));
896    }
897
898    #[test]
899    fn test_generate_empty_description() {
900        let provider = MockLlmProvider::new(42);
901        let result = NlConfigGenerator::generate("", &provider);
902        assert!(result.is_err());
903
904        let result = NlConfigGenerator::generate("   ", &provider);
905        assert!(result.is_err());
906    }
907
908    #[test]
909    fn test_extract_json_from_response() {
910        let content = r#"Here is the parsed output: {"industry": "retail", "country": "US"} done"#;
911        let json = NlConfigGenerator::extract_json(content);
912        assert!(json.is_some());
913        assert_eq!(
914            json.expect("json should be present"),
915            r#"{"industry": "retail", "country": "US"}"#
916        );
917    }
918
919    #[test]
920    fn test_extract_json_nested() {
921        let content = r#"{"industry": "retail", "features": ["fraud", "audit"]}"#;
922        let json = NlConfigGenerator::extract_json(content);
923        assert!(json.is_some());
924    }
925
926    #[test]
927    fn test_extract_json_missing() {
928        let content = "No JSON here at all";
929        let json = NlConfigGenerator::extract_json(content);
930        assert!(json.is_none());
931    }
932
933    #[test]
934    fn test_parse_llm_response_valid() {
935        let content = r#"{"industry": "retail", "country": "US", "company_size": "medium", "period_months": 12, "features": ["fraud"]}"#;
936        let intent =
937            NlConfigGenerator::parse_llm_response(content).expect("should parse valid JSON");
938
939        assert_eq!(intent.industry, Some("retail".to_string()));
940        assert_eq!(intent.country, Some("US".to_string()));
941        assert_eq!(intent.company_size, Some("medium".to_string()));
942        assert_eq!(intent.period_months, Some(12));
943        assert_eq!(intent.features, vec!["fraud".to_string()]);
944    }
945
946    #[test]
947    fn test_parse_llm_response_partial() {
948        let content = r#"{"industry": "retail"}"#;
949        let intent =
950            NlConfigGenerator::parse_llm_response(content).expect("should parse partial JSON");
951
952        assert_eq!(intent.industry, Some("retail".to_string()));
953        assert_eq!(intent.country, None);
954        assert!(intent.features.is_empty());
955    }
956
957    #[test]
958    fn test_country_to_currency_mapping() {
959        assert_eq!(NlConfigGenerator::country_to_currency("US"), "USD");
960        assert_eq!(NlConfigGenerator::country_to_currency("DE"), "EUR");
961        assert_eq!(NlConfigGenerator::country_to_currency("GB"), "GBP");
962        assert_eq!(NlConfigGenerator::country_to_currency("JP"), "JPY");
963        assert_eq!(NlConfigGenerator::country_to_currency("CN"), "CNY");
964        assert_eq!(NlConfigGenerator::country_to_currency("BR"), "BRL");
965        assert_eq!(NlConfigGenerator::country_to_currency("XX"), "USD"); // Unknown defaults to USD
966    }
967
968    #[test]
969    fn test_merge_intents() {
970        let primary = ConfigIntent {
971            industry: Some("retail".to_string()),
972            country: None,
973            company_size: None,
974            period_months: Some(12),
975            features: vec![],
976        };
977        let fallback = ConfigIntent {
978            industry: Some("manufacturing".to_string()),
979            country: Some("DE".to_string()),
980            company_size: Some("large".to_string()),
981            period_months: Some(6),
982            features: vec!["fraud".to_string()],
983        };
984
985        let merged = NlConfigGenerator::merge_intents(primary, fallback);
986        assert_eq!(merged.industry, Some("retail".to_string())); // primary wins
987        assert_eq!(merged.country, Some("DE".to_string())); // fallback fills gap
988        assert_eq!(merged.company_size, Some("large".to_string())); // fallback fills gap
989        assert_eq!(merged.period_months, Some(12)); // primary wins
990        assert_eq!(merged.features, vec!["fraud".to_string()]); // fallback since primary empty
991    }
992
993    #[test]
994    fn test_parse_uk_country() {
995        let provider = MockLlmProvider::new(42);
996        let intent = NlConfigGenerator::parse_intent(
997            "Generate data for a UK manufacturing company",
998            &provider,
999        )
1000        .expect("should parse successfully");
1001
1002        assert_eq!(intent.country, Some("GB".to_string()));
1003    }
1004
1005    #[test]
1006    fn test_intent_to_yaml_banking_feature() {
1007        let intent = ConfigIntent {
1008            industry: Some("financial_services".to_string()),
1009            country: Some("US".to_string()),
1010            company_size: Some("large".to_string()),
1011            period_months: Some(12),
1012            features: vec!["banking".to_string()],
1013        };
1014
1015        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
1016
1017        assert!(yaml.contains("banking:"));
1018        assert!(yaml.contains("kyc_enabled: true"));
1019        assert!(yaml.contains("aml_enabled: true"));
1020    }
1021
1022    #[test]
1023    fn test_intent_to_yaml_process_mining_feature() {
1024        let intent = ConfigIntent {
1025            features: vec!["process_mining".to_string()],
1026            ..ConfigIntent::default()
1027        };
1028
1029        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
1030
1031        assert!(yaml.contains("business_processes:"));
1032        assert!(yaml.contains("ocel_export: true"));
1033    }
1034
1035    #[test]
1036    fn test_intent_to_yaml_distributions_feature() {
1037        let intent = ConfigIntent {
1038            industry: Some("retail".to_string()),
1039            features: vec!["distributions".to_string()],
1040            ..ConfigIntent::default()
1041        };
1042
1043        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
1044
1045        assert!(yaml.contains("distributions:"));
1046        assert!(yaml.contains("industry_profile: retail"));
1047        assert!(yaml.contains("benford_compliance: true"));
1048    }
1049
1050    #[test]
1051    fn test_extract_yaml_from_fenced_block() {
1052        let content = "Here is the config:\n```yaml\nglobal:\n  industry: retail\n```\nDone.";
1053        let yaml = NlConfigGenerator::extract_yaml(content);
1054        assert!(yaml.contains("global:"));
1055        assert!(yaml.contains("industry: retail"));
1056        assert!(!yaml.contains("```"));
1057    }
1058
1059    #[test]
1060    fn test_extract_yaml_plain_fences() {
1061        let content = "```\nglobal:\n  seed: 42\n```";
1062        let yaml = NlConfigGenerator::extract_yaml(content);
1063        assert!(yaml.contains("global:"));
1064        assert!(yaml.contains("seed: 42"));
1065        assert!(!yaml.contains("```"));
1066    }
1067
1068    #[test]
1069    fn test_extract_yaml_no_fences() {
1070        let content = "global:\n  industry: manufacturing\n";
1071        let yaml = NlConfigGenerator::extract_yaml(content);
1072        assert!(yaml.contains("global:"));
1073        assert!(yaml.contains("industry: manufacturing"));
1074    }
1075
1076    #[test]
1077    fn test_generate_full_falls_back_to_template() {
1078        // MockLlmProvider returns a fixed response that won't parse as valid
1079        // DataSynth YAML, so generate_full should fall back to template-based
1080        let provider = MockLlmProvider::new(42);
1081        let yaml = NlConfigGenerator::generate_full(
1082            "Generate 1 year of retail data for a medium US company",
1083            &provider,
1084        )
1085        .expect("should fall back to template-based generation");
1086
1087        assert!(yaml.contains("industry: retail"));
1088        assert!(yaml.contains("period_months: 12"));
1089    }
1090
1091    #[test]
1092    fn test_generate_full_empty_description() {
1093        let provider = MockLlmProvider::new(42);
1094        let result = NlConfigGenerator::generate_full("", &provider);
1095        assert!(result.is_err());
1096
1097        let result = NlConfigGenerator::generate_full("   ", &provider);
1098        assert!(result.is_err());
1099    }
1100
1101    #[test]
1102    fn test_full_schema_system_prompt_covers_key_sections() {
1103        let prompt = NlConfigGenerator::full_schema_system_prompt();
1104        assert!(prompt.contains("global:"));
1105        assert!(prompt.contains("companies:"));
1106        assert!(prompt.contains("chart_of_accounts:"));
1107        assert!(prompt.contains("transactions:"));
1108        assert!(prompt.contains("fraud:"));
1109        assert!(prompt.contains("banking:"));
1110        assert!(prompt.contains("distributions:"));
1111        assert!(prompt.contains("diffusion:"));
1112    }
1113}