Skip to main content

trueno/tuner/evolution/
bandit.rs

1//! Bandit-based kernel selection (MLT-13)
2//!
3//! UCB1 and Thompson Sampling algorithms for kernel exploration vs exploitation.
4
5use super::super::types::KernelType;
6
7// ============================================================================
8// KernelArm
9// ============================================================================
10
11/// Bandit arm for kernel selection (MLT-13)
12#[derive(Debug, Clone, Default)]
13pub struct KernelArm {
14    /// Number of times this kernel was selected
15    pub pulls: u32,
16    /// Sum of rewards (normalized throughput)
17    pub total_reward: f32,
18    /// Sum of squared rewards (for variance estimation)
19    pub total_reward_sq: f32,
20}
21
22impl KernelArm {
23    /// Get mean reward
24    #[allow(clippy::cast_precision_loss)] // Pull counts << 2^24
25    pub fn mean(&self) -> f32 {
26        if self.pulls == 0 {
27            0.0
28        } else {
29            self.total_reward / self.pulls as f32
30        }
31    }
32
33    /// Get UCB score (Upper Confidence Bound)
34    #[allow(clippy::cast_precision_loss)] // Pull counts << 2^24
35    pub fn ucb(&self, total_pulls: u32, c: f32) -> f32 {
36        if self.pulls == 0 {
37            f32::INFINITY // Unexplored arms have infinite UCB
38        } else {
39            self.mean() + c * (2.0 * (total_pulls.max(1) as f32).ln() / self.pulls as f32).sqrt()
40        }
41    }
42}
43
44// ============================================================================
45// KernelBandit
46// ============================================================================
47
48/// Bandit-based kernel selector (MLT-13)
49///
50/// Uses UCB1 algorithm for exploration vs exploitation.
51/// Reference: Li et al. (2010) "A Contextual-Bandit Approach"
52#[derive(Debug, Clone, Default)]
53pub struct KernelBandit {
54    /// Arms for each kernel type
55    pub(crate) arms: Vec<KernelArm>,
56    /// Total number of pulls across all arms
57    pub(crate) total_pulls: u32,
58    /// Exploration parameter (higher = more exploration)
59    pub(crate) exploration_c: f32,
60    /// Whether to use Thompson Sampling (alternative to UCB)
61    pub(crate) use_thompson: bool,
62}
63
64impl KernelBandit {
65    /// Number of kernel types
66    pub const NUM_KERNELS: usize = 12;
67
68    /// Create a new bandit with default exploration
69    pub fn new() -> Self {
70        Self {
71            arms: vec![KernelArm::default(); Self::NUM_KERNELS],
72            total_pulls: 0,
73            exploration_c: 2.0, // sqrt(2) is theoretically optimal
74            use_thompson: false,
75        }
76    }
77
78    /// Create a bandit with Thompson Sampling
79    pub fn with_thompson_sampling() -> Self {
80        Self {
81            arms: vec![KernelArm::default(); Self::NUM_KERNELS],
82            total_pulls: 0,
83            exploration_c: 2.0,
84            use_thompson: true,
85        }
86    }
87
88    /// Select kernel using UCB1 or Thompson Sampling
89    pub fn select(&self) -> KernelType {
90        let idx = if self.use_thompson { self.select_thompson() } else { self.select_ucb() };
91        KernelType::from_index(idx)
92    }
93
94    #[allow(clippy::cast_precision_loss)] // Index/count values << 2^24; modulo bounds ensure safe range
95    fn select_ucb(&self) -> usize {
96        self.arms
97            .iter()
98            .enumerate()
99            .max_by(|(_, a), (_, b)| {
100                a.ucb(self.total_pulls, self.exploration_c)
101                    .partial_cmp(&b.ucb(self.total_pulls, self.exploration_c))
102                    .unwrap_or(std::cmp::Ordering::Equal)
103            })
104            .map(|(i, _)| i)
105            .unwrap_or(0)
106    }
107
108    #[allow(clippy::cast_precision_loss)]
109    // SAFETY: usize index <= 11 (NUM_KERNELS) and u64 modulo 1000 -- both lossless in f32.
110    fn select_thompson(&self) -> usize {
111        // Thompson Sampling with Beta distribution approximation
112        // For each arm, sample from Beta(successes+1, failures+1)
113        use std::collections::hash_map::DefaultHasher;
114        use std::hash::{Hash, Hasher};
115
116        // Simple pseudo-random based on current state
117        let mut hasher = DefaultHasher::new();
118        self.total_pulls.hash(&mut hasher);
119        let seed = hasher.finish();
120
121        self.arms
122            .iter()
123            .enumerate()
124            .max_by(|(i, a), (j, b)| {
125                let sample_a =
126                    a.mean() + 0.1 * ((seed.wrapping_add(*i as u64) % 1000) as f32 / 1000.0 - 0.5);
127                let sample_b =
128                    b.mean() + 0.1 * ((seed.wrapping_add(*j as u64) % 1000) as f32 / 1000.0 - 0.5);
129                sample_a.partial_cmp(&sample_b).unwrap_or(std::cmp::Ordering::Equal)
130            })
131            .map(|(i, _)| i)
132            .unwrap_or(0)
133    }
134
135    /// Update arm with observed reward
136    pub fn update(&mut self, kernel: KernelType, reward: f32) {
137        contract_pre_update!();
138        let idx = kernel.to_index();
139        if idx < self.arms.len() {
140            self.arms[idx].pulls += 1;
141            self.arms[idx].total_reward += reward;
142            self.arms[idx].total_reward_sq += reward * reward;
143            self.total_pulls += 1;
144        }
145    }
146
147    /// Get the best kernel based on empirical mean
148    pub fn best_kernel(&self) -> KernelType {
149        let idx = self
150            .arms
151            .iter()
152            .enumerate()
153            .max_by(|(_, a), (_, b)| {
154                a.mean().partial_cmp(&b.mean()).unwrap_or(std::cmp::Ordering::Equal)
155            })
156            .map(|(i, _)| i)
157            .unwrap_or(0);
158        KernelType::from_index(idx)
159    }
160
161    /// Get exploration rate (fraction of pulls that were exploratory)
162    #[allow(clippy::cast_precision_loss)] // Pull counts << 2^24
163    pub fn exploration_rate(&self) -> f32 {
164        if self.total_pulls == 0 {
165            return 1.0;
166        }
167        let best_pulls = self.arms.iter().map(|a| a.pulls).max().unwrap_or(0);
168        1.0 - (best_pulls as f32 / self.total_pulls as f32)
169    }
170
171    /// Get regret estimate (cumulative regret vs oracle)
172    #[allow(clippy::cast_precision_loss)] // Pull counts << 2^24
173    pub fn estimated_regret(&self) -> f32 {
174        let best_mean = self.arms.iter().map(|a| a.mean()).fold(0.0f32, f32::max);
175        self.arms.iter().map(|a| (best_mean - a.mean()) * a.pulls as f32).sum()
176    }
177}