use super::analysis::ImageContentType;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ZensimTarget {
pub target: f32,
pub max_overshoot: Option<f32>,
pub max_undershoot_ship: Option<f32>,
pub max_undershoot: Option<f32>,
pub max_passes: u8,
}
impl Default for ZensimTarget {
fn default() -> Self {
Self {
target: 80.0,
max_overshoot: Some(1.5),
max_undershoot_ship: Some(0.5),
max_undershoot: None,
max_passes: 2,
}
}
}
impl ZensimTarget {
#[must_use]
pub fn new(target: f32) -> Self {
Self {
target,
..Default::default()
}
}
}
impl From<f32> for ZensimTarget {
fn from(target: f32) -> Self {
Self::new(target)
}
}
impl ZensimTarget {
#[must_use]
pub fn with_max_overshoot(mut self, v: Option<f32>) -> Self {
self.max_overshoot = v;
self
}
#[must_use]
pub fn with_max_undershoot(mut self, v: Option<f32>) -> Self {
self.max_undershoot = v;
self
}
#[must_use]
pub fn with_max_undershoot_ship(mut self, v: Option<f32>) -> Self {
self.max_undershoot_ship = v;
self
}
#[must_use]
pub fn with_max_passes(mut self, n: u8) -> Self {
self.max_passes = n;
self
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ZensimEncodeMetrics {
pub achieved_score: f32,
pub passes_used: u8,
pub bytes: usize,
pub targets_met: bool,
}
impl ZensimEncodeMetrics {
pub(crate) fn no_target(bytes: usize) -> Self {
Self {
achieved_score: f32::NAN,
passes_used: 1,
bytes,
targets_met: true,
}
}
}
#[must_use]
pub(crate) fn zensim_to_starting_q_for_bucket(target: f32, bucket: ImageContentType) -> f32 {
const PHOTO: &[(f32, f32)] = &[
(60.0, 30.0),
(70.0, 60.0),
(75.0, 75.0),
(80.0, 85.0),
(85.0, 90.0),
(90.0, 98.0),
(95.0, 100.0),
];
const DRAWING: &[(f32, f32)] = &[
(60.0, 30.0),
(70.0, 60.0),
(75.0, 72.5),
(80.0, 85.0),
(85.0, 90.0),
(90.0, 100.0),
(95.0, 100.0),
];
const ICON: &[(f32, f32)] = &[
(60.0, 65.0),
(70.0, 78.0),
(75.0, 85.0),
(80.0, 90.0),
(85.0, 95.0),
(90.0, 98.0),
(95.0, 100.0),
];
let anchors = match bucket {
ImageContentType::Photo => PHOTO,
ImageContentType::Drawing | ImageContentType::Text => DRAWING,
ImageContentType::Icon => ICON,
};
interpolate_anchors(target, anchors)
}
#[cfg(feature = "analyzer")]
#[allow(clippy::needless_pass_by_value, dead_code)]
fn starting_q_via_likelihoods(
target: f32,
pixels: &[u8],
layout: crate::PixelLayout,
width: u32,
height: u32,
) -> Option<f32> {
use crate::PixelLayout;
use crate::encoder::analysis::ZenanalyzeDiag;
use crate::encoder::analysis::classifier::{classify_image_type_rgb8_diag, rgba8_to_rgb8};
let w = width as usize;
let h = height as usize;
if w < 8 || h < 8 {
return None;
}
if width <= 128 && height <= 128 {
return Some(zensim_to_starting_q_for_bucket(
target,
ImageContentType::Icon,
));
}
let (_bucket, diag): (ImageContentType, ZenanalyzeDiag) = match layout {
PixelLayout::Rgb8 => {
let n = w * h * 3;
if pixels.len() < n {
return None;
}
classify_image_type_rgb8_diag(&pixels[..n], width, height)
}
PixelLayout::Rgba8 => {
let n = w * h * 4;
if pixels.len() < n {
return None;
}
let rgb = rgba8_to_rgb8(&pixels[..n]);
classify_image_type_rgb8_diag(&rgb, width, height)
}
_ => return None,
};
let q_photo = zensim_to_starting_q_for_bucket(target, ImageContentType::Photo);
let q_drawing = zensim_to_starting_q_for_bucket(target, ImageContentType::Drawing);
let w_screen = (diag.screen_content + diag.text_likelihood).clamp(0.0, 1.0);
let w_photo = diag.natural_likelihood.clamp(0.0, 1.0);
let total = w_screen + w_photo;
if total <= f32::EPSILON {
return Some(q_photo);
}
Some((w_screen * q_drawing + w_photo * q_photo) / total)
}
fn interpolate_anchors(target: f32, anchors: &[(f32, f32)]) -> f32 {
if anchors.is_empty() {
return target;
}
if target <= anchors[0].0 {
return anchors[0].1;
}
let last = anchors[anchors.len() - 1];
if target >= last.0 {
return last.1.min(100.0);
}
for w in anchors.windows(2) {
let (lo, hi) = (w[0], w[1]);
if target >= lo.0 && target <= hi.0 {
let t = (target - lo.0) / (hi.0 - lo.0);
return lo.1 + t * (hi.1 - lo.1);
}
}
target
}
#[cfg(feature = "ablation")]
#[derive(Debug, Default, Clone, Copy, PartialEq)]
pub struct AblationToggles {
pub disable_phase3: bool,
pub trace_phase3: bool,
pub naive_starting_q: bool,
pub no_multi_pass_stats: bool,
pub pre_phase2_anchors: bool,
pub no_secant: bool,
pub use_quadrant_proxy: bool,
pub phase3_fine_gap: Option<f32>,
}
#[cfg(feature = "ablation")]
mod ablation_runtime {
use core::cell::Cell;
use std::thread_local;
thread_local! {
pub(crate) static DISABLE_PHASE3: Cell<bool> = const { Cell::new(false) };
pub(crate) static TRACE_PHASE3: Cell<bool> = const { Cell::new(false) };
pub(crate) static NAIVE_STARTING_Q: Cell<bool> = const { Cell::new(false) };
pub(crate) static NO_MULTI_PASS_STATS: Cell<bool> = const { Cell::new(false) };
pub(crate) static PRE_PHASE2_ANCHORS: Cell<bool> = const { Cell::new(false) };
pub(crate) static NO_SECANT: Cell<bool> = const { Cell::new(false) };
pub(crate) static USE_QUADRANT_PROXY: Cell<bool> = const { Cell::new(false) };
pub(crate) static PHASE3_FINE_GAP: Cell<f32> = const { Cell::new(f32::NAN) };
}
}
#[cfg(feature = "ablation")]
pub fn set_toggles(t: AblationToggles) {
use ablation_runtime::*;
DISABLE_PHASE3.with(|c| c.set(t.disable_phase3));
TRACE_PHASE3.with(|c| c.set(t.trace_phase3));
NAIVE_STARTING_Q.with(|c| c.set(t.naive_starting_q));
NO_MULTI_PASS_STATS.with(|c| c.set(t.no_multi_pass_stats));
PRE_PHASE2_ANCHORS.with(|c| c.set(t.pre_phase2_anchors));
NO_SECANT.with(|c| c.set(t.no_secant));
USE_QUADRANT_PROXY.with(|c| c.set(t.use_quadrant_proxy));
PHASE3_FINE_GAP.with(|c| c.set(t.phase3_fine_gap.unwrap_or(f32::NAN)));
}
#[cfg(feature = "target-zensim")]
const PHASE3_FINE_GAP_DEFAULT: f32 = 0.5;
#[cfg(feature = "target-zensim")]
pub(crate) mod iteration {
use super::*;
use crate::PixelLayout;
use crate::encoder::api::{EncodeDiagnostics, EncodeError};
use crate::encoder::config::LossyConfig;
use alloc::format;
use alloc::vec::Vec;
pub(crate) type IterationResult = Result<(Vec<u8>, ZensimEncodeMetrics), EncodeError>;
#[cfg(feature = "ablation")]
fn ablate_disable_phase3() -> bool {
super::ablation_runtime::DISABLE_PHASE3.with(|c| c.get())
}
#[cfg(not(feature = "ablation"))]
const fn ablate_disable_phase3() -> bool {
false
}
#[cfg(feature = "ablation")]
fn trace_phase3() -> bool {
super::ablation_runtime::TRACE_PHASE3.with(|c| c.get())
}
#[cfg(not(feature = "ablation"))]
const fn trace_phase3() -> bool {
false
}
#[cfg(feature = "ablation")]
fn ablate_naive_starting_q() -> bool {
super::ablation_runtime::NAIVE_STARTING_Q.with(|c| c.get())
}
#[cfg(not(feature = "ablation"))]
const fn ablate_naive_starting_q() -> bool {
false
}
#[cfg(feature = "ablation")]
fn ablate_no_multi_pass_stats() -> bool {
super::ablation_runtime::NO_MULTI_PASS_STATS.with(|c| c.get())
}
#[cfg(not(feature = "ablation"))]
const fn ablate_no_multi_pass_stats() -> bool {
false
}
#[cfg(feature = "ablation")]
fn ablate_pre_phase2_anchors() -> bool {
super::ablation_runtime::PRE_PHASE2_ANCHORS.with(|c| c.get())
}
#[cfg(not(feature = "ablation"))]
const fn ablate_pre_phase2_anchors() -> bool {
false
}
#[cfg(feature = "ablation")]
fn ablate_no_secant() -> bool {
super::ablation_runtime::NO_SECANT.with(|c| c.get())
}
#[cfg(not(feature = "ablation"))]
const fn ablate_no_secant() -> bool {
false
}
#[cfg_attr(not(feature = "ablation"), allow(dead_code))]
fn naive_starting_q(target: f32) -> f32 {
const NAIVE: &[(f32, f32)] = &[
(60.0, 50.0),
(70.0, 65.0),
(75.0, 75.0),
(80.0, 82.0),
(85.0, 90.0),
(90.0, 96.0),
(95.0, 100.0),
];
interpolate_anchors(target, NAIVE)
}
#[cfg_attr(not(feature = "ablation"), allow(dead_code))]
fn pre_phase2_starting_q(target: f32, bucket: ImageContentType) -> f32 {
const PHOTO_HAND: &[(f32, f32)] = &[
(60.0, 50.0),
(70.0, 65.0),
(75.0, 72.0),
(80.0, 80.0),
(85.0, 88.0),
(90.0, 94.0),
(95.0, 98.0),
];
const DRAWING_HAND: &[(f32, f32)] = &[
(60.0, 55.0),
(70.0, 70.0),
(75.0, 78.0),
(80.0, 84.0),
(85.0, 90.0),
(90.0, 96.0),
(95.0, 100.0),
];
const ICON_HAND: &[(f32, f32)] = &[
(60.0, 65.0),
(70.0, 78.0),
(75.0, 85.0),
(80.0, 90.0),
(85.0, 95.0),
(90.0, 98.0),
(95.0, 100.0),
];
let anchors = match bucket {
ImageContentType::Photo => PHOTO_HAND,
ImageContentType::Drawing | ImageContentType::Text => DRAWING_HAND,
ImageContentType::Icon => ICON_HAND,
};
interpolate_anchors(target, anchors)
}
const MAX_DELTA_Q: f32 = 10.0;
const DEFAULT_SENSITIVITY: f32 = 1.5;
const MIN_DELTA_Q: f32 = 0.5;
pub(crate) fn run(
cfg: &LossyConfig,
target: ZensimTarget,
pixels: &[u8],
layout: PixelLayout,
width: u32,
height: u32,
) -> IterationResult {
match layout {
PixelLayout::Rgb8 | PixelLayout::Rgba8 => {}
other => {
return Err(EncodeError::TargetZensimUnsupportedLayout(other));
}
}
let bucket = if ablate_naive_starting_q() {
None
} else {
detect_bucket(pixels, layout, width, height)
};
let mut q = if ablate_naive_starting_q() {
naive_starting_q(target.target)
} else if ablate_pre_phase2_anchors() {
match bucket {
Some(b) => pre_phase2_starting_q(target.target, b),
None => pre_phase2_starting_q(target.target, ImageContentType::Photo),
}
} else {
#[cfg(feature = "analyzer")]
{
match starting_q_via_likelihoods(target.target, pixels, layout, width, height) {
Some(q) => q,
None => match bucket {
Some(b) => zensim_to_starting_q_for_bucket(target.target, b),
None => {
zensim_to_starting_q_for_bucket(target.target, ImageContentType::Photo)
}
},
}
}
#[cfg(not(feature = "analyzer"))]
{
match bucket {
Some(b) => zensim_to_starting_q_for_bucket(target.target, b),
None => zensim_to_starting_q_for_bucket(target.target, ImageContentType::Photo),
}
}
};
q = q.clamp(0.0, 100.0);
let (bytes0, diag0) =
encode_at_with_diagnostics(cfg, q, false, None, pixels, layout, width, height)?;
let max_passes = target.max_passes.max(1);
if trace_phase3() {
eprintln!(
"PHASE3_TRACE pass=0 target={:.3} q={:.3} mps=false bytes={} num_segs={} mb={}x{}",
target.target,
q,
bytes0.len(),
diag0.num_segments,
diag0.mb_width,
diag0.mb_height
);
}
if max_passes <= 1 {
return Ok((
bytes0.clone(),
ZensimEncodeMetrics {
achieved_score: f32::NAN,
passes_used: 1,
bytes: bytes0.len(),
targets_met: true,
},
));
}
let z = zensim::Zensim::new(zensim::ZensimProfile::latest());
let pre = build_source_reference(&z, pixels, layout, width, height).ok_or_else(|| {
EncodeError::InvalidBufferSize(
"zensim precompute_reference failed (image too small?)".into(),
)
})?;
let (score0, dm0) = measure_score_and_diffmap(&z, &pre, &bytes0, layout, width, height)?;
if trace_phase3() {
eprintln!(
"PHASE3_TRACE pass=0_measured target={:.3} q={:.3} achieved={:.4} bytes={} gap={:.4}",
target.target,
q,
score0,
bytes0.len(),
target.target - score0
);
}
let mut best = Candidate {
bytes: bytes0,
score: score0,
q,
seg_overrides: None,
};
if in_band(score0, &target) {
return finalize(best, 1, &target);
}
let per_segment_enabled = !ablate_disable_phase3()
&& diag0.num_segments > 1
&& !diag0.segment_map.is_empty()
&& diag0.mb_width > 0
&& diag0.mb_height > 0;
let num_segments = diag0.num_segments as usize;
let mut last_diag = diag0;
let mut prev_probe: Option<(f32, f32)> = None;
let mut last_q = q;
let mut last_score = score0;
let mut last_dm = dm0;
let mut cum_overrides: [i8; 4] = [0; 4];
#[cfg(feature = "ablation")]
let phase3_fine_gap: f32 = {
let v = super::ablation_runtime::PHASE3_FINE_GAP.with(|c| c.get());
if v.is_nan() {
super::PHASE3_FINE_GAP_DEFAULT
} else {
v
}
};
#[cfg(not(feature = "ablation"))]
let phase3_fine_gap: f32 = super::PHASE3_FINE_GAP_DEFAULT;
for pass in 1..max_passes {
let abs_gap = (target.target - last_score).abs();
let use_per_segment =
per_segment_enabled && abs_gap <= phase3_fine_gap && prev_probe.is_none();
let (next_q, next_overrides) = if use_per_segment {
let dec = next_segment_overrides(
cum_overrides,
&last_dm,
width,
height,
&last_diag,
last_score,
&target,
);
if trace_phase3() {
eprintln!(
"PHASE3_TRACE pass={}_decide use_per_segment=true gap={:.4} \
means=[{:.4},{:.4},{:.4},{:.4}] counts=[{},{},{},{}] \
cum_before=[{},{},{},{}] cum_after=[{},{},{},{}] picked_seg={:?} dir={}",
pass,
target.target - last_score,
dec.means[0],
dec.means[1],
dec.means[2],
dec.means[3],
dec.counts[0],
dec.counts[1],
dec.counts[2],
dec.counts[3],
cum_overrides[0],
cum_overrides[1],
cum_overrides[2],
cum_overrides[3],
dec.overrides[0],
dec.overrides[1],
dec.overrides[2],
dec.overrides[3],
dec.picked_seg,
dec.direction,
);
}
(last_q, Some(dec.overrides))
} else {
let nq = compute_next_q(last_q, last_score, prev_probe, &target);
if trace_phase3() {
eprintln!(
"PHASE3_TRACE pass={}_decide use_per_segment=false gap={:.4} \
last_q={:.3} next_q={:.3} prev_probe={:?}",
pass,
target.target - last_score,
last_q,
nq.clamp(0.0, 100.0),
prev_probe,
);
}
(nq.clamp(0.0, 100.0), None)
};
let q_moved = (next_q - last_q).abs() >= 0.05;
let overrides_moved = match next_overrides {
Some(o) => o != cum_overrides,
None => false,
};
if !q_moved && !overrides_moved {
break;
}
let mps = !ablate_no_multi_pass_stats();
let (bytes_n, diag_n) = encode_at_with_diagnostics(
cfg,
next_q,
mps,
next_overrides,
pixels,
layout,
width,
height,
)?;
let (score_n, dm_n) =
measure_score_and_diffmap(&z, &pre, &bytes_n, layout, width, height)?;
let passes_used = pass + 1;
if trace_phase3() {
eprintln!(
"PHASE3_TRACE pass={}_measured target={:.3} q={:.3} achieved={:.4} bytes={} \
overrides={:?} num_segs={} delta_score={:.4}",
pass,
target.target,
next_q,
score_n,
bytes_n.len(),
next_overrides,
diag_n.num_segments,
score_n - last_score,
);
}
best = pick_best(
best,
Candidate {
bytes: bytes_n,
score: score_n,
q: next_q,
seg_overrides: next_overrides,
},
&target,
);
if in_band(score_n, &target) {
return finalize(best, passes_used, &target);
}
if let Some(ov) = next_overrides {
cum_overrides = ov;
}
prev_probe = Some((last_q, last_score));
last_q = next_q;
last_score = score_n;
last_dm = dm_n;
let _ = num_segments; last_diag = diag_n;
}
finalize(best, max_passes, &target)
}
struct Candidate {
bytes: Vec<u8>,
score: f32,
#[allow(dead_code)]
q: f32,
#[allow(dead_code)]
seg_overrides: Option<[i8; 4]>,
}
fn pick_best(prev: Candidate, cand: Candidate, target: &ZensimTarget) -> Candidate {
let prev_feas = prev.score >= target.target;
let cand_feas = cand.score >= target.target;
match (prev_feas, cand_feas) {
(false, true) => cand,
(true, false) => prev,
(true, true) => {
if cand.bytes.len() < prev.bytes.len() {
cand
} else {
prev
}
}
(false, false) => {
if cand.score > prev.score { cand } else { prev }
}
}
}
fn in_band(score: f32, target: &ZensimTarget) -> bool {
let lower = target.target - target.max_undershoot_ship.unwrap_or(0.0);
let upper = target.target + target.max_overshoot.unwrap_or(f32::INFINITY);
score >= lower && score <= upper
}
fn compute_next_q(
q: f32,
last_score: f32,
prev: Option<(f32, f32)>,
target: &ZensimTarget,
) -> f32 {
let gap = target.target - last_score;
if ablate_no_secant() {
let mut step = gap * 0.5;
step = step.clamp(-MAX_DELTA_Q, MAX_DELTA_Q);
if step.abs() < MIN_DELTA_Q {
step = MIN_DELTA_Q.copysign(if gap == 0.0 { 1.0 } else { gap });
}
return q + step;
}
let delta = if let Some((q_prev, s_prev)) = prev {
let dq = q - q_prev;
let ds = last_score - s_prev;
if dq.abs() > 0.1 && ds.abs() > 0.05 {
let slope = ds / dq; let mut step = gap / slope.max(0.05);
step = step.clamp(-MAX_DELTA_Q, MAX_DELTA_Q);
if step.abs() < MIN_DELTA_Q {
step = MIN_DELTA_Q.copysign(step);
}
step
} else {
fallback_step(gap)
}
} else {
fallback_step(gap)
};
q + delta
}
fn fallback_step(gap: f32) -> f32 {
let mut step = gap * DEFAULT_SENSITIVITY;
step = step.clamp(-MAX_DELTA_Q, MAX_DELTA_Q);
if step.abs() < MIN_DELTA_Q {
step = MIN_DELTA_Q.copysign(if gap == 0.0 { 1.0 } else { gap });
}
step
}
fn finalize(best: Candidate, passes_used: u8, target: &ZensimTarget) -> IterationResult {
if let Some(slack) = target.max_undershoot
&& best.score < target.target - slack
{
return Err(EncodeError::InvalidBufferSize(format!(
"target_zensim: achieved {:.3} below floor {:.3} (max_undershoot {:.3}) after {} passes",
best.score, target.target, slack, passes_used,
)));
}
let targets_met = best.score >= target.target
|| target
.max_undershoot
.is_none_or(|t| (target.target - best.score) <= t);
let bytes_len = best.bytes.len();
Ok((
best.bytes,
ZensimEncodeMetrics {
achieved_score: best.score,
passes_used,
bytes: bytes_len,
targets_met,
},
))
}
fn encode_at_with_diagnostics(
cfg: &LossyConfig,
q: f32,
enable_multi_pass: bool,
seg_overrides: Option<[i8; 4]>,
pixels: &[u8],
layout: PixelLayout,
width: u32,
height: u32,
) -> Result<(Vec<u8>, EncodeDiagnostics), EncodeError> {
let mut probe_cfg = cfg.clone();
probe_cfg.quality = q.clamp(0.0, 100.0);
probe_cfg.multi_pass_stats = enable_multi_pass;
probe_cfg.target_size = 0;
probe_cfg.target_psnr = 0.0;
probe_cfg.target_zensim = None;
probe_cfg.segment_quant_overrides = seg_overrides;
let req =
crate::encoder::api::EncodeRequest::lossy(&probe_cfg, pixels, layout, width, height);
match req.encode_inner_with_diagnostics() {
Ok((bytes, _stats, diag)) => Ok((bytes, diag)),
Err(at_err) => Err(at_err.decompose().0),
}
}
fn build_source_reference(
z: &zensim::Zensim,
pixels: &[u8],
layout: PixelLayout,
width: u32,
height: u32,
) -> Option<zensim::PrecomputedReference> {
let w = width as usize;
let h = height as usize;
match layout {
PixelLayout::Rgb8 => {
if pixels.len() < w * h * 3 {
return None;
}
let chunks: &[[u8; 3]] = bytemuck::cast_slice(&pixels[..w * h * 3]);
let slice = zensim::RgbSlice::new(chunks, w, h);
z.precompute_reference(&slice).ok()
}
PixelLayout::Rgba8 => {
if pixels.len() < w * h * 4 {
return None;
}
let chunks: &[[u8; 4]] = bytemuck::cast_slice(&pixels[..w * h * 4]);
let slice = zensim::RgbaSlice::new(chunks, w, h);
z.precompute_reference(&slice).ok()
}
_ => None,
}
}
fn measure_score_and_diffmap(
z: &zensim::Zensim,
pre: &zensim::PrecomputedReference,
webp: &[u8],
layout: PixelLayout,
width: u32,
height: u32,
) -> Result<(f32, Vec<f32>), EncodeError> {
match layout {
PixelLayout::Rgb8 => measure_rgb(z, pre, webp, width, height),
PixelLayout::Rgba8 => measure_rgba(z, pre, webp, width, height),
other => Err(EncodeError::TargetZensimUnsupportedLayout(other)),
}
}
fn measure_rgb(
z: &zensim::Zensim,
pre: &zensim::PrecomputedReference,
webp: &[u8],
width: u32,
height: u32,
) -> Result<(f32, Vec<f32>), EncodeError> {
let (rgb, w, h) = crate::oneshot::decode_rgb(webp).map_err(|e| {
EncodeError::InvalidBufferSize(format!(
"target_zensim: decode for measurement failed: {:?}",
e.decompose().0,
))
})?;
if w != width || h != height {
return Err(EncodeError::InvalidBufferSize(format!(
"target_zensim: decoded dims {}x{} != source {}x{}",
w, h, width, height,
)));
}
let n = (w as usize) * (h as usize) * 3;
if rgb.len() < n {
return Err(EncodeError::InvalidBufferSize(
"target_zensim: short decoded buffer".into(),
));
}
let chunks: &[[u8; 3]] = bytemuck::cast_slice(&rgb[..n]);
let slice = zensim::RgbSlice::new(chunks, w as usize, h as usize);
let dm = z
.compute_with_ref_and_diffmap(pre, &slice, zensim::DiffmapWeighting::Trained)
.map_err(|e| {
EncodeError::InvalidBufferSize(format!(
"zensim compute_with_ref_and_diffmap failed: {:?}",
e
))
})?;
let score = dm.score() as f32;
Ok((score, dm.diffmap().to_vec()))
}
fn measure_rgba(
z: &zensim::Zensim,
pre: &zensim::PrecomputedReference,
webp: &[u8],
width: u32,
height: u32,
) -> Result<(f32, Vec<f32>), EncodeError> {
let (rgba, w, h) = crate::oneshot::decode_rgba(webp).map_err(|e| {
EncodeError::InvalidBufferSize(format!(
"target_zensim: rgba decode for measurement failed: {:?}",
e.decompose().0,
))
})?;
if w != width || h != height {
return Err(EncodeError::InvalidBufferSize(format!(
"target_zensim: decoded dims {}x{} != source {}x{}",
w, h, width, height,
)));
}
let n = (w as usize) * (h as usize) * 4;
if rgba.len() < n {
return Err(EncodeError::InvalidBufferSize(
"target_zensim: short decoded rgba buffer".into(),
));
}
let chunks: &[[u8; 4]] = bytemuck::cast_slice(&rgba[..n]);
let slice = zensim::RgbaSlice::new(chunks, w as usize, h as usize);
let dm = z
.compute_with_ref_and_diffmap(pre, &slice, zensim::DiffmapWeighting::Trained)
.map_err(|e| {
EncodeError::InvalidBufferSize(format!(
"zensim compute_with_ref_and_diffmap failed: {:?}",
e
))
})?;
let score = dm.score() as f32;
Ok((score, dm.diffmap().to_vec()))
}
pub(crate) struct OverrideDecision {
pub overrides: [i8; 4],
pub means: [f32; 4],
pub counts: [u64; 4],
pub picked_seg: Option<usize>,
pub direction: i8, }
fn next_segment_overrides(
cum: [i8; 4],
diffmap: &[f32],
width: u32,
height: u32,
diag: &EncodeDiagnostics,
score: f32,
target: &ZensimTarget,
) -> OverrideDecision {
#[cfg(feature = "ablation")]
let use_quadrant = super::ablation_runtime::USE_QUADRANT_PROXY.with(|c| c.get());
#[cfg(not(feature = "ablation"))]
let use_quadrant = false;
if use_quadrant {
let q_overrides =
next_segment_overrides_quadrant_proxy(cum, diffmap, width, height, score, target);
return OverrideDecision {
overrides: q_overrides,
means: [0.0; 4],
counts: [0; 4],
picked_seg: None,
direction: 0,
};
}
let n = (diag.num_segments as usize).clamp(2, 4);
let w = width as usize;
let h = height as usize;
let mb_w = diag.mb_width as usize;
let mb_h = diag.mb_height as usize;
let expected = mb_w.saturating_mul(mb_h);
if diag.segment_map.len() != expected || expected == 0 {
return OverrideDecision {
overrides: cum,
means: [0.0; 4],
counts: [0; 4],
picked_seg: None,
direction: 0,
};
}
let mut sum = [0.0f64; 4];
let mut counts = [0u64; 4];
for mb_y in 0..mb_h {
let py0 = mb_y * 16;
let py1 = (py0 + 16).min(h);
if py0 >= h {
break;
}
for mb_x in 0..mb_w {
let px0 = mb_x * 16;
let px1 = (px0 + 16).min(w);
if px0 >= w {
continue;
}
let seg = diag.segment_map[mb_y * mb_w + mb_x] as usize;
if seg >= n {
continue;
}
let mut block_sum = 0.0f64;
let mut block_count = 0u64;
for py in py0..py1 {
let row = &diffmap[py * w + px0..py * w + px1];
for &v in row {
block_sum += v as f64;
block_count += 1;
}
}
if block_count > 0 {
sum[seg] += block_sum;
counts[seg] += block_count;
}
}
}
let mut means = [0.0f32; 4];
for s in 0..n {
if counts[s] > 0 {
means[s] = (sum[s] / counts[s] as f64) as f32;
}
}
let mut worst = 0usize;
let mut best = 0usize;
let mut found_worst = false;
let mut found_best = false;
for s in 0..n {
if counts[s] == 0 {
continue;
}
if !found_worst || means[s] > means[worst] {
worst = s;
found_worst = true;
}
if !found_best || means[s] < means[best] {
best = s;
found_best = true;
}
}
if !found_worst {
return OverrideDecision {
overrides: cum,
means,
counts,
picked_seg: None,
direction: 0,
};
}
let mut out = cum;
let gap = target.target - score;
let mut picked: Option<usize> = None;
let mut direction: i8 = 0;
if gap > 0.0 {
let step = if gap > 4.0 {
-3
} else if gap > 2.0 {
-2
} else {
-1
};
out[worst] = (i32::from(out[worst]) + step).clamp(-16, 16) as i8;
picked = Some(worst);
direction = -1;
} else if let Some(t) = target.max_overshoot
&& (score - target.target) > t
{
let overshoot = score - target.target - t;
let step = if overshoot > 4.0 {
3
} else if overshoot > 2.0 {
2
} else {
1
};
out[best] = (i32::from(out[best]) + step).clamp(-16, 16) as i8;
picked = Some(best);
direction = 1;
}
OverrideDecision {
overrides: out,
means,
counts,
picked_seg: picked,
direction,
}
}
#[cfg_attr(not(feature = "ablation"), allow(dead_code))]
fn next_segment_overrides_quadrant_proxy(
cum: [i8; 4],
diffmap: &[f32],
width: u32,
height: u32,
score: f32,
target: &ZensimTarget,
) -> [i8; 4] {
let n: usize = 4;
let mut sum = [0.0f64; 4];
let mut count = [0u64; 4];
let w = width as usize;
let h = height as usize;
let half_w = w / 2;
let half_h = h / 2;
for y in 0..h {
let row = &diffmap[y * w..y * w + w];
for (x, &v) in row.iter().enumerate() {
let qx = usize::from(x >= half_w);
let qy = usize::from(y >= half_h);
let q = qy * 2 + qx;
sum[q] += v as f64;
count[q] += 1;
}
}
let mut means = [0.0f32; 4];
for s in 0..n {
if count[s] > 0 {
means[s] = (sum[s] / count[s] as f64) as f32;
}
}
let mut worst = 0usize;
let mut best = 0usize;
let mut found = false;
for s in 0..n {
if count[s] == 0 {
continue;
}
if !found || means[s] > means[worst] {
worst = s;
}
if !found || means[s] < means[best] {
best = s;
}
found = true;
}
let mut out = cum;
let gap = target.target - score;
if gap > 0.0 {
let step = if gap > 4.0 {
-3
} else if gap > 2.0 {
-2
} else {
-1
};
out[worst] = (i32::from(out[worst]) + step).clamp(-16, 16) as i8;
} else if let Some(t) = target.max_overshoot
&& (score - target.target) > t
{
let overshoot = score - target.target - t;
let step = if overshoot > 4.0 {
3
} else if overshoot > 2.0 {
2
} else {
1
};
out[best] = (i32::from(out[best]) + step).clamp(-16, 16) as i8;
}
out
}
fn detect_bucket(
pixels: &[u8],
layout: PixelLayout,
width: u32,
height: u32,
) -> Option<ImageContentType> {
let w = width as usize;
let h = height as usize;
let bpp = match layout {
PixelLayout::Rgb8 => 3usize,
PixelLayout::Rgba8 => 4usize,
_ => return None,
};
if w < 8 || h < 8 || pixels.len() < w * h * bpp {
return None;
}
#[cfg(feature = "analyzer")]
{
let bucket = match layout {
PixelLayout::Rgb8 => crate::encoder::analysis::classify_image_type_rgb8(
&pixels[..w * h * 3],
width,
height,
),
PixelLayout::Rgba8 => {
let rgb = crate::encoder::analysis::rgba8_to_rgb8(&pixels[..w * h * 4]);
crate::encoder::analysis::classify_image_type_rgb8(&rgb, width, height)
}
_ => return None,
};
Some(bucket)
}
#[cfg(not(feature = "analyzer"))]
{
let mut y_plane: Vec<u8> = Vec::with_capacity(w * h);
let mut alpha_hist = [0u32; 256];
match layout {
PixelLayout::Rgb8 => {
for px in pixels.chunks_exact(3).take(w * h) {
let y = ((u32::from(px[0]) * 76
+ u32::from(px[1]) * 150
+ u32::from(px[2]) * 30)
>> 8) as u8;
y_plane.push(y);
alpha_hist[y as usize] += 1;
}
}
PixelLayout::Rgba8 => {
for px in pixels.chunks_exact(4).take(w * h) {
let y = ((u32::from(px[0]) * 76
+ u32::from(px[1]) * 150
+ u32::from(px[2]) * 30)
>> 8) as u8;
y_plane.push(y);
alpha_hist[px[3] as usize] += 1;
}
}
_ => return None,
}
let bucket =
crate::encoder::analysis::classify_image_type(&y_plane, w, h, w, &alpha_hist);
Some(bucket)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn target_default() {
let t = ZensimTarget::default();
assert_eq!(t.target, 80.0);
assert_eq!(t.max_overshoot, Some(1.5));
assert_eq!(t.max_undershoot_ship, Some(0.5));
assert_eq!(t.max_undershoot, None);
assert_eq!(t.max_passes, 2);
}
#[test]
fn target_builder() {
let t = ZensimTarget::new(85.0)
.with_max_overshoot(Some(0.5))
.with_max_undershoot_ship(Some(0.25))
.with_max_undershoot(Some(2.0))
.with_max_passes(3);
assert_eq!(t.target, 85.0);
assert_eq!(t.max_overshoot, Some(0.5));
assert_eq!(t.max_undershoot_ship, Some(0.25));
assert_eq!(t.max_undershoot, Some(2.0));
assert_eq!(t.max_passes, 3);
}
#[test]
fn metrics_no_target() {
let m = ZensimEncodeMetrics::no_target(1234);
assert!(m.achieved_score.is_nan());
assert_eq!(m.passes_used, 1);
assert_eq!(m.bytes, 1234);
assert!(m.targets_met);
}
#[test]
fn calibration_monotonic_per_bucket() {
for &b in &[
ImageContentType::Photo,
ImageContentType::Drawing,
ImageContentType::Icon,
] {
let mut prev = 0.0f32;
for t in (60..=95).step_by(5) {
let q = zensim_to_starting_q_for_bucket(t as f32, b);
assert!(
q >= prev,
"non-monotonic at {b:?} target {t}: {prev} -> {q}"
);
assert!((1.0..=100.0).contains(&q), "{b:?} target {t} q={q}");
prev = q;
}
}
}
#[test]
fn calibration_clamps_at_endpoints() {
let q_low = zensim_to_starting_q_for_bucket(40.0, ImageContentType::Photo);
assert_eq!(q_low, 30.0);
let q_high = zensim_to_starting_q_for_bucket(99.0, ImageContentType::Photo);
assert_eq!(q_high, 100.0);
}
}