1use std::sync::Arc;
7use dashmap::DashMap;
8use llm_optimizer_types::models::ModelConfig;
9use uuid::Uuid;
10
11use crate::{
12 context::RequestContext,
13 contextual_bandit::{ContextualThompson, LinUCB},
14 errors::{DecisionError, Result},
15 reward::{RewardCalculator, RewardWeights, ResponseMetrics, UserFeedback},
16};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum BanditAlgorithm {
21 LinUCB,
23 ContextualThompson,
25}
26
27pub struct ReinforcementEngine {
29 algorithm: BanditAlgorithm,
31 linucb: Option<Arc<DashMap<String, LinUCB>>>,
33 contextual_thompson: Option<Arc<DashMap<String, ContextualThompson>>>,
35 reward_calculator: RewardCalculator,
37 variant_configs: Arc<DashMap<Uuid, ModelConfig>>,
39 feature_dimension: usize,
41 alpha: f64,
43}
44
45impl ReinforcementEngine {
46 pub fn with_linucb(alpha: f64, reward_weights: RewardWeights) -> Self {
48 let feature_dim = RequestContext::feature_dimension();
49
50 Self {
51 algorithm: BanditAlgorithm::LinUCB,
52 linucb: Some(Arc::new(DashMap::new())),
53 contextual_thompson: None,
54 reward_calculator: RewardCalculator::new(reward_weights, 1.0, 5000.0),
55 variant_configs: Arc::new(DashMap::new()),
56 feature_dimension: feature_dim,
57 alpha,
58 }
59 }
60
61 pub fn with_contextual_thompson(reward_weights: RewardWeights) -> Self {
63 let feature_dim = RequestContext::feature_dimension();
64
65 Self {
66 algorithm: BanditAlgorithm::ContextualThompson,
67 linucb: None,
68 contextual_thompson: Some(Arc::new(DashMap::new())),
69 reward_calculator: RewardCalculator::new(reward_weights, 1.0, 5000.0),
70 variant_configs: Arc::new(DashMap::new()),
71 feature_dimension: feature_dim,
72 alpha: 0.0, }
74 }
75
76 pub fn create_policy(
78 &self,
79 policy_name: impl Into<String>,
80 variants: Vec<(Uuid, ModelConfig)>,
81 ) -> Result<()> {
82 let name = policy_name.into();
83
84 for (variant_id, config) in &variants {
86 self.variant_configs.insert(*variant_id, config.clone());
87 }
88
89 match self.algorithm {
91 BanditAlgorithm::LinUCB => {
92 let mut bandit = LinUCB::new(self.alpha, self.feature_dimension);
93 for (variant_id, _) in variants {
94 bandit.add_arm(variant_id);
95 }
96 self.linucb
97 .as_ref()
98 .ok_or_else(|| DecisionError::InvalidState("LinUCB not initialized".to_string()))?
99 .insert(name, bandit);
100 }
101 BanditAlgorithm::ContextualThompson => {
102 let mut bandit = ContextualThompson::new(self.feature_dimension);
103 for (variant_id, _) in variants {
104 bandit.add_arm(variant_id);
105 }
106 self.contextual_thompson
107 .as_ref()
108 .ok_or_else(|| {
109 DecisionError::InvalidState("ContextualThompson not initialized".to_string())
110 })?
111 .insert(name, bandit);
112 }
113 }
114
115 Ok(())
116 }
117
118 pub fn select_variant(
120 &self,
121 policy_name: &str,
122 context: &RequestContext,
123 ) -> Result<(Uuid, ModelConfig)> {
124 let variant_id = match self.algorithm {
125 BanditAlgorithm::LinUCB => {
126 let linucb = self
127 .linucb
128 .as_ref()
129 .ok_or_else(|| DecisionError::InvalidState("LinUCB not initialized".to_string()))?;
130
131 let bandit = linucb
132 .get(policy_name)
133 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
134
135 bandit.select_arm(context)?
136 }
137 BanditAlgorithm::ContextualThompson => {
138 let ct = self.contextual_thompson.as_ref().ok_or_else(|| {
139 DecisionError::InvalidState("ContextualThompson not initialized".to_string())
140 })?;
141
142 let bandit = ct
143 .get(policy_name)
144 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
145
146 bandit.select_arm(context)?
147 }
148 };
149
150 let config = self
152 .variant_configs
153 .get(&variant_id)
154 .ok_or_else(|| DecisionError::VariantNotFound(variant_id.to_string()))?
155 .clone();
156
157 Ok((variant_id, config))
158 }
159
160 pub fn update_from_metrics(
162 &self,
163 policy_name: &str,
164 variant_id: &Uuid,
165 context: &RequestContext,
166 metrics: &ResponseMetrics,
167 ) -> Result<()> {
168 let reward = self.reward_calculator.calculate_reward_metrics_only(metrics);
169 self.update_reward(policy_name, variant_id, context, reward)
170 }
171
172 pub fn update_from_feedback(
174 &self,
175 policy_name: &str,
176 variant_id: &Uuid,
177 context: &RequestContext,
178 metrics: &ResponseMetrics,
179 feedback: &UserFeedback,
180 ) -> Result<()> {
181 let reward = self.reward_calculator.calculate_reward(metrics, feedback);
182 self.update_reward(policy_name, variant_id, context, reward)
183 }
184
185 fn update_reward(
187 &self,
188 policy_name: &str,
189 variant_id: &Uuid,
190 context: &RequestContext,
191 reward: f64,
192 ) -> Result<()> {
193 match self.algorithm {
194 BanditAlgorithm::LinUCB => {
195 let linucb = self
196 .linucb
197 .as_ref()
198 .ok_or_else(|| DecisionError::InvalidState("LinUCB not initialized".to_string()))?;
199
200 let mut bandit = linucb
201 .get_mut(policy_name)
202 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
203
204 bandit.update(variant_id, context, reward)?;
205 }
206 BanditAlgorithm::ContextualThompson => {
207 let ct = self.contextual_thompson.as_ref().ok_or_else(|| {
208 DecisionError::InvalidState("ContextualThompson not initialized".to_string())
209 })?;
210
211 let mut bandit = ct
212 .get_mut(policy_name)
213 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
214
215 bandit.update(variant_id, context, reward)?;
216 }
217 }
218
219 Ok(())
220 }
221
222 pub fn get_policy_stats(&self, policy_name: &str) -> Result<Vec<VariantStats>> {
224 match self.algorithm {
225 BanditAlgorithm::LinUCB => {
226 let linucb = self
227 .linucb
228 .as_ref()
229 .ok_or_else(|| DecisionError::InvalidState("LinUCB not initialized".to_string()))?;
230
231 let bandit = linucb
232 .get(policy_name)
233 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
234
235 let rewards = bandit.get_average_rewards();
236
237 Ok(rewards
238 .iter()
239 .map(|(id, reward)| {
240 let arm = bandit.get_arm(id).unwrap();
241 VariantStats {
242 variant_id: *id,
243 average_reward: *reward,
244 num_selections: arm.num_selections,
245 total_reward: arm.total_reward,
246 }
247 })
248 .collect())
249 }
250 BanditAlgorithm::ContextualThompson => {
251 let ct = self.contextual_thompson.as_ref().ok_or_else(|| {
252 DecisionError::InvalidState("ContextualThompson not initialized".to_string())
253 })?;
254
255 let bandit = ct
256 .get(policy_name)
257 .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
258
259 let rewards = bandit.get_average_rewards();
260
261 Ok(rewards
262 .iter()
263 .map(|(id, reward)| {
264 let arm = bandit.get_arm(id).unwrap();
265 VariantStats {
266 variant_id: *id,
267 average_reward: *reward,
268 num_selections: arm.num_selections,
269 total_reward: arm.total_reward,
270 }
271 })
272 .collect())
273 }
274 }
275 }
276
277 pub fn algorithm(&self) -> BanditAlgorithm {
279 self.algorithm
280 }
281
282 pub fn set_reward_weights(&mut self, weights: RewardWeights) {
284 self.reward_calculator.set_weights(weights);
285 }
286}
287
288#[derive(Debug, Clone)]
290pub struct VariantStats {
291 pub variant_id: Uuid,
292 pub average_reward: f64,
293 pub num_selections: u64,
294 pub total_reward: f64,
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use crate::context::OutputLengthCategory;
301
302 fn test_variants() -> Vec<(Uuid, ModelConfig)> {
303 vec![
304 (Uuid::new_v4(), ModelConfig::default()),
305 (Uuid::new_v4(), ModelConfig::default()),
306 ]
307 }
308
309 #[test]
310 fn test_create_engine_linucb() {
311 let engine = ReinforcementEngine::with_linucb(1.0, RewardWeights::default_weights());
312 assert_eq!(engine.algorithm(), BanditAlgorithm::LinUCB);
313 }
314
315 #[test]
316 fn test_create_engine_contextual_thompson() {
317 let engine =
318 ReinforcementEngine::with_contextual_thompson(RewardWeights::default_weights());
319 assert_eq!(engine.algorithm(), BanditAlgorithm::ContextualThompson);
320 }
321
322 #[test]
323 fn test_create_policy_linucb() {
324 let engine = ReinforcementEngine::with_linucb(1.0, RewardWeights::default_weights());
325 let variants = test_variants();
326
327 engine.create_policy("test_policy", variants).unwrap();
328 }
329
330 #[test]
331 fn test_create_policy_contextual_thompson() {
332 let engine =
333 ReinforcementEngine::with_contextual_thompson(RewardWeights::default_weights());
334 let variants = test_variants();
335
336 engine.create_policy("test_policy", variants).unwrap();
337 }
338
339 #[test]
340 fn test_select_variant_linucb() {
341 let engine = ReinforcementEngine::with_linucb(1.0, RewardWeights::default_weights());
342 let variants = test_variants();
343
344 engine.create_policy("test_policy", variants).unwrap();
345
346 let context = RequestContext::new(100)
347 .with_task_type("generation")
348 .with_output_length(OutputLengthCategory::Medium);
349
350 let (variant_id, _config) = engine.select_variant("test_policy", &context).unwrap();
351 assert!(variant_id != Uuid::nil());
352 }
353
354 #[test]
355 fn test_select_variant_contextual_thompson() {
356 let engine =
357 ReinforcementEngine::with_contextual_thompson(RewardWeights::default_weights());
358 let variants = test_variants();
359
360 engine.create_policy("test_policy", variants).unwrap();
361
362 let context = RequestContext::new(100);
363 let (variant_id, _config) = engine.select_variant("test_policy", &context).unwrap();
364 assert!(variant_id != Uuid::nil());
365 }
366
367 #[test]
368 fn test_update_from_metrics_linucb() {
369 let engine = ReinforcementEngine::with_linucb(1.0, RewardWeights::default_weights());
370 let variants = test_variants();
371 let variant_id = variants[0].0;
372
373 engine
374 .create_policy("test_policy", variants.clone())
375 .unwrap();
376
377 let context = RequestContext::new(100);
378 let metrics = ResponseMetrics {
379 quality_score: 0.9,
380 cost: 0.1,
381 latency_ms: 1000.0,
382 token_count: 500,
383 };
384
385 engine
386 .update_from_metrics("test_policy", &variant_id, &context, &metrics)
387 .unwrap();
388 }
389
390 #[test]
391 fn test_update_from_feedback() {
392 let engine =
393 ReinforcementEngine::with_contextual_thompson(RewardWeights::default_weights());
394 let variants = test_variants();
395 let variant_id = variants[0].0;
396
397 engine.create_policy("test_policy", variants).unwrap();
398
399 let context = RequestContext::new(100);
400 let metrics = ResponseMetrics {
401 quality_score: 0.8,
402 cost: 0.2,
403 latency_ms: 1500.0,
404 token_count: 600,
405 };
406
407 let mut feedback = UserFeedback::new();
408 feedback.task_completed = true;
409 feedback.explicit_rating = Some(4.0);
410
411 engine
412 .update_from_feedback("test_policy", &variant_id, &context, &metrics, &feedback)
413 .unwrap();
414 }
415
416 #[test]
417 fn test_get_policy_stats() {
418 let engine = ReinforcementEngine::with_linucb(1.0, RewardWeights::default_weights());
419 let variants = test_variants();
420
421 engine.create_policy("test_policy", variants).unwrap();
422
423 let stats = engine.get_policy_stats("test_policy").unwrap();
424 assert_eq!(stats.len(), 2);
425 }
426
427 #[test]
428 fn test_learning_convergence() {
429 let engine = ReinforcementEngine::with_linucb(0.5, RewardWeights::default_weights());
430 let variants = test_variants();
431 let good_variant = variants[0].0;
432 let bad_variant = variants[1].0;
433
434 engine
435 .create_policy("test_policy", variants.clone())
436 .unwrap();
437
438 for _ in 0..50 {
440 let context = RequestContext::new(100);
441
442 let good_metrics = ResponseMetrics {
443 quality_score: 0.9,
444 cost: 0.1,
445 latency_ms: 1000.0,
446 token_count: 500,
447 };
448
449 let bad_metrics = ResponseMetrics {
450 quality_score: 0.3,
451 cost: 0.5,
452 latency_ms: 3000.0,
453 token_count: 800,
454 };
455
456 engine
457 .update_from_metrics("test_policy", &good_variant, &context, &good_metrics)
458 .unwrap();
459 engine
460 .update_from_metrics("test_policy", &bad_variant, &context, &bad_metrics)
461 .unwrap();
462 }
463
464 let stats = engine.get_policy_stats("test_policy").unwrap();
465 let good_stats = stats.iter().find(|s| s.variant_id == good_variant).unwrap();
466 let bad_stats = stats.iter().find(|s| s.variant_id == bad_variant).unwrap();
467
468 assert!(good_stats.average_reward > bad_stats.average_reward);
470 }
471
472 #[test]
473 fn test_set_reward_weights() {
474 let mut engine = ReinforcementEngine::with_linucb(1.0, RewardWeights::default_weights());
475
476 let new_weights = RewardWeights::cost_focused();
477 engine.set_reward_weights(new_weights);
478 }
479}