1use serde::{Deserialize, Serialize};
15
16use datasynth_core::llm::provider::{LlmProvider, LlmRequest};
17
18use super::auto_tuner::{AutoTuneResult, AutoTuner, ConfigPatch};
19use super::recommendation_engine::{EnhancementReport, RecommendationEngine};
20use crate::config::EvaluationThresholds;
21use crate::ComprehensiveEvaluation;
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct AiTunerConfig {
26 #[serde(default = "default_max_iterations")]
28 pub max_iterations: usize,
29 #[serde(default = "default_convergence_threshold")]
31 pub convergence_threshold: f64,
32 #[serde(default = "default_min_confidence")]
34 pub min_confidence: f64,
35 #[serde(default = "default_use_llm")]
37 pub use_llm: bool,
38}
39
40fn default_max_iterations() -> usize {
41 5
42}
43fn default_convergence_threshold() -> f64 {
44 0.01
45}
46fn default_min_confidence() -> f64 {
47 0.5
48}
49fn default_use_llm() -> bool {
50 true
51}
52
53impl Default for AiTunerConfig {
54 fn default() -> Self {
55 Self {
56 max_iterations: default_max_iterations(),
57 convergence_threshold: default_convergence_threshold(),
58 min_confidence: default_min_confidence(),
59 use_llm: default_use_llm(),
60 }
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct TuningIteration {
67 pub iteration: usize,
69 pub health_score: f64,
71 pub failure_count: usize,
73 pub rule_patches: Vec<ConfigPatch>,
75 pub ai_patches: Vec<ConfigPatch>,
77 pub applied_patches: Vec<ConfigPatch>,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct AiTuneResult {
84 pub iterations: Vec<TuningIteration>,
86 pub final_patches: Vec<ConfigPatch>,
88 pub initial_health_score: f64,
90 pub final_health_score: f64,
92 pub converged: bool,
94 pub summary: String,
96}
97
98impl AiTuneResult {
99 pub fn improvement(&self) -> f64 {
101 self.final_health_score - self.initial_health_score
102 }
103}
104
105pub struct AiTuner<'a> {
107 auto_tuner: AutoTuner,
108 recommendation_engine: RecommendationEngine,
109 provider: &'a dyn LlmProvider,
110 config: AiTunerConfig,
111}
112
113impl<'a> AiTuner<'a> {
114 pub fn new(provider: &'a dyn LlmProvider, config: AiTunerConfig) -> Self {
116 Self {
117 auto_tuner: AutoTuner::new(),
118 recommendation_engine: RecommendationEngine::new(),
119 provider,
120 config,
121 }
122 }
123
124 pub fn with_thresholds(
126 provider: &'a dyn LlmProvider,
127 config: AiTunerConfig,
128 thresholds: EvaluationThresholds,
129 ) -> Self {
130 Self {
131 auto_tuner: AutoTuner::with_thresholds(thresholds.clone()),
132 recommendation_engine: RecommendationEngine::with_thresholds(thresholds),
133 provider,
134 config,
135 }
136 }
137
138 pub fn analyze_iteration(
142 &mut self,
143 evaluation: &ComprehensiveEvaluation,
144 iteration: usize,
145 ) -> TuningIteration {
146 let auto_result = self.auto_tuner.analyze(evaluation);
148 let report = self.recommendation_engine.generate_report(evaluation);
149
150 let rule_patches = auto_result.patches.clone();
151
152 let ai_patches = if self.config.use_llm && !auto_result.unaddressable_metrics.is_empty() {
154 self.llm_analyze_gaps(&auto_result, &report)
155 } else {
156 vec![]
157 };
158
159 let applied_patches = merge_patches(&rule_patches, &ai_patches, self.config.min_confidence);
161
162 TuningIteration {
163 iteration,
164 health_score: report.health_score,
165 failure_count: evaluation.failures.len(),
166 rule_patches,
167 ai_patches,
168 applied_patches,
169 }
170 }
171
172 fn llm_analyze_gaps(
174 &self,
175 auto_result: &AutoTuneResult,
176 report: &EnhancementReport,
177 ) -> Vec<ConfigPatch> {
178 let prompt = self.build_gap_analysis_prompt(auto_result, report);
179
180 let request = LlmRequest::new(prompt)
181 .with_system(Self::tuning_system_prompt().to_string())
182 .with_temperature(0.3)
183 .with_max_tokens(2048);
184
185 match self.provider.complete(&request) {
186 Ok(response) => self.parse_llm_patches(&response.content),
187 Err(e) => {
188 tracing::warn!("LLM gap analysis failed: {e}");
189 vec![]
190 }
191 }
192 }
193
194 fn build_gap_analysis_prompt(
196 &self,
197 auto_result: &AutoTuneResult,
198 report: &EnhancementReport,
199 ) -> String {
200 let mut prompt = String::with_capacity(2048);
201
202 prompt
203 .push_str("Analyze these synthetic data quality gaps and suggest config patches.\n\n");
204
205 if !auto_result.unaddressable_metrics.is_empty() {
207 prompt.push_str("## Metrics the rule-based tuner could not address:\n");
208 for metric in &auto_result.unaddressable_metrics {
209 prompt.push_str(&format!("- {metric}\n"));
210 }
211 prompt.push('\n');
212 }
213
214 if !report.top_issues.is_empty() {
216 prompt.push_str("## Top issues:\n");
217 for issue in &report.top_issues {
218 prompt.push_str(&format!("- {issue}\n"));
219 }
220 prompt.push('\n');
221 }
222
223 if auto_result.has_patches() {
225 prompt.push_str("## Already suggested patches (do not repeat):\n");
226 for patch in &auto_result.patches {
227 prompt.push_str(&format!("- {}: {}\n", patch.path, patch.suggested_value));
228 }
229 prompt.push('\n');
230 }
231
232 prompt.push_str(&format!(
233 "Current health score: {:.2}\n",
234 report.health_score
235 ));
236 prompt
237 }
238
239 fn parse_llm_patches(&self, content: &str) -> Vec<ConfigPatch> {
241 let json_str = datasynth_core::llm::extract_json_array(content);
243
244 match json_str {
245 Some(json) => match serde_json::from_str::<Vec<LlmPatchSuggestion>>(json) {
246 Ok(suggestions) => suggestions
247 .into_iter()
248 .filter(|s| s.confidence >= self.config.min_confidence)
249 .map(|s| {
250 ConfigPatch::new(s.path, s.value)
251 .with_confidence(s.confidence)
252 .with_impact(s.reasoning)
253 })
254 .collect(),
255 Err(e) => {
256 tracing::debug!("Failed to parse LLM patches as JSON: {e}");
257 vec![]
258 }
259 },
260 None => {
261 tracing::debug!("No JSON array found in LLM response");
262 vec![]
263 }
264 }
265 }
266
267 fn tuning_system_prompt() -> &'static str {
269 concat!(
270 "You are a synthetic data quality tuner for DataSynth. ",
271 "Given evaluation gaps, suggest config patches to improve data quality.\n\n",
272 "Return a JSON array of patches. Each patch has:\n",
273 "- path: dot-separated config path (e.g., \"distributions.amounts.components[0].mu\")\n",
274 "- value: new value as string\n",
275 "- confidence: 0.0-1.0 confidence this will help\n",
276 "- reasoning: one sentence explaining why\n\n",
277 "Valid config paths include:\n",
278 "- transactions.count, transactions.anomaly_rate\n",
279 "- distributions.amounts.*, distributions.correlations.*\n",
280 "- temporal_patterns.period_end.*, temporal_patterns.intraday.*\n",
281 "- anomaly_injection.base_rate, anomaly_injection.types\n",
282 "- data_quality.missing_value_rate, data_quality.typo_rate\n",
283 "- fraud.injection_rate, fraud.types\n",
284 "- graph_export.ensure_connected\n\n",
285 "Rules:\n",
286 "- Only suggest patches for unaddressed metrics\n",
287 "- Don't repeat patches already applied\n",
288 "- Keep confidence realistic\n",
289 "- Return ONLY the JSON array, no other text\n"
290 )
291 }
292}
293
294#[derive(Debug, Clone, Serialize, Deserialize)]
296struct LlmPatchSuggestion {
297 path: String,
298 value: String,
299 #[serde(default = "default_llm_confidence")]
300 confidence: f64,
301 #[serde(default)]
302 reasoning: String,
303}
304
305fn default_llm_confidence() -> f64 {
306 0.5
307}
308
309fn merge_patches(
311 rule_patches: &[ConfigPatch],
312 ai_patches: &[ConfigPatch],
313 min_confidence: f64,
314) -> Vec<ConfigPatch> {
315 let mut merged = rule_patches.to_vec();
316
317 let rule_paths: std::collections::HashSet<&str> =
319 rule_patches.iter().map(|p| p.path.as_str()).collect();
320
321 for patch in ai_patches {
323 if patch.confidence >= min_confidence && !rule_paths.contains(patch.path.as_str()) {
324 merged.push(patch.clone());
325 }
326 }
327
328 merged
329}
330
331#[cfg(test)]
332#[allow(clippy::unwrap_used)]
333mod tests {
334 use super::*;
335 use datasynth_core::llm::MockLlmProvider;
336
337 #[test]
338 fn test_ai_tuner_single_iteration() {
339 let provider = MockLlmProvider::new(42);
340 let config = AiTunerConfig {
341 max_iterations: 1,
342 use_llm: false, ..Default::default()
344 };
345 let mut tuner = AiTuner::new(&provider, config);
346
347 let evaluation = ComprehensiveEvaluation::new();
348 let iteration = tuner.analyze_iteration(&evaluation, 1);
349
350 assert_eq!(iteration.iteration, 1);
351 assert!(iteration.ai_patches.is_empty());
352 assert_eq!(iteration.failure_count, 0);
354 }
355
356 #[test]
357 fn test_ai_tuner_config_defaults() {
358 let config = AiTunerConfig::default();
359 assert_eq!(config.max_iterations, 5);
360 assert!((config.convergence_threshold - 0.01).abs() < 1e-10);
361 assert!((config.min_confidence - 0.5).abs() < 1e-10);
362 assert!(config.use_llm);
363 }
364
365 #[test]
366 fn test_merge_patches_no_conflicts() {
367 let rule = vec![
368 ConfigPatch::new("path.a", "1").with_confidence(0.9),
369 ConfigPatch::new("path.b", "2").with_confidence(0.8),
370 ];
371 let ai = vec![
372 ConfigPatch::new("path.c", "3").with_confidence(0.7),
373 ConfigPatch::new("path.d", "4").with_confidence(0.3), ];
375
376 let merged = merge_patches(&rule, &ai, 0.5);
377 assert_eq!(merged.len(), 3); }
379
380 #[test]
381 fn test_merge_patches_with_conflicts() {
382 let rule = vec![ConfigPatch::new("path.a", "1").with_confidence(0.9)];
383 let ai = vec![
384 ConfigPatch::new("path.a", "2").with_confidence(0.8), ConfigPatch::new("path.b", "3").with_confidence(0.7),
386 ];
387
388 let merged = merge_patches(&rule, &ai, 0.5);
389 assert_eq!(merged.len(), 2); assert_eq!(merged[0].suggested_value, "1"); }
392
393 #[test]
396 fn test_parse_llm_patches_valid() {
397 let provider = MockLlmProvider::new(42);
398 let config = AiTunerConfig::default();
399 let tuner = AiTuner::new(&provider, config);
400
401 let json = r#"[{"path": "transactions.count", "value": "10000", "confidence": 0.8, "reasoning": "More samples improve distribution fidelity"}]"#;
402 let patches = tuner.parse_llm_patches(json);
403 assert_eq!(patches.len(), 1);
404 assert_eq!(patches[0].path, "transactions.count");
405 assert_eq!(patches[0].suggested_value, "10000");
406 assert!((patches[0].confidence - 0.8).abs() < 1e-10);
407 }
408
409 #[test]
410 fn test_parse_llm_patches_filters_low_confidence() {
411 let provider = MockLlmProvider::new(42);
412 let config = AiTunerConfig {
413 min_confidence: 0.6,
414 ..Default::default()
415 };
416 let tuner = AiTuner::new(&provider, config);
417
418 let json = r#"[
419 {"path": "a", "value": "1", "confidence": 0.8},
420 {"path": "b", "value": "2", "confidence": 0.3}
421 ]"#;
422 let patches = tuner.parse_llm_patches(json);
423 assert_eq!(patches.len(), 1);
424 assert_eq!(patches[0].path, "a");
425 }
426
427 #[test]
428 fn test_ai_tune_result_improvement() {
429 let result = AiTuneResult {
430 iterations: vec![],
431 final_patches: vec![],
432 initial_health_score: 0.6,
433 final_health_score: 0.85,
434 converged: true,
435 summary: String::new(),
436 };
437 assert!((result.improvement() - 0.25).abs() < 1e-10);
438 }
439}