1use dashmap::DashMap;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9use uuid::Uuid;
10
11use crate::{
12 adaptive_params::{AdaptiveParameterTuner, ParameterConfig, ParameterRange, ParameterStats},
13 context::RequestContext,
14 contextual_bandit::LinUCB,
15 errors::{DecisionError, Result},
16 parameter_search::{GridSearchConfig, ParameterSearchManager, SearchStrategy},
17 reward::{RewardCalculator, RewardWeights, ResponseMetrics, UserFeedback},
18};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
22pub enum OptimizationMode {
23 Explore,
25 Exploit,
27 Balanced,
29}
30
31#[derive(Debug, Clone)]
33pub struct OptimizationPolicy {
34 pub name: String,
36 pub range: ParameterRange,
38 pub mode: OptimizationMode,
40 pub exploration_rate: f64,
42}
43
44impl OptimizationPolicy {
45 pub fn new(name: impl Into<String>, range: ParameterRange, mode: OptimizationMode) -> Self {
47 Self {
48 name: name.into(),
49 range,
50 mode,
51 exploration_rate: 0.2,
52 }
53 }
54
55 pub fn with_exploration_rate(mut self, rate: f64) -> Self {
57 self.exploration_rate = rate.clamp(0.0, 1.0);
58 self
59 }
60}
61
62pub struct ParameterOptimizer {
64 tuners: Arc<DashMap<String, AdaptiveParameterTuner>>,
66 bandits: Arc<DashMap<String, LinUCB>>,
68 policies: Arc<DashMap<String, OptimizationPolicy>>,
70 reward_calculator: RewardCalculator,
72 feature_dimension: usize,
74 alpha: f64,
76}
77
78impl ParameterOptimizer {
79 pub fn new(reward_weights: RewardWeights, alpha: f64) -> Self {
81 let feature_dim = RequestContext::feature_dimension();
82
83 Self {
84 tuners: Arc::new(DashMap::new()),
85 bandits: Arc::new(DashMap::new()),
86 policies: Arc::new(DashMap::new()),
87 reward_calculator: RewardCalculator::new(reward_weights, 1.0, 5000.0),
88 feature_dimension: feature_dim,
89 alpha,
90 }
91 }
92
93 pub fn with_defaults() -> Self {
95 Self::new(RewardWeights::default_weights(), 1.0)
96 }
97
98 pub fn create_policy(&self, policy: OptimizationPolicy) -> Result<()> {
100 let policy_name = policy.name.clone();
101
102 let tuner = AdaptiveParameterTuner::new(policy.range.clone());
104 self.tuners.insert(policy_name.clone(), tuner);
105
106 let bandit = LinUCB::new(self.alpha, self.feature_dimension);
108 self.bandits.insert(policy_name.clone(), bandit);
109
110 self.policies.insert(policy_name, policy);
112
113 Ok(())
114 }
115
116 pub fn initialize_with_search(
118 &self,
119 policy_name: &str,
120 strategy: SearchStrategy,
121 num_configs: usize,
122 ) -> Result<Vec<Uuid>> {
123 let policy = self
124 .policies
125 .get(policy_name)
126 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
127
128 let mut tuner = self
129 .tuners
130 .get_mut(policy_name)
131 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
132
133 let mut bandit = self
134 .bandits
135 .get_mut(policy_name)
136 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
137
138 let configs = match strategy {
140 SearchStrategy::Grid => {
141 let grid_config = GridSearchConfig {
142 temp_steps: (num_configs as f64).cbrt().ceil() as usize,
143 top_p_steps: (num_configs as f64).cbrt().ceil() as usize,
144 max_tokens_steps: (num_configs as f64).cbrt().ceil() as usize,
145 };
146 let search = ParameterSearchManager::with_grid_search(policy.range.clone(), grid_config);
147 search.grid_search.map(|s| s.all_configs()).unwrap_or_default()
148 }
149 SearchStrategy::Random => {
150 let mut search = ParameterSearchManager::with_random_search(policy.range.clone(), num_configs);
151 let mut configs = Vec::new();
152 while let Some(config) = search.next() {
153 configs.push(config);
154 }
155 configs
156 }
157 SearchStrategy::LatinHypercube => {
158 let search = ParameterSearchManager::with_lhs(policy.range.clone(), num_configs);
159 search.lhs_search.map(|s| s.all_configs()).unwrap_or_default()
160 }
161 SearchStrategy::Sobol => {
162 return Err(DecisionError::InvalidParameter(
163 "Sobol sequence not yet implemented".to_string(),
164 ));
165 }
166 };
167
168 let mut config_ids = Vec::new();
170 for config in configs {
171 let config_id = tuner.register_config(config)?;
172 bandit.add_arm(config_id);
173 config_ids.push(config_id);
174 }
175
176 Ok(config_ids)
177 }
178
179 pub fn select_parameters(
181 &self,
182 policy_name: &str,
183 context: &RequestContext,
184 ) -> Result<(Uuid, ParameterConfig)> {
185 let policy = self
186 .policies
187 .get(policy_name)
188 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
189
190 match policy.mode {
191 OptimizationMode::Explore => self.select_explore(policy_name, context),
192 OptimizationMode::Exploit => self.select_exploit(policy_name, context),
193 OptimizationMode::Balanced => {
194 if rand::random::<f64>() < policy.exploration_rate {
196 self.select_explore(policy_name, context)
197 } else {
198 self.select_exploit(policy_name, context)
199 }
200 }
201 }
202 }
203
204 fn select_explore(
206 &self,
207 policy_name: &str,
208 context: &RequestContext,
209 ) -> Result<(Uuid, ParameterConfig)> {
210 let bandit = self
211 .bandits
212 .get(policy_name)
213 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
214
215 let tuner = self
216 .tuners
217 .get(policy_name)
218 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
219
220 let config_id = bandit.select_arm(context)?;
221 let stats = tuner
222 .get_stats(&config_id)
223 .ok_or_else(|| DecisionError::VariantNotFound(config_id.to_string()))?;
224
225 Ok((config_id, stats.config.clone()))
226 }
227
228 fn select_exploit(
230 &self,
231 policy_name: &str,
232 context: &RequestContext,
233 ) -> Result<(Uuid, ParameterConfig)> {
234 let tuner = self
235 .tuners
236 .get(policy_name)
237 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
238
239 if let Some(task_type) = &context.task_type {
241 if let Some((config_id, config)) = tuner.get_best_for_task(task_type) {
242 return Ok((config_id, config));
243 }
244 }
245
246 let all_stats = tuner.get_all_stats();
248 let best = all_stats
249 .iter()
250 .max_by(|a, b| {
251 a.average_reward
252 .partial_cmp(&b.average_reward)
253 .unwrap_or(std::cmp::Ordering::Equal)
254 })
255 .ok_or_else(|| DecisionError::InvalidState("No configurations available".to_string()))?;
256
257 Ok((best.config_id, best.config.clone()))
258 }
259
260 pub fn update_performance(
262 &self,
263 policy_name: &str,
264 config_id: &Uuid,
265 context: &RequestContext,
266 metrics: &ResponseMetrics,
267 feedback: Option<&UserFeedback>,
268 ) -> Result<()> {
269 let reward = if let Some(fb) = feedback {
271 self.reward_calculator.calculate_reward(metrics, fb)
272 } else {
273 self.reward_calculator.calculate_reward_metrics_only(metrics)
274 };
275
276 let mut tuner = self
278 .tuners
279 .get_mut(policy_name)
280 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
281
282 tuner.update_config(config_id, reward, metrics, feedback)?;
283
284 let mut bandit = self
286 .bandits
287 .get_mut(policy_name)
288 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
289
290 bandit.update(config_id, context, reward)?;
291
292 Ok(())
293 }
294
295 pub fn get_performance_stats(&self, policy_name: &str) -> Result<Vec<ParameterStats>> {
297 let tuner = self
298 .tuners
299 .get(policy_name)
300 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
301
302 Ok(tuner.get_all_stats())
303 }
304
305 pub fn get_best_for_task(
307 &self,
308 policy_name: &str,
309 task_type: &str,
310 ) -> Result<Option<(Uuid, ParameterConfig)>> {
311 let tuner = self
312 .tuners
313 .get(policy_name)
314 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
315
316 Ok(tuner.get_best_for_task(task_type))
317 }
318
319 pub fn update_task_bests(&self, policy_name: &str, task_types: &[String]) -> Result<()> {
321 let mut tuner = self
322 .tuners
323 .get_mut(policy_name)
324 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
325
326 for task_type in task_types {
327 tuner.update_task_best(task_type.clone());
328 }
329
330 Ok(())
331 }
332
333 pub fn set_mode(&self, policy_name: &str, mode: OptimizationMode) -> Result<()> {
335 let mut policy = self
336 .policies
337 .get_mut(policy_name)
338 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
339
340 policy.mode = mode;
341 Ok(())
342 }
343
344 pub fn get_mode(&self, policy_name: &str) -> Result<OptimizationMode> {
346 let policy = self
347 .policies
348 .get(policy_name)
349 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
350
351 Ok(policy.mode)
352 }
353
354 pub fn set_reward_weights(&mut self, weights: RewardWeights) {
356 self.reward_calculator.set_weights(weights);
357 }
358
359 pub fn num_policies(&self) -> usize {
361 self.policies.len()
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn test_optimizer_creation() {
371 let optimizer = ParameterOptimizer::with_defaults();
372 assert_eq!(optimizer.num_policies(), 0);
373 }
374
375 #[test]
376 fn test_create_policy() {
377 let optimizer = ParameterOptimizer::with_defaults();
378 let policy = OptimizationPolicy::new(
379 "test_policy",
380 ParameterRange::default(),
381 OptimizationMode::Balanced,
382 );
383
384 optimizer.create_policy(policy).unwrap();
385 assert_eq!(optimizer.num_policies(), 1);
386 }
387
388 #[test]
389 fn test_initialize_with_grid_search() {
390 let optimizer = ParameterOptimizer::with_defaults();
391 let policy = OptimizationPolicy::new(
392 "test_policy",
393 ParameterRange::default(),
394 OptimizationMode::Explore,
395 );
396
397 optimizer.create_policy(policy).unwrap();
398 let config_ids = optimizer
399 .initialize_with_search("test_policy", SearchStrategy::Grid, 8)
400 .unwrap();
401
402 assert!(!config_ids.is_empty());
403 }
404
405 #[test]
406 fn test_initialize_with_random_search() {
407 let optimizer = ParameterOptimizer::with_defaults();
408 let policy = OptimizationPolicy::new(
409 "test_policy",
410 ParameterRange::default(),
411 OptimizationMode::Explore,
412 );
413
414 optimizer.create_policy(policy).unwrap();
415 let config_ids = optimizer
416 .initialize_with_search("test_policy", SearchStrategy::Random, 10)
417 .unwrap();
418
419 assert_eq!(config_ids.len(), 10);
420 }
421
422 #[test]
423 fn test_initialize_with_lhs() {
424 let optimizer = ParameterOptimizer::with_defaults();
425 let policy = OptimizationPolicy::new(
426 "test_policy",
427 ParameterRange::default(),
428 OptimizationMode::Explore,
429 );
430
431 optimizer.create_policy(policy).unwrap();
432 let config_ids = optimizer
433 .initialize_with_search("test_policy", SearchStrategy::LatinHypercube, 15)
434 .unwrap();
435
436 assert!(!config_ids.is_empty());
437 }
438
439 #[test]
440 fn test_select_parameters_explore() {
441 let optimizer = ParameterOptimizer::with_defaults();
442 let policy = OptimizationPolicy::new(
443 "test_policy",
444 ParameterRange::default(),
445 OptimizationMode::Explore,
446 );
447
448 optimizer.create_policy(policy).unwrap();
449 optimizer
450 .initialize_with_search("test_policy", SearchStrategy::Random, 5)
451 .unwrap();
452
453 let context = RequestContext::new(100);
454 let (config_id, config) = optimizer.select_parameters("test_policy", &context).unwrap();
455
456 assert!(config_id != Uuid::nil());
457 assert!(config.validate().is_ok());
458 }
459
460 #[test]
461 fn test_select_parameters_balanced() {
462 let optimizer = ParameterOptimizer::with_defaults();
463 let policy = OptimizationPolicy::new(
464 "test_policy",
465 ParameterRange::default(),
466 OptimizationMode::Balanced,
467 );
468
469 optimizer.create_policy(policy).unwrap();
470 optimizer
471 .initialize_with_search("test_policy", SearchStrategy::Random, 5)
472 .unwrap();
473
474 let context = RequestContext::new(100);
475 let (_, config) = optimizer.select_parameters("test_policy", &context).unwrap();
476 assert!(config.validate().is_ok());
477 }
478
479 #[test]
480 fn test_update_performance() {
481 let optimizer = ParameterOptimizer::with_defaults();
482 let policy = OptimizationPolicy::new(
483 "test_policy",
484 ParameterRange::default(),
485 OptimizationMode::Explore,
486 );
487
488 optimizer.create_policy(policy).unwrap();
489 let config_ids = optimizer
490 .initialize_with_search("test_policy", SearchStrategy::Random, 3)
491 .unwrap();
492
493 let context = RequestContext::new(100);
494 let metrics = ResponseMetrics {
495 quality_score: 0.9,
496 cost: 0.1,
497 latency_ms: 1000.0,
498 token_count: 500,
499 };
500
501 optimizer
502 .update_performance("test_policy", &config_ids[0], &context, &metrics, None)
503 .unwrap();
504
505 let stats = optimizer.get_performance_stats("test_policy").unwrap();
506 let updated = stats.iter().find(|s| s.config_id == config_ids[0]).unwrap();
507 assert_eq!(updated.num_uses, 1);
508 }
509
510 #[test]
511 fn test_optimizer_learning() {
512 let optimizer = ParameterOptimizer::with_defaults();
513 let policy = OptimizationPolicy::new(
514 "test_policy",
515 ParameterRange::default(),
516 OptimizationMode::Explore,
517 );
518
519 optimizer.create_policy(policy).unwrap();
520 let config_ids = optimizer
521 .initialize_with_search("test_policy", SearchStrategy::Random, 3)
522 .unwrap();
523
524 let good_id = config_ids[0];
525 let bad_id = config_ids[1];
526
527 let context = RequestContext::new(100);
528 let good_metrics = ResponseMetrics {
529 quality_score: 0.95,
530 cost: 0.05,
531 latency_ms: 800.0,
532 token_count: 400,
533 };
534
535 let bad_metrics = ResponseMetrics {
536 quality_score: 0.4,
537 cost: 0.3,
538 latency_ms: 2000.0,
539 token_count: 800,
540 };
541
542 for _ in 0..20 {
544 optimizer
545 .update_performance("test_policy", &good_id, &context, &good_metrics, None)
546 .unwrap();
547 optimizer
548 .update_performance("test_policy", &bad_id, &context, &bad_metrics, None)
549 .unwrap();
550 }
551
552 let stats = optimizer.get_performance_stats("test_policy").unwrap();
553 let good_stats = stats.iter().find(|s| s.config_id == good_id).unwrap();
554 let bad_stats = stats.iter().find(|s| s.config_id == bad_id).unwrap();
555
556 assert!(good_stats.average_reward > bad_stats.average_reward);
557 }
558
559 #[test]
560 fn test_get_best_for_task() {
561 let optimizer = ParameterOptimizer::with_defaults();
562 let range = ParameterRange::for_task_type("code");
563 let policy = OptimizationPolicy::new("code_policy", range, OptimizationMode::Explore);
564
565 optimizer.create_policy(policy).unwrap();
566 optimizer
567 .initialize_with_search("code_policy", SearchStrategy::Random, 5)
568 .unwrap();
569
570 let context = RequestContext::new(100).with_task_type("code");
571 let (config_id, _) = optimizer.select_parameters("code_policy", &context).unwrap();
572
573 let metrics = ResponseMetrics {
574 quality_score: 0.95,
575 cost: 0.1,
576 latency_ms: 1000.0,
577 token_count: 500,
578 };
579
580 for _ in 0..15 {
582 optimizer
583 .update_performance("code_policy", &config_id, &context, &metrics, None)
584 .unwrap();
585 }
586
587 optimizer
588 .update_task_bests("code_policy", &["code".to_string()])
589 .unwrap();
590
591 let best = optimizer.get_best_for_task("code_policy", "code").unwrap();
592 assert!(best.is_some());
593 }
594
595 #[test]
596 fn test_set_mode() {
597 let optimizer = ParameterOptimizer::with_defaults();
598 let policy = OptimizationPolicy::new(
599 "test_policy",
600 ParameterRange::default(),
601 OptimizationMode::Explore,
602 );
603
604 optimizer.create_policy(policy).unwrap();
605 assert_eq!(
606 optimizer.get_mode("test_policy").unwrap(),
607 OptimizationMode::Explore
608 );
609
610 optimizer
611 .set_mode("test_policy", OptimizationMode::Exploit)
612 .unwrap();
613 assert_eq!(
614 optimizer.get_mode("test_policy").unwrap(),
615 OptimizationMode::Exploit
616 );
617 }
618
619 #[test]
620 fn test_policy_with_exploration_rate() {
621 let policy = OptimizationPolicy::new(
622 "test",
623 ParameterRange::default(),
624 OptimizationMode::Balanced,
625 )
626 .with_exploration_rate(0.3);
627
628 assert_eq!(policy.exploration_rate, 0.3);
629 }
630
631 #[test]
632 fn test_exploration_rate_clamping() {
633 let policy = OptimizationPolicy::new(
634 "test",
635 ParameterRange::default(),
636 OptimizationMode::Balanced,
637 )
638 .with_exploration_rate(1.5);
639
640 assert_eq!(policy.exploration_rate, 1.0);
641 }
642}