#[cfg(feature = "alloc")]
use alloc::{
format,
string::{String, ToString},
vec::Vec,
};
use serde::{Deserialize, Serialize};
use super::FeatureVector;
use super::optimizer::{OptimizerState, OptimizerType};
use super::schedule::{EarlyStoppingConfig, EarlyStoppingState, LearningRateSchedule};
use crate::core::error::Result;
pub trait Model: Send + Sync {
fn predict(
&self,
features: &FeatureVector,
source_ids: &[&String],
) -> Result<Vec<(String, f32)>>;
fn name(&self) -> &str;
fn feature_dim(&self) -> usize;
fn train(&mut self, samples: &[TrainingSample]) -> Result<()>;
fn update(&mut self, features: &FeatureVector, source_id: &str, reward: f32) -> Result<()>;
fn to_bytes(&self) -> Vec<u8>;
fn model_type(&self) -> &'static str;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingSample {
pub features: FeatureVector,
pub selected_source: String,
pub success: bool,
pub latency_ms: u32,
pub result_count: u32,
}
impl TrainingSample {
#[must_use]
pub fn new(
features: FeatureVector,
selected_source: impl Into<String>,
success: bool,
latency_ms: u32,
result_count: u32,
) -> Self {
Self {
features,
selected_source: selected_source.into(),
success,
latency_ms,
result_count,
}
}
#[must_use]
pub fn reward(&self) -> f32 {
if !self.success {
return 0.0;
}
let mut reward = 1.0;
let latency_penalty = (self.latency_ms as f32 / 10000.0).min(1.0);
reward *= 1.0 - (latency_penalty * 0.5);
if self.result_count > 0 {
reward *= 1.0 + (self.result_count.min(1000) as f32 / 10000.0);
}
reward.clamp(0.0, 1.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub model_type: ModelType,
pub feature_dim: usize,
pub num_classes: usize,
pub learning_rate: f32,
pub regularization: f32,
}
impl Default for ModelConfig {
fn default() -> Self {
Self {
model_type: ModelType::NaiveBayes,
feature_dim: 48,
num_classes: 10,
learning_rate: 0.01,
regularization: 0.001,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelType {
NaiveBayes,
NeuralNetwork,
Ensemble,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelState {
pub config: ModelConfig,
pub weights: Vec<f32>,
pub source_ids: Vec<String>,
pub iterations: u64,
pub extra_params: Vec<f32>,
pub layer_dims: Vec<(usize, usize)>,
pub activation_types: Vec<u8>,
#[serde(default)]
pub optimizer_type: Option<OptimizerType>,
#[serde(default)]
pub optimizer_state: Option<OptimizerState>,
#[serde(default)]
pub lr_schedule: Option<LearningRateSchedule>,
#[serde(default)]
pub epoch: u64,
#[serde(default)]
pub early_stopping_config: Option<EarlyStoppingConfig>,
#[serde(default)]
pub early_stopping_state: Option<EarlyStoppingState>,
}
pub trait ModelPersistence: Model {
fn to_state(&self) -> ModelState;
fn from_state(state: ModelState) -> Result<Self>
where
Self: Sized;
fn to_bytes(&self) -> Vec<u8> {
self.to_state().to_bytes()
}
fn from_bytes(bytes: &[u8]) -> Result<Self>
where
Self: Sized,
{
let state = ModelState::from_bytes(bytes)?;
Self::from_state(state)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
struct V3Extension {
#[serde(default)]
optimizer_type: Option<OptimizerType>,
#[serde(default)]
optimizer_state: Option<OptimizerState>,
#[serde(default)]
lr_schedule: Option<LearningRateSchedule>,
#[serde(default)]
epoch: u64,
#[serde(default)]
early_stopping_config: Option<EarlyStoppingConfig>,
#[serde(default)]
early_stopping_state: Option<EarlyStoppingState>,
}
impl ModelState {
#[must_use]
pub fn new(config: ModelConfig, source_ids: Vec<String>) -> Self {
Self {
config,
weights: Vec::new(),
source_ids,
iterations: 0,
extra_params: Vec::new(),
layer_dims: Vec::new(),
activation_types: Vec::new(),
optimizer_type: None,
optimizer_state: None,
lr_schedule: None,
epoch: 0,
early_stopping_config: None,
early_stopping_state: None,
}
}
#[must_use]
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::new();
bytes.extend_from_slice(&3u32.to_le_bytes());
let model_type_byte: u8 = match self.config.model_type {
ModelType::NaiveBayes => 0,
ModelType::NeuralNetwork => 1,
ModelType::Ensemble => 2,
};
bytes.push(model_type_byte);
bytes.extend_from_slice(&(self.config.feature_dim as u32).to_le_bytes());
bytes.extend_from_slice(&(self.config.num_classes as u32).to_le_bytes());
bytes.extend_from_slice(&self.config.learning_rate.to_le_bytes());
bytes.extend_from_slice(&self.config.regularization.to_le_bytes());
bytes.extend_from_slice(&(self.weights.len() as u32).to_le_bytes());
for &w in &self.weights {
bytes.extend_from_slice(&w.to_le_bytes());
}
bytes.extend_from_slice(&(self.extra_params.len() as u32).to_le_bytes());
for &p in &self.extra_params {
bytes.extend_from_slice(&p.to_le_bytes());
}
bytes.extend_from_slice(&(self.layer_dims.len() as u32).to_le_bytes());
for &(input_dim, output_dim) in &self.layer_dims {
bytes.extend_from_slice(&(input_dim as u32).to_le_bytes());
bytes.extend_from_slice(&(output_dim as u32).to_le_bytes());
}
bytes.extend_from_slice(&(self.activation_types.len() as u32).to_le_bytes());
bytes.extend_from_slice(&self.activation_types);
bytes.extend_from_slice(&(self.source_ids.len() as u32).to_le_bytes());
for id in &self.source_ids {
let id_bytes = id.as_bytes();
bytes.extend_from_slice(&(id_bytes.len() as u32).to_le_bytes());
bytes.extend_from_slice(id_bytes);
}
bytes.extend_from_slice(&self.iterations.to_le_bytes());
let has_extension = self.optimizer_type.is_some()
|| self.optimizer_state.is_some()
|| self.lr_schedule.is_some()
|| self.early_stopping_config.is_some()
|| self.early_stopping_state.is_some()
|| self.epoch > 0;
if has_extension {
let ext = V3Extension {
optimizer_type: self.optimizer_type.clone(),
optimizer_state: self.optimizer_state.clone(),
lr_schedule: self.lr_schedule.clone(),
epoch: self.epoch,
early_stopping_config: self.early_stopping_config.clone(),
early_stopping_state: self.early_stopping_state.clone(),
};
match serde_json::to_vec(&ext) {
Ok(json_bytes) => {
bytes.extend_from_slice(&(json_bytes.len() as u32).to_le_bytes());
bytes.extend_from_slice(&json_bytes);
}
Err(_) => {
bytes.extend_from_slice(&0u32.to_le_bytes());
}
}
} else {
bytes.extend_from_slice(&0u32.to_le_bytes());
}
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
use crate::core::error::OxiRouterError;
if bytes.len() < 4 {
return Err(OxiRouterError::ModelError(
"Invalid model state: too short".to_string(),
));
}
let version = u32::from_le_bytes(
bytes[0..4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid version bytes".to_string()))?,
);
match version {
1 => Self::from_bytes_v1(bytes),
2 => Self::from_bytes_v2(bytes),
3 => Self::from_bytes_v3(bytes),
_ => Err(OxiRouterError::ModelError(format!(
"Unsupported model state version: {}",
version
))),
}
}
fn from_bytes_v3(bytes: &[u8]) -> Result<Self> {
use crate::core::error::OxiRouterError;
let mut state = Self::from_bytes_v2_core(bytes)?;
let v2_end = Self::v2_end_position(bytes)?;
if v2_end + 4 <= bytes.len() {
let blob_len =
u32::from_le_bytes(bytes[v2_end..v2_end + 4].try_into().map_err(|_| {
OxiRouterError::ModelError("Invalid v3 blob length".to_string())
})?) as usize;
if blob_len > 0 && v2_end + 4 + blob_len <= bytes.len() {
let json_bytes = &bytes[v2_end + 4..v2_end + 4 + blob_len];
match serde_json::from_slice::<V3Extension>(json_bytes) {
Ok(ext) => {
state.optimizer_type = ext.optimizer_type;
state.optimizer_state = ext.optimizer_state;
state.lr_schedule = ext.lr_schedule;
state.epoch = ext.epoch;
state.early_stopping_config = ext.early_stopping_config;
state.early_stopping_state = ext.early_stopping_state;
}
Err(_) => {
}
}
}
}
Ok(state)
}
fn v2_end_position(bytes: &[u8]) -> Result<usize> {
use crate::core::error::OxiRouterError;
let mut pos = 4;
if pos >= bytes.len() {
return Err(OxiRouterError::ModelError(
"v3: truncated at model_type".to_string(),
));
}
pos += 1;
pos += 16;
if pos > bytes.len() {
return Err(OxiRouterError::ModelError(
"v3: truncated at config".to_string(),
));
}
if pos + 4 > bytes.len() {
return Err(OxiRouterError::ModelError(
"v3: truncated at weight count".to_string(),
));
}
let weight_count = u32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("v3: invalid weight count".to_string()))?,
) as usize;
pos += 4 + weight_count * 4;
if pos > bytes.len() {
return Err(OxiRouterError::ModelError(
"v3: truncated in weights".to_string(),
));
}
if pos + 4 > bytes.len() {
return Err(OxiRouterError::ModelError(
"v3: truncated at extra_params count".to_string(),
));
}
let extra_count = u32::from_le_bytes(bytes[pos..pos + 4].try_into().map_err(|_| {
OxiRouterError::ModelError("v3: invalid extra_params count".to_string())
})?) as usize;
pos += 4 + extra_count * 4;
if pos > bytes.len() {
return Err(OxiRouterError::ModelError(
"v3: truncated in extra_params".to_string(),
));
}
if pos + 4 > bytes.len() {
return Err(OxiRouterError::ModelError(
"v3: truncated at layer_dims count".to_string(),
));
}
let layer_count =
u32::from_le_bytes(bytes[pos..pos + 4].try_into().map_err(|_| {
OxiRouterError::ModelError("v3: invalid layer_dims count".to_string())
})?) as usize;
pos += 4 + layer_count * 8; if pos > bytes.len() {
return Err(OxiRouterError::ModelError(
"v3: truncated in layer_dims".to_string(),
));
}
if pos + 4 > bytes.len() {
return Err(OxiRouterError::ModelError(
"v3: truncated at activation count".to_string(),
));
}
let activation_count =
u32::from_le_bytes(bytes[pos..pos + 4].try_into().map_err(|_| {
OxiRouterError::ModelError("v3: invalid activation count".to_string())
})?) as usize;
pos += 4 + activation_count;
if pos > bytes.len() {
return Err(OxiRouterError::ModelError(
"v3: truncated in activations".to_string(),
));
}
if pos + 4 > bytes.len() {
return Err(OxiRouterError::ModelError(
"v3: truncated at source_ids count".to_string(),
));
}
let id_count =
u32::from_le_bytes(bytes[pos..pos + 4].try_into().map_err(|_| {
OxiRouterError::ModelError("v3: invalid source_ids count".to_string())
})?) as usize;
pos += 4;
for _ in 0..id_count {
if pos + 4 > bytes.len() {
return Err(OxiRouterError::ModelError(
"v3: truncated at source ID length".to_string(),
));
}
let id_len = u32::from_le_bytes(bytes[pos..pos + 4].try_into().map_err(|_| {
OxiRouterError::ModelError("v3: invalid source ID length".to_string())
})?) as usize;
pos += 4 + id_len;
if pos > bytes.len() {
return Err(OxiRouterError::ModelError(
"v3: truncated in source ID".to_string(),
));
}
}
pos += 8;
if pos > bytes.len() {
return Err(OxiRouterError::ModelError(
"v3: truncated at iterations".to_string(),
));
}
Ok(pos)
}
fn from_bytes_v2_core(bytes: &[u8]) -> Result<Self> {
use crate::core::error::OxiRouterError;
let mut pos = 4;
if pos >= bytes.len() {
return Err(OxiRouterError::ModelError("Missing model type".to_string()));
}
let model_type = match bytes[pos] {
0 => ModelType::NaiveBayes,
1 => ModelType::NeuralNetwork,
2 => ModelType::Ensemble,
_ => return Err(OxiRouterError::ModelError("Unknown model type".to_string())),
};
pos += 1;
if pos + 16 > bytes.len() {
return Err(OxiRouterError::ModelError(
"Missing config fields".to_string(),
));
}
let feature_dim = u32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid feature_dim".to_string()))?,
) as usize;
pos += 4;
let num_classes = u32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid num_classes".to_string()))?,
) as usize;
pos += 4;
let learning_rate = f32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid learning_rate".to_string()))?,
);
pos += 4;
let regularization = f32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid regularization".to_string()))?,
);
pos += 4;
if pos + 4 > bytes.len() {
return Err(OxiRouterError::ModelError(
"Missing weight count".to_string(),
));
}
let weight_count = u32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid weight count".to_string()))?,
) as usize;
pos += 4;
let mut weights = Vec::with_capacity(weight_count);
for _ in 0..weight_count {
if pos + 4 > bytes.len() {
return Err(OxiRouterError::ModelError(
"Unexpected end of weights".to_string(),
));
}
let w = f32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid weight".to_string()))?,
);
weights.push(w);
pos += 4;
}
if pos + 4 > bytes.len() {
return Err(OxiRouterError::ModelError(
"Missing extra_params count".to_string(),
));
}
let extra_count =
u32::from_le_bytes(bytes[pos..pos + 4].try_into().map_err(|_| {
OxiRouterError::ModelError("Invalid extra_params count".to_string())
})?) as usize;
pos += 4;
let mut extra_params = Vec::with_capacity(extra_count);
for _ in 0..extra_count {
if pos + 4 > bytes.len() {
return Err(OxiRouterError::ModelError(
"Unexpected end of extra_params".to_string(),
));
}
let p = f32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid extra param".to_string()))?,
);
extra_params.push(p);
pos += 4;
}
if pos + 4 > bytes.len() {
return Err(OxiRouterError::ModelError(
"Missing layer_dims count".to_string(),
));
}
let layer_count = u32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid layer_dims count".to_string()))?,
) as usize;
pos += 4;
let mut layer_dims = Vec::with_capacity(layer_count);
for _ in 0..layer_count {
if pos + 8 > bytes.len() {
return Err(OxiRouterError::ModelError(
"Unexpected end of layer_dims".to_string(),
));
}
let input_dim = u32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid input_dim".to_string()))?,
) as usize;
pos += 4;
let output_dim = u32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid output_dim".to_string()))?,
) as usize;
pos += 4;
layer_dims.push((input_dim, output_dim));
}
if pos + 4 > bytes.len() {
return Err(OxiRouterError::ModelError(
"Missing activation_types count".to_string(),
));
}
let activation_count = u32::from_le_bytes(bytes[pos..pos + 4].try_into().map_err(|_| {
OxiRouterError::ModelError("Invalid activation_types count".to_string())
})?) as usize;
pos += 4;
if pos + activation_count > bytes.len() {
return Err(OxiRouterError::ModelError(
"Unexpected end of activation_types".to_string(),
));
}
let activation_types = bytes[pos..pos + activation_count].to_vec();
pos += activation_count;
if pos + 4 > bytes.len() {
return Err(OxiRouterError::ModelError(
"Missing source_ids count".to_string(),
));
}
let id_count = u32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid source_ids count".to_string()))?,
) as usize;
pos += 4;
let mut source_ids = Vec::with_capacity(id_count);
for _ in 0..id_count {
if pos + 4 > bytes.len() {
return Err(OxiRouterError::ModelError(
"Unexpected end of source ID length".to_string(),
));
}
let id_len = u32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid ID length".to_string()))?,
) as usize;
pos += 4;
if pos + id_len > bytes.len() {
return Err(OxiRouterError::ModelError(
"Unexpected end of source ID".to_string(),
));
}
let id = String::from_utf8(bytes[pos..pos + id_len].to_vec()).map_err(|_| {
OxiRouterError::ModelError("Invalid UTF-8 in source ID".to_string())
})?;
source_ids.push(id);
pos += id_len;
}
let iterations = if pos + 8 <= bytes.len() {
u64::from_le_bytes(
bytes[pos..pos + 8]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid iterations".to_string()))?,
)
} else {
0
};
Ok(Self {
config: ModelConfig {
model_type,
feature_dim,
num_classes,
learning_rate,
regularization,
},
weights,
source_ids,
iterations,
extra_params,
layer_dims,
activation_types,
optimizer_type: None,
optimizer_state: None,
lr_schedule: None,
epoch: 0,
early_stopping_config: None,
early_stopping_state: None,
})
}
fn from_bytes_v1(bytes: &[u8]) -> Result<Self> {
use crate::core::error::OxiRouterError;
if bytes.len() < 20 {
return Err(OxiRouterError::ModelError(
"Invalid v1 model state: too short".to_string(),
));
}
let mut pos = 4;
let feature_dim = u32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid feature_dim".to_string()))?,
) as usize;
pos += 4;
let num_classes = u32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid num_classes".to_string()))?,
) as usize;
pos += 4;
let weight_count = u32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid weight count".to_string()))?,
) as usize;
pos += 4;
let mut weights = Vec::with_capacity(weight_count);
for _ in 0..weight_count {
if pos + 4 > bytes.len() {
return Err(OxiRouterError::ModelError(
"Unexpected end of weights".to_string(),
));
}
let w = f32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid weight".to_string()))?,
);
weights.push(w);
pos += 4;
}
if pos + 4 > bytes.len() {
return Err(OxiRouterError::ModelError(
"Unexpected end of source IDs".to_string(),
));
}
let id_count = u32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid source ID count".to_string()))?,
) as usize;
pos += 4;
let mut source_ids = Vec::with_capacity(id_count);
for _ in 0..id_count {
if pos + 4 > bytes.len() {
return Err(OxiRouterError::ModelError(
"Unexpected end of source ID length".to_string(),
));
}
let id_len = u32::from_le_bytes(
bytes[pos..pos + 4]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid ID length".to_string()))?,
) as usize;
pos += 4;
if pos + id_len > bytes.len() {
return Err(OxiRouterError::ModelError(
"Unexpected end of source ID".to_string(),
));
}
let id = String::from_utf8(bytes[pos..pos + id_len].to_vec()).map_err(|_| {
OxiRouterError::ModelError("Invalid UTF-8 in source ID".to_string())
})?;
source_ids.push(id);
pos += id_len;
}
let iterations = if pos + 8 <= bytes.len() {
u64::from_le_bytes(
bytes[pos..pos + 8]
.try_into()
.map_err(|_| OxiRouterError::ModelError("Invalid iterations".to_string()))?,
)
} else {
0
};
Ok(Self {
config: ModelConfig {
model_type: ModelType::NaiveBayes,
feature_dim,
num_classes,
learning_rate: 0.01,
regularization: 0.001,
},
weights,
source_ids,
iterations,
extra_params: Vec::new(),
layer_dims: Vec::new(),
activation_types: Vec::new(),
optimizer_type: None,
optimizer_state: None,
lr_schedule: None,
epoch: 0,
early_stopping_config: None,
early_stopping_state: None,
})
}
fn from_bytes_v2(bytes: &[u8]) -> Result<Self> {
Self::from_bytes_v2_core(bytes)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(all(not(feature = "std"), feature = "alloc"))]
use alloc::vec;
#[test]
fn test_training_sample_reward() {
let features = FeatureVector::new();
let sample = TrainingSample::new(features.clone(), "src1", true, 100, 50);
assert!(sample.reward() > 0.8);
let failed = TrainingSample::new(features.clone(), "src1", false, 100, 0);
assert_eq!(failed.reward(), 0.0);
let slow = TrainingSample::new(features, "src1", true, 9000, 50);
assert!(slow.reward() < 0.7);
}
#[test]
fn test_model_state_serialization() {
let config = ModelConfig::default();
let mut state = ModelState::new(config, vec!["src1".to_string(), "src2".to_string()]);
state.weights = vec![0.1, 0.2, 0.3, 0.4];
state.extra_params = vec![0.5, 0.6];
state.layer_dims = vec![(10, 5), (5, 2)];
state.activation_types = vec![0, 3]; state.iterations = 100;
let bytes = state.to_bytes();
let restored = ModelState::from_bytes(&bytes).unwrap();
assert_eq!(restored.weights, state.weights);
assert_eq!(restored.source_ids, state.source_ids);
assert_eq!(restored.iterations, state.iterations);
assert_eq!(restored.extra_params, state.extra_params);
assert_eq!(restored.layer_dims, state.layer_dims);
assert_eq!(restored.activation_types, state.activation_types);
assert_eq!(restored.config.model_type, state.config.model_type);
assert_eq!(restored.config.learning_rate, state.config.learning_rate);
}
}