Skip to main content

llm_optimizer_decision/
context.rs

1//! Context feature extraction for contextual bandits
2//!
3//! This module provides feature extraction from request context for
4//! context-aware model selection and parameter optimization.
5
6use chrono::{Datelike, Timelike};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Request context for contextual bandit decisions
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct RequestContext {
13    /// User ID (if available)
14    pub user_id: Option<String>,
15    /// Task type (classification, generation, extraction, etc.)
16    pub task_type: Option<String>,
17    /// Input length in tokens
18    pub input_length: usize,
19    /// Expected output length category
20    pub output_length_category: OutputLengthCategory,
21    /// Request priority (1-10)
22    pub priority: u8,
23    /// Geographic region
24    pub region: Option<String>,
25    /// Time of day (hour 0-23)
26    pub hour_of_day: u8,
27    /// Day of week (0-6, where 0 is Monday)
28    pub day_of_week: u8,
29    /// Language code (e.g., "en", "es", "fr")
30    pub language: Option<String>,
31    /// Custom metadata
32    pub metadata: HashMap<String, String>,
33}
34
35/// Expected output length category
36#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
37#[serde(rename_all = "snake_case")]
38pub enum OutputLengthCategory {
39    Short,   // < 100 tokens
40    Medium,  // 100-500 tokens
41    Long,    // > 500 tokens
42}
43
44impl RequestContext {
45    /// Create a new request context with defaults
46    pub fn new(input_length: usize) -> Self {
47        let now = chrono::Utc::now();
48
49        Self {
50            user_id: None,
51            task_type: None,
52            input_length,
53            output_length_category: OutputLengthCategory::Medium,
54            priority: 5,
55            region: None,
56            hour_of_day: now.hour() as u8,
57            day_of_week: now.weekday().num_days_from_monday() as u8,
58            language: None,
59            metadata: HashMap::new(),
60        }
61    }
62
63    /// Extract feature vector for machine learning
64    pub fn to_feature_vector(&self) -> Vec<f64> {
65        let mut features = Vec::with_capacity(10);
66
67        // Normalized input length (log scale)
68        features.push(((self.input_length as f64 + 1.0).ln() / 10.0).min(1.0));
69
70        // Output length category (one-hot encoding-ish)
71        features.push(match self.output_length_category {
72            OutputLengthCategory::Short => 0.0,
73            OutputLengthCategory::Medium => 0.5,
74            OutputLengthCategory::Long => 1.0,
75        });
76
77        // Normalized priority
78        features.push(self.priority as f64 / 10.0);
79
80        // Time features (cyclical encoding)
81        let hour_rad = (self.hour_of_day as f64 / 24.0) * 2.0 * std::f64::consts::PI;
82        features.push(hour_rad.cos());
83        features.push(hour_rad.sin());
84
85        let day_rad = (self.day_of_week as f64 / 7.0) * 2.0 * std::f64::consts::PI;
86        features.push(day_rad.cos());
87        features.push(day_rad.sin());
88
89        // Task type indicator (basic encoding)
90        let task_indicator = match self.task_type.as_deref() {
91            Some("classification") => 0.0,
92            Some("generation") => 0.33,
93            Some("extraction") => 0.67,
94            Some("summarization") => 1.0,
95            _ => 0.5, // unknown
96        };
97        features.push(task_indicator);
98
99        // Bias term (always 1.0)
100        features.push(1.0);
101
102        // Add padding to ensure consistent feature vector size
103        while features.len() < 10 {
104            features.push(0.0);
105        }
106
107        features
108    }
109
110    /// Get feature dimension
111    pub fn feature_dimension() -> usize {
112        10
113    }
114
115    /// Set task type
116    pub fn with_task_type(mut self, task_type: impl Into<String>) -> Self {
117        self.task_type = Some(task_type.into());
118        self
119    }
120
121    /// Set output length category
122    pub fn with_output_length(mut self, category: OutputLengthCategory) -> Self {
123        self.output_length_category = category;
124        self
125    }
126
127    /// Set priority
128    pub fn with_priority(mut self, priority: u8) -> Self {
129        self.priority = priority.min(10);
130        self
131    }
132
133    /// Set user ID
134    pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
135        self.user_id = Some(user_id.into());
136        self
137    }
138
139    /// Set language
140    pub fn with_language(mut self, language: impl Into<String>) -> Self {
141        self.language = Some(language.into());
142        self
143    }
144
145    /// Add custom metadata
146    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
147        self.metadata.insert(key.into(), value.into());
148        self
149    }
150}
151
152impl Default for RequestContext {
153    fn default() -> Self {
154        Self::new(0)
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn test_context_creation() {
164        let context = RequestContext::new(100);
165        assert_eq!(context.input_length, 100);
166        assert_eq!(context.priority, 5);
167    }
168
169    #[test]
170    fn test_feature_vector() {
171        let context = RequestContext::new(256)
172            .with_task_type("generation")
173            .with_output_length(OutputLengthCategory::Long)
174            .with_priority(8);
175
176        let features = context.to_feature_vector();
177        assert_eq!(features.len(), RequestContext::feature_dimension());
178
179        // Check all features are in valid range
180        for &f in &features {
181            assert!(f >= -1.0 && f <= 1.0 || f.is_finite());
182        }
183    }
184
185    #[test]
186    fn test_feature_dimension() {
187        assert_eq!(RequestContext::feature_dimension(), 10);
188    }
189
190    #[test]
191    fn test_cyclical_time_encoding() {
192        let context1 = RequestContext::new(100);
193        let mut context2 = context1.clone();
194        context2.hour_of_day = (context1.hour_of_day + 12) % 24;
195
196        let f1 = context1.to_feature_vector();
197        let f2 = context2.to_feature_vector();
198
199        // Hour features should be different
200        assert!((f1[3] - f2[3]).abs() > 0.1 || (f1[4] - f2[4]).abs() > 0.1);
201    }
202
203    #[test]
204    fn test_builder_pattern() {
205        let context = RequestContext::new(500)
206            .with_task_type("classification")
207            .with_priority(9)
208            .with_user_id("user123")
209            .with_language("en")
210            .with_metadata("custom_key", "custom_value");
211
212        assert_eq!(context.task_type, Some("classification".to_string()));
213        assert_eq!(context.priority, 9);
214        assert_eq!(context.user_id, Some("user123".to_string()));
215        assert_eq!(context.language, Some("en".to_string()));
216        assert_eq!(context.metadata.get("custom_key"), Some(&"custom_value".to_string()));
217    }
218
219    #[test]
220    fn test_output_length_categories() {
221        let short = RequestContext::new(10).with_output_length(OutputLengthCategory::Short);
222        let medium = RequestContext::new(10).with_output_length(OutputLengthCategory::Medium);
223        let long = RequestContext::new(10).with_output_length(OutputLengthCategory::Long);
224
225        assert_eq!(short.to_feature_vector()[1], 0.0);
226        assert_eq!(medium.to_feature_vector()[1], 0.5);
227        assert_eq!(long.to_feature_vector()[1], 1.0);
228    }
229}