Skip to main content

llm_optimizer_decision/
adaptive_params.rs

1//! Adaptive Parameter Tuning for LLM configurations
2//!
3//! This module provides adaptive tuning for temperature, top-p, and max tokens
4//! based on task context, historical performance, and user feedback.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use uuid::Uuid;
9
10use crate::{
11    context::RequestContext,
12    errors::{DecisionError, Result},
13    reward::{ResponseMetrics, UserFeedback},
14};
15
16/// Parameter configuration for LLM generation
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub struct ParameterConfig {
19    /// Sampling temperature (0.0 - 2.0)
20    pub temperature: f64,
21    /// Nucleus sampling threshold (0.0 - 1.0)
22    pub top_p: f64,
23    /// Maximum output tokens
24    pub max_tokens: usize,
25}
26
27impl Default for ParameterConfig {
28    fn default() -> Self {
29        Self {
30            temperature: 0.7,
31            top_p: 0.9,
32            max_tokens: 2048,
33        }
34    }
35}
36
37impl ParameterConfig {
38    /// Create new parameter configuration
39    pub fn new(temperature: f64, top_p: f64, max_tokens: usize) -> Result<Self> {
40        let config = Self {
41            temperature,
42            top_p,
43            max_tokens,
44        };
45        config.validate()?;
46        Ok(config)
47    }
48
49    /// Validate parameter ranges
50    pub fn validate(&self) -> Result<()> {
51        if self.temperature < 0.0 || self.temperature > 2.0 {
52            return Err(DecisionError::InvalidParameter(format!(
53                "Temperature {} out of range [0.0, 2.0]",
54                self.temperature
55            )));
56        }
57
58        if self.top_p < 0.0 || self.top_p > 1.0 {
59            return Err(DecisionError::InvalidParameter(format!(
60                "Top-p {} out of range [0.0, 1.0]",
61                self.top_p
62            )));
63        }
64
65        if self.max_tokens == 0 {
66            return Err(DecisionError::InvalidParameter(
67                "Max tokens must be greater than 0".to_string(),
68            ));
69        }
70
71        Ok(())
72    }
73
74    /// Create configuration for creative tasks
75    pub fn creative() -> Self {
76        Self {
77            temperature: 1.2,
78            top_p: 0.95,
79            max_tokens: 2048,
80        }
81    }
82
83    /// Create configuration for analytical tasks
84    pub fn analytical() -> Self {
85        Self {
86            temperature: 0.3,
87            top_p: 0.85,
88            max_tokens: 1024,
89        }
90    }
91
92    /// Create configuration for code generation
93    pub fn code_generation() -> Self {
94        Self {
95            temperature: 0.2,
96            top_p: 0.9,
97            max_tokens: 2048,
98        }
99    }
100
101    /// Create configuration for balanced general use
102    pub fn balanced() -> Self {
103        Self::default()
104    }
105}
106
107/// Parameter range constraints
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ParameterRange {
110    /// Minimum temperature
111    pub temp_min: f64,
112    /// Maximum temperature
113    pub temp_max: f64,
114    /// Minimum top-p
115    pub top_p_min: f64,
116    /// Maximum top-p
117    pub top_p_max: f64,
118    /// Minimum max tokens
119    pub max_tokens_min: usize,
120    /// Maximum max tokens
121    pub max_tokens_max: usize,
122}
123
124impl Default for ParameterRange {
125    fn default() -> Self {
126        Self {
127            temp_min: 0.0,
128            temp_max: 2.0,
129            top_p_min: 0.7,
130            top_p_max: 1.0,
131            max_tokens_min: 256,
132            max_tokens_max: 8192,
133        }
134    }
135}
136
137impl ParameterRange {
138    /// Create new parameter range
139    pub fn new(
140        temp_min: f64,
141        temp_max: f64,
142        top_p_min: f64,
143        top_p_max: f64,
144        max_tokens_min: usize,
145        max_tokens_max: usize,
146    ) -> Result<Self> {
147        if temp_min >= temp_max {
148            return Err(DecisionError::InvalidParameter(
149                "Temperature min must be less than max".to_string(),
150            ));
151        }
152
153        if top_p_min >= top_p_max {
154            return Err(DecisionError::InvalidParameter(
155                "Top-p min must be less than max".to_string(),
156            ));
157        }
158
159        if max_tokens_min >= max_tokens_max {
160            return Err(DecisionError::InvalidParameter(
161                "Max tokens min must be less than max".to_string(),
162            ));
163        }
164
165        Ok(Self {
166            temp_min,
167            temp_max,
168            top_p_min,
169            top_p_max,
170            max_tokens_min,
171            max_tokens_max,
172        })
173    }
174
175    /// Check if configuration is within range
176    pub fn contains(&self, config: &ParameterConfig) -> bool {
177        config.temperature >= self.temp_min
178            && config.temperature <= self.temp_max
179            && config.top_p >= self.top_p_min
180            && config.top_p <= self.top_p_max
181            && config.max_tokens >= self.max_tokens_min
182            && config.max_tokens <= self.max_tokens_max
183    }
184
185    /// Clamp configuration to range
186    pub fn clamp(&self, config: &ParameterConfig) -> ParameterConfig {
187        ParameterConfig {
188            temperature: config.temperature.clamp(self.temp_min, self.temp_max),
189            top_p: config.top_p.clamp(self.top_p_min, self.top_p_max),
190            max_tokens: config
191                .max_tokens
192                .clamp(self.max_tokens_min, self.max_tokens_max),
193        }
194    }
195
196    /// Create restricted range for specific task type
197    pub fn for_task_type(task_type: &str) -> Self {
198        match task_type {
199            "creative" | "storytelling" | "brainstorming" => Self {
200                temp_min: 0.8,
201                temp_max: 1.5,
202                top_p_min: 0.9,
203                top_p_max: 0.98,
204                max_tokens_min: 512,
205                max_tokens_max: 4096,
206            },
207            "code" | "programming" | "technical" => Self {
208                temp_min: 0.0,
209                temp_max: 0.5,
210                top_p_min: 0.85,
211                top_p_max: 0.95,
212                max_tokens_min: 256,
213                max_tokens_max: 4096,
214            },
215            "analytical" | "reasoning" | "math" => Self {
216                temp_min: 0.0,
217                temp_max: 0.4,
218                top_p_min: 0.8,
219                top_p_max: 0.9,
220                max_tokens_min: 512,
221                max_tokens_max: 2048,
222            },
223            _ => Self::default(),
224        }
225    }
226}
227
228/// Performance statistics for a parameter configuration
229#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct ParameterStats {
231    /// Configuration ID
232    pub config_id: Uuid,
233    /// Parameter configuration
234    pub config: ParameterConfig,
235    /// Number of times used
236    pub num_uses: u64,
237    /// Total reward accumulated
238    pub total_reward: f64,
239    /// Average reward
240    pub average_reward: f64,
241    /// Average quality score
242    pub avg_quality: f64,
243    /// Average cost
244    pub avg_cost: f64,
245    /// Average latency
246    pub avg_latency: f64,
247    /// Success rate (task completion)
248    pub success_rate: f64,
249}
250
251impl ParameterStats {
252    /// Create new parameter statistics
253    pub fn new(config_id: Uuid, config: ParameterConfig) -> Self {
254        Self {
255            config_id,
256            config,
257            num_uses: 0,
258            total_reward: 0.0,
259            average_reward: 0.0,
260            avg_quality: 0.0,
261            avg_cost: 0.0,
262            avg_latency: 0.0,
263            success_rate: 0.0,
264        }
265    }
266
267    /// Update statistics with new observation
268    pub fn update(&mut self, reward: f64, metrics: &ResponseMetrics, success: bool) {
269        let n = self.num_uses as f64;
270        let n_plus_1 = (self.num_uses + 1) as f64;
271
272        // Running average updates
273        self.total_reward += reward;
274        self.average_reward = (self.average_reward * n + reward) / n_plus_1;
275        self.avg_quality = (self.avg_quality * n + metrics.quality_score) / n_plus_1;
276        self.avg_cost = (self.avg_cost * n + metrics.cost) / n_plus_1;
277        self.avg_latency = (self.avg_latency * n + metrics.latency_ms) / n_plus_1;
278
279        let success_count = (self.success_rate * n) + if success { 1.0 } else { 0.0 };
280        self.success_rate = success_count / n_plus_1;
281
282        self.num_uses += 1;
283    }
284
285    /// Get confidence interval width (exploration bonus)
286    pub fn confidence_width(&self, exploration_factor: f64) -> f64 {
287        if self.num_uses == 0 {
288            return f64::INFINITY;
289        }
290        exploration_factor * (2.0 * (self.num_uses as f64).ln()).sqrt() / (self.num_uses as f64)
291    }
292
293    /// Upper confidence bound for this configuration
294    pub fn ucb(&self, exploration_factor: f64) -> f64 {
295        self.average_reward + self.confidence_width(exploration_factor)
296    }
297}
298
299/// Adaptive parameter tuning engine
300pub struct AdaptiveParameterTuner {
301    /// Parameter range constraints
302    range: ParameterRange,
303    /// Statistics for each configuration
304    config_stats: HashMap<Uuid, ParameterStats>,
305    /// Task-specific best configurations
306    task_best_configs: HashMap<String, Uuid>,
307    /// Exploration factor for UCB
308    exploration_factor: f64,
309    /// Learning rate for gradient-based updates
310    learning_rate: f64,
311    /// Minimum uses before considering a configuration stable
312    min_uses_for_stability: u64,
313}
314
315impl AdaptiveParameterTuner {
316    /// Create new adaptive parameter tuner
317    pub fn new(range: ParameterRange) -> Self {
318        Self {
319            range,
320            config_stats: HashMap::new(),
321            task_best_configs: HashMap::new(),
322            exploration_factor: 2.0,
323            learning_rate: 0.1,
324            min_uses_for_stability: 10,
325        }
326    }
327
328    /// Create with default range
329    pub fn with_defaults() -> Self {
330        Self::new(ParameterRange::default())
331    }
332
333    /// Set exploration factor
334    pub fn with_exploration_factor(mut self, factor: f64) -> Self {
335        self.exploration_factor = factor;
336        self
337    }
338
339    /// Set learning rate
340    pub fn with_learning_rate(mut self, rate: f64) -> Self {
341        self.learning_rate = rate;
342        self
343    }
344
345    /// Register a new parameter configuration
346    pub fn register_config(&mut self, config: ParameterConfig) -> Result<Uuid> {
347        config.validate()?;
348        if !self.range.contains(&config) {
349            return Err(DecisionError::InvalidParameter(
350                "Configuration outside allowed range".to_string(),
351            ));
352        }
353
354        let config_id = Uuid::new_v4();
355        self.config_stats
356            .insert(config_id, ParameterStats::new(config_id, config));
357        Ok(config_id)
358    }
359
360    /// Select best configuration for given context using UCB
361    pub fn select_config(&self, context: &RequestContext) -> Result<(Uuid, ParameterConfig)> {
362        if self.config_stats.is_empty() {
363            return Err(DecisionError::InvalidState(
364                "No configurations registered".to_string(),
365            ));
366        }
367
368        // Check for task-specific best configuration
369        if let Some(task_type) = &context.task_type {
370            if let Some(config_id) = self.task_best_configs.get(task_type) {
371                if let Some(stats) = self.config_stats.get(config_id) {
372                    if stats.num_uses >= self.min_uses_for_stability {
373                        // Exploit with 80% probability, explore with 20%
374                        if rand::random::<f64>() < 0.8 {
375                            return Ok((*config_id, stats.config.clone()));
376                        }
377                    }
378                }
379            }
380        }
381
382        // Use UCB for selection
383        let (best_id, best_stats) = self
384            .config_stats
385            .iter()
386            .max_by(|(_, a), (_, b)| {
387                let ucb_a = a.ucb(self.exploration_factor);
388                let ucb_b = b.ucb(self.exploration_factor);
389                ucb_a.partial_cmp(&ucb_b).unwrap_or(std::cmp::Ordering::Equal)
390            })
391            .ok_or_else(|| DecisionError::InvalidState("No configurations available".to_string()))?;
392
393        Ok((*best_id, best_stats.config.clone()))
394    }
395
396    /// Update configuration performance
397    pub fn update_config(
398        &mut self,
399        config_id: &Uuid,
400        reward: f64,
401        metrics: &ResponseMetrics,
402        feedback: Option<&UserFeedback>,
403    ) -> Result<()> {
404        let stats = self.config_stats.get_mut(config_id).ok_or_else(|| {
405            DecisionError::InvalidParameter(format!("Configuration {} not found", config_id))
406        })?;
407
408        let success = feedback.map(|f| f.task_completed).unwrap_or(true);
409        stats.update(reward, metrics, success);
410
411        Ok(())
412    }
413
414    /// Get best configuration for a task type
415    pub fn get_best_for_task(&self, task_type: &str) -> Option<(Uuid, ParameterConfig)> {
416        // Filter configurations by task suitability and find best
417        let task_range = ParameterRange::for_task_type(task_type);
418
419        self.config_stats
420            .iter()
421            .filter(|(_, stats)| {
422                stats.num_uses >= self.min_uses_for_stability
423                    && task_range.contains(&stats.config)
424            })
425            .max_by(|(_, a), (_, b)| {
426                a.average_reward
427                    .partial_cmp(&b.average_reward)
428                    .unwrap_or(std::cmp::Ordering::Equal)
429            })
430            .map(|(id, stats)| (*id, stats.config.clone()))
431    }
432
433    /// Update task-specific best configuration
434    pub fn update_task_best(&mut self, task_type: String) {
435        if let Some((config_id, _)) = self.get_best_for_task(&task_type) {
436            self.task_best_configs.insert(task_type, config_id);
437        }
438    }
439
440    /// Suggest new configuration based on gradient
441    pub fn suggest_improvement(&self, config_id: &Uuid) -> Result<ParameterConfig> {
442        let stats = self.config_stats.get(config_id).ok_or_else(|| {
443            DecisionError::InvalidParameter(format!("Configuration {} not found", config_id))
444        })?;
445
446        if stats.num_uses < self.min_uses_for_stability {
447            return Err(DecisionError::InvalidState(
448                "Not enough data for improvement suggestion".to_string(),
449            ));
450        }
451
452        // Simple gradient-based suggestion
453        // If quality is low, decrease temperature (more focused)
454        // If quality is high but diversity might help, try increasing slightly
455        let mut new_config = stats.config.clone();
456
457        if stats.avg_quality < 0.7 {
458            // Low quality - make more deterministic
459            new_config.temperature *= 1.0 - self.learning_rate;
460            new_config.top_p *= 1.0 - self.learning_rate * 0.5;
461        } else if stats.avg_quality > 0.9 && stats.success_rate > 0.8 {
462            // High quality - try slight exploration
463            new_config.temperature *= 1.0 + self.learning_rate * 0.5;
464            new_config.top_p = (new_config.top_p + 0.05).min(1.0);
465        }
466
467        // Clamp to range
468        new_config = self.range.clamp(&new_config);
469        new_config.validate()?;
470
471        Ok(new_config)
472    }
473
474    /// Get statistics for all configurations
475    pub fn get_all_stats(&self) -> Vec<ParameterStats> {
476        self.config_stats.values().cloned().collect()
477    }
478
479    /// Get statistics for specific configuration
480    pub fn get_stats(&self, config_id: &Uuid) -> Option<&ParameterStats> {
481        self.config_stats.get(config_id)
482    }
483
484    /// Get number of registered configurations
485    pub fn num_configs(&self) -> usize {
486        self.config_stats.len()
487    }
488
489    /// Clear all statistics (for reset)
490    pub fn reset(&mut self) {
491        self.config_stats.clear();
492        self.task_best_configs.clear();
493    }
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499
500    #[test]
501    fn test_parameter_config_creation() {
502        let config = ParameterConfig::new(0.7, 0.9, 1024).unwrap();
503        assert_eq!(config.temperature, 0.7);
504        assert_eq!(config.top_p, 0.9);
505        assert_eq!(config.max_tokens, 1024);
506    }
507
508    #[test]
509    fn test_parameter_config_validation() {
510        assert!(ParameterConfig::new(-0.1, 0.9, 1024).is_err());
511        assert!(ParameterConfig::new(2.5, 0.9, 1024).is_err());
512        assert!(ParameterConfig::new(0.7, 1.5, 1024).is_err());
513        assert!(ParameterConfig::new(0.7, 0.9, 0).is_err());
514    }
515
516    #[test]
517    fn test_preset_configs() {
518        let creative = ParameterConfig::creative();
519        assert!(creative.temperature > 1.0);
520        assert!(creative.validate().is_ok());
521
522        let analytical = ParameterConfig::analytical();
523        assert!(analytical.temperature < 0.5);
524        assert!(analytical.validate().is_ok());
525
526        let code = ParameterConfig::code_generation();
527        assert!(code.temperature < 0.3);
528        assert!(code.validate().is_ok());
529    }
530
531    #[test]
532    fn test_parameter_range_contains() {
533        let range = ParameterRange::default();
534        let config = ParameterConfig::default();
535        assert!(range.contains(&config));
536
537        let out_of_range = ParameterConfig {
538            temperature: 3.0,
539            top_p: 0.9,
540            max_tokens: 1024,
541        };
542        assert!(!range.contains(&out_of_range));
543    }
544
545    #[test]
546    fn test_parameter_range_clamp() {
547        let range = ParameterRange::default();
548        let config = ParameterConfig {
549            temperature: 3.0,
550            top_p: 0.5,
551            max_tokens: 10000,
552        };
553
554        let clamped = range.clamp(&config);
555        assert_eq!(clamped.temperature, range.temp_max);
556        assert_eq!(clamped.top_p, 0.7); // Clamped to min
557        assert_eq!(clamped.max_tokens, range.max_tokens_max);
558    }
559
560    #[test]
561    fn test_task_specific_ranges() {
562        let creative_range = ParameterRange::for_task_type("creative");
563        assert!(creative_range.temp_min >= 0.8);
564
565        let code_range = ParameterRange::for_task_type("code");
566        assert!(code_range.temp_max <= 0.5);
567
568        let analytical_range = ParameterRange::for_task_type("analytical");
569        assert!(analytical_range.temp_max <= 0.4);
570    }
571
572    #[test]
573    fn test_parameter_stats_creation() {
574        let config_id = Uuid::new_v4();
575        let config = ParameterConfig::default();
576        let stats = ParameterStats::new(config_id, config.clone());
577
578        assert_eq!(stats.config_id, config_id);
579        assert_eq!(stats.num_uses, 0);
580        assert_eq!(stats.average_reward, 0.0);
581    }
582
583    #[test]
584    fn test_parameter_stats_update() {
585        let config_id = Uuid::new_v4();
586        let config = ParameterConfig::default();
587        let mut stats = ParameterStats::new(config_id, config);
588
589        let metrics = ResponseMetrics {
590            quality_score: 0.9,
591            cost: 0.1,
592            latency_ms: 1000.0,
593            token_count: 500,
594        };
595
596        stats.update(0.8, &metrics, true);
597
598        assert_eq!(stats.num_uses, 1);
599        assert_eq!(stats.average_reward, 0.8);
600        assert_eq!(stats.avg_quality, 0.9);
601        assert_eq!(stats.success_rate, 1.0);
602    }
603
604    #[test]
605    fn test_parameter_stats_running_average() {
606        let config_id = Uuid::new_v4();
607        let config = ParameterConfig::default();
608        let mut stats = ParameterStats::new(config_id, config);
609
610        let metrics1 = ResponseMetrics {
611            quality_score: 0.8,
612            cost: 0.1,
613            latency_ms: 1000.0,
614            token_count: 500,
615        };
616
617        let metrics2 = ResponseMetrics {
618            quality_score: 1.0,
619            cost: 0.2,
620            latency_ms: 1500.0,
621            token_count: 600,
622        };
623
624        stats.update(0.7, &metrics1, true);
625        stats.update(0.9, &metrics2, true);
626
627        assert_eq!(stats.num_uses, 2);
628        assert_eq!(stats.average_reward, 0.8);
629        assert_eq!(stats.avg_quality, 0.9);
630        assert_eq!(stats.success_rate, 1.0);
631    }
632
633    #[test]
634    fn test_ucb_calculation() {
635        let config_id = Uuid::new_v4();
636        let config = ParameterConfig::default();
637        let mut stats = ParameterStats::new(config_id, config);
638
639        let metrics = ResponseMetrics {
640            quality_score: 0.9,
641            cost: 0.1,
642            latency_ms: 1000.0,
643            token_count: 500,
644        };
645
646        // Update multiple times to get meaningful UCB
647        for _ in 0..5 {
648            stats.update(0.8, &metrics, true);
649        }
650
651        let ucb = stats.ucb(2.0);
652        assert!(ucb >= stats.average_reward);
653        assert!(stats.num_uses == 5);
654    }
655
656    #[test]
657    fn test_adaptive_tuner_creation() {
658        let tuner = AdaptiveParameterTuner::with_defaults();
659        assert_eq!(tuner.num_configs(), 0);
660    }
661
662    #[test]
663    fn test_register_config() {
664        let mut tuner = AdaptiveParameterTuner::with_defaults();
665        let config = ParameterConfig::default();
666
667        let config_id = tuner.register_config(config).unwrap();
668        assert_eq!(tuner.num_configs(), 1);
669        assert!(tuner.get_stats(&config_id).is_some());
670    }
671
672    #[test]
673    fn test_register_invalid_config() {
674        let mut tuner = AdaptiveParameterTuner::with_defaults();
675        let config = ParameterConfig {
676            temperature: 3.0,
677            top_p: 0.9,
678            max_tokens: 1024,
679        };
680
681        assert!(tuner.register_config(config).is_err());
682    }
683
684    #[test]
685    fn test_select_config() {
686        let mut tuner = AdaptiveParameterTuner::with_defaults();
687        let config1 = ParameterConfig::default();
688        let config2 = ParameterConfig::creative();
689
690        tuner.register_config(config1).unwrap();
691        tuner.register_config(config2).unwrap();
692
693        let context = RequestContext::new(100);
694        let (config_id, _) = tuner.select_config(&context).unwrap();
695        assert!(tuner.get_stats(&config_id).is_some());
696    }
697
698    #[test]
699    fn test_update_config() {
700        let mut tuner = AdaptiveParameterTuner::with_defaults();
701        let config = ParameterConfig::default();
702        let config_id = tuner.register_config(config).unwrap();
703
704        let metrics = ResponseMetrics {
705            quality_score: 0.9,
706            cost: 0.1,
707            latency_ms: 1000.0,
708            token_count: 500,
709        };
710
711        tuner.update_config(&config_id, 0.8, &metrics, None).unwrap();
712
713        let stats = tuner.get_stats(&config_id).unwrap();
714        assert_eq!(stats.num_uses, 1);
715        assert_eq!(stats.average_reward, 0.8);
716    }
717
718    #[test]
719    fn test_tuner_learning() {
720        let mut tuner = AdaptiveParameterTuner::with_defaults();
721        let config1 = ParameterConfig::default();
722        let config2 = ParameterConfig::creative();
723
724        let id1 = tuner.register_config(config1).unwrap();
725        let id2 = tuner.register_config(config2).unwrap();
726
727        let good_metrics = ResponseMetrics {
728            quality_score: 0.95,
729            cost: 0.1,
730            latency_ms: 1000.0,
731            token_count: 500,
732        };
733
734        let bad_metrics = ResponseMetrics {
735            quality_score: 0.5,
736            cost: 0.2,
737            latency_ms: 2000.0,
738            token_count: 600,
739        };
740
741        // Update config1 with good performance
742        for _ in 0..20 {
743            tuner.update_config(&id1, 0.9, &good_metrics, None).unwrap();
744        }
745
746        // Update config2 with poor performance
747        for _ in 0..20 {
748            tuner.update_config(&id2, 0.3, &bad_metrics, None).unwrap();
749        }
750
751        let stats1 = tuner.get_stats(&id1).unwrap();
752        let stats2 = tuner.get_stats(&id2).unwrap();
753
754        assert!(stats1.average_reward > stats2.average_reward);
755    }
756
757    #[test]
758    fn test_get_best_for_task() {
759        let mut tuner = AdaptiveParameterTuner::with_defaults();
760        let code_config = ParameterConfig::code_generation();
761        let config_id = tuner.register_config(code_config).unwrap();
762
763        let good_metrics = ResponseMetrics {
764            quality_score: 0.95,
765            cost: 0.1,
766            latency_ms: 1000.0,
767            token_count: 500,
768        };
769
770        // Need enough samples for stability
771        for _ in 0..15 {
772            tuner.update_config(&config_id, 0.9, &good_metrics, None).unwrap();
773        }
774
775        tuner.update_task_best("code".to_string());
776        let best = tuner.get_best_for_task("code");
777        assert!(best.is_some());
778    }
779
780    #[test]
781    fn test_suggest_improvement() {
782        let mut tuner = AdaptiveParameterTuner::with_defaults();
783        let config = ParameterConfig::default();
784        let config_id = tuner.register_config(config).unwrap();
785
786        let metrics = ResponseMetrics {
787            quality_score: 0.5,
788            cost: 0.1,
789            latency_ms: 1000.0,
790            token_count: 500,
791        };
792
793        // Need enough samples for suggestion
794        for _ in 0..15 {
795            tuner.update_config(&config_id, 0.6, &metrics, None).unwrap();
796        }
797
798        let improved = tuner.suggest_improvement(&config_id).unwrap();
799        let original = tuner.get_stats(&config_id).unwrap();
800
801        // Low quality should suggest lower temperature
802        assert!(improved.temperature <= original.config.temperature);
803    }
804
805    #[test]
806    fn test_reset() {
807        let mut tuner = AdaptiveParameterTuner::with_defaults();
808        tuner.register_config(ParameterConfig::default()).unwrap();
809
810        assert_eq!(tuner.num_configs(), 1);
811        tuner.reset();
812        assert_eq!(tuner.num_configs(), 0);
813    }
814}