1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use uuid::Uuid;
9
10use crate::{
11 context::RequestContext,
12 errors::{DecisionError, Result},
13 reward::{ResponseMetrics, UserFeedback},
14};
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub struct ParameterConfig {
19 pub temperature: f64,
21 pub top_p: f64,
23 pub max_tokens: usize,
25}
26
27impl Default for ParameterConfig {
28 fn default() -> Self {
29 Self {
30 temperature: 0.7,
31 top_p: 0.9,
32 max_tokens: 2048,
33 }
34 }
35}
36
37impl ParameterConfig {
38 pub fn new(temperature: f64, top_p: f64, max_tokens: usize) -> Result<Self> {
40 let config = Self {
41 temperature,
42 top_p,
43 max_tokens,
44 };
45 config.validate()?;
46 Ok(config)
47 }
48
49 pub fn validate(&self) -> Result<()> {
51 if self.temperature < 0.0 || self.temperature > 2.0 {
52 return Err(DecisionError::InvalidParameter(format!(
53 "Temperature {} out of range [0.0, 2.0]",
54 self.temperature
55 )));
56 }
57
58 if self.top_p < 0.0 || self.top_p > 1.0 {
59 return Err(DecisionError::InvalidParameter(format!(
60 "Top-p {} out of range [0.0, 1.0]",
61 self.top_p
62 )));
63 }
64
65 if self.max_tokens == 0 {
66 return Err(DecisionError::InvalidParameter(
67 "Max tokens must be greater than 0".to_string(),
68 ));
69 }
70
71 Ok(())
72 }
73
74 pub fn creative() -> Self {
76 Self {
77 temperature: 1.2,
78 top_p: 0.95,
79 max_tokens: 2048,
80 }
81 }
82
83 pub fn analytical() -> Self {
85 Self {
86 temperature: 0.3,
87 top_p: 0.85,
88 max_tokens: 1024,
89 }
90 }
91
92 pub fn code_generation() -> Self {
94 Self {
95 temperature: 0.2,
96 top_p: 0.9,
97 max_tokens: 2048,
98 }
99 }
100
101 pub fn balanced() -> Self {
103 Self::default()
104 }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ParameterRange {
110 pub temp_min: f64,
112 pub temp_max: f64,
114 pub top_p_min: f64,
116 pub top_p_max: f64,
118 pub max_tokens_min: usize,
120 pub max_tokens_max: usize,
122}
123
124impl Default for ParameterRange {
125 fn default() -> Self {
126 Self {
127 temp_min: 0.0,
128 temp_max: 2.0,
129 top_p_min: 0.7,
130 top_p_max: 1.0,
131 max_tokens_min: 256,
132 max_tokens_max: 8192,
133 }
134 }
135}
136
137impl ParameterRange {
138 pub fn new(
140 temp_min: f64,
141 temp_max: f64,
142 top_p_min: f64,
143 top_p_max: f64,
144 max_tokens_min: usize,
145 max_tokens_max: usize,
146 ) -> Result<Self> {
147 if temp_min >= temp_max {
148 return Err(DecisionError::InvalidParameter(
149 "Temperature min must be less than max".to_string(),
150 ));
151 }
152
153 if top_p_min >= top_p_max {
154 return Err(DecisionError::InvalidParameter(
155 "Top-p min must be less than max".to_string(),
156 ));
157 }
158
159 if max_tokens_min >= max_tokens_max {
160 return Err(DecisionError::InvalidParameter(
161 "Max tokens min must be less than max".to_string(),
162 ));
163 }
164
165 Ok(Self {
166 temp_min,
167 temp_max,
168 top_p_min,
169 top_p_max,
170 max_tokens_min,
171 max_tokens_max,
172 })
173 }
174
175 pub fn contains(&self, config: &ParameterConfig) -> bool {
177 config.temperature >= self.temp_min
178 && config.temperature <= self.temp_max
179 && config.top_p >= self.top_p_min
180 && config.top_p <= self.top_p_max
181 && config.max_tokens >= self.max_tokens_min
182 && config.max_tokens <= self.max_tokens_max
183 }
184
185 pub fn clamp(&self, config: &ParameterConfig) -> ParameterConfig {
187 ParameterConfig {
188 temperature: config.temperature.clamp(self.temp_min, self.temp_max),
189 top_p: config.top_p.clamp(self.top_p_min, self.top_p_max),
190 max_tokens: config
191 .max_tokens
192 .clamp(self.max_tokens_min, self.max_tokens_max),
193 }
194 }
195
196 pub fn for_task_type(task_type: &str) -> Self {
198 match task_type {
199 "creative" | "storytelling" | "brainstorming" => Self {
200 temp_min: 0.8,
201 temp_max: 1.5,
202 top_p_min: 0.9,
203 top_p_max: 0.98,
204 max_tokens_min: 512,
205 max_tokens_max: 4096,
206 },
207 "code" | "programming" | "technical" => Self {
208 temp_min: 0.0,
209 temp_max: 0.5,
210 top_p_min: 0.85,
211 top_p_max: 0.95,
212 max_tokens_min: 256,
213 max_tokens_max: 4096,
214 },
215 "analytical" | "reasoning" | "math" => Self {
216 temp_min: 0.0,
217 temp_max: 0.4,
218 top_p_min: 0.8,
219 top_p_max: 0.9,
220 max_tokens_min: 512,
221 max_tokens_max: 2048,
222 },
223 _ => Self::default(),
224 }
225 }
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct ParameterStats {
231 pub config_id: Uuid,
233 pub config: ParameterConfig,
235 pub num_uses: u64,
237 pub total_reward: f64,
239 pub average_reward: f64,
241 pub avg_quality: f64,
243 pub avg_cost: f64,
245 pub avg_latency: f64,
247 pub success_rate: f64,
249}
250
251impl ParameterStats {
252 pub fn new(config_id: Uuid, config: ParameterConfig) -> Self {
254 Self {
255 config_id,
256 config,
257 num_uses: 0,
258 total_reward: 0.0,
259 average_reward: 0.0,
260 avg_quality: 0.0,
261 avg_cost: 0.0,
262 avg_latency: 0.0,
263 success_rate: 0.0,
264 }
265 }
266
267 pub fn update(&mut self, reward: f64, metrics: &ResponseMetrics, success: bool) {
269 let n = self.num_uses as f64;
270 let n_plus_1 = (self.num_uses + 1) as f64;
271
272 self.total_reward += reward;
274 self.average_reward = (self.average_reward * n + reward) / n_plus_1;
275 self.avg_quality = (self.avg_quality * n + metrics.quality_score) / n_plus_1;
276 self.avg_cost = (self.avg_cost * n + metrics.cost) / n_plus_1;
277 self.avg_latency = (self.avg_latency * n + metrics.latency_ms) / n_plus_1;
278
279 let success_count = (self.success_rate * n) + if success { 1.0 } else { 0.0 };
280 self.success_rate = success_count / n_plus_1;
281
282 self.num_uses += 1;
283 }
284
285 pub fn confidence_width(&self, exploration_factor: f64) -> f64 {
287 if self.num_uses == 0 {
288 return f64::INFINITY;
289 }
290 exploration_factor * (2.0 * (self.num_uses as f64).ln()).sqrt() / (self.num_uses as f64)
291 }
292
293 pub fn ucb(&self, exploration_factor: f64) -> f64 {
295 self.average_reward + self.confidence_width(exploration_factor)
296 }
297}
298
299pub struct AdaptiveParameterTuner {
301 range: ParameterRange,
303 config_stats: HashMap<Uuid, ParameterStats>,
305 task_best_configs: HashMap<String, Uuid>,
307 exploration_factor: f64,
309 learning_rate: f64,
311 min_uses_for_stability: u64,
313}
314
315impl AdaptiveParameterTuner {
316 pub fn new(range: ParameterRange) -> Self {
318 Self {
319 range,
320 config_stats: HashMap::new(),
321 task_best_configs: HashMap::new(),
322 exploration_factor: 2.0,
323 learning_rate: 0.1,
324 min_uses_for_stability: 10,
325 }
326 }
327
328 pub fn with_defaults() -> Self {
330 Self::new(ParameterRange::default())
331 }
332
333 pub fn with_exploration_factor(mut self, factor: f64) -> Self {
335 self.exploration_factor = factor;
336 self
337 }
338
339 pub fn with_learning_rate(mut self, rate: f64) -> Self {
341 self.learning_rate = rate;
342 self
343 }
344
345 pub fn register_config(&mut self, config: ParameterConfig) -> Result<Uuid> {
347 config.validate()?;
348 if !self.range.contains(&config) {
349 return Err(DecisionError::InvalidParameter(
350 "Configuration outside allowed range".to_string(),
351 ));
352 }
353
354 let config_id = Uuid::new_v4();
355 self.config_stats
356 .insert(config_id, ParameterStats::new(config_id, config));
357 Ok(config_id)
358 }
359
360 pub fn select_config(&self, context: &RequestContext) -> Result<(Uuid, ParameterConfig)> {
362 if self.config_stats.is_empty() {
363 return Err(DecisionError::InvalidState(
364 "No configurations registered".to_string(),
365 ));
366 }
367
368 if let Some(task_type) = &context.task_type {
370 if let Some(config_id) = self.task_best_configs.get(task_type) {
371 if let Some(stats) = self.config_stats.get(config_id) {
372 if stats.num_uses >= self.min_uses_for_stability {
373 if rand::random::<f64>() < 0.8 {
375 return Ok((*config_id, stats.config.clone()));
376 }
377 }
378 }
379 }
380 }
381
382 let (best_id, best_stats) = self
384 .config_stats
385 .iter()
386 .max_by(|(_, a), (_, b)| {
387 let ucb_a = a.ucb(self.exploration_factor);
388 let ucb_b = b.ucb(self.exploration_factor);
389 ucb_a.partial_cmp(&ucb_b).unwrap_or(std::cmp::Ordering::Equal)
390 })
391 .ok_or_else(|| DecisionError::InvalidState("No configurations available".to_string()))?;
392
393 Ok((*best_id, best_stats.config.clone()))
394 }
395
396 pub fn update_config(
398 &mut self,
399 config_id: &Uuid,
400 reward: f64,
401 metrics: &ResponseMetrics,
402 feedback: Option<&UserFeedback>,
403 ) -> Result<()> {
404 let stats = self.config_stats.get_mut(config_id).ok_or_else(|| {
405 DecisionError::InvalidParameter(format!("Configuration {} not found", config_id))
406 })?;
407
408 let success = feedback.map(|f| f.task_completed).unwrap_or(true);
409 stats.update(reward, metrics, success);
410
411 Ok(())
412 }
413
414 pub fn get_best_for_task(&self, task_type: &str) -> Option<(Uuid, ParameterConfig)> {
416 let task_range = ParameterRange::for_task_type(task_type);
418
419 self.config_stats
420 .iter()
421 .filter(|(_, stats)| {
422 stats.num_uses >= self.min_uses_for_stability
423 && task_range.contains(&stats.config)
424 })
425 .max_by(|(_, a), (_, b)| {
426 a.average_reward
427 .partial_cmp(&b.average_reward)
428 .unwrap_or(std::cmp::Ordering::Equal)
429 })
430 .map(|(id, stats)| (*id, stats.config.clone()))
431 }
432
433 pub fn update_task_best(&mut self, task_type: String) {
435 if let Some((config_id, _)) = self.get_best_for_task(&task_type) {
436 self.task_best_configs.insert(task_type, config_id);
437 }
438 }
439
440 pub fn suggest_improvement(&self, config_id: &Uuid) -> Result<ParameterConfig> {
442 let stats = self.config_stats.get(config_id).ok_or_else(|| {
443 DecisionError::InvalidParameter(format!("Configuration {} not found", config_id))
444 })?;
445
446 if stats.num_uses < self.min_uses_for_stability {
447 return Err(DecisionError::InvalidState(
448 "Not enough data for improvement suggestion".to_string(),
449 ));
450 }
451
452 let mut new_config = stats.config.clone();
456
457 if stats.avg_quality < 0.7 {
458 new_config.temperature *= 1.0 - self.learning_rate;
460 new_config.top_p *= 1.0 - self.learning_rate * 0.5;
461 } else if stats.avg_quality > 0.9 && stats.success_rate > 0.8 {
462 new_config.temperature *= 1.0 + self.learning_rate * 0.5;
464 new_config.top_p = (new_config.top_p + 0.05).min(1.0);
465 }
466
467 new_config = self.range.clamp(&new_config);
469 new_config.validate()?;
470
471 Ok(new_config)
472 }
473
474 pub fn get_all_stats(&self) -> Vec<ParameterStats> {
476 self.config_stats.values().cloned().collect()
477 }
478
479 pub fn get_stats(&self, config_id: &Uuid) -> Option<&ParameterStats> {
481 self.config_stats.get(config_id)
482 }
483
484 pub fn num_configs(&self) -> usize {
486 self.config_stats.len()
487 }
488
489 pub fn reset(&mut self) {
491 self.config_stats.clear();
492 self.task_best_configs.clear();
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499
500 #[test]
501 fn test_parameter_config_creation() {
502 let config = ParameterConfig::new(0.7, 0.9, 1024).unwrap();
503 assert_eq!(config.temperature, 0.7);
504 assert_eq!(config.top_p, 0.9);
505 assert_eq!(config.max_tokens, 1024);
506 }
507
508 #[test]
509 fn test_parameter_config_validation() {
510 assert!(ParameterConfig::new(-0.1, 0.9, 1024).is_err());
511 assert!(ParameterConfig::new(2.5, 0.9, 1024).is_err());
512 assert!(ParameterConfig::new(0.7, 1.5, 1024).is_err());
513 assert!(ParameterConfig::new(0.7, 0.9, 0).is_err());
514 }
515
516 #[test]
517 fn test_preset_configs() {
518 let creative = ParameterConfig::creative();
519 assert!(creative.temperature > 1.0);
520 assert!(creative.validate().is_ok());
521
522 let analytical = ParameterConfig::analytical();
523 assert!(analytical.temperature < 0.5);
524 assert!(analytical.validate().is_ok());
525
526 let code = ParameterConfig::code_generation();
527 assert!(code.temperature < 0.3);
528 assert!(code.validate().is_ok());
529 }
530
531 #[test]
532 fn test_parameter_range_contains() {
533 let range = ParameterRange::default();
534 let config = ParameterConfig::default();
535 assert!(range.contains(&config));
536
537 let out_of_range = ParameterConfig {
538 temperature: 3.0,
539 top_p: 0.9,
540 max_tokens: 1024,
541 };
542 assert!(!range.contains(&out_of_range));
543 }
544
545 #[test]
546 fn test_parameter_range_clamp() {
547 let range = ParameterRange::default();
548 let config = ParameterConfig {
549 temperature: 3.0,
550 top_p: 0.5,
551 max_tokens: 10000,
552 };
553
554 let clamped = range.clamp(&config);
555 assert_eq!(clamped.temperature, range.temp_max);
556 assert_eq!(clamped.top_p, 0.7); assert_eq!(clamped.max_tokens, range.max_tokens_max);
558 }
559
560 #[test]
561 fn test_task_specific_ranges() {
562 let creative_range = ParameterRange::for_task_type("creative");
563 assert!(creative_range.temp_min >= 0.8);
564
565 let code_range = ParameterRange::for_task_type("code");
566 assert!(code_range.temp_max <= 0.5);
567
568 let analytical_range = ParameterRange::for_task_type("analytical");
569 assert!(analytical_range.temp_max <= 0.4);
570 }
571
572 #[test]
573 fn test_parameter_stats_creation() {
574 let config_id = Uuid::new_v4();
575 let config = ParameterConfig::default();
576 let stats = ParameterStats::new(config_id, config.clone());
577
578 assert_eq!(stats.config_id, config_id);
579 assert_eq!(stats.num_uses, 0);
580 assert_eq!(stats.average_reward, 0.0);
581 }
582
583 #[test]
584 fn test_parameter_stats_update() {
585 let config_id = Uuid::new_v4();
586 let config = ParameterConfig::default();
587 let mut stats = ParameterStats::new(config_id, config);
588
589 let metrics = ResponseMetrics {
590 quality_score: 0.9,
591 cost: 0.1,
592 latency_ms: 1000.0,
593 token_count: 500,
594 };
595
596 stats.update(0.8, &metrics, true);
597
598 assert_eq!(stats.num_uses, 1);
599 assert_eq!(stats.average_reward, 0.8);
600 assert_eq!(stats.avg_quality, 0.9);
601 assert_eq!(stats.success_rate, 1.0);
602 }
603
604 #[test]
605 fn test_parameter_stats_running_average() {
606 let config_id = Uuid::new_v4();
607 let config = ParameterConfig::default();
608 let mut stats = ParameterStats::new(config_id, config);
609
610 let metrics1 = ResponseMetrics {
611 quality_score: 0.8,
612 cost: 0.1,
613 latency_ms: 1000.0,
614 token_count: 500,
615 };
616
617 let metrics2 = ResponseMetrics {
618 quality_score: 1.0,
619 cost: 0.2,
620 latency_ms: 1500.0,
621 token_count: 600,
622 };
623
624 stats.update(0.7, &metrics1, true);
625 stats.update(0.9, &metrics2, true);
626
627 assert_eq!(stats.num_uses, 2);
628 assert_eq!(stats.average_reward, 0.8);
629 assert_eq!(stats.avg_quality, 0.9);
630 assert_eq!(stats.success_rate, 1.0);
631 }
632
633 #[test]
634 fn test_ucb_calculation() {
635 let config_id = Uuid::new_v4();
636 let config = ParameterConfig::default();
637 let mut stats = ParameterStats::new(config_id, config);
638
639 let metrics = ResponseMetrics {
640 quality_score: 0.9,
641 cost: 0.1,
642 latency_ms: 1000.0,
643 token_count: 500,
644 };
645
646 for _ in 0..5 {
648 stats.update(0.8, &metrics, true);
649 }
650
651 let ucb = stats.ucb(2.0);
652 assert!(ucb >= stats.average_reward);
653 assert!(stats.num_uses == 5);
654 }
655
656 #[test]
657 fn test_adaptive_tuner_creation() {
658 let tuner = AdaptiveParameterTuner::with_defaults();
659 assert_eq!(tuner.num_configs(), 0);
660 }
661
662 #[test]
663 fn test_register_config() {
664 let mut tuner = AdaptiveParameterTuner::with_defaults();
665 let config = ParameterConfig::default();
666
667 let config_id = tuner.register_config(config).unwrap();
668 assert_eq!(tuner.num_configs(), 1);
669 assert!(tuner.get_stats(&config_id).is_some());
670 }
671
672 #[test]
673 fn test_register_invalid_config() {
674 let mut tuner = AdaptiveParameterTuner::with_defaults();
675 let config = ParameterConfig {
676 temperature: 3.0,
677 top_p: 0.9,
678 max_tokens: 1024,
679 };
680
681 assert!(tuner.register_config(config).is_err());
682 }
683
684 #[test]
685 fn test_select_config() {
686 let mut tuner = AdaptiveParameterTuner::with_defaults();
687 let config1 = ParameterConfig::default();
688 let config2 = ParameterConfig::creative();
689
690 tuner.register_config(config1).unwrap();
691 tuner.register_config(config2).unwrap();
692
693 let context = RequestContext::new(100);
694 let (config_id, _) = tuner.select_config(&context).unwrap();
695 assert!(tuner.get_stats(&config_id).is_some());
696 }
697
698 #[test]
699 fn test_update_config() {
700 let mut tuner = AdaptiveParameterTuner::with_defaults();
701 let config = ParameterConfig::default();
702 let config_id = tuner.register_config(config).unwrap();
703
704 let metrics = ResponseMetrics {
705 quality_score: 0.9,
706 cost: 0.1,
707 latency_ms: 1000.0,
708 token_count: 500,
709 };
710
711 tuner.update_config(&config_id, 0.8, &metrics, None).unwrap();
712
713 let stats = tuner.get_stats(&config_id).unwrap();
714 assert_eq!(stats.num_uses, 1);
715 assert_eq!(stats.average_reward, 0.8);
716 }
717
718 #[test]
719 fn test_tuner_learning() {
720 let mut tuner = AdaptiveParameterTuner::with_defaults();
721 let config1 = ParameterConfig::default();
722 let config2 = ParameterConfig::creative();
723
724 let id1 = tuner.register_config(config1).unwrap();
725 let id2 = tuner.register_config(config2).unwrap();
726
727 let good_metrics = ResponseMetrics {
728 quality_score: 0.95,
729 cost: 0.1,
730 latency_ms: 1000.0,
731 token_count: 500,
732 };
733
734 let bad_metrics = ResponseMetrics {
735 quality_score: 0.5,
736 cost: 0.2,
737 latency_ms: 2000.0,
738 token_count: 600,
739 };
740
741 for _ in 0..20 {
743 tuner.update_config(&id1, 0.9, &good_metrics, None).unwrap();
744 }
745
746 for _ in 0..20 {
748 tuner.update_config(&id2, 0.3, &bad_metrics, None).unwrap();
749 }
750
751 let stats1 = tuner.get_stats(&id1).unwrap();
752 let stats2 = tuner.get_stats(&id2).unwrap();
753
754 assert!(stats1.average_reward > stats2.average_reward);
755 }
756
757 #[test]
758 fn test_get_best_for_task() {
759 let mut tuner = AdaptiveParameterTuner::with_defaults();
760 let code_config = ParameterConfig::code_generation();
761 let config_id = tuner.register_config(code_config).unwrap();
762
763 let good_metrics = ResponseMetrics {
764 quality_score: 0.95,
765 cost: 0.1,
766 latency_ms: 1000.0,
767 token_count: 500,
768 };
769
770 for _ in 0..15 {
772 tuner.update_config(&config_id, 0.9, &good_metrics, None).unwrap();
773 }
774
775 tuner.update_task_best("code".to_string());
776 let best = tuner.get_best_for_task("code");
777 assert!(best.is_some());
778 }
779
780 #[test]
781 fn test_suggest_improvement() {
782 let mut tuner = AdaptiveParameterTuner::with_defaults();
783 let config = ParameterConfig::default();
784 let config_id = tuner.register_config(config).unwrap();
785
786 let metrics = ResponseMetrics {
787 quality_score: 0.5,
788 cost: 0.1,
789 latency_ms: 1000.0,
790 token_count: 500,
791 };
792
793 for _ in 0..15 {
795 tuner.update_config(&config_id, 0.6, &metrics, None).unwrap();
796 }
797
798 let improved = tuner.suggest_improvement(&config_id).unwrap();
799 let original = tuner.get_stats(&config_id).unwrap();
800
801 assert!(improved.temperature <= original.config.temperature);
803 }
804
805 #[test]
806 fn test_reset() {
807 let mut tuner = AdaptiveParameterTuner::with_defaults();
808 tuner.register_config(ParameterConfig::default()).unwrap();
809
810 assert_eq!(tuner.num_configs(), 1);
811 tuner.reset();
812 assert_eq!(tuner.num_configs(), 0);
813 }
814}