trueno/tuner/evolution/
bandit.rs1use super::super::types::KernelType;
6
7#[derive(Debug, Clone, Default)]
13pub struct KernelArm {
14 pub pulls: u32,
16 pub total_reward: f32,
18 pub total_reward_sq: f32,
20}
21
22impl KernelArm {
23 #[allow(clippy::cast_precision_loss)] pub fn mean(&self) -> f32 {
26 if self.pulls == 0 {
27 0.0
28 } else {
29 self.total_reward / self.pulls as f32
30 }
31 }
32
33 #[allow(clippy::cast_precision_loss)] pub fn ucb(&self, total_pulls: u32, c: f32) -> f32 {
36 if self.pulls == 0 {
37 f32::INFINITY } else {
39 self.mean() + c * (2.0 * (total_pulls.max(1) as f32).ln() / self.pulls as f32).sqrt()
40 }
41 }
42}
43
44#[derive(Debug, Clone, Default)]
53pub struct KernelBandit {
54 pub(crate) arms: Vec<KernelArm>,
56 pub(crate) total_pulls: u32,
58 pub(crate) exploration_c: f32,
60 pub(crate) use_thompson: bool,
62}
63
64impl KernelBandit {
65 pub const NUM_KERNELS: usize = 12;
67
68 pub fn new() -> Self {
70 Self {
71 arms: vec![KernelArm::default(); Self::NUM_KERNELS],
72 total_pulls: 0,
73 exploration_c: 2.0, use_thompson: false,
75 }
76 }
77
78 pub fn with_thompson_sampling() -> Self {
80 Self {
81 arms: vec![KernelArm::default(); Self::NUM_KERNELS],
82 total_pulls: 0,
83 exploration_c: 2.0,
84 use_thompson: true,
85 }
86 }
87
88 pub fn select(&self) -> KernelType {
90 let idx = if self.use_thompson { self.select_thompson() } else { self.select_ucb() };
91 KernelType::from_index(idx)
92 }
93
94 #[allow(clippy::cast_precision_loss)] fn select_ucb(&self) -> usize {
96 self.arms
97 .iter()
98 .enumerate()
99 .max_by(|(_, a), (_, b)| {
100 a.ucb(self.total_pulls, self.exploration_c)
101 .partial_cmp(&b.ucb(self.total_pulls, self.exploration_c))
102 .unwrap_or(std::cmp::Ordering::Equal)
103 })
104 .map(|(i, _)| i)
105 .unwrap_or(0)
106 }
107
108 #[allow(clippy::cast_precision_loss)]
109 fn select_thompson(&self) -> usize {
111 use std::collections::hash_map::DefaultHasher;
114 use std::hash::{Hash, Hasher};
115
116 let mut hasher = DefaultHasher::new();
118 self.total_pulls.hash(&mut hasher);
119 let seed = hasher.finish();
120
121 self.arms
122 .iter()
123 .enumerate()
124 .max_by(|(i, a), (j, b)| {
125 let sample_a =
126 a.mean() + 0.1 * ((seed.wrapping_add(*i as u64) % 1000) as f32 / 1000.0 - 0.5);
127 let sample_b =
128 b.mean() + 0.1 * ((seed.wrapping_add(*j as u64) % 1000) as f32 / 1000.0 - 0.5);
129 sample_a.partial_cmp(&sample_b).unwrap_or(std::cmp::Ordering::Equal)
130 })
131 .map(|(i, _)| i)
132 .unwrap_or(0)
133 }
134
135 pub fn update(&mut self, kernel: KernelType, reward: f32) {
137 contract_pre_update!();
138 let idx = kernel.to_index();
139 if idx < self.arms.len() {
140 self.arms[idx].pulls += 1;
141 self.arms[idx].total_reward += reward;
142 self.arms[idx].total_reward_sq += reward * reward;
143 self.total_pulls += 1;
144 }
145 }
146
147 pub fn best_kernel(&self) -> KernelType {
149 let idx = self
150 .arms
151 .iter()
152 .enumerate()
153 .max_by(|(_, a), (_, b)| {
154 a.mean().partial_cmp(&b.mean()).unwrap_or(std::cmp::Ordering::Equal)
155 })
156 .map(|(i, _)| i)
157 .unwrap_or(0);
158 KernelType::from_index(idx)
159 }
160
161 #[allow(clippy::cast_precision_loss)] pub fn exploration_rate(&self) -> f32 {
164 if self.total_pulls == 0 {
165 return 1.0;
166 }
167 let best_pulls = self.arms.iter().map(|a| a.pulls).max().unwrap_or(0);
168 1.0 - (best_pulls as f32 / self.total_pulls as f32)
169 }
170
171 #[allow(clippy::cast_precision_loss)] pub fn estimated_regret(&self) -> f32 {
174 let best_mean = self.arms.iter().map(|a| a.mean()).fold(0.0f32, f32::max);
175 self.arms.iter().map(|a| (best_mean - a.mean()) * a.pulls as f32).sum()
176 }
177}