Skip to main content

llm_optimizer_decision/
reinforcement_feedback.rs

1//! Reinforcement Feedback Engine
2//!
3//! This module provides the main API for reinforcement learning-based
4//! optimization using contextual bandits with user feedback.
5
6use std::sync::Arc;
7use dashmap::DashMap;
8use llm_optimizer_types::models::ModelConfig;
9use uuid::Uuid;
10
11use crate::{
12    context::RequestContext,
13    contextual_bandit::{ContextualThompson, LinUCB},
14    errors::{DecisionError, Result},
15    reward::{RewardCalculator, RewardWeights, ResponseMetrics, UserFeedback},
16};
17
18/// Algorithm type for contextual bandits
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum BanditAlgorithm {
21    /// Linear Upper Confidence Bound
22    LinUCB,
23    /// Contextual Thompson Sampling
24    ContextualThompson,
25}
26
27/// Reinforcement feedback engine
28pub struct ReinforcementEngine {
29    /// Algorithm to use
30    algorithm: BanditAlgorithm,
31    /// LinUCB instance (if using LinUCB)
32    linucb: Option<Arc<DashMap<String, LinUCB>>>,
33    /// Contextual Thompson instance (if using ContextualThompson)
34    contextual_thompson: Option<Arc<DashMap<String, ContextualThompson>>>,
35    /// Reward calculator
36    reward_calculator: RewardCalculator,
37    /// Variant configurations
38    variant_configs: Arc<DashMap<Uuid, ModelConfig>>,
39    /// Feature dimension
40    feature_dimension: usize,
41    /// LinUCB exploration parameter (alpha)
42    alpha: f64,
43}
44
45impl ReinforcementEngine {
46    /// Create a new reinforcement engine with LinUCB
47    pub fn with_linucb(alpha: f64, reward_weights: RewardWeights) -> Self {
48        let feature_dim = RequestContext::feature_dimension();
49
50        Self {
51            algorithm: BanditAlgorithm::LinUCB,
52            linucb: Some(Arc::new(DashMap::new())),
53            contextual_thompson: None,
54            reward_calculator: RewardCalculator::new(reward_weights, 1.0, 5000.0),
55            variant_configs: Arc::new(DashMap::new()),
56            feature_dimension: feature_dim,
57            alpha,
58        }
59    }
60
61    /// Create a new reinforcement engine with Contextual Thompson Sampling
62    pub fn with_contextual_thompson(reward_weights: RewardWeights) -> Self {
63        let feature_dim = RequestContext::feature_dimension();
64
65        Self {
66            algorithm: BanditAlgorithm::ContextualThompson,
67            linucb: None,
68            contextual_thompson: Some(Arc::new(DashMap::new())),
69            reward_calculator: RewardCalculator::new(reward_weights, 1.0, 5000.0),
70            variant_configs: Arc::new(DashMap::new()),
71            feature_dimension: feature_dim,
72            alpha: 0.0, // Not used for Thompson Sampling
73        }
74    }
75
76    /// Create a new experiment/policy
77    pub fn create_policy(
78        &self,
79        policy_name: impl Into<String>,
80        variants: Vec<(Uuid, ModelConfig)>,
81    ) -> Result<()> {
82        let name = policy_name.into();
83
84        // Store variant configurations
85        for (variant_id, config) in &variants {
86            self.variant_configs.insert(*variant_id, config.clone());
87        }
88
89        // Initialize bandit for this policy
90        match self.algorithm {
91            BanditAlgorithm::LinUCB => {
92                let mut bandit = LinUCB::new(self.alpha, self.feature_dimension);
93                for (variant_id, _) in variants {
94                    bandit.add_arm(variant_id);
95                }
96                self.linucb
97                    .as_ref()
98                    .ok_or_else(|| DecisionError::InvalidState("LinUCB not initialized".to_string()))?
99                    .insert(name, bandit);
100            }
101            BanditAlgorithm::ContextualThompson => {
102                let mut bandit = ContextualThompson::new(self.feature_dimension);
103                for (variant_id, _) in variants {
104                    bandit.add_arm(variant_id);
105                }
106                self.contextual_thompson
107                    .as_ref()
108                    .ok_or_else(|| {
109                        DecisionError::InvalidState("ContextualThompson not initialized".to_string())
110                    })?
111                    .insert(name, bandit);
112            }
113        }
114
115        Ok(())
116    }
117
118    /// Select best variant for given context
119    pub fn select_variant(
120        &self,
121        policy_name: &str,
122        context: &RequestContext,
123    ) -> Result<(Uuid, ModelConfig)> {
124        let variant_id = match self.algorithm {
125            BanditAlgorithm::LinUCB => {
126                let linucb = self
127                    .linucb
128                    .as_ref()
129                    .ok_or_else(|| DecisionError::InvalidState("LinUCB not initialized".to_string()))?;
130
131                let bandit = linucb
132                    .get(policy_name)
133                    .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
134
135                bandit.select_arm(context)?
136            }
137            BanditAlgorithm::ContextualThompson => {
138                let ct = self.contextual_thompson.as_ref().ok_or_else(|| {
139                    DecisionError::InvalidState("ContextualThompson not initialized".to_string())
140                })?;
141
142                let bandit = ct
143                    .get(policy_name)
144                    .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
145
146                bandit.select_arm(context)?
147            }
148        };
149
150        // Get variant configuration
151        let config = self
152            .variant_configs
153            .get(&variant_id)
154            .ok_or_else(|| DecisionError::VariantNotFound(variant_id.to_string()))?
155            .clone();
156
157        Ok((variant_id, config))
158    }
159
160    /// Update with observed reward from metrics only
161    pub fn update_from_metrics(
162        &self,
163        policy_name: &str,
164        variant_id: &Uuid,
165        context: &RequestContext,
166        metrics: &ResponseMetrics,
167    ) -> Result<()> {
168        let reward = self.reward_calculator.calculate_reward_metrics_only(metrics);
169        self.update_reward(policy_name, variant_id, context, reward)
170    }
171
172    /// Update with observed reward from metrics and user feedback
173    pub fn update_from_feedback(
174        &self,
175        policy_name: &str,
176        variant_id: &Uuid,
177        context: &RequestContext,
178        metrics: &ResponseMetrics,
179        feedback: &UserFeedback,
180    ) -> Result<()> {
181        let reward = self.reward_calculator.calculate_reward(metrics, feedback);
182        self.update_reward(policy_name, variant_id, context, reward)
183    }
184
185    /// Internal: update with computed reward
186    fn update_reward(
187        &self,
188        policy_name: &str,
189        variant_id: &Uuid,
190        context: &RequestContext,
191        reward: f64,
192    ) -> Result<()> {
193        match self.algorithm {
194            BanditAlgorithm::LinUCB => {
195                let linucb = self
196                    .linucb
197                    .as_ref()
198                    .ok_or_else(|| DecisionError::InvalidState("LinUCB not initialized".to_string()))?;
199
200                let mut bandit = linucb
201                    .get_mut(policy_name)
202                    .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
203
204                bandit.update(variant_id, context, reward)?;
205            }
206            BanditAlgorithm::ContextualThompson => {
207                let ct = self.contextual_thompson.as_ref().ok_or_else(|| {
208                    DecisionError::InvalidState("ContextualThompson not initialized".to_string())
209                })?;
210
211                let mut bandit = ct
212                    .get_mut(policy_name)
213                    .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
214
215                bandit.update(variant_id, context, reward)?;
216            }
217        }
218
219        Ok(())
220    }
221
222    /// Get current performance statistics for all variants in a policy
223    pub fn get_policy_stats(&self, policy_name: &str) -> Result<Vec<VariantStats>> {
224        match self.algorithm {
225            BanditAlgorithm::LinUCB => {
226                let linucb = self
227                    .linucb
228                    .as_ref()
229                    .ok_or_else(|| DecisionError::InvalidState("LinUCB not initialized".to_string()))?;
230
231                let bandit = linucb
232                    .get(policy_name)
233                    .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
234
235                let rewards = bandit.get_average_rewards();
236
237                Ok(rewards
238                    .iter()
239                    .map(|(id, reward)| {
240                        let arm = bandit.get_arm(id).unwrap();
241                        VariantStats {
242                            variant_id: *id,
243                            average_reward: *reward,
244                            num_selections: arm.num_selections,
245                            total_reward: arm.total_reward,
246                        }
247                    })
248                    .collect())
249            }
250            BanditAlgorithm::ContextualThompson => {
251                let ct = self.contextual_thompson.as_ref().ok_or_else(|| {
252                    DecisionError::InvalidState("ContextualThompson not initialized".to_string())
253                })?;
254
255                let bandit = ct
256                    .get(policy_name)
257                    .ok_or_else(|| DecisionError::ExperimentNotFound(policy_name.to_string()))?;
258
259                let rewards = bandit.get_average_rewards();
260
261                Ok(rewards
262                    .iter()
263                    .map(|(id, reward)| {
264                        let arm = bandit.get_arm(id).unwrap();
265                        VariantStats {
266                            variant_id: *id,
267                            average_reward: *reward,
268                            num_selections: arm.num_selections,
269                            total_reward: arm.total_reward,
270                        }
271                    })
272                    .collect())
273            }
274        }
275    }
276
277    /// Get algorithm type
278    pub fn algorithm(&self) -> BanditAlgorithm {
279        self.algorithm
280    }
281
282    /// Update reward calculator weights
283    pub fn set_reward_weights(&mut self, weights: RewardWeights) {
284        self.reward_calculator.set_weights(weights);
285    }
286}
287
288/// Variant statistics
289#[derive(Debug, Clone)]
290pub struct VariantStats {
291    pub variant_id: Uuid,
292    pub average_reward: f64,
293    pub num_selections: u64,
294    pub total_reward: f64,
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use crate::context::OutputLengthCategory;
301
302    fn test_variants() -> Vec<(Uuid, ModelConfig)> {
303        vec![
304            (Uuid::new_v4(), ModelConfig::default()),
305            (Uuid::new_v4(), ModelConfig::default()),
306        ]
307    }
308
309    #[test]
310    fn test_create_engine_linucb() {
311        let engine = ReinforcementEngine::with_linucb(1.0, RewardWeights::default_weights());
312        assert_eq!(engine.algorithm(), BanditAlgorithm::LinUCB);
313    }
314
315    #[test]
316    fn test_create_engine_contextual_thompson() {
317        let engine =
318            ReinforcementEngine::with_contextual_thompson(RewardWeights::default_weights());
319        assert_eq!(engine.algorithm(), BanditAlgorithm::ContextualThompson);
320    }
321
322    #[test]
323    fn test_create_policy_linucb() {
324        let engine = ReinforcementEngine::with_linucb(1.0, RewardWeights::default_weights());
325        let variants = test_variants();
326
327        engine.create_policy("test_policy", variants).unwrap();
328    }
329
330    #[test]
331    fn test_create_policy_contextual_thompson() {
332        let engine =
333            ReinforcementEngine::with_contextual_thompson(RewardWeights::default_weights());
334        let variants = test_variants();
335
336        engine.create_policy("test_policy", variants).unwrap();
337    }
338
339    #[test]
340    fn test_select_variant_linucb() {
341        let engine = ReinforcementEngine::with_linucb(1.0, RewardWeights::default_weights());
342        let variants = test_variants();
343
344        engine.create_policy("test_policy", variants).unwrap();
345
346        let context = RequestContext::new(100)
347            .with_task_type("generation")
348            .with_output_length(OutputLengthCategory::Medium);
349
350        let (variant_id, _config) = engine.select_variant("test_policy", &context).unwrap();
351        assert!(variant_id != Uuid::nil());
352    }
353
354    #[test]
355    fn test_select_variant_contextual_thompson() {
356        let engine =
357            ReinforcementEngine::with_contextual_thompson(RewardWeights::default_weights());
358        let variants = test_variants();
359
360        engine.create_policy("test_policy", variants).unwrap();
361
362        let context = RequestContext::new(100);
363        let (variant_id, _config) = engine.select_variant("test_policy", &context).unwrap();
364        assert!(variant_id != Uuid::nil());
365    }
366
367    #[test]
368    fn test_update_from_metrics_linucb() {
369        let engine = ReinforcementEngine::with_linucb(1.0, RewardWeights::default_weights());
370        let variants = test_variants();
371        let variant_id = variants[0].0;
372
373        engine
374            .create_policy("test_policy", variants.clone())
375            .unwrap();
376
377        let context = RequestContext::new(100);
378        let metrics = ResponseMetrics {
379            quality_score: 0.9,
380            cost: 0.1,
381            latency_ms: 1000.0,
382            token_count: 500,
383        };
384
385        engine
386            .update_from_metrics("test_policy", &variant_id, &context, &metrics)
387            .unwrap();
388    }
389
390    #[test]
391    fn test_update_from_feedback() {
392        let engine =
393            ReinforcementEngine::with_contextual_thompson(RewardWeights::default_weights());
394        let variants = test_variants();
395        let variant_id = variants[0].0;
396
397        engine.create_policy("test_policy", variants).unwrap();
398
399        let context = RequestContext::new(100);
400        let metrics = ResponseMetrics {
401            quality_score: 0.8,
402            cost: 0.2,
403            latency_ms: 1500.0,
404            token_count: 600,
405        };
406
407        let mut feedback = UserFeedback::new();
408        feedback.task_completed = true;
409        feedback.explicit_rating = Some(4.0);
410
411        engine
412            .update_from_feedback("test_policy", &variant_id, &context, &metrics, &feedback)
413            .unwrap();
414    }
415
416    #[test]
417    fn test_get_policy_stats() {
418        let engine = ReinforcementEngine::with_linucb(1.0, RewardWeights::default_weights());
419        let variants = test_variants();
420
421        engine.create_policy("test_policy", variants).unwrap();
422
423        let stats = engine.get_policy_stats("test_policy").unwrap();
424        assert_eq!(stats.len(), 2);
425    }
426
427    #[test]
428    fn test_learning_convergence() {
429        let engine = ReinforcementEngine::with_linucb(0.5, RewardWeights::default_weights());
430        let variants = test_variants();
431        let good_variant = variants[0].0;
432        let bad_variant = variants[1].0;
433
434        engine
435            .create_policy("test_policy", variants.clone())
436            .unwrap();
437
438        // Simulate: good variant gets high rewards, bad variant gets low rewards
439        for _ in 0..50 {
440            let context = RequestContext::new(100);
441
442            let good_metrics = ResponseMetrics {
443                quality_score: 0.9,
444                cost: 0.1,
445                latency_ms: 1000.0,
446                token_count: 500,
447            };
448
449            let bad_metrics = ResponseMetrics {
450                quality_score: 0.3,
451                cost: 0.5,
452                latency_ms: 3000.0,
453                token_count: 800,
454            };
455
456            engine
457                .update_from_metrics("test_policy", &good_variant, &context, &good_metrics)
458                .unwrap();
459            engine
460                .update_from_metrics("test_policy", &bad_variant, &context, &bad_metrics)
461                .unwrap();
462        }
463
464        let stats = engine.get_policy_stats("test_policy").unwrap();
465        let good_stats = stats.iter().find(|s| s.variant_id == good_variant).unwrap();
466        let bad_stats = stats.iter().find(|s| s.variant_id == bad_variant).unwrap();
467
468        // Good variant should have higher average reward
469        assert!(good_stats.average_reward > bad_stats.average_reward);
470    }
471
472    #[test]
473    fn test_set_reward_weights() {
474        let mut engine = ReinforcementEngine::with_linucb(1.0, RewardWeights::default_weights());
475
476        let new_weights = RewardWeights::cost_focused();
477        engine.set_reward_weights(new_weights);
478    }
479}