use alloc::vec::Vec;
use core::f32;
use zenanalyze::feature::{AnalysisFeature, AnalysisQuery, FeatureSet};
use zenpredict::{
AllowedMask, Model, Predictor, ScoreTransform,
argmin::argmin_masked_in_range as argmin_in_range,
};
use super::spec::{
CELLS, FEAT_COLS, N_CELLS, PickerConstraints, RANGE_BYTES_LOG, RANGE_FILTER_SHARPNESS,
RANGE_FILTER_STRENGTH, RANGE_SNS, SCHEMA_HASH,
};
const ANALYSIS_FEATURES: &[AnalysisFeature] = &[
AnalysisFeature::LaplacianVarianceP50,
AnalysisFeature::LaplacianVarianceP75,
AnalysisFeature::LaplacianVariance,
AnalysisFeature::QuantSurvivalY,
AnalysisFeature::CbSharpness,
AnalysisFeature::PixelCount,
AnalysisFeature::Uniformity,
AnalysisFeature::DistinctColorBins,
AnalysisFeature::CrSharpness,
AnalysisFeature::EdgeDensity,
AnalysisFeature::NoiseFloorYP50,
AnalysisFeature::LumaHistogramEntropy,
AnalysisFeature::NaturalLikelihood,
AnalysisFeature::QuantSurvivalYP50,
AnalysisFeature::NoiseFloorUvP50,
AnalysisFeature::AqMapMean,
AnalysisFeature::CrHorizSharpness,
AnalysisFeature::MinDim,
AnalysisFeature::EdgeSlopeStdev,
AnalysisFeature::LaplacianVarianceP90,
AnalysisFeature::PatchFraction,
AnalysisFeature::MaxDim,
AnalysisFeature::AspectMinOverMax,
AnalysisFeature::AqMapP75,
AnalysisFeature::CbHorizSharpness,
AnalysisFeature::NoiseFloorYP25,
AnalysisFeature::NoiseFloorUV,
AnalysisFeature::ChromaComplexity,
AnalysisFeature::QuantSurvivalYP75,
AnalysisFeature::AqMapStd,
AnalysisFeature::GradientFraction,
AnalysisFeature::NoiseFloorYP75,
AnalysisFeature::ScreenContentLikelihood,
AnalysisFeature::HighFreqEnergyRatio,
AnalysisFeature::Colourfulness,
AnalysisFeature::QuantSurvivalUv,
];
const _: () = {
if ANALYSIS_FEATURES.len() != FEAT_COLS.len() {
panic!("ANALYSIS_FEATURES.len() != FEAT_COLS.len()");
}
};
#[repr(C, align(16))]
struct AlignedModel<const N: usize>([u8; N]);
const MODEL_BYTES_RAW: &[u8] = &AlignedModel(*include_bytes!("zenwebp_picker_v0.1.bin")).0;
#[derive(Clone, Copy, Debug)]
pub struct TuningPick {
pub sns_strength: u8,
pub filter_strength: u8,
pub filter_sharpness: u8,
pub method: u8,
pub segments: u8,
pub cell_idx: usize,
}
#[derive(Clone, Copy, Debug)]
pub enum PickError {
NoBakedModel,
Parse,
SchemaMismatch { expected: u64, got: u64 },
Forward,
NoAllowedCell,
}
fn engineered_features(raw_feats: &[f32], width: u32, height: u32, target_zensim: f32) -> Vec<f32> {
debug_assert_eq!(
raw_feats.len(),
FEAT_COLS.len(),
"raw feature count mismatch"
);
let pixels = (width as f32) * (height as f32);
let log_px = libm::logf(pixels.max(1.0));
let target_norm = target_zensim / 100.0;
let size_oh = match (width as u64) * (height as u64) {
n if n < 64 * 64 => [1.0_f32, 0.0, 0.0, 0.0], n if n < 256 * 256 => [0.0, 1.0, 0.0, 0.0], n if n < 1024 * 1024 => [0.0, 0.0, 1.0, 0.0], _ => [0.0, 0.0, 0.0, 1.0], };
let n_feat = raw_feats.len();
let mut out = Vec::with_capacity(n_feat + 4 + 5 + n_feat + 1);
out.extend_from_slice(raw_feats);
out.extend_from_slice(&size_oh);
out.extend_from_slice(&[
log_px,
log_px * log_px,
target_norm,
target_norm * target_norm,
target_norm * log_px,
]);
for f in raw_feats {
out.push(target_norm * f);
}
out.push(0.0); out
}
fn extract_raw_features_rgb8(rgb: &[u8], width: u32, height: u32) -> Vec<f32> {
let mut feats = FeatureSet::new();
for f in ANALYSIS_FEATURES {
feats = feats.with(*f);
}
let query = AnalysisQuery::new(feats);
let analysis = zenanalyze::analyze_features_rgb8(rgb, width, height, &query);
ANALYSIS_FEATURES
.iter()
.map(|f| analysis.get_f32(*f).unwrap_or(0.0))
.collect()
}
pub fn pick_tuning(
rgb: &[u8],
width: u32,
height: u32,
target_zensim: f32,
constraints: &PickerConstraints,
) -> Result<TuningPick, PickError> {
let raw_feats = extract_raw_features_rgb8(rgb, width, height);
pick_tuning_from_features(&raw_feats, width, height, target_zensim, constraints)
}
pub fn pick_tuning_from_features(
raw_feats: &[f32],
width: u32,
height: u32,
target_zensim: f32,
constraints: &PickerConstraints,
) -> Result<TuningPick, PickError> {
if MODEL_BYTES_RAW.is_empty() {
return Err(PickError::NoBakedModel);
}
let model = Model::from_bytes_with_schema(MODEL_BYTES_RAW, SCHEMA_HASH).map_err(|e| {
match e {
zenpredict::PredictError::SchemaHashMismatch { expected, got } => {
PickError::SchemaMismatch { expected, got }
}
_ => PickError::Parse,
}
})?;
let mut predictor = Predictor::new(model);
let feats = engineered_features(raw_feats, width, height, target_zensim);
let mask_arr = constraints.allowed_mask();
let mask = AllowedMask::new(&mask_arr);
let output = predictor.predict(&feats).map_err(|_| PickError::Forward)?;
let cell_idx = argmin_in_range(
output,
(RANGE_BYTES_LOG.start, RANGE_BYTES_LOG.end),
&mask,
ScoreTransform::Exp,
None,
)
.ok_or(PickError::NoAllowedCell)?;
if cell_idx >= N_CELLS {
return Err(PickError::Forward);
}
let sns = clamp_to_u8(output[RANGE_SNS.start + cell_idx], 0.0, 100.0);
let filter_strength = clamp_to_u8(output[RANGE_FILTER_STRENGTH.start + cell_idx], 0.0, 100.0);
let filter_sharpness = clamp_to_u8(output[RANGE_FILTER_SHARPNESS.start + cell_idx], 0.0, 7.0);
let cell = CELLS[cell_idx];
Ok(TuningPick {
sns_strength: sns,
filter_strength,
filter_sharpness,
method: cell.method,
segments: cell.segments,
cell_idx,
})
}
fn clamp_to_u8(v: f32, lo: f32, hi: f32) -> u8 {
let clamped = if v.is_nan() { lo } else { v.max(lo).min(hi) };
libm::roundf(clamped) as u8
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn engineered_features_layout_correct() {
const N: usize = FEAT_COLS.len();
let raw = [0.1_f32; N];
let v = engineered_features(&raw, 512, 512, 80.0);
assert_eq!(v.len(), N + 4 + 5 + N + 1);
assert_eq!(v[N..N + 4], [0.0, 0.0, 1.0, 0.0]);
}
#[test]
fn engineered_features_size_class_buckets() {
const N: usize = FEAT_COLS.len();
let raw = [0.0_f32; N];
let v = engineered_features(&raw, 32, 32, 80.0);
assert_eq!(v[N..N + 4], [1.0, 0.0, 0.0, 0.0]);
let v = engineered_features(&raw, 200, 200, 80.0);
assert_eq!(v[N..N + 4], [0.0, 1.0, 0.0, 0.0]);
let v = engineered_features(&raw, 2048, 2048, 80.0);
assert_eq!(v[N..N + 4], [0.0, 0.0, 0.0, 1.0]);
}
#[test]
fn picker_loads_and_picks_a_cell() {
let raw = [0.5_f32; FEAT_COLS.len()];
let pick = pick_tuning_from_features(&raw, 512, 512, 80.0, &PickerConstraints::default());
assert!(pick.is_ok(), "pick_tuning failed: {:?}", pick);
let p = pick.unwrap();
assert!(p.cell_idx < N_CELLS);
assert!(p.sns_strength <= 100);
assert!(p.filter_strength <= 100);
assert!(p.filter_sharpness <= 7);
assert!([4, 5, 6].contains(&p.method));
assert!([1, 4].contains(&p.segments));
}
#[test]
fn picker_respects_method_constraint() {
let raw = [0.5_f32; FEAT_COLS.len()];
let constraints = PickerConstraints {
allowed_methods: Some(&[4]),
..Default::default()
};
let pick = pick_tuning_from_features(&raw, 512, 512, 80.0, &constraints);
assert!(pick.is_ok());
assert_eq!(pick.unwrap().method, 4);
}
#[test]
fn picker_returns_no_allowed_cell_when_all_masked() {
let raw = [0.5_f32; FEAT_COLS.len()];
let constraints = PickerConstraints {
allowed_methods: Some(&[3]), ..Default::default()
};
let pick = pick_tuning_from_features(&raw, 512, 512, 80.0, &constraints);
assert!(matches!(pick, Err(PickError::NoAllowedCell)));
}
}