use crate::EncoderConfig;
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum QualityTarget {
Zensim(f32),
}
#[derive(Debug, Clone)]
pub struct AutoTuneOptions {
pub time_budget: Option<Duration>,
pub speed_range: Option<std::ops::RangeInclusive<u8>>,
pub pareto_weight: f32,
}
impl Default for AutoTuneOptions {
fn default() -> Self {
Self {
time_budget: None,
speed_range: None,
pareto_weight: 0.0,
}
}
}
impl AutoTuneOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_time_budget(mut self, budget: Duration) -> Self {
self.time_budget = Some(budget);
self
}
pub fn with_speed_range(mut self, range: std::ops::RangeInclusive<u8>) -> Self {
self.speed_range = Some(range);
self
}
pub fn with_pareto_weight(mut self, w: f32) -> Self {
self.pareto_weight = w.clamp(0.0, 1.0);
self
}
}
#[derive(Debug, thiserror::Error)]
pub enum AutoTuneError {
#[error("auto-tune model not yet baked — run scripts/train_bake_pipeline.sh")]
ModelNotBaked,
#[error("zenanalyze feature extraction failed: {0}")]
FeatureExtraction(String),
#[error("zenpredict inference failed: {0}")]
Inference(String),
#[error("LUT JSON malformed: {0}")]
LutMalformed(String),
#[error("no cell satisfies the constraints (time_budget too tight or speed_range empty?)")]
NoCellAllowed,
#[error("target_zq {0:.1} is outside the LUT's covered range")]
TargetOutOfRange(f32),
#[error("internal: {0}")]
Internal(&'static str),
}
#[repr(C, align(4))]
struct Align4<T: ?Sized>(T);
static MODEL_BYTES_ALIGNED: &Align4<[u8]> =
&Align4(*include_bytes!("models/rav1e_picker_v0_1_1.bin"));
static MODEL_BYTES: &[u8] = &MODEL_BYTES_ALIGNED.0;
const ENCODE_MS_LUT_JSON: &str = include_str!("models/rav1e_encode_ms_lut_v0_1_1.json");
const QUALITY_LUT_JSON: &str = include_str!("models/rav1e_quality_lut_v0_1_1.json");
struct EncodeMsLut {
median_ms_per_mpx: Vec<[f32; 4]>, }
struct QualityLut {
target_zqs: Vec<u32>,
median_q: Vec<Vec<i32>>,
}
fn size_class_idx(w: u32, h: u32) -> usize {
let n = (w as u64) * (h as u64);
if n < 64 * 64 {
0 } else if n < 256 * 256 {
1 } else if n < 1024 * 1024 {
2 } else {
3 }
}
#[cfg(feature = "auto-tune")]
fn parse_encode_ms_lut(json: &str) -> Result<EncodeMsLut, AutoTuneError> {
if json.trim().is_empty() {
return Err(AutoTuneError::ModelNotBaked);
}
let v: serde_json::Value = serde_json::from_str(json)
.map_err(|e| AutoTuneError::LutMalformed(format!("encode_ms_lut: {e}")))?;
let median = v
.get("median_ms_per_mpx")
.and_then(|m| m.as_object())
.ok_or_else(|| AutoTuneError::LutMalformed("missing median_ms_per_mpx".into()))?;
let mut rows: Vec<(u32, [f32; 4])> = Vec::new();
for (k, sub) in median {
let speed: u32 = k
.strip_prefix("speed")
.and_then(|s| s.parse().ok())
.ok_or_else(|| AutoTuneError::LutMalformed(format!("bad cell key {k}")))?;
let sub = sub
.as_object()
.ok_or_else(|| AutoTuneError::LutMalformed("cell value not object".into()))?;
let mut row = [f32::INFINITY; 4];
for (sz_label, val) in sub {
let idx = match sz_label.as_str() {
"tiny" => 0,
"small" => 1,
"medium" => 2,
"large" => 3,
_ => continue,
};
row[idx] = val.as_f64().unwrap_or(f64::INFINITY) as f32;
}
rows.push((speed, row));
}
rows.sort_by_key(|(s, _)| *s);
Ok(EncodeMsLut {
median_ms_per_mpx: rows.into_iter().map(|(_, r)| r).collect(),
})
}
#[cfg(feature = "auto-tune")]
fn parse_quality_lut(json: &str) -> Result<QualityLut, AutoTuneError> {
if json.trim().is_empty() {
return Err(AutoTuneError::ModelNotBaked);
}
let v: serde_json::Value = serde_json::from_str(json)
.map_err(|e| AutoTuneError::LutMalformed(format!("quality_lut: {e}")))?;
let target_zqs: Vec<u32> = v
.get("target_zqs")
.and_then(|a| a.as_array())
.ok_or_else(|| AutoTuneError::LutMalformed("missing target_zqs".into()))?
.iter()
.filter_map(|n| n.as_u64().map(|x| x as u32))
.collect();
let median_q: Vec<Vec<i32>> = v
.get("median_q")
.and_then(|a| a.as_array())
.ok_or_else(|| AutoTuneError::LutMalformed("missing median_q".into()))?
.iter()
.map(|row| {
row.as_array()
.map(|r| {
r.iter()
.filter_map(|n| n.as_i64().map(|x| x as i32))
.collect()
})
.unwrap_or_default()
})
.collect();
Ok(QualityLut {
target_zqs,
median_q,
})
}
fn nearest_target_zq_idx(target_zqs: &[u32], target: f32) -> Option<usize> {
if target_zqs.is_empty() {
return None;
}
let target_i = target.round() as i32;
let mut best = 0usize;
let mut best_d = i32::MAX;
for (i, &t) in target_zqs.iter().enumerate() {
let d = (t as i32 - target_i).abs();
if d < best_d {
best_d = d;
best = i;
}
}
Some(best)
}
#[cfg(feature = "auto-tune")]
impl EncoderConfig {
pub fn auto_tune(
self,
rgb: &[u8],
width: u32,
height: u32,
target: QualityTarget,
opts: AutoTuneOptions,
) -> Result<Self, AutoTuneError> {
let model = zenpredict::Model::from_bytes(MODEL_BYTES).map_err(|e| {
if MODEL_BYTES.len() < 16 {
AutoTuneError::ModelNotBaked
} else {
AutoTuneError::Inference(format!("Model::from_bytes: {e}"))
}
})?;
let n_cells = (model.header().n_outputs as usize).max(1);
let feature_cols_str = model
.metadata()
.get_utf8("zentrain.feature_columns")
.map_err(|e| {
AutoTuneError::Internal(Box::leak(format!("metadata: {e}").into_boxed_str()))
})?;
let feature_cols: Vec<&str> = feature_cols_str
.split(|c: char| c == '\n' || c == ',')
.filter(|s| !s.is_empty())
.collect();
let supported = zenanalyze::feature::FeatureSet::SUPPORTED;
let lookup = |col: &str| -> Option<zenanalyze::feature::AnalysisFeature> {
let target = col.strip_prefix("feat_")?;
supported.iter().find(|f| f.name() == target)
};
let mut feature_set = zenanalyze::feature::FeatureSet::new();
for c in &feature_cols {
if let Some(f) = lookup(c) {
feature_set = feature_set.with(f);
}
}
let query = zenanalyze::feature::AnalysisQuery::new(feature_set);
let analysis = zenanalyze::analyze_features_rgb8(rgb, width, height, &query);
let raw_feats: Vec<f32> = feature_cols
.iter()
.map(|c| lookup(c).and_then(|f| analysis.get_f32(f)).unwrap_or(0.0))
.collect();
let target_zq = match target {
QualityTarget::Zensim(z) => z,
};
let pixels = (width as f32) * (height as f32);
let log_px = pixels.max(1.0).ln();
let zq_norm = target_zq / 100.0;
let size_oh = match (width as u64) * (height as u64) {
n if n < 64 * 64 => [1.0_f32, 0.0, 0.0, 0.0],
n if n < 256 * 256 => [0.0, 1.0, 0.0, 0.0],
n if n < 1024 * 1024 => [0.0, 0.0, 1.0, 0.0],
_ => [0.0, 0.0, 0.0, 1.0],
};
let n = raw_feats.len();
let mut features = Vec::with_capacity(2 * n + 10);
features.extend_from_slice(&raw_feats);
features.extend_from_slice(&size_oh);
features.extend_from_slice(&[
log_px,
log_px * log_px,
zq_norm,
zq_norm * zq_norm,
zq_norm * log_px,
]);
for f in &raw_feats {
features.push(zq_norm * f);
}
features.push(0.0);
let mut predictor = zenpredict::Predictor::new(model);
let output = predictor
.predict(&features)
.map_err(|e| AutoTuneError::Inference(format!("{e}")))?;
let bytes_log: Vec<f32> = output[..n_cells.min(output.len())].to_vec();
let ms_lut = parse_encode_ms_lut(ENCODE_MS_LUT_JSON)?;
let q_lut = parse_quality_lut(QUALITY_LUT_JSON)?;
let zq_idx = nearest_target_zq_idx(&q_lut.target_zqs, target_zq)
.ok_or(AutoTuneError::Internal("empty quality LUT"))?;
let sz_idx = size_class_idx(width, height);
let mpx = (width as f32) * (height as f32) / 1_000_000.0;
let budget_ms = opts.time_budget.map(|d| d.as_secs_f32() * 1000.0);
let mut allowed = vec![false; n_cells];
let mut encode_ms_est = vec![f32::INFINITY; n_cells];
for cell in 0..n_cells {
let speed = (cell + 1) as u8;
if let Some(ref range) = opts.speed_range {
if !range.contains(&speed) {
continue;
}
}
if cell < q_lut.median_q.len()
&& zq_idx < q_lut.median_q[cell].len()
&& q_lut.median_q[cell][zq_idx] < 0
{
continue;
}
let ms_per_mpx = ms_lut
.median_ms_per_mpx
.get(cell)
.and_then(|row| row.get(sz_idx).copied())
.unwrap_or(f32::INFINITY);
let est = ms_per_mpx * mpx;
encode_ms_est[cell] = est;
if let Some(b) = budget_ms {
if est > b {
continue;
}
}
allowed[cell] = true;
}
let alpha = opts.pareto_weight;
let (bytes_min, bytes_max) = bytes_log
.iter()
.copied()
.fold((f32::INFINITY, f32::NEG_INFINITY), |(lo, hi), x| {
(lo.min(x), hi.max(x))
});
let (ms_min, ms_max) = encode_ms_est
.iter()
.filter(|x| x.is_finite())
.copied()
.fold((f32::INFINITY, f32::NEG_INFINITY), |(lo, hi), x| {
(lo.min(x), hi.max(x))
});
let bytes_span = (bytes_max - bytes_min).max(1e-6);
let ms_span = (ms_max - ms_min).max(1e-6);
let mut best_cell: Option<usize> = None;
let mut best_score = f32::INFINITY;
for cell in 0..n_cells {
if !allowed[cell] {
continue;
}
let bytes_norm = (bytes_log[cell] - bytes_min) / bytes_span;
let ms_norm = if encode_ms_est[cell].is_finite() {
(encode_ms_est[cell] - ms_min) / ms_span
} else {
1.0
};
let score = (1.0 - alpha) * bytes_norm + alpha * ms_norm;
if score < best_score {
best_score = score;
best_cell = Some(cell);
}
}
let cell = best_cell.ok_or(AutoTuneError::NoCellAllowed)?;
let speed = (cell + 1) as u8;
let q = q_lut
.median_q
.get(cell)
.and_then(|row| row.get(zq_idx).copied())
.filter(|q| *q >= 0)
.ok_or(AutoTuneError::TargetOutOfRange(target_zq))?;
Ok(self.speed(speed).quality(q as f32))
}
}
#[cfg(all(test, feature = "auto-tune"))]
mod tests {
use super::*;
#[test]
fn auto_tune_returns_model_not_baked_with_empty_artifacts() {
if MODEL_BYTES.len() < 16 {
let rgb = [128u8, 128, 128];
let cfg = EncoderConfig::new();
let r = cfg.auto_tune(
&rgb,
1,
1,
QualityTarget::Zensim(85.0),
AutoTuneOptions::new(),
);
assert!(matches!(r, Err(AutoTuneError::ModelNotBaked)));
}
}
}