Skip to main content

datasynth_eval/enhancement/
ai_tuner.rs

1//! AI-powered evaluation-driven tuning loop.
2//!
3//! Wraps [`AutoTuner`] and [`RecommendationEngine`] with an LLM interpretation
4//! layer that provides intelligent gap analysis and creative config suggestions
5//! beyond what the rule-based tuner can derive.
6//!
7//! The tuning loop:
8//! 1. Evaluate generated data against quality thresholds
9//! 2. AutoTuner produces rule-based config patches
10//! 3. LLM interprets remaining gaps and suggests additional patches
11//! 4. Patches are merged, applied, and the cycle repeats
12//! 5. Convergence when health score stabilizes or max iterations reached
13
14use serde::{Deserialize, Serialize};
15
16use datasynth_core::llm::provider::{LlmProvider, LlmRequest};
17
18use super::auto_tuner::{AutoTuneResult, AutoTuner, ConfigPatch};
19use super::recommendation_engine::{EnhancementReport, RecommendationEngine};
20use crate::config::EvaluationThresholds;
21use crate::ComprehensiveEvaluation;
22
23/// Configuration for the AI tuning loop.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct AiTunerConfig {
26    /// Maximum number of tuning iterations.
27    #[serde(default = "default_max_iterations")]
28    pub max_iterations: usize,
29    /// Stop if health score improvement is below this threshold.
30    #[serde(default = "default_convergence_threshold")]
31    pub convergence_threshold: f64,
32    /// Minimum confidence to accept an LLM-suggested patch.
33    #[serde(default = "default_min_confidence")]
34    pub min_confidence: f64,
35    /// Whether to include LLM-generated patches (false = rule-based only).
36    #[serde(default = "default_use_llm")]
37    pub use_llm: bool,
38}
39
40fn default_max_iterations() -> usize {
41    5
42}
43fn default_convergence_threshold() -> f64 {
44    0.01
45}
46fn default_min_confidence() -> f64 {
47    0.5
48}
49fn default_use_llm() -> bool {
50    true
51}
52
53impl Default for AiTunerConfig {
54    fn default() -> Self {
55        Self {
56            max_iterations: default_max_iterations(),
57            convergence_threshold: default_convergence_threshold(),
58            min_confidence: default_min_confidence(),
59            use_llm: default_use_llm(),
60        }
61    }
62}
63
64/// Result of a single tuning iteration.
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct TuningIteration {
67    /// Iteration number (1-based).
68    pub iteration: usize,
69    /// Health score at the start of this iteration.
70    pub health_score: f64,
71    /// Number of failures at the start.
72    pub failure_count: usize,
73    /// Rule-based patches from AutoTuner.
74    pub rule_patches: Vec<ConfigPatch>,
75    /// AI-suggested patches from LLM interpretation.
76    pub ai_patches: Vec<ConfigPatch>,
77    /// Combined patches applied this iteration.
78    pub applied_patches: Vec<ConfigPatch>,
79}
80
81/// Complete result of the AI tuning loop.
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct AiTuneResult {
84    /// All iterations performed.
85    pub iterations: Vec<TuningIteration>,
86    /// Final combined patches (union of all iterations).
87    pub final_patches: Vec<ConfigPatch>,
88    /// Initial health score.
89    pub initial_health_score: f64,
90    /// Final health score.
91    pub final_health_score: f64,
92    /// Whether convergence was reached (vs max iterations).
93    pub converged: bool,
94    /// Human-readable summary.
95    pub summary: String,
96}
97
98impl AiTuneResult {
99    /// Health score improvement from initial to final.
100    pub fn improvement(&self) -> f64 {
101        self.final_health_score - self.initial_health_score
102    }
103}
104
105/// AI-powered auto-tuner that combines rule-based analysis with LLM interpretation.
106pub struct AiTuner<'a> {
107    auto_tuner: AutoTuner,
108    recommendation_engine: RecommendationEngine,
109    provider: &'a dyn LlmProvider,
110    config: AiTunerConfig,
111}
112
113impl<'a> AiTuner<'a> {
114    /// Create a new AI tuner with default thresholds.
115    pub fn new(provider: &'a dyn LlmProvider, config: AiTunerConfig) -> Self {
116        Self {
117            auto_tuner: AutoTuner::new(),
118            recommendation_engine: RecommendationEngine::new(),
119            provider,
120            config,
121        }
122    }
123
124    /// Create with custom evaluation thresholds.
125    pub fn with_thresholds(
126        provider: &'a dyn LlmProvider,
127        config: AiTunerConfig,
128        thresholds: EvaluationThresholds,
129    ) -> Self {
130        Self {
131            auto_tuner: AutoTuner::with_thresholds(thresholds.clone()),
132            recommendation_engine: RecommendationEngine::with_thresholds(thresholds),
133            provider,
134            config,
135        }
136    }
137
138    /// Run a single tuning iteration: analyze evaluation results and produce patches.
139    ///
140    /// This is the core method for integrating into a generation loop.
141    pub fn analyze_iteration(
142        &mut self,
143        evaluation: &ComprehensiveEvaluation,
144        iteration: usize,
145    ) -> TuningIteration {
146        // Rule-based analysis
147        let auto_result = self.auto_tuner.analyze(evaluation);
148        let report = self.recommendation_engine.generate_report(evaluation);
149
150        let rule_patches = auto_result.patches.clone();
151
152        // LLM-powered analysis of remaining gaps
153        let ai_patches = if self.config.use_llm && !auto_result.unaddressable_metrics.is_empty() {
154            self.llm_analyze_gaps(&auto_result, &report)
155        } else {
156            vec![]
157        };
158
159        // Merge patches: rule-based first, then AI suggestions that don't conflict
160        let applied_patches = merge_patches(&rule_patches, &ai_patches, self.config.min_confidence);
161
162        TuningIteration {
163            iteration,
164            health_score: report.health_score,
165            failure_count: evaluation.failures.len(),
166            rule_patches,
167            ai_patches,
168            applied_patches,
169        }
170    }
171
172    /// Use LLM to interpret gaps that the rule-based tuner couldn't address.
173    fn llm_analyze_gaps(
174        &self,
175        auto_result: &AutoTuneResult,
176        report: &EnhancementReport,
177    ) -> Vec<ConfigPatch> {
178        let prompt = self.build_gap_analysis_prompt(auto_result, report);
179
180        let request = LlmRequest::new(prompt)
181            .with_system(Self::tuning_system_prompt().to_string())
182            .with_temperature(0.3)
183            .with_max_tokens(2048);
184
185        match self.provider.complete(&request) {
186            Ok(response) => self.parse_llm_patches(&response.content),
187            Err(e) => {
188                tracing::warn!("LLM gap analysis failed: {e}");
189                vec![]
190            }
191        }
192    }
193
194    /// Build a structured prompt describing the evaluation gaps.
195    fn build_gap_analysis_prompt(
196        &self,
197        auto_result: &AutoTuneResult,
198        report: &EnhancementReport,
199    ) -> String {
200        let mut prompt = String::with_capacity(2048);
201
202        prompt
203            .push_str("Analyze these synthetic data quality gaps and suggest config patches.\n\n");
204
205        // Unaddressable metrics
206        if !auto_result.unaddressable_metrics.is_empty() {
207            prompt.push_str("## Metrics the rule-based tuner could not address:\n");
208            for metric in &auto_result.unaddressable_metrics {
209                prompt.push_str(&format!("- {metric}\n"));
210            }
211            prompt.push('\n');
212        }
213
214        // Top issues
215        if !report.top_issues.is_empty() {
216            prompt.push_str("## Top issues:\n");
217            for issue in &report.top_issues {
218                prompt.push_str(&format!("- {issue}\n"));
219            }
220            prompt.push('\n');
221        }
222
223        // Already applied patches (to avoid duplicates)
224        if auto_result.has_patches() {
225            prompt.push_str("## Already suggested patches (do not repeat):\n");
226            for patch in &auto_result.patches {
227                prompt.push_str(&format!("- {}: {}\n", patch.path, patch.suggested_value));
228            }
229            prompt.push('\n');
230        }
231
232        prompt.push_str(&format!(
233            "Current health score: {:.2}\n",
234            report.health_score
235        ));
236        prompt
237    }
238
239    /// Parse LLM response into config patches.
240    fn parse_llm_patches(&self, content: &str) -> Vec<ConfigPatch> {
241        // Try to find JSON array of patches in the response
242        let json_str = datasynth_core::llm::extract_json_array(content);
243
244        match json_str {
245            Some(json) => match serde_json::from_str::<Vec<LlmPatchSuggestion>>(json) {
246                Ok(suggestions) => suggestions
247                    .into_iter()
248                    .filter(|s| s.confidence >= self.config.min_confidence)
249                    .map(|s| {
250                        ConfigPatch::new(s.path, s.value)
251                            .with_confidence(s.confidence)
252                            .with_impact(s.reasoning)
253                    })
254                    .collect(),
255                Err(e) => {
256                    tracing::debug!("Failed to parse LLM patches as JSON: {e}");
257                    vec![]
258                }
259            },
260            None => {
261                tracing::debug!("No JSON array found in LLM response");
262                vec![]
263            }
264        }
265    }
266
267    /// System prompt for the LLM gap analyzer.
268    fn tuning_system_prompt() -> &'static str {
269        concat!(
270            "You are a synthetic data quality tuner for DataSynth. ",
271            "Given evaluation gaps, suggest config patches to improve data quality.\n\n",
272            "Return a JSON array of patches. Each patch has:\n",
273            "- path: dot-separated config path (e.g., \"distributions.amounts.components[0].mu\")\n",
274            "- value: new value as string\n",
275            "- confidence: 0.0-1.0 confidence this will help\n",
276            "- reasoning: one sentence explaining why\n\n",
277            "Valid config paths include:\n",
278            "- transactions.count, transactions.anomaly_rate\n",
279            "- distributions.amounts.*, distributions.correlations.*\n",
280            "- temporal_patterns.period_end.*, temporal_patterns.intraday.*\n",
281            "- anomaly_injection.base_rate, anomaly_injection.types\n",
282            "- data_quality.missing_value_rate, data_quality.typo_rate\n",
283            "- fraud.injection_rate, fraud.types\n",
284            "- graph_export.ensure_connected\n\n",
285            "Rules:\n",
286            "- Only suggest patches for unaddressed metrics\n",
287            "- Don't repeat patches already applied\n",
288            "- Keep confidence realistic\n",
289            "- Return ONLY the JSON array, no other text\n"
290        )
291    }
292}
293
294/// An LLM-suggested patch (deserialized from JSON).
295#[derive(Debug, Clone, Serialize, Deserialize)]
296struct LlmPatchSuggestion {
297    path: String,
298    value: String,
299    #[serde(default = "default_llm_confidence")]
300    confidence: f64,
301    #[serde(default)]
302    reasoning: String,
303}
304
305fn default_llm_confidence() -> f64 {
306    0.5
307}
308
309/// Merge rule-based and AI patches, filtering by confidence and removing conflicts.
310fn merge_patches(
311    rule_patches: &[ConfigPatch],
312    ai_patches: &[ConfigPatch],
313    min_confidence: f64,
314) -> Vec<ConfigPatch> {
315    let mut merged = rule_patches.to_vec();
316
317    // Track paths already covered by rule-based patches
318    let rule_paths: std::collections::HashSet<&str> =
319        rule_patches.iter().map(|p| p.path.as_str()).collect();
320
321    // Add AI patches that don't conflict with rule-based ones
322    for patch in ai_patches {
323        if patch.confidence >= min_confidence && !rule_paths.contains(patch.path.as_str()) {
324            merged.push(patch.clone());
325        }
326    }
327
328    merged
329}
330
331#[cfg(test)]
332#[allow(clippy::unwrap_used)]
333mod tests {
334    use super::*;
335    use datasynth_core::llm::MockLlmProvider;
336
337    #[test]
338    fn test_ai_tuner_single_iteration() {
339        let provider = MockLlmProvider::new(42);
340        let config = AiTunerConfig {
341            max_iterations: 1,
342            use_llm: false, // Rule-based only for deterministic test
343            ..Default::default()
344        };
345        let mut tuner = AiTuner::new(&provider, config);
346
347        let evaluation = ComprehensiveEvaluation::new();
348        let iteration = tuner.analyze_iteration(&evaluation, 1);
349
350        assert_eq!(iteration.iteration, 1);
351        assert!(iteration.ai_patches.is_empty());
352        // A passing evaluation should have no failures
353        assert_eq!(iteration.failure_count, 0);
354    }
355
356    #[test]
357    fn test_ai_tuner_config_defaults() {
358        let config = AiTunerConfig::default();
359        assert_eq!(config.max_iterations, 5);
360        assert!((config.convergence_threshold - 0.01).abs() < 1e-10);
361        assert!((config.min_confidence - 0.5).abs() < 1e-10);
362        assert!(config.use_llm);
363    }
364
365    #[test]
366    fn test_merge_patches_no_conflicts() {
367        let rule = vec![
368            ConfigPatch::new("path.a", "1").with_confidence(0.9),
369            ConfigPatch::new("path.b", "2").with_confidence(0.8),
370        ];
371        let ai = vec![
372            ConfigPatch::new("path.c", "3").with_confidence(0.7),
373            ConfigPatch::new("path.d", "4").with_confidence(0.3), // Below threshold
374        ];
375
376        let merged = merge_patches(&rule, &ai, 0.5);
377        assert_eq!(merged.len(), 3); // a, b, c (d filtered by confidence)
378    }
379
380    #[test]
381    fn test_merge_patches_with_conflicts() {
382        let rule = vec![ConfigPatch::new("path.a", "1").with_confidence(0.9)];
383        let ai = vec![
384            ConfigPatch::new("path.a", "2").with_confidence(0.8), // Conflicts
385            ConfigPatch::new("path.b", "3").with_confidence(0.7),
386        ];
387
388        let merged = merge_patches(&rule, &ai, 0.5);
389        assert_eq!(merged.len(), 2); // a (rule) + b (ai, no conflict)
390        assert_eq!(merged[0].suggested_value, "1"); // Rule wins for path.a
391    }
392
393    // JSON extraction tests moved to datasynth_core::llm::json_utils
394
395    #[test]
396    fn test_parse_llm_patches_valid() {
397        let provider = MockLlmProvider::new(42);
398        let config = AiTunerConfig::default();
399        let tuner = AiTuner::new(&provider, config);
400
401        let json = r#"[{"path": "transactions.count", "value": "10000", "confidence": 0.8, "reasoning": "More samples improve distribution fidelity"}]"#;
402        let patches = tuner.parse_llm_patches(json);
403        assert_eq!(patches.len(), 1);
404        assert_eq!(patches[0].path, "transactions.count");
405        assert_eq!(patches[0].suggested_value, "10000");
406        assert!((patches[0].confidence - 0.8).abs() < 1e-10);
407    }
408
409    #[test]
410    fn test_parse_llm_patches_filters_low_confidence() {
411        let provider = MockLlmProvider::new(42);
412        let config = AiTunerConfig {
413            min_confidence: 0.6,
414            ..Default::default()
415        };
416        let tuner = AiTuner::new(&provider, config);
417
418        let json = r#"[
419            {"path": "a", "value": "1", "confidence": 0.8},
420            {"path": "b", "value": "2", "confidence": 0.3}
421        ]"#;
422        let patches = tuner.parse_llm_patches(json);
423        assert_eq!(patches.len(), 1);
424        assert_eq!(patches[0].path, "a");
425    }
426
427    #[test]
428    fn test_ai_tune_result_improvement() {
429        let result = AiTuneResult {
430            iterations: vec![],
431            final_patches: vec![],
432            initial_health_score: 0.6,
433            final_health_score: 0.85,
434            converged: true,
435            summary: String::new(),
436        };
437        assert!((result.improvement() - 0.25).abs() < 1e-10);
438    }
439}