Skip to main content

llm_optimizer_decision/
parameter_optimizer.rs

1//! Parameter Optimizer
2//!
3//! High-level API for adaptive parameter optimization integrating
4//! contextual bandits, search strategies, and performance tracking.
5
6use dashmap::DashMap;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9use uuid::Uuid;
10
11use crate::{
12    adaptive_params::{AdaptiveParameterTuner, ParameterConfig, ParameterRange, ParameterStats},
13    context::RequestContext,
14    contextual_bandit::LinUCB,
15    errors::{DecisionError, Result},
16    parameter_search::{GridSearchConfig, ParameterSearchManager, SearchStrategy},
17    reward::{RewardCalculator, RewardWeights, ResponseMetrics, UserFeedback},
18};
19
20/// Optimization mode
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
22pub enum OptimizationMode {
23    /// Explore parameter space (search phase)
24    Explore,
25    /// Exploit best known parameters (production phase)
26    Exploit,
27    /// Balanced exploration and exploitation
28    Balanced,
29}
30
31/// Parameter optimization policy
32#[derive(Debug, Clone)]
33pub struct OptimizationPolicy {
34    /// Policy name
35    pub name: String,
36    /// Parameter range
37    pub range: ParameterRange,
38    /// Optimization mode
39    pub mode: OptimizationMode,
40    /// Exploration rate (for balanced mode)
41    pub exploration_rate: f64,
42}
43
44impl OptimizationPolicy {
45    /// Create new optimization policy
46    pub fn new(name: impl Into<String>, range: ParameterRange, mode: OptimizationMode) -> Self {
47        Self {
48            name: name.into(),
49            range,
50            mode,
51            exploration_rate: 0.2,
52        }
53    }
54
55    /// Set exploration rate
56    pub fn with_exploration_rate(mut self, rate: f64) -> Self {
57        self.exploration_rate = rate.clamp(0.0, 1.0);
58        self
59    }
60}
61
62/// Parameter optimizer engine
63pub struct ParameterOptimizer {
64    /// Adaptive tuners per policy
65    tuners: Arc<DashMap<String, AdaptiveParameterTuner>>,
66    /// Contextual bandits for parameter selection
67    bandits: Arc<DashMap<String, LinUCB>>,
68    /// Optimization policies
69    policies: Arc<DashMap<String, OptimizationPolicy>>,
70    /// Reward calculator
71    reward_calculator: RewardCalculator,
72    /// Feature dimension for contextual bandits
73    feature_dimension: usize,
74    /// LinUCB exploration parameter
75    alpha: f64,
76}
77
78impl ParameterOptimizer {
79    /// Create new parameter optimizer
80    pub fn new(reward_weights: RewardWeights, alpha: f64) -> Self {
81        let feature_dim = RequestContext::feature_dimension();
82
83        Self {
84            tuners: Arc::new(DashMap::new()),
85            bandits: Arc::new(DashMap::new()),
86            policies: Arc::new(DashMap::new()),
87            reward_calculator: RewardCalculator::new(reward_weights, 1.0, 5000.0),
88            feature_dimension: feature_dim,
89            alpha,
90        }
91    }
92
93    /// Create with default configuration
94    pub fn with_defaults() -> Self {
95        Self::new(RewardWeights::default_weights(), 1.0)
96    }
97
98    /// Create new optimization policy
99    pub fn create_policy(&self, policy: OptimizationPolicy) -> Result<()> {
100        let policy_name = policy.name.clone();
101
102        // Create adaptive tuner for this policy
103        let tuner = AdaptiveParameterTuner::new(policy.range.clone());
104        self.tuners.insert(policy_name.clone(), tuner);
105
106        // Create contextual bandit for parameter selection
107        let bandit = LinUCB::new(self.alpha, self.feature_dimension);
108        self.bandits.insert(policy_name.clone(), bandit);
109
110        // Store policy
111        self.policies.insert(policy_name, policy);
112
113        Ok(())
114    }
115
116    /// Initialize policy with search strategy
117    pub fn initialize_with_search(
118        &self,
119        policy_name: &str,
120        strategy: SearchStrategy,
121        num_configs: usize,
122    ) -> Result<Vec<Uuid>> {
123        let policy = self
124            .policies
125            .get(policy_name)
126            .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
127
128        let mut tuner = self
129            .tuners
130            .get_mut(policy_name)
131            .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
132
133        let mut bandit = self
134            .bandits
135            .get_mut(policy_name)
136            .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
137
138        // Generate configurations based on strategy
139        let configs = match strategy {
140            SearchStrategy::Grid => {
141                let grid_config = GridSearchConfig {
142                    temp_steps: (num_configs as f64).cbrt().ceil() as usize,
143                    top_p_steps: (num_configs as f64).cbrt().ceil() as usize,
144                    max_tokens_steps: (num_configs as f64).cbrt().ceil() as usize,
145                };
146                let search = ParameterSearchManager::with_grid_search(policy.range.clone(), grid_config);
147                search.grid_search.map(|s| s.all_configs()).unwrap_or_default()
148            }
149            SearchStrategy::Random => {
150                let mut search = ParameterSearchManager::with_random_search(policy.range.clone(), num_configs);
151                let mut configs = Vec::new();
152                while let Some(config) = search.next() {
153                    configs.push(config);
154                }
155                configs
156            }
157            SearchStrategy::LatinHypercube => {
158                let search = ParameterSearchManager::with_lhs(policy.range.clone(), num_configs);
159                search.lhs_search.map(|s| s.all_configs()).unwrap_or_default()
160            }
161            SearchStrategy::Sobol => {
162                return Err(DecisionError::InvalidParameter(
163                    "Sobol sequence not yet implemented".to_string(),
164                ));
165            }
166        };
167
168        // Register all configurations
169        let mut config_ids = Vec::new();
170        for config in configs {
171            let config_id = tuner.register_config(config)?;
172            bandit.add_arm(config_id);
173            config_ids.push(config_id);
174        }
175
176        Ok(config_ids)
177    }
178
179    /// Select parameter configuration for request
180    pub fn select_parameters(
181        &self,
182        policy_name: &str,
183        context: &RequestContext,
184    ) -> Result<(Uuid, ParameterConfig)> {
185        let policy = self
186            .policies
187            .get(policy_name)
188            .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
189
190        match policy.mode {
191            OptimizationMode::Explore => self.select_explore(policy_name, context),
192            OptimizationMode::Exploit => self.select_exploit(policy_name, context),
193            OptimizationMode::Balanced => {
194                // Randomly choose between explore and exploit based on rate
195                if rand::random::<f64>() < policy.exploration_rate {
196                    self.select_explore(policy_name, context)
197                } else {
198                    self.select_exploit(policy_name, context)
199                }
200            }
201        }
202    }
203
204    /// Select using exploration (contextual bandit)
205    fn select_explore(
206        &self,
207        policy_name: &str,
208        context: &RequestContext,
209    ) -> Result<(Uuid, ParameterConfig)> {
210        let bandit = self
211            .bandits
212            .get(policy_name)
213            .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
214
215        let tuner = self
216            .tuners
217            .get(policy_name)
218            .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
219
220        let config_id = bandit.select_arm(context)?;
221        let stats = tuner
222            .get_stats(&config_id)
223            .ok_or_else(|| DecisionError::VariantNotFound(config_id.to_string()))?;
224
225        Ok((config_id, stats.config.clone()))
226    }
227
228    /// Select using exploitation (best known)
229    fn select_exploit(
230        &self,
231        policy_name: &str,
232        context: &RequestContext,
233    ) -> Result<(Uuid, ParameterConfig)> {
234        let tuner = self
235            .tuners
236            .get(policy_name)
237            .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
238
239        // Use task-specific best if available
240        if let Some(task_type) = &context.task_type {
241            if let Some((config_id, config)) = tuner.get_best_for_task(task_type) {
242                return Ok((config_id, config));
243            }
244        }
245
246        // Fall back to overall best
247        let all_stats = tuner.get_all_stats();
248        let best = all_stats
249            .iter()
250            .max_by(|a, b| {
251                a.average_reward
252                    .partial_cmp(&b.average_reward)
253                    .unwrap_or(std::cmp::Ordering::Equal)
254            })
255            .ok_or_else(|| DecisionError::InvalidState("No configurations available".to_string()))?;
256
257        Ok((best.config_id, best.config.clone()))
258    }
259
260    /// Update with performance feedback
261    pub fn update_performance(
262        &self,
263        policy_name: &str,
264        config_id: &Uuid,
265        context: &RequestContext,
266        metrics: &ResponseMetrics,
267        feedback: Option<&UserFeedback>,
268    ) -> Result<()> {
269        // Calculate reward
270        let reward = if let Some(fb) = feedback {
271            self.reward_calculator.calculate_reward(metrics, fb)
272        } else {
273            self.reward_calculator.calculate_reward_metrics_only(metrics)
274        };
275
276        // Update adaptive tuner
277        let mut tuner = self
278            .tuners
279            .get_mut(policy_name)
280            .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
281
282        tuner.update_config(config_id, reward, metrics, feedback)?;
283
284        // Update contextual bandit
285        let mut bandit = self
286            .bandits
287            .get_mut(policy_name)
288            .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
289
290        bandit.update(config_id, context, reward)?;
291
292        Ok(())
293    }
294
295    /// Get performance statistics
296    pub fn get_performance_stats(&self, policy_name: &str) -> Result<Vec<ParameterStats>> {
297        let tuner = self
298            .tuners
299            .get(policy_name)
300            .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
301
302        Ok(tuner.get_all_stats())
303    }
304
305    /// Get best configuration for task type
306    pub fn get_best_for_task(
307        &self,
308        policy_name: &str,
309        task_type: &str,
310    ) -> Result<Option<(Uuid, ParameterConfig)>> {
311        let tuner = self
312            .tuners
313            .get(policy_name)
314            .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
315
316        Ok(tuner.get_best_for_task(task_type))
317    }
318
319    /// Update task-specific best configurations
320    pub fn update_task_bests(&self, policy_name: &str, task_types: &[String]) -> Result<()> {
321        let mut tuner = self
322            .tuners
323            .get_mut(policy_name)
324            .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
325
326        for task_type in task_types {
327            tuner.update_task_best(task_type.clone());
328        }
329
330        Ok(())
331    }
332
333    /// Change optimization mode
334    pub fn set_mode(&self, policy_name: &str, mode: OptimizationMode) -> Result<()> {
335        let mut policy = self
336            .policies
337            .get_mut(policy_name)
338            .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
339
340        policy.mode = mode;
341        Ok(())
342    }
343
344    /// Get current optimization mode
345    pub fn get_mode(&self, policy_name: &str) -> Result<OptimizationMode> {
346        let policy = self
347            .policies
348            .get(policy_name)
349            .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
350
351        Ok(policy.mode)
352    }
353
354    /// Update reward weights
355    pub fn set_reward_weights(&mut self, weights: RewardWeights) {
356        self.reward_calculator.set_weights(weights);
357    }
358
359    /// Get number of policies
360    pub fn num_policies(&self) -> usize {
361        self.policies.len()
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn test_optimizer_creation() {
371        let optimizer = ParameterOptimizer::with_defaults();
372        assert_eq!(optimizer.num_policies(), 0);
373    }
374
375    #[test]
376    fn test_create_policy() {
377        let optimizer = ParameterOptimizer::with_defaults();
378        let policy = OptimizationPolicy::new(
379            "test_policy",
380            ParameterRange::default(),
381            OptimizationMode::Balanced,
382        );
383
384        optimizer.create_policy(policy).unwrap();
385        assert_eq!(optimizer.num_policies(), 1);
386    }
387
388    #[test]
389    fn test_initialize_with_grid_search() {
390        let optimizer = ParameterOptimizer::with_defaults();
391        let policy = OptimizationPolicy::new(
392            "test_policy",
393            ParameterRange::default(),
394            OptimizationMode::Explore,
395        );
396
397        optimizer.create_policy(policy).unwrap();
398        let config_ids = optimizer
399            .initialize_with_search("test_policy", SearchStrategy::Grid, 8)
400            .unwrap();
401
402        assert!(!config_ids.is_empty());
403    }
404
405    #[test]
406    fn test_initialize_with_random_search() {
407        let optimizer = ParameterOptimizer::with_defaults();
408        let policy = OptimizationPolicy::new(
409            "test_policy",
410            ParameterRange::default(),
411            OptimizationMode::Explore,
412        );
413
414        optimizer.create_policy(policy).unwrap();
415        let config_ids = optimizer
416            .initialize_with_search("test_policy", SearchStrategy::Random, 10)
417            .unwrap();
418
419        assert_eq!(config_ids.len(), 10);
420    }
421
422    #[test]
423    fn test_initialize_with_lhs() {
424        let optimizer = ParameterOptimizer::with_defaults();
425        let policy = OptimizationPolicy::new(
426            "test_policy",
427            ParameterRange::default(),
428            OptimizationMode::Explore,
429        );
430
431        optimizer.create_policy(policy).unwrap();
432        let config_ids = optimizer
433            .initialize_with_search("test_policy", SearchStrategy::LatinHypercube, 15)
434            .unwrap();
435
436        assert!(!config_ids.is_empty());
437    }
438
439    #[test]
440    fn test_select_parameters_explore() {
441        let optimizer = ParameterOptimizer::with_defaults();
442        let policy = OptimizationPolicy::new(
443            "test_policy",
444            ParameterRange::default(),
445            OptimizationMode::Explore,
446        );
447
448        optimizer.create_policy(policy).unwrap();
449        optimizer
450            .initialize_with_search("test_policy", SearchStrategy::Random, 5)
451            .unwrap();
452
453        let context = RequestContext::new(100);
454        let (config_id, config) = optimizer.select_parameters("test_policy", &context).unwrap();
455
456        assert!(config_id != Uuid::nil());
457        assert!(config.validate().is_ok());
458    }
459
460    #[test]
461    fn test_select_parameters_balanced() {
462        let optimizer = ParameterOptimizer::with_defaults();
463        let policy = OptimizationPolicy::new(
464            "test_policy",
465            ParameterRange::default(),
466            OptimizationMode::Balanced,
467        );
468
469        optimizer.create_policy(policy).unwrap();
470        optimizer
471            .initialize_with_search("test_policy", SearchStrategy::Random, 5)
472            .unwrap();
473
474        let context = RequestContext::new(100);
475        let (_, config) = optimizer.select_parameters("test_policy", &context).unwrap();
476        assert!(config.validate().is_ok());
477    }
478
479    #[test]
480    fn test_update_performance() {
481        let optimizer = ParameterOptimizer::with_defaults();
482        let policy = OptimizationPolicy::new(
483            "test_policy",
484            ParameterRange::default(),
485            OptimizationMode::Explore,
486        );
487
488        optimizer.create_policy(policy).unwrap();
489        let config_ids = optimizer
490            .initialize_with_search("test_policy", SearchStrategy::Random, 3)
491            .unwrap();
492
493        let context = RequestContext::new(100);
494        let metrics = ResponseMetrics {
495            quality_score: 0.9,
496            cost: 0.1,
497            latency_ms: 1000.0,
498            token_count: 500,
499        };
500
501        optimizer
502            .update_performance("test_policy", &config_ids[0], &context, &metrics, None)
503            .unwrap();
504
505        let stats = optimizer.get_performance_stats("test_policy").unwrap();
506        let updated = stats.iter().find(|s| s.config_id == config_ids[0]).unwrap();
507        assert_eq!(updated.num_uses, 1);
508    }
509
510    #[test]
511    fn test_optimizer_learning() {
512        let optimizer = ParameterOptimizer::with_defaults();
513        let policy = OptimizationPolicy::new(
514            "test_policy",
515            ParameterRange::default(),
516            OptimizationMode::Explore,
517        );
518
519        optimizer.create_policy(policy).unwrap();
520        let config_ids = optimizer
521            .initialize_with_search("test_policy", SearchStrategy::Random, 3)
522            .unwrap();
523
524        let good_id = config_ids[0];
525        let bad_id = config_ids[1];
526
527        let context = RequestContext::new(100);
528        let good_metrics = ResponseMetrics {
529            quality_score: 0.95,
530            cost: 0.05,
531            latency_ms: 800.0,
532            token_count: 400,
533        };
534
535        let bad_metrics = ResponseMetrics {
536            quality_score: 0.4,
537            cost: 0.3,
538            latency_ms: 2000.0,
539            token_count: 800,
540        };
541
542        // Train with clear difference
543        for _ in 0..20 {
544            optimizer
545                .update_performance("test_policy", &good_id, &context, &good_metrics, None)
546                .unwrap();
547            optimizer
548                .update_performance("test_policy", &bad_id, &context, &bad_metrics, None)
549                .unwrap();
550        }
551
552        let stats = optimizer.get_performance_stats("test_policy").unwrap();
553        let good_stats = stats.iter().find(|s| s.config_id == good_id).unwrap();
554        let bad_stats = stats.iter().find(|s| s.config_id == bad_id).unwrap();
555
556        assert!(good_stats.average_reward > bad_stats.average_reward);
557    }
558
559    #[test]
560    fn test_get_best_for_task() {
561        let optimizer = ParameterOptimizer::with_defaults();
562        let range = ParameterRange::for_task_type("code");
563        let policy = OptimizationPolicy::new("code_policy", range, OptimizationMode::Explore);
564
565        optimizer.create_policy(policy).unwrap();
566        optimizer
567            .initialize_with_search("code_policy", SearchStrategy::Random, 5)
568            .unwrap();
569
570        let context = RequestContext::new(100).with_task_type("code");
571        let (config_id, _) = optimizer.select_parameters("code_policy", &context).unwrap();
572
573        let metrics = ResponseMetrics {
574            quality_score: 0.95,
575            cost: 0.1,
576            latency_ms: 1000.0,
577            token_count: 500,
578        };
579
580        // Need enough samples for best selection
581        for _ in 0..15 {
582            optimizer
583                .update_performance("code_policy", &config_id, &context, &metrics, None)
584                .unwrap();
585        }
586
587        optimizer
588            .update_task_bests("code_policy", &["code".to_string()])
589            .unwrap();
590
591        let best = optimizer.get_best_for_task("code_policy", "code").unwrap();
592        assert!(best.is_some());
593    }
594
595    #[test]
596    fn test_set_mode() {
597        let optimizer = ParameterOptimizer::with_defaults();
598        let policy = OptimizationPolicy::new(
599            "test_policy",
600            ParameterRange::default(),
601            OptimizationMode::Explore,
602        );
603
604        optimizer.create_policy(policy).unwrap();
605        assert_eq!(
606            optimizer.get_mode("test_policy").unwrap(),
607            OptimizationMode::Explore
608        );
609
610        optimizer
611            .set_mode("test_policy", OptimizationMode::Exploit)
612            .unwrap();
613        assert_eq!(
614            optimizer.get_mode("test_policy").unwrap(),
615            OptimizationMode::Exploit
616        );
617    }
618
619    #[test]
620    fn test_policy_with_exploration_rate() {
621        let policy = OptimizationPolicy::new(
622            "test",
623            ParameterRange::default(),
624            OptimizationMode::Balanced,
625        )
626        .with_exploration_rate(0.3);
627
628        assert_eq!(policy.exploration_rate, 0.3);
629    }
630
631    #[test]
632    fn test_exploration_rate_clamping() {
633        let policy = OptimizationPolicy::new(
634            "test",
635            ParameterRange::default(),
636            OptimizationMode::Balanced,
637        )
638        .with_exploration_rate(1.5);
639
640        assert_eq!(policy.exploration_rate, 1.0);
641    }
642}