llm_optimizer_decision/
context.rs1use chrono::{Datelike, Timelike};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct RequestContext {
13 pub user_id: Option<String>,
15 pub task_type: Option<String>,
17 pub input_length: usize,
19 pub output_length_category: OutputLengthCategory,
21 pub priority: u8,
23 pub region: Option<String>,
25 pub hour_of_day: u8,
27 pub day_of_week: u8,
29 pub language: Option<String>,
31 pub metadata: HashMap<String, String>,
33}
34
35#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
37#[serde(rename_all = "snake_case")]
38pub enum OutputLengthCategory {
39 Short, Medium, Long, }
43
44impl RequestContext {
45 pub fn new(input_length: usize) -> Self {
47 let now = chrono::Utc::now();
48
49 Self {
50 user_id: None,
51 task_type: None,
52 input_length,
53 output_length_category: OutputLengthCategory::Medium,
54 priority: 5,
55 region: None,
56 hour_of_day: now.hour() as u8,
57 day_of_week: now.weekday().num_days_from_monday() as u8,
58 language: None,
59 metadata: HashMap::new(),
60 }
61 }
62
63 pub fn to_feature_vector(&self) -> Vec<f64> {
65 let mut features = Vec::with_capacity(10);
66
67 features.push(((self.input_length as f64 + 1.0).ln() / 10.0).min(1.0));
69
70 features.push(match self.output_length_category {
72 OutputLengthCategory::Short => 0.0,
73 OutputLengthCategory::Medium => 0.5,
74 OutputLengthCategory::Long => 1.0,
75 });
76
77 features.push(self.priority as f64 / 10.0);
79
80 let hour_rad = (self.hour_of_day as f64 / 24.0) * 2.0 * std::f64::consts::PI;
82 features.push(hour_rad.cos());
83 features.push(hour_rad.sin());
84
85 let day_rad = (self.day_of_week as f64 / 7.0) * 2.0 * std::f64::consts::PI;
86 features.push(day_rad.cos());
87 features.push(day_rad.sin());
88
89 let task_indicator = match self.task_type.as_deref() {
91 Some("classification") => 0.0,
92 Some("generation") => 0.33,
93 Some("extraction") => 0.67,
94 Some("summarization") => 1.0,
95 _ => 0.5, };
97 features.push(task_indicator);
98
99 features.push(1.0);
101
102 while features.len() < 10 {
104 features.push(0.0);
105 }
106
107 features
108 }
109
110 pub fn feature_dimension() -> usize {
112 10
113 }
114
115 pub fn with_task_type(mut self, task_type: impl Into<String>) -> Self {
117 self.task_type = Some(task_type.into());
118 self
119 }
120
121 pub fn with_output_length(mut self, category: OutputLengthCategory) -> Self {
123 self.output_length_category = category;
124 self
125 }
126
127 pub fn with_priority(mut self, priority: u8) -> Self {
129 self.priority = priority.min(10);
130 self
131 }
132
133 pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
135 self.user_id = Some(user_id.into());
136 self
137 }
138
139 pub fn with_language(mut self, language: impl Into<String>) -> Self {
141 self.language = Some(language.into());
142 self
143 }
144
145 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
147 self.metadata.insert(key.into(), value.into());
148 self
149 }
150}
151
152impl Default for RequestContext {
153 fn default() -> Self {
154 Self::new(0)
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn test_context_creation() {
164 let context = RequestContext::new(100);
165 assert_eq!(context.input_length, 100);
166 assert_eq!(context.priority, 5);
167 }
168
169 #[test]
170 fn test_feature_vector() {
171 let context = RequestContext::new(256)
172 .with_task_type("generation")
173 .with_output_length(OutputLengthCategory::Long)
174 .with_priority(8);
175
176 let features = context.to_feature_vector();
177 assert_eq!(features.len(), RequestContext::feature_dimension());
178
179 for &f in &features {
181 assert!(f >= -1.0 && f <= 1.0 || f.is_finite());
182 }
183 }
184
185 #[test]
186 fn test_feature_dimension() {
187 assert_eq!(RequestContext::feature_dimension(), 10);
188 }
189
190 #[test]
191 fn test_cyclical_time_encoding() {
192 let context1 = RequestContext::new(100);
193 let mut context2 = context1.clone();
194 context2.hour_of_day = (context1.hour_of_day + 12) % 24;
195
196 let f1 = context1.to_feature_vector();
197 let f2 = context2.to_feature_vector();
198
199 assert!((f1[3] - f2[3]).abs() > 0.1 || (f1[4] - f2[4]).abs() > 0.1);
201 }
202
203 #[test]
204 fn test_builder_pattern() {
205 let context = RequestContext::new(500)
206 .with_task_type("classification")
207 .with_priority(9)
208 .with_user_id("user123")
209 .with_language("en")
210 .with_metadata("custom_key", "custom_value");
211
212 assert_eq!(context.task_type, Some("classification".to_string()));
213 assert_eq!(context.priority, 9);
214 assert_eq!(context.user_id, Some("user123".to_string()));
215 assert_eq!(context.language, Some("en".to_string()));
216 assert_eq!(context.metadata.get("custom_key"), Some(&"custom_value".to_string()));
217 }
218
219 #[test]
220 fn test_output_length_categories() {
221 let short = RequestContext::new(10).with_output_length(OutputLengthCategory::Short);
222 let medium = RequestContext::new(10).with_output_length(OutputLengthCategory::Medium);
223 let long = RequestContext::new(10).with_output_length(OutputLengthCategory::Long);
224
225 assert_eq!(short.to_feature_vector()[1], 0.0);
226 assert_eq!(medium.to_feature_vector()[1], 0.5);
227 assert_eq!(long.to_feature_vector()[1], 1.0);
228 }
229}