Skip to main content

datasynth_core/llm/
nl_config.rs

1//! Natural language to YAML configuration generator.
2//!
3//! Takes a free-text description of desired synthetic data (e.g., "Generate 1 year of
4//! retail data for a medium US company with fraud detection") and produces a valid
5//! `GeneratorConfig` YAML string.
6
7use super::provider::{LlmProvider, LlmRequest};
8use crate::error::SynthError;
9
10/// Structured representation of user intent extracted from natural language.
11#[derive(Debug, Clone, Default)]
12pub struct ConfigIntent {
13    /// Target industry (e.g., "retail", "manufacturing", "financial_services").
14    pub industry: Option<String>,
15    /// Country code (e.g., "US", "DE", "GB").
16    pub country: Option<String>,
17    /// Company size: "small", "medium", or "large".
18    pub company_size: Option<String>,
19    /// Duration in months.
20    pub period_months: Option<u32>,
21    /// Requested feature flags (e.g., "fraud", "audit", "banking", "controls").
22    pub features: Vec<String>,
23}
24
25/// Generates YAML configuration from natural language descriptions.
26///
27/// The generator uses a two-phase approach:
28/// 1. Parse the natural language description into a structured [`ConfigIntent`].
29/// 2. Map the intent to a YAML configuration string using preset templates.
30pub struct NlConfigGenerator;
31
32impl NlConfigGenerator {
33    /// Generate a YAML configuration from a natural language description.
34    ///
35    /// Uses the provided LLM provider to help parse the description, with
36    /// keyword-based fallback parsing for reliability.
37    ///
38    /// # Errors
39    ///
40    /// Returns `SynthError::GenerationError` if the description cannot be parsed
41    /// or the resulting configuration is invalid.
42    pub fn generate(description: &str, provider: &dyn LlmProvider) -> Result<String, SynthError> {
43        if description.trim().is_empty() {
44            return Err(SynthError::generation(
45                "Natural language description cannot be empty",
46            ));
47        }
48
49        let intent = Self::parse_intent(description, provider)?;
50        Self::intent_to_yaml(&intent)
51    }
52
53    /// Parse a natural language description into a structured [`ConfigIntent`].
54    ///
55    /// Attempts to use the LLM provider first, then falls back to keyword-based
56    /// extraction for reliability.
57    pub fn parse_intent(
58        description: &str,
59        provider: &dyn LlmProvider,
60    ) -> Result<ConfigIntent, SynthError> {
61        // Try LLM-based parsing first
62        let llm_intent = Self::parse_with_llm(description, provider);
63
64        // Always run keyword-based parsing as fallback/supplement
65        let keyword_intent = Self::parse_with_keywords(description);
66
67        // Merge: prefer LLM results where available, fall back to keywords
68        match llm_intent {
69            Ok(llm) => Ok(Self::merge_intents(llm, keyword_intent)),
70            Err(e) => {
71                tracing::warn!(
72                    "LLM-based config parsing failed, falling back to keyword parsing: {}",
73                    e
74                );
75                Ok(keyword_intent)
76            }
77        }
78    }
79
80    /// Generate a complete YAML configuration from a natural language description.
81    ///
82    /// Unlike `generate`, which maps to a template via structured intent, this
83    /// method asks the LLM to produce the full YAML directly using the complete
84    /// DataSynth config schema as guidance.  Falls back to `generate` if the
85    /// LLM response is not valid YAML or does not contain expected top-level keys.
86    pub fn generate_full(
87        description: &str,
88        provider: &dyn LlmProvider,
89    ) -> Result<String, SynthError> {
90        if description.trim().is_empty() {
91            return Err(SynthError::generation(
92                "Natural language description cannot be empty",
93            ));
94        }
95
96        let system = Self::full_schema_system_prompt();
97        let request = LlmRequest::new(description.to_string())
98            .with_system(system)
99            .with_temperature(0.2)
100            .with_max_tokens(4096);
101
102        match provider.complete(&request) {
103            Ok(response) => {
104                let yaml_text = Self::extract_yaml(&response.content);
105                // Validate that it parses as a YAML mapping with at least one known key
106                if let Ok(value) = serde_yaml::from_str::<serde_yaml::Value>(&yaml_text) {
107                    if let Some(map) = value.as_mapping() {
108                        let known_keys = [
109                            "global",
110                            "companies",
111                            "chart_of_accounts",
112                            "transactions",
113                            "output",
114                            "fraud",
115                            "audit_standards",
116                            "banking",
117                            "internal_controls",
118                            "distributions",
119                            "temporal_patterns",
120                            "document_flows",
121                            "intercompany",
122                            "master_data",
123                            "business_processes",
124                            "hr",
125                            "manufacturing",
126                            "tax",
127                            "treasury",
128                            "esg",
129                            "project_accounting",
130                            "diffusion",
131                            "llm",
132                            "causal",
133                        ];
134                        let has_known = map
135                            .keys()
136                            .any(|k| k.as_str().map(|s| known_keys.contains(&s)).unwrap_or(false));
137                        if has_known {
138                            return Ok(yaml_text);
139                        }
140                    }
141                }
142                // Fallback to template-based generation
143                tracing::warn!(
144                    "LLM full-config response did not contain valid DataSynth YAML; falling back to template"
145                );
146                Self::generate(description, provider)
147            }
148            Err(e) => {
149                tracing::warn!("LLM full-config generation failed: {e}; falling back to template");
150                Self::generate(description, provider)
151            }
152        }
153    }
154
155    /// Extract YAML content from an LLM response, stripping ``` fences if present.
156    pub fn extract_yaml(content: &str) -> String {
157        let trimmed = content.trim();
158
159        // Try to extract from ```yaml ... ``` fenced block
160        if let Some(start) = trimmed.find("```yaml") {
161            let after = &trimmed[start + 7..];
162            if let Some(end) = after.find("```") {
163                return after[..end].trim().to_string();
164            }
165        }
166
167        // Try plain ``` ... ``` fenced block
168        if let Some(start) = trimmed.find("```") {
169            let after = &trimmed[start + 3..];
170            if let Some(end) = after.find("```") {
171                return after[..end].trim().to_string();
172            }
173        }
174
175        // No fences — return as-is
176        trimmed.to_string()
177    }
178
179    /// System prompt describing the full DataSynth configuration schema.
180    ///
181    /// Used by `generate_full` so the LLM can produce a complete config.
182    pub fn full_schema_system_prompt() -> String {
183        concat!(
184            "You are a DataSynth configuration generator. Given a natural language description, ",
185            "produce a complete, valid DataSynth YAML configuration.\n\n",
186            "Top-level sections (all optional — include only what is relevant):\n\n",
187            "global:\n",
188            "  industry: <retail|manufacturing|financial_services|healthcare|technology>\n",
189            "  start_date: \"YYYY-MM-DD\"\n",
190            "  period_months: <1-120>\n",
191            "  seed: <integer>\n\n",
192            "companies:\n",
193            "  - code: \"C001\"\n",
194            "    name: \"...\"\n",
195            "    currency: \"USD\"\n",
196            "    country: \"US\"\n\n",
197            "chart_of_accounts:\n",
198            "  complexity: <small|medium|large>\n\n",
199            "transactions:\n",
200            "  count: <number>\n",
201            "  anomaly_rate: <0.0-1.0>\n\n",
202            "output:\n",
203            "  format: <csv|json|parquet>\n",
204            "  compression: <true|false>\n\n",
205            "fraud:\n",
206            "  enabled: true\n",
207            "  types: [fictitious_transaction, duplicate_payment, split_transaction, ...]\n",
208            "  injection_rate: <0.0-1.0>\n\n",
209            "internal_controls:\n",
210            "  enabled: true\n",
211            "  coso_enabled: true\n",
212            "  target_maturity_level: <ad_hoc|repeatable|defined|managed|optimized>\n\n",
213            "distributions:\n",
214            "  enabled: true\n",
215            "  industry_profile: <industry>\n",
216            "  amounts: { enabled: true, distribution_type: lognormal, benford_compliance: true }\n\n",
217            "temporal_patterns:\n",
218            "  enabled: true\n",
219            "  business_days: { enabled: true }\n",
220            "  period_end: { model: exponential }\n\n",
221            "banking:\n",
222            "  enabled: true\n",
223            "  customer_count: <number>\n",
224            "  kyc_enabled: true\n",
225            "  aml_enabled: true\n\n",
226            "audit_standards:\n",
227            "  enabled: true\n",
228            "  isa_compliance: { enabled: true, compliance_level: standard }\n",
229            "  sox: { enabled: true }\n\n",
230            "intercompany:\n",
231            "  enabled: true\n\n",
232            "document_flows:\n",
233            "  p2p: { enabled: true }\n",
234            "  o2c: { enabled: true }\n\n",
235            "master_data:\n",
236            "  vendors: { count: <number> }\n",
237            "  customers: { count: <number> }\n\n",
238            "hr:\n",
239            "  enabled: true\n",
240            "  payroll: { enabled: true }\n\n",
241            "manufacturing:\n",
242            "  enabled: true\n\n",
243            "tax:\n",
244            "  enabled: true\n\n",
245            "treasury:\n",
246            "  enabled: true\n\n",
247            "esg:\n",
248            "  enabled: true\n\n",
249            "project_accounting:\n",
250            "  enabled: true\n\n",
251            "diffusion:\n",
252            "  enabled: true\n",
253            "  backend: <statistical|neural|hybrid>\n\n",
254            "Return ONLY the YAML configuration (optionally inside ```yaml fences), no other text.\n"
255        ).to_string()
256    }
257
258    /// Map a [`ConfigIntent`] to a YAML configuration string.
259    pub fn intent_to_yaml(intent: &ConfigIntent) -> Result<String, SynthError> {
260        let industry = intent.industry.as_deref().unwrap_or("manufacturing");
261        let country = intent.country.as_deref().unwrap_or("US");
262        let complexity = intent.company_size.as_deref().unwrap_or("medium");
263        let period_months = intent.period_months.unwrap_or(12);
264
265        // Validate inputs
266        if !(1..=120).contains(&period_months) {
267            return Err(SynthError::generation(format!(
268                "Period months must be between 1 and 120, got {period_months}"
269            )));
270        }
271
272        let valid_complexities = ["small", "medium", "large"];
273        if !valid_complexities.contains(&complexity) {
274            return Err(SynthError::generation(format!(
275                "Invalid company size '{complexity}', must be one of: small, medium, large"
276            )));
277        }
278
279        let currency = Self::country_to_currency(country);
280        let company_name = Self::industry_company_name(industry);
281
282        let mut yaml = String::with_capacity(2048);
283
284        // Global settings
285        yaml.push_str(&format!(
286            "global:\n  industry: {industry}\n  start_date: \"2024-01-01\"\n  period_months: {period_months}\n  seed: 42\n\n"
287        ));
288
289        // Companies
290        yaml.push_str(&format!(
291            "companies:\n  - code: \"C001\"\n    name: \"{company_name}\"\n    currency: \"{currency}\"\n    country: \"{country}\"\n\n"
292        ));
293
294        // Chart of accounts
295        yaml.push_str(&format!(
296            "chart_of_accounts:\n  complexity: {complexity}\n\n"
297        ));
298
299        // Transactions
300        let tx_count = Self::complexity_to_tx_count(complexity);
301        yaml.push_str(&format!(
302            "transactions:\n  count: {tx_count}\n  anomaly_rate: 0.02\n\n"
303        ));
304
305        // Output
306        yaml.push_str("output:\n  format: csv\n  compression: false\n\n");
307
308        // Feature-specific sections
309        for feature in &intent.features {
310            match feature.as_str() {
311                "fraud" => {
312                    yaml.push_str(
313                        "fraud:\n  enabled: true\n  types:\n    - fictitious_transaction\n    - duplicate_payment\n    - split_transaction\n  injection_rate: 0.03\n\n",
314                    );
315                }
316                "audit" => {
317                    yaml.push_str(
318                        "audit_standards:\n  enabled: true\n  isa_compliance:\n    enabled: true\n    compliance_level: standard\n    framework: isa\n  analytical_procedures:\n    enabled: true\n    procedures_per_account: 3\n  confirmations:\n    enabled: true\n    positive_response_rate: 0.85\n  sox:\n    enabled: true\n    materiality_threshold: 10000.0\n\n",
319                    );
320                }
321                "banking" => {
322                    yaml.push_str(
323                        "banking:\n  enabled: true\n  customer_count: 100\n  account_types:\n    - checking\n    - savings\n    - loan\n  kyc_enabled: true\n  aml_enabled: true\n\n",
324                    );
325                }
326                "controls" => {
327                    yaml.push_str(
328                        "internal_controls:\n  enabled: true\n  coso_enabled: true\n  include_entity_level_controls: true\n  target_maturity_level: \"managed\"\n  exception_rate: 0.02\n  sod_violation_rate: 0.01\n\n",
329                    );
330                }
331                "process_mining" => {
332                    yaml.push_str(
333                        "business_processes:\n  enabled: true\n  ocel_export: true\n  p2p:\n    enabled: true\n  o2c:\n    enabled: true\n\n",
334                    );
335                }
336                "intercompany" => {
337                    yaml.push_str(
338                        "intercompany:\n  enabled: true\n  matching_tolerance: 0.01\n  elimination_enabled: true\n\n",
339                    );
340                }
341                "distributions" => {
342                    yaml.push_str(&format!(
343                        "distributions:\n  enabled: true\n  industry_profile: {industry}\n  amounts:\n    enabled: true\n    distribution_type: lognormal\n    benford_compliance: true\n\n"
344                    ));
345                }
346                other => {
347                    tracing::warn!(
348                        "Unknown NL config feature '{}' ignored. Valid features: fraud, audit, banking, controls, process_mining, intercompany, distributions",
349                        other
350                    );
351                }
352            }
353        }
354
355        Ok(yaml)
356    }
357
358    /// Attempt LLM-based parsing of the description.
359    fn parse_with_llm(
360        description: &str,
361        provider: &dyn LlmProvider,
362    ) -> Result<ConfigIntent, SynthError> {
363        let system_prompt = "You are a configuration parser. Extract structured fields from a natural language description of desired synthetic data generation. Return ONLY a JSON object with these fields: industry (string or null), country (string or null), company_size (string or null), period_months (number or null), features (array of strings). Valid industries: retail, manufacturing, financial_services, healthcare, technology. Valid sizes: small, medium, large. Valid features: fraud, audit, banking, controls, process_mining, intercompany, distributions.";
364
365        let request = LlmRequest::new(description)
366            .with_system(system_prompt.to_string())
367            .with_temperature(0.1)
368            .with_max_tokens(512);
369
370        let response = provider.complete(&request)?;
371        Self::parse_llm_response(&response.content)
372    }
373
374    /// Parse the LLM response JSON into a ConfigIntent.
375    fn parse_llm_response(content: &str) -> Result<ConfigIntent, SynthError> {
376        // Try to find JSON in the response
377        let json_str = Self::extract_json(content)
378            .ok_or_else(|| SynthError::generation("No JSON found in LLM response"))?;
379
380        let value: serde_json::Value = serde_json::from_str(json_str)
381            .map_err(|e| SynthError::generation(format!("Failed to parse LLM JSON: {e}")))?;
382
383        let industry = value
384            .get("industry")
385            .and_then(|v| v.as_str())
386            .map(String::from);
387        let country = value
388            .get("country")
389            .and_then(|v| v.as_str())
390            .map(String::from);
391        let company_size = value
392            .get("company_size")
393            .and_then(|v| v.as_str())
394            .map(String::from);
395        let period_months = value
396            .get("period_months")
397            .and_then(serde_json::Value::as_u64)
398            .map(|v| v as u32);
399        let features = value
400            .get("features")
401            .and_then(|v| v.as_array())
402            .map(|arr| {
403                arr.iter()
404                    .filter_map(|v| v.as_str().map(String::from))
405                    .collect()
406            })
407            .unwrap_or_default();
408
409        Ok(ConfigIntent {
410            industry,
411            country,
412            company_size,
413            period_months,
414            features,
415        })
416    }
417
418    /// Extract a JSON object substring from potentially noisy LLM output.
419    fn extract_json(content: &str) -> Option<&str> {
420        super::json_utils::extract_json_object(content)
421    }
422
423    /// Keyword-based parsing as a reliable fallback.
424    fn parse_with_keywords(description: &str) -> ConfigIntent {
425        let lower = description.to_lowercase();
426
427        let industry = Self::extract_industry(&lower);
428        let country = Self::extract_country(&lower);
429        let company_size = Self::extract_size(&lower);
430        let period_months = Self::extract_period(&lower);
431        let features = Self::extract_features(&lower);
432
433        ConfigIntent {
434            industry,
435            country,
436            company_size,
437            period_months,
438            features,
439        }
440    }
441
442    /// Extract industry from lowercased text.
443    ///
444    /// Uses a scoring approach: each industry gets points for keyword matches,
445    /// and the highest-scoring industry wins. This avoids order-dependent
446    /// issues where "banking" in a feature context incorrectly triggers
447    /// "financial_services" over "technology".
448    fn extract_industry(text: &str) -> Option<String> {
449        let patterns: &[(&[&str], &str)] = &[
450            (
451                &["retail", "store", "shop", "e-commerce", "ecommerce"],
452                "retail",
453            ),
454            (
455                &["manufactur", "factory", "production", "assembly"],
456                "manufacturing",
457            ),
458            (
459                &[
460                    "financial",
461                    "finance",
462                    "insurance",
463                    "fintech",
464                    "investment firm",
465                ],
466                "financial_services",
467            ),
468            (
469                &["health", "hospital", "medical", "pharma", "clinic"],
470                "healthcare",
471            ),
472            (
473                &["tech", "software", "saas", "startup", "digital"],
474                "technology",
475            ),
476        ];
477
478        let mut best: Option<(&str, usize)> = None;
479        for (keywords, industry) in patterns {
480            let count = keywords.iter().filter(|kw| text.contains(*kw)).count();
481            if count > 0 && (best.is_none() || count > best.expect("checked is_some").1) {
482                best = Some((industry, count));
483            }
484        }
485        best.map(|(industry, _)| industry.to_string())
486    }
487
488    /// Extract country from lowercased text.
489    fn extract_country(text: &str) -> Option<String> {
490        // Check full country names first (most reliable), then short codes.
491        // Short codes like "in", "de", "us" can clash with English words,
492        // so we only use unambiguous short codes.
493        let name_patterns = [
494            (&["united states", "u.s.", "america"][..], "US"),
495            (&["germany", "german"][..], "DE"),
496            (&["united kingdom", "british", "england"][..], "GB"),
497            (&["china", "chinese"][..], "CN"),
498            (&["japan", "japanese"][..], "JP"),
499            (&["india", "indian"][..], "IN"),
500            (&["brazil", "brazilian"][..], "BR"),
501            (&["mexico", "mexican"][..], "MX"),
502            (&["australia", "australian"][..], "AU"),
503            (&["singapore", "singaporean"][..], "SG"),
504            (&["korea", "korean"][..], "KR"),
505            (&["france", "french"][..], "FR"),
506            (&["canada", "canadian"][..], "CA"),
507        ];
508
509        for (keywords, code) in &name_patterns {
510            if keywords.iter().any(|kw| text.contains(kw)) {
511                return Some(code.to_string());
512            }
513        }
514
515        // Fall back to short codes (padded with spaces).
516        // Excluded: "in" (India - clashes with preposition "in"),
517        //           "de" (Germany - clashes with various uses).
518        let padded = format!(" {text} ");
519        let safe_codes = [
520            (" us ", "US"),
521            (" uk ", "GB"),
522            (" gb ", "GB"),
523            (" cn ", "CN"),
524            (" jp ", "JP"),
525            (" br ", "BR"),
526            (" mx ", "MX"),
527            (" au ", "AU"),
528            (" sg ", "SG"),
529            (" kr ", "KR"),
530            (" fr ", "FR"),
531            (" ca ", "CA"),
532        ];
533
534        for (code_pattern, code) in &safe_codes {
535            if padded.contains(code_pattern) {
536                return Some(code.to_string());
537            }
538        }
539
540        None
541    }
542
543    /// Extract company size from lowercased text.
544    fn extract_size(text: &str) -> Option<String> {
545        if text.contains("small") || text.contains("startup") || text.contains("tiny") {
546            Some("small".to_string())
547        } else if text.contains("large")
548            || text.contains("enterprise")
549            || text.contains("big")
550            || text.contains("multinational")
551            || text.contains("fortune 500")
552        {
553            Some("large".to_string())
554        } else if text.contains("medium")
555            || text.contains("mid-size")
556            || text.contains("midsize")
557            || text.contains("mid size")
558        {
559            Some("medium".to_string())
560        } else {
561            None
562        }
563    }
564
565    /// Extract period in months from lowercased text.
566    fn extract_period(text: &str) -> Option<u32> {
567        // Match patterns like "1 year", "2 years", "6 months", "18 months"
568        // Also handle "one year", "two years", etc.
569        let word_numbers = [
570            ("one", 1u32),
571            ("two", 2),
572            ("three", 3),
573            ("four", 4),
574            ("five", 5),
575            ("six", 6),
576            ("twelve", 12),
577            ("eighteen", 18),
578            ("twenty-four", 24),
579        ];
580
581        // Try "N year(s)" pattern
582        for (word, num) in &word_numbers {
583            if text.contains(&format!("{word} year")) {
584                return Some(num * 12);
585            }
586            if text.contains(&format!("{word} month")) {
587                return Some(*num);
588            }
589        }
590
591        // Try numeric patterns: "N year(s)", "N month(s)"
592        let tokens: Vec<&str> = text.split_whitespace().collect();
593        for window in tokens.windows(2) {
594            if let Ok(num) = window[0].parse::<u32>() {
595                if window[1].starts_with("year") {
596                    return Some(num * 12);
597                }
598                if window[1].starts_with("month") {
599                    return Some(num);
600                }
601            }
602        }
603
604        None
605    }
606
607    /// Extract feature flags from lowercased text.
608    fn extract_features(text: &str) -> Vec<String> {
609        let mut features = Vec::new();
610
611        let feature_patterns = [
612            (&["fraud", "fraudulent", "suspicious"][..], "fraud"),
613            (&["audit", "auditing", "assurance"][..], "audit"),
614            (&["banking", "bank account", "kyc", "aml"][..], "banking"),
615            (
616                &["control", "sox", "sod", "segregation of duties", "coso"][..],
617                "controls",
618            ),
619            (
620                &["process mining", "ocel", "event log"][..],
621                "process_mining",
622            ),
623            (
624                &["intercompany", "inter-company", "consolidation"][..],
625                "intercompany",
626            ),
627            (
628                &["distribution", "benford", "statistical"][..],
629                "distributions",
630            ),
631        ];
632
633        for (keywords, feature) in &feature_patterns {
634            if keywords.iter().any(|kw| text.contains(kw)) {
635                features.push(feature.to_string());
636            }
637        }
638
639        features
640    }
641
642    /// Merge two ConfigIntents, preferring the primary where available.
643    fn merge_intents(primary: ConfigIntent, fallback: ConfigIntent) -> ConfigIntent {
644        ConfigIntent {
645            industry: primary.industry.or(fallback.industry),
646            country: primary.country.or(fallback.country),
647            company_size: primary.company_size.or(fallback.company_size),
648            period_months: primary.period_months.or(fallback.period_months),
649            features: if primary.features.is_empty() {
650                fallback.features
651            } else {
652                primary.features
653            },
654        }
655    }
656
657    /// Map country code to default currency.
658    fn country_to_currency(country: &str) -> &'static str {
659        match country {
660            "US" | "CA" => "USD",
661            "DE" | "FR" => "EUR",
662            "GB" => "GBP",
663            "CN" => "CNY",
664            "JP" => "JPY",
665            "IN" => "INR",
666            "BR" => "BRL",
667            "MX" => "MXN",
668            "AU" => "AUD",
669            "SG" => "SGD",
670            "KR" => "KRW",
671            _ => "USD",
672        }
673    }
674
675    /// Generate a company name based on industry.
676    fn industry_company_name(industry: &str) -> &'static str {
677        match industry {
678            "retail" => "Retail Corp",
679            "manufacturing" => "Manufacturing Industries Inc",
680            "financial_services" => "Financial Services Group",
681            "healthcare" => "HealthCare Solutions",
682            "technology" => "TechCorp Solutions",
683            _ => "DataSynth Corp",
684        }
685    }
686
687    /// Map complexity to an appropriate transaction count.
688    fn complexity_to_tx_count(complexity: &str) -> u32 {
689        match complexity {
690            "small" => 1000,
691            "medium" => 5000,
692            "large" => 25000,
693            _ => 5000,
694        }
695    }
696}
697
698#[cfg(test)]
699mod tests {
700    use super::*;
701    use crate::llm::mock_provider::MockLlmProvider;
702
703    #[test]
704    fn test_parse_retail_description() {
705        let provider = MockLlmProvider::new(42);
706        let intent = NlConfigGenerator::parse_intent(
707            "Generate 1 year of retail data for a medium US company",
708            &provider,
709        )
710        .expect("should parse successfully");
711
712        assert_eq!(intent.industry, Some("retail".to_string()));
713        assert_eq!(intent.country, Some("US".to_string()));
714        assert_eq!(intent.company_size, Some("medium".to_string()));
715        assert_eq!(intent.period_months, Some(12));
716    }
717
718    #[test]
719    fn test_parse_manufacturing_with_fraud() {
720        let provider = MockLlmProvider::new(42);
721        let intent = NlConfigGenerator::parse_intent(
722            "Create 6 months of manufacturing data for a large German company with fraud detection",
723            &provider,
724        )
725        .expect("should parse successfully");
726
727        assert_eq!(intent.industry, Some("manufacturing".to_string()));
728        assert_eq!(intent.country, Some("DE".to_string()));
729        assert_eq!(intent.company_size, Some("large".to_string()));
730        assert_eq!(intent.period_months, Some(6));
731        assert!(intent.features.contains(&"fraud".to_string()));
732    }
733
734    #[test]
735    fn test_parse_financial_services_with_audit() {
736        let provider = MockLlmProvider::new(42);
737        let intent = NlConfigGenerator::parse_intent(
738            "I need 2 years of financial services data for audit testing with SOX controls",
739            &provider,
740        )
741        .expect("should parse successfully");
742
743        assert_eq!(intent.industry, Some("financial_services".to_string()));
744        assert_eq!(intent.period_months, Some(24));
745        assert!(intent.features.contains(&"audit".to_string()));
746        assert!(intent.features.contains(&"controls".to_string()));
747    }
748
749    #[test]
750    fn test_parse_healthcare_small() {
751        let provider = MockLlmProvider::new(42);
752        let intent = NlConfigGenerator::parse_intent(
753            "Small healthcare company in Japan, 3 months of data",
754            &provider,
755        )
756        .expect("should parse successfully");
757
758        assert_eq!(intent.industry, Some("healthcare".to_string()));
759        assert_eq!(intent.country, Some("JP".to_string()));
760        assert_eq!(intent.company_size, Some("small".to_string()));
761        assert_eq!(intent.period_months, Some(3));
762    }
763
764    #[test]
765    fn test_parse_technology_with_banking() {
766        let provider = MockLlmProvider::new(42);
767        let intent = NlConfigGenerator::parse_intent(
768            "Generate data for a technology startup in Singapore with banking and KYC",
769            &provider,
770        )
771        .expect("should parse successfully");
772
773        assert_eq!(intent.industry, Some("technology".to_string()));
774        assert_eq!(intent.country, Some("SG".to_string()));
775        assert_eq!(intent.company_size, Some("small".to_string()));
776        assert!(intent.features.contains(&"banking".to_string()));
777    }
778
779    #[test]
780    fn test_parse_word_numbers() {
781        let provider = MockLlmProvider::new(42);
782        let intent =
783            NlConfigGenerator::parse_intent("Generate two years of retail data", &provider)
784                .expect("should parse successfully");
785
786        assert_eq!(intent.period_months, Some(24));
787    }
788
789    #[test]
790    fn test_parse_multiple_features() {
791        let provider = MockLlmProvider::new(42);
792        let intent = NlConfigGenerator::parse_intent(
793            "Manufacturing data with fraud detection, audit trail, process mining, and intercompany consolidation",
794            &provider,
795        )
796        .expect("should parse successfully");
797
798        assert_eq!(intent.industry, Some("manufacturing".to_string()));
799        assert!(intent.features.contains(&"fraud".to_string()));
800        assert!(intent.features.contains(&"audit".to_string()));
801        assert!(intent.features.contains(&"process_mining".to_string()));
802        assert!(intent.features.contains(&"intercompany".to_string()));
803    }
804
805    #[test]
806    fn test_intent_to_yaml_basic() {
807        let intent = ConfigIntent {
808            industry: Some("retail".to_string()),
809            country: Some("US".to_string()),
810            company_size: Some("medium".to_string()),
811            period_months: Some(12),
812            features: vec![],
813        };
814
815        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
816
817        assert!(yaml.contains("industry: retail"));
818        assert!(yaml.contains("period_months: 12"));
819        assert!(yaml.contains("currency: \"USD\""));
820        assert!(yaml.contains("country: \"US\""));
821        assert!(yaml.contains("complexity: medium"));
822        assert!(yaml.contains("count: 5000"));
823    }
824
825    #[test]
826    fn test_intent_to_yaml_with_features() {
827        let intent = ConfigIntent {
828            industry: Some("manufacturing".to_string()),
829            country: Some("DE".to_string()),
830            company_size: Some("large".to_string()),
831            period_months: Some(24),
832            features: vec![
833                "fraud".to_string(),
834                "audit".to_string(),
835                "controls".to_string(),
836            ],
837        };
838
839        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
840
841        assert!(yaml.contains("industry: manufacturing"));
842        assert!(yaml.contains("currency: \"EUR\""));
843        assert!(yaml.contains("complexity: large"));
844        assert!(yaml.contains("count: 25000"));
845        assert!(yaml.contains("fraud:"));
846        assert!(yaml.contains("audit_standards:"));
847        assert!(yaml.contains("internal_controls:"));
848    }
849
850    #[test]
851    fn test_intent_to_yaml_defaults() {
852        let intent = ConfigIntent::default();
853
854        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
855
856        // Should use defaults
857        assert!(yaml.contains("industry: manufacturing"));
858        assert!(yaml.contains("period_months: 12"));
859        assert!(yaml.contains("complexity: medium"));
860    }
861
862    #[test]
863    fn test_intent_to_yaml_invalid_period() {
864        let intent = ConfigIntent {
865            period_months: Some(0),
866            ..ConfigIntent::default()
867        };
868
869        let result = NlConfigGenerator::intent_to_yaml(&intent);
870        assert!(result.is_err());
871
872        let intent = ConfigIntent {
873            period_months: Some(121),
874            ..ConfigIntent::default()
875        };
876
877        let result = NlConfigGenerator::intent_to_yaml(&intent);
878        assert!(result.is_err());
879    }
880
881    #[test]
882    fn test_generate_end_to_end() {
883        let provider = MockLlmProvider::new(42);
884        let yaml = NlConfigGenerator::generate(
885            "Generate 1 year of retail data for a medium US company with fraud detection",
886            &provider,
887        )
888        .expect("should generate YAML");
889
890        assert!(yaml.contains("industry: retail"));
891        assert!(yaml.contains("period_months: 12"));
892        assert!(yaml.contains("currency: \"USD\""));
893        assert!(yaml.contains("fraud:"));
894        assert!(yaml.contains("complexity: medium"));
895    }
896
897    #[test]
898    fn test_generate_empty_description() {
899        let provider = MockLlmProvider::new(42);
900        let result = NlConfigGenerator::generate("", &provider);
901        assert!(result.is_err());
902
903        let result = NlConfigGenerator::generate("   ", &provider);
904        assert!(result.is_err());
905    }
906
907    #[test]
908    fn test_extract_json_from_response() {
909        let content = r#"Here is the parsed output: {"industry": "retail", "country": "US"} done"#;
910        let json = NlConfigGenerator::extract_json(content);
911        assert!(json.is_some());
912        assert_eq!(
913            json.expect("json should be present"),
914            r#"{"industry": "retail", "country": "US"}"#
915        );
916    }
917
918    #[test]
919    fn test_extract_json_nested() {
920        let content = r#"{"industry": "retail", "features": ["fraud", "audit"]}"#;
921        let json = NlConfigGenerator::extract_json(content);
922        assert!(json.is_some());
923    }
924
925    #[test]
926    fn test_extract_json_missing() {
927        let content = "No JSON here at all";
928        let json = NlConfigGenerator::extract_json(content);
929        assert!(json.is_none());
930    }
931
932    #[test]
933    fn test_parse_llm_response_valid() {
934        let content = r#"{"industry": "retail", "country": "US", "company_size": "medium", "period_months": 12, "features": ["fraud"]}"#;
935        let intent =
936            NlConfigGenerator::parse_llm_response(content).expect("should parse valid JSON");
937
938        assert_eq!(intent.industry, Some("retail".to_string()));
939        assert_eq!(intent.country, Some("US".to_string()));
940        assert_eq!(intent.company_size, Some("medium".to_string()));
941        assert_eq!(intent.period_months, Some(12));
942        assert_eq!(intent.features, vec!["fraud".to_string()]);
943    }
944
945    #[test]
946    fn test_parse_llm_response_partial() {
947        let content = r#"{"industry": "retail"}"#;
948        let intent =
949            NlConfigGenerator::parse_llm_response(content).expect("should parse partial JSON");
950
951        assert_eq!(intent.industry, Some("retail".to_string()));
952        assert_eq!(intent.country, None);
953        assert!(intent.features.is_empty());
954    }
955
956    #[test]
957    fn test_country_to_currency_mapping() {
958        assert_eq!(NlConfigGenerator::country_to_currency("US"), "USD");
959        assert_eq!(NlConfigGenerator::country_to_currency("DE"), "EUR");
960        assert_eq!(NlConfigGenerator::country_to_currency("GB"), "GBP");
961        assert_eq!(NlConfigGenerator::country_to_currency("JP"), "JPY");
962        assert_eq!(NlConfigGenerator::country_to_currency("CN"), "CNY");
963        assert_eq!(NlConfigGenerator::country_to_currency("BR"), "BRL");
964        assert_eq!(NlConfigGenerator::country_to_currency("XX"), "USD"); // Unknown defaults to USD
965    }
966
967    #[test]
968    fn test_merge_intents() {
969        let primary = ConfigIntent {
970            industry: Some("retail".to_string()),
971            country: None,
972            company_size: None,
973            period_months: Some(12),
974            features: vec![],
975        };
976        let fallback = ConfigIntent {
977            industry: Some("manufacturing".to_string()),
978            country: Some("DE".to_string()),
979            company_size: Some("large".to_string()),
980            period_months: Some(6),
981            features: vec!["fraud".to_string()],
982        };
983
984        let merged = NlConfigGenerator::merge_intents(primary, fallback);
985        assert_eq!(merged.industry, Some("retail".to_string())); // primary wins
986        assert_eq!(merged.country, Some("DE".to_string())); // fallback fills gap
987        assert_eq!(merged.company_size, Some("large".to_string())); // fallback fills gap
988        assert_eq!(merged.period_months, Some(12)); // primary wins
989        assert_eq!(merged.features, vec!["fraud".to_string()]); // fallback since primary empty
990    }
991
992    #[test]
993    fn test_parse_uk_country() {
994        let provider = MockLlmProvider::new(42);
995        let intent = NlConfigGenerator::parse_intent(
996            "Generate data for a UK manufacturing company",
997            &provider,
998        )
999        .expect("should parse successfully");
1000
1001        assert_eq!(intent.country, Some("GB".to_string()));
1002    }
1003
1004    #[test]
1005    fn test_intent_to_yaml_banking_feature() {
1006        let intent = ConfigIntent {
1007            industry: Some("financial_services".to_string()),
1008            country: Some("US".to_string()),
1009            company_size: Some("large".to_string()),
1010            period_months: Some(12),
1011            features: vec!["banking".to_string()],
1012        };
1013
1014        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
1015
1016        assert!(yaml.contains("banking:"));
1017        assert!(yaml.contains("kyc_enabled: true"));
1018        assert!(yaml.contains("aml_enabled: true"));
1019    }
1020
1021    #[test]
1022    fn test_intent_to_yaml_process_mining_feature() {
1023        let intent = ConfigIntent {
1024            features: vec!["process_mining".to_string()],
1025            ..ConfigIntent::default()
1026        };
1027
1028        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
1029
1030        assert!(yaml.contains("business_processes:"));
1031        assert!(yaml.contains("ocel_export: true"));
1032    }
1033
1034    #[test]
1035    fn test_intent_to_yaml_distributions_feature() {
1036        let intent = ConfigIntent {
1037            industry: Some("retail".to_string()),
1038            features: vec!["distributions".to_string()],
1039            ..ConfigIntent::default()
1040        };
1041
1042        let yaml = NlConfigGenerator::intent_to_yaml(&intent).expect("should generate YAML");
1043
1044        assert!(yaml.contains("distributions:"));
1045        assert!(yaml.contains("industry_profile: retail"));
1046        assert!(yaml.contains("benford_compliance: true"));
1047    }
1048
1049    #[test]
1050    fn test_extract_yaml_from_fenced_block() {
1051        let content = "Here is the config:\n```yaml\nglobal:\n  industry: retail\n```\nDone.";
1052        let yaml = NlConfigGenerator::extract_yaml(content);
1053        assert!(yaml.contains("global:"));
1054        assert!(yaml.contains("industry: retail"));
1055        assert!(!yaml.contains("```"));
1056    }
1057
1058    #[test]
1059    fn test_extract_yaml_plain_fences() {
1060        let content = "```\nglobal:\n  seed: 42\n```";
1061        let yaml = NlConfigGenerator::extract_yaml(content);
1062        assert!(yaml.contains("global:"));
1063        assert!(yaml.contains("seed: 42"));
1064        assert!(!yaml.contains("```"));
1065    }
1066
1067    #[test]
1068    fn test_extract_yaml_no_fences() {
1069        let content = "global:\n  industry: manufacturing\n";
1070        let yaml = NlConfigGenerator::extract_yaml(content);
1071        assert!(yaml.contains("global:"));
1072        assert!(yaml.contains("industry: manufacturing"));
1073    }
1074
1075    #[test]
1076    fn test_generate_full_falls_back_to_template() {
1077        // MockLlmProvider returns a fixed response that won't parse as valid
1078        // DataSynth YAML, so generate_full should fall back to template-based
1079        let provider = MockLlmProvider::new(42);
1080        let yaml = NlConfigGenerator::generate_full(
1081            "Generate 1 year of retail data for a medium US company",
1082            &provider,
1083        )
1084        .expect("should fall back to template-based generation");
1085
1086        assert!(yaml.contains("industry: retail"));
1087        assert!(yaml.contains("period_months: 12"));
1088    }
1089
1090    #[test]
1091    fn test_generate_full_empty_description() {
1092        let provider = MockLlmProvider::new(42);
1093        let result = NlConfigGenerator::generate_full("", &provider);
1094        assert!(result.is_err());
1095
1096        let result = NlConfigGenerator::generate_full("   ", &provider);
1097        assert!(result.is_err());
1098    }
1099
1100    #[test]
1101    fn test_full_schema_system_prompt_covers_key_sections() {
1102        let prompt = NlConfigGenerator::full_schema_system_prompt();
1103        assert!(prompt.contains("global:"));
1104        assert!(prompt.contains("companies:"));
1105        assert!(prompt.contains("chart_of_accounts:"));
1106        assert!(prompt.contains("transactions:"));
1107        assert!(prompt.contains("fraud:"));
1108        assert!(prompt.contains("banking:"));
1109        assert!(prompt.contains("distributions:"));
1110        assert!(prompt.contains("diffusion:"));
1111    }
1112}