use serde::{Deserialize, Serialize};
const LEARNING_RATE: f32 = 0.01;
const EPOCHS: usize = 50;
#[derive(Debug, Clone, Copy)]
pub struct CompressionFeatures {
pub compression_ratio: f32,
pub message_count: f32,
pub avg_message_length: f32,
pub tool_output_fraction: f32,
}
impl CompressionFeatures {
#[must_use]
pub fn to_vec(&self) -> [f32; 4] {
[
self.compression_ratio,
self.message_count,
self.avg_message_length,
self.tool_output_fraction,
]
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompressionModelWeights {
pub weights: Vec<f32>,
pub bias: f32,
pub sample_count: u64,
}
impl Default for CompressionModelWeights {
fn default() -> Self {
Self {
weights: vec![0.3, 0.05, 0.05, -0.1],
bias: 0.0,
sample_count: 0,
}
}
}
pub struct CompressionPredictor {
weights: CompressionModelWeights,
}
impl CompressionPredictor {
#[must_use]
pub fn new() -> Self {
Self {
weights: CompressionModelWeights::default(),
}
}
#[must_use]
pub fn from_weights(weights: CompressionModelWeights) -> Self {
Self { weights }
}
pub fn serialize(&self) -> Result<String, serde_json::Error> {
serde_json::to_string(&self.weights)
}
#[must_use]
pub fn sample_count(&self) -> u64 {
self.weights.sample_count
}
#[must_use]
pub fn is_cold_start(&self, min_samples: u64) -> bool {
self.weights.sample_count < min_samples
}
#[must_use]
pub fn predict(&self, features: &CompressionFeatures) -> f32 {
let x = features.to_vec();
let dot: f32 = x
.iter()
.zip(self.weights.weights.iter())
.map(|(xi, wi)| xi * wi)
.sum();
sigmoid(dot + self.weights.bias)
}
pub fn train(&mut self, samples: &[(CompressionFeatures, f32)]) {
if samples.is_empty() {
return;
}
let n_features = self.weights.weights.len();
#[allow(clippy::cast_precision_loss)]
let n = samples.len() as f32;
for _ in 0..EPOCHS {
let mut grad_w = vec![0.0f32; n_features];
let mut grad_b = 0.0f32;
for (features, target) in samples {
let pred = self.predict(features);
let error = (pred - target) * pred * (1.0 - pred);
let x = features.to_vec();
for (i, xi) in x.iter().enumerate() {
grad_w[i] += error * xi;
}
grad_b += error;
}
for (wi, gi) in self.weights.weights.iter_mut().zip(grad_w.iter()) {
*wi -= LEARNING_RATE * gi / n;
}
self.weights.bias -= LEARNING_RATE * grad_b / n;
}
self.weights.sample_count += samples.len() as u64;
}
#[must_use]
pub fn select_ratio(
&self,
floor: f32,
candidate_ratios: &[f32],
message_count: f32,
avg_message_length: f32,
tool_output_fraction: f32,
) -> Option<f32> {
let mut sorted = candidate_ratios.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
for &ratio in &sorted {
let features = CompressionFeatures {
compression_ratio: ratio,
message_count,
avg_message_length,
tool_output_fraction,
};
if self.predict(&features) >= floor {
return Some(ratio);
}
}
None
}
}
impl Default for CompressionPredictor {
fn default() -> Self {
Self::new()
}
}
#[inline]
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sigmoid_at_zero_is_half() {
assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
}
#[test]
fn predict_with_default_weights_returns_valid_range() {
let model = CompressionPredictor::new();
let features = CompressionFeatures {
compression_ratio: 0.5,
message_count: 0.5,
avg_message_length: 0.5,
tool_output_fraction: 0.2,
};
let p = model.predict(&features);
assert!(p > 0.0 && p < 1.0, "prediction must be in (0, 1): {p}");
}
#[test]
fn is_cold_start_true_when_no_training() {
let model = CompressionPredictor::new();
assert!(model.is_cold_start(10));
}
#[test]
fn is_cold_start_false_after_training() {
let mut model = CompressionPredictor::new();
let features = CompressionFeatures {
compression_ratio: 0.5,
message_count: 0.5,
avg_message_length: 0.5,
tool_output_fraction: 0.2,
};
let samples: Vec<_> = (0..10).map(|_| (features, 0.7f32)).collect();
model.train(&samples);
assert!(!model.is_cold_start(10));
}
#[test]
fn training_on_high_ratio_improves_high_ratio_prediction() {
let mut model = CompressionPredictor::new();
let high_ratio = CompressionFeatures {
compression_ratio: 0.8,
message_count: 0.5,
avg_message_length: 0.5,
tool_output_fraction: 0.2,
};
let initial_pred = model.predict(&high_ratio);
let samples: Vec<_> = (0..50).map(|_| (high_ratio, 0.9f32)).collect();
model.train(&samples);
let trained_pred = model.predict(&high_ratio);
assert!(
trained_pred > initial_pred,
"training on high target must increase prediction: {initial_pred} -> {trained_pred}"
);
}
#[test]
fn select_ratio_returns_most_aggressive_passing_ratio() {
let mut model = CompressionPredictor::new();
let good = CompressionFeatures {
compression_ratio: 0.8,
message_count: 0.5,
avg_message_length: 0.5,
tool_output_fraction: 0.2,
};
let bad = CompressionFeatures {
compression_ratio: 0.2,
message_count: 0.5,
avg_message_length: 0.5,
tool_output_fraction: 0.2,
};
let samples: Vec<(CompressionFeatures, f32)> = (0..100)
.map(|i| if i % 2 == 0 { (good, 0.95) } else { (bad, 0.1) })
.collect();
model.train(&samples);
let ratios = vec![0.2, 0.4, 0.6, 0.8, 0.9];
let selected = model.select_ratio(0.5, &ratios, 0.5, 0.5, 0.2);
if let Some(r) = selected {
let features = CompressionFeatures {
compression_ratio: r,
message_count: 0.5,
avg_message_length: 0.5,
tool_output_fraction: 0.2,
};
assert!(
model.predict(&features) >= 0.5,
"selected ratio must clear floor"
);
}
}
#[test]
fn select_ratio_returns_none_when_nothing_passes() {
let model = CompressionPredictor::new();
let ratios = vec![0.1, 0.2];
let selected = model.select_ratio(0.99, &ratios, 0.5, 0.5, 0.2);
assert!(
selected.is_none(),
"expected None when no ratio clears the floor"
);
}
#[test]
fn sample_count_increments_after_training() {
let mut model = CompressionPredictor::new();
assert_eq!(model.sample_count(), 0);
let features = CompressionFeatures {
compression_ratio: 0.5,
message_count: 0.5,
avg_message_length: 0.5,
tool_output_fraction: 0.2,
};
model.train(&[(features, 0.7)]);
assert_eq!(model.sample_count(), 1);
model.train(&[(features, 0.6), (features, 0.8)]);
assert_eq!(model.sample_count(), 3);
}
#[test]
fn serialize_roundtrip() {
let model = CompressionPredictor::new();
let json = model.serialize().expect("serialize");
let weights: CompressionModelWeights = serde_json::from_str(&json).expect("deserialize");
assert_eq!(weights.weights.len(), model.weights.weights.len());
assert!((weights.bias - model.weights.bias).abs() < 1e-6);
}
}