use chrono::{Datelike, Timelike};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestContext {
pub user_id: Option<String>,
pub task_type: Option<String>,
pub input_length: usize,
pub output_length_category: OutputLengthCategory,
pub priority: u8,
pub region: Option<String>,
pub hour_of_day: u8,
pub day_of_week: u8,
pub language: Option<String>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum OutputLengthCategory {
Short, Medium, Long, }
impl RequestContext {
pub fn new(input_length: usize) -> Self {
let now = chrono::Utc::now();
Self {
user_id: None,
task_type: None,
input_length,
output_length_category: OutputLengthCategory::Medium,
priority: 5,
region: None,
hour_of_day: now.hour() as u8,
day_of_week: now.weekday().num_days_from_monday() as u8,
language: None,
metadata: HashMap::new(),
}
}
pub fn to_feature_vector(&self) -> Vec<f64> {
let mut features = Vec::with_capacity(10);
features.push(((self.input_length as f64 + 1.0).ln() / 10.0).min(1.0));
features.push(match self.output_length_category {
OutputLengthCategory::Short => 0.0,
OutputLengthCategory::Medium => 0.5,
OutputLengthCategory::Long => 1.0,
});
features.push(self.priority as f64 / 10.0);
let hour_rad = (self.hour_of_day as f64 / 24.0) * 2.0 * std::f64::consts::PI;
features.push(hour_rad.cos());
features.push(hour_rad.sin());
let day_rad = (self.day_of_week as f64 / 7.0) * 2.0 * std::f64::consts::PI;
features.push(day_rad.cos());
features.push(day_rad.sin());
let task_indicator = match self.task_type.as_deref() {
Some("classification") => 0.0,
Some("generation") => 0.33,
Some("extraction") => 0.67,
Some("summarization") => 1.0,
_ => 0.5, };
features.push(task_indicator);
features.push(1.0);
while features.len() < 10 {
features.push(0.0);
}
features
}
pub fn feature_dimension() -> usize {
10
}
pub fn with_task_type(mut self, task_type: impl Into<String>) -> Self {
self.task_type = Some(task_type.into());
self
}
pub fn with_output_length(mut self, category: OutputLengthCategory) -> Self {
self.output_length_category = category;
self
}
pub fn with_priority(mut self, priority: u8) -> Self {
self.priority = priority.min(10);
self
}
pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
self.user_id = Some(user_id.into());
self
}
pub fn with_language(mut self, language: impl Into<String>) -> Self {
self.language = Some(language.into());
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
}
impl Default for RequestContext {
fn default() -> Self {
Self::new(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_creation() {
let context = RequestContext::new(100);
assert_eq!(context.input_length, 100);
assert_eq!(context.priority, 5);
}
#[test]
fn test_feature_vector() {
let context = RequestContext::new(256)
.with_task_type("generation")
.with_output_length(OutputLengthCategory::Long)
.with_priority(8);
let features = context.to_feature_vector();
assert_eq!(features.len(), RequestContext::feature_dimension());
for &f in &features {
assert!(f >= -1.0 && f <= 1.0 || f.is_finite());
}
}
#[test]
fn test_feature_dimension() {
assert_eq!(RequestContext::feature_dimension(), 10);
}
#[test]
fn test_cyclical_time_encoding() {
let context1 = RequestContext::new(100);
let mut context2 = context1.clone();
context2.hour_of_day = (context1.hour_of_day + 12) % 24;
let f1 = context1.to_feature_vector();
let f2 = context2.to_feature_vector();
assert!((f1[3] - f2[3]).abs() > 0.1 || (f1[4] - f2[4]).abs() > 0.1);
}
#[test]
fn test_builder_pattern() {
let context = RequestContext::new(500)
.with_task_type("classification")
.with_priority(9)
.with_user_id("user123")
.with_language("en")
.with_metadata("custom_key", "custom_value");
assert_eq!(context.task_type, Some("classification".to_string()));
assert_eq!(context.priority, 9);
assert_eq!(context.user_id, Some("user123".to_string()));
assert_eq!(context.language, Some("en".to_string()));
assert_eq!(context.metadata.get("custom_key"), Some(&"custom_value".to_string()));
}
#[test]
fn test_output_length_categories() {
let short = RequestContext::new(10).with_output_length(OutputLengthCategory::Short);
let medium = RequestContext::new(10).with_output_length(OutputLengthCategory::Medium);
let long = RequestContext::new(10).with_output_length(OutputLengthCategory::Long);
assert_eq!(short.to_feature_vector()[1], 0.0);
assert_eq!(medium.to_feature_vector()[1], 0.5);
assert_eq!(long.to_feature_vector()[1], 1.0);
}
}