use crate::error::ZensimError;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BlurKernel {
Box {
passes: u8,
},
}
impl Default for BlurKernel {
fn default() -> Self {
Self::Box { passes: 1 }
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, Default, PartialEq)]
pub enum DownscaleFilter {
#[default]
Box2x2,
#[cfg(feature = "zenresize")]
#[allow(dead_code)]
Mitchell,
#[cfg(feature = "zenresize")]
#[allow(dead_code)]
Lanczos,
#[cfg(feature = "zenresize")]
#[allow(dead_code)]
MitchellBlur(f32),
}
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub struct ZensimConfig {
pub blur_radius: usize,
pub blur_passes: u8,
#[allow(dead_code)] pub blur_kernel: BlurKernel,
#[allow(dead_code)] pub downscale_filter: DownscaleFilter,
pub compute_all_features: bool,
pub extended_features: bool,
pub extended_masking_strength: f32,
pub num_scales: usize,
pub score_mapping_a: f64,
pub score_mapping_b: f64,
pub allow_multithreading: bool,
}
impl Default for ZensimConfig {
fn default() -> Self {
Self {
blur_radius: 5,
blur_passes: 1,
blur_kernel: BlurKernel::default(),
downscale_filter: DownscaleFilter::default(),
compute_all_features: false,
extended_features: false,
extended_masking_strength: 4.0,
num_scales: crate::NUM_SCALES,
score_mapping_a: 18.0,
score_mapping_b: 0.7,
allow_multithreading: true,
}
}
}
pub(crate) fn distance_to_score(raw_distance: f64) -> f64 {
distance_to_score_mapped(raw_distance, 18.0, 0.7)
}
fn distance_to_score_mapped(raw_distance: f64, a: f64, b: f64) -> f64 {
if raw_distance <= 0.0 {
100.0
} else {
100.0 - a * raw_distance.powf(b)
}
}
#[cfg_attr(not(feature = "training"), allow(dead_code))]
pub fn score_from_features(features: &[f64], weights: &[f64]) -> (f64, f64) {
assert_eq!(
features.len(),
weights.len(),
"features and weights must have same length"
);
let raw_distance: f64 = features
.iter()
.zip(weights.iter())
.map(|(&f, &w)| w * f)
.sum();
let per_scale_candidates = [
FEATURES_PER_CHANNEL_EXTENDED * 3, FEATURES_PER_CHANNEL_WITH_PEAKS * 3, FEATURES_PER_CHANNEL_BASIC * 3, ];
let features_per_scale = per_scale_candidates
.iter()
.copied()
.find(|&ps| ps > 0 && features.len().is_multiple_of(ps))
.unwrap_or(FEATURES_PER_CHANNEL_BASIC * 3);
let n_scales = features.len() / features_per_scale;
let raw_distance = raw_distance / n_scales.max(1) as f64;
(distance_to_score(raw_distance), raw_distance)
}
#[cfg_attr(not(feature = "training"), allow(dead_code))]
pub fn precompute_reference_with_scales(
source: &[[u8; 3]],
width: usize,
height: usize,
num_scales: usize,
) -> Result<crate::streaming::PrecomputedReference, ZensimError> {
if width < 8 || height < 8 {
return Err(ZensimError::ImageTooSmall);
}
if source.len() != width * height {
return Err(ZensimError::InvalidDataLength);
}
let src_img = crate::source::RgbSlice::new(source, width, height);
Ok(crate::streaming::PrecomputedReference::new(
&src_img, num_scales, true,
))
}
#[cfg(feature = "training")]
pub fn compute_zensim_with_ref_and_config(
precomputed: &crate::streaming::PrecomputedReference,
distorted: &[[u8; 3]],
width: usize,
height: usize,
config: ZensimConfig,
) -> Result<ZensimResult, ZensimError> {
if width < 8 || height < 8 {
return Err(ZensimError::ImageTooSmall);
}
if distorted.len() != width * height {
return Err(ZensimError::InvalidDataLength);
}
let dst_img = crate::source::RgbSlice::new(distorted, width, height);
let result = crate::streaming::compute_zensim_streaming_with_ref(
precomputed,
&dst_img,
&config,
WEIGHTS,
);
Ok(result)
}
#[derive(Default)]
pub(crate) struct ScaleStats {
pub(crate) ssim: [f64; 6],
pub(crate) edge: [f64; 12],
pub(crate) mse: [f64; 3],
pub(crate) hf_energy_loss: [f64; 3],
pub(crate) hf_mag_loss: [f64; 3],
pub(crate) ssim_2nd: [f64; 3],
pub(crate) edge_2nd: [f64; 6],
pub(crate) hf_energy_gain: [f64; 3],
pub(crate) masked_ssim: [f64; 9],
pub(crate) masked_art_4th: [f64; 3],
pub(crate) masked_det_4th: [f64; 3],
pub(crate) masked_mse: [f64; 3],
pub(crate) ssim_max: [f64; 3],
pub(crate) art_max: [f64; 3],
pub(crate) det_max: [f64; 3],
pub(crate) ssim_p95: [f64; 3],
pub(crate) art_p95: [f64; 3],
pub(crate) det_p95: [f64; 3],
}
#[derive(Debug, Clone)]
pub struct ZensimResult {
score: f64,
raw_distance: f64,
features: Vec<f64>,
profile: crate::profile::ZensimProfile,
mean_offset: [f64; 3],
}
impl ZensimResult {
pub(crate) fn new(
score: f64,
raw_distance: f64,
features: Vec<f64>,
profile: crate::profile::ZensimProfile,
mean_offset: [f64; 3],
) -> Self {
Self {
score,
raw_distance,
features,
profile,
mean_offset,
}
}
pub(crate) fn with_profile(mut self, profile: crate::profile::ZensimProfile) -> Self {
self.profile = profile;
self
}
pub fn nan() -> Self {
Self {
score: f64::NAN,
raw_distance: f64::NAN,
features: vec![],
profile: crate::profile::ZensimProfile::PreviewV0_1,
mean_offset: [f64::NAN; 3],
}
}
pub fn score(&self) -> f64 {
self.score
}
pub fn raw_distance(&self) -> f64 {
self.raw_distance
}
pub fn features(&self) -> &[f64] {
&self.features
}
pub fn into_features(self) -> Vec<f64> {
self.features
}
pub fn profile(&self) -> crate::profile::ZensimProfile {
self.profile
}
pub fn mean_offset(&self) -> [f64; 3] {
self.mean_offset
}
pub fn dissimilarity(&self) -> f64 {
score_to_dissimilarity(self.score)
}
pub fn approx_ssim2(&self) -> f64 {
if self.raw_distance <= 0.0 {
return 100.0;
}
(100.0 - 19.0379 * self.raw_distance.powf(0.5979)).max(-100.0)
}
pub fn approx_dssim(&self) -> f64 {
if self.raw_distance <= 0.0 {
return 0.0;
}
0.000922 * self.raw_distance.powf(1.2244)
}
pub fn approx_butteraugli(&self) -> f64 {
if self.raw_distance <= 0.0 {
return 0.0;
}
2.365353 * self.raw_distance.powf(0.6130)
}
}
pub fn score_to_dissimilarity(score: f64) -> f64 {
(100.0 - score) / 100.0
}
pub fn dissimilarity_to_score(dissimilarity: f64) -> f64 {
(100.0 * (1.0 - dissimilarity)).clamp(0.0, 100.0)
}
#[cfg(feature = "classification")]
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorCategory {
Identical,
RoundingError,
ChannelSwap,
AlphaCompositing,
Unclassified,
}
#[cfg(feature = "classification")]
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct ErrorClassification {
pub dominant: ErrorCategory,
pub confidence: f64,
pub rounding_bias: Option<RoundingBias>,
}
#[cfg(feature = "classification")]
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct RoundingBias {
pub positive_fraction: [f64; 3],
pub balanced: bool,
}
#[cfg(feature = "classification")]
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct DeltaStats {
pub mean_delta: [f64; 3],
pub stddev_delta: [f64; 3],
pub max_abs_delta: [f64; 3],
pub signed_small_histogram: [[u64; 7]; 3],
pub native_max: f64,
pub pixel_count: u64,
pub pixels_differing: u64,
pub pixels_differing_by_more_than_1: u64,
pub has_alpha: bool,
pub alpha_max_delta: u8,
pub alpha_pixels_differing: u64,
pub src_histogram: [[u64; 256]; 4],
pub dst_histogram: [[u64; 256]; 4],
pub opaque_stats: Option<AlphaStratifiedStats>,
pub semitransparent_stats: Option<AlphaStratifiedStats>,
pub alpha_error_correlation: Option<f64>,
}
#[cfg(feature = "classification")]
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct AlphaStratifiedStats {
pub pixel_count: u64,
pub mean_abs_delta: [f64; 3],
pub max_abs_delta: [f64; 3],
}
#[cfg(feature = "classification")]
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct ClassifiedResult {
pub result: ZensimResult,
pub classification: ErrorClassification,
pub delta_stats: DeltaStats,
}
use crate::profile::{ProfileParams, ZensimProfile};
use crate::source::ImageSource;
#[derive(Clone, Debug)]
pub struct Zensim {
profile: ZensimProfile,
parallel: bool,
}
impl Zensim {
pub fn new(profile: ZensimProfile) -> Self {
Self {
profile,
parallel: true,
}
}
pub fn with_parallel(mut self, parallel: bool) -> Self {
self.parallel = parallel;
self
}
pub fn profile(&self) -> ZensimProfile {
self.profile
}
pub fn parallel(&self) -> bool {
self.parallel
}
pub fn compute(
&self,
source: &impl ImageSource,
distorted: &impl ImageSource,
) -> Result<ZensimResult, ZensimError> {
let params = self.profile.params();
validate_pair(source, distorted)?;
let config = config_from_params(params, self.parallel);
let result = compute_with_config_inner(source, distorted, &config, params.weights);
Ok(result.with_profile(self.profile))
}
pub fn precompute_reference(
&self,
source: &impl ImageSource,
) -> Result<crate::streaming::PrecomputedReference, ZensimError> {
let params = self.profile.params();
if source.width() < 8 || source.height() < 8 {
return Err(ZensimError::ImageTooSmall);
}
Ok(crate::streaming::PrecomputedReference::new(
source,
params.num_scales,
self.parallel,
))
}
pub fn compute_with_ref(
&self,
precomputed: &crate::streaming::PrecomputedReference,
distorted: &impl ImageSource,
) -> Result<ZensimResult, ZensimError> {
let params = self.profile.params();
if distorted.width() < 8 || distorted.height() < 8 {
return Err(ZensimError::ImageTooSmall);
}
let config = config_from_params(params, self.parallel);
let result = crate::streaming::compute_zensim_streaming_with_ref(
precomputed,
distorted,
&config,
params.weights,
);
Ok(result.with_profile(self.profile))
}
pub fn precompute_reference_linear_planar(
&self,
planes: [&[f32]; 3],
width: usize,
height: usize,
stride: usize,
) -> Result<crate::streaming::PrecomputedReference, ZensimError> {
let params = self.profile.params();
if width < 8 || height < 8 {
return Err(ZensimError::ImageTooSmall);
}
Ok(crate::streaming::PrecomputedReference::from_linear_planar(
planes,
width,
height,
stride,
params.num_scales,
self.parallel,
))
}
#[cfg(feature = "training")]
pub fn compute_all_features(
&self,
source: &impl ImageSource,
distorted: &impl ImageSource,
) -> Result<ZensimResult, ZensimError> {
let params = self.profile.params();
validate_pair(source, distorted)?;
let mut config = config_from_params(params, self.parallel);
config.compute_all_features = true;
let result = compute_with_config_inner(source, distorted, &config, params.weights);
Ok(result.with_profile(self.profile))
}
}
#[cfg(feature = "classification")]
impl Zensim {
pub fn classify(
&self,
source: &impl ImageSource,
distorted: &impl ImageSource,
) -> Result<ClassifiedResult, ZensimError> {
validate_pair(source, distorted)?;
let delta_stats = crate::streaming::compute_delta_stats(source, distorted);
let result = self.compute(source, distorted)?;
let classification = derive_classification(&delta_stats, &result);
Ok(ClassifiedResult {
result,
classification,
delta_stats,
})
}
}
#[cfg(feature = "training")]
impl Zensim {
pub fn compute_with_params(
params: &ProfileParams,
source: &impl ImageSource,
distorted: &impl ImageSource,
) -> Result<ZensimResult, ZensimError> {
validate_pair(source, distorted)?;
let config = config_from_params(params, true);
let result = compute_with_config_inner(source, distorted, &config, params.weights);
Ok(result)
}
}
#[cfg(feature = "classification")]
fn derive_classification(delta_stats: &DeltaStats, _result: &ZensimResult) -> ErrorClassification {
let mut rounding_bias: Option<RoundingBias> = None;
let mut score_rounding = 0.0f64;
let mut score_swap = 0.0f64;
let mut score_alpha = 0.0f64;
if delta_stats.pixels_differing == 0 {
return ErrorClassification {
dominant: ErrorCategory::Identical,
confidence: 1.0,
rounding_bias: None,
};
}
let max_delta = delta_stats
.max_abs_delta
.iter()
.copied()
.fold(0.0f64, f64::max);
if delta_stats.pixels_differing_by_more_than_1 == 0 {
score_rounding = 1.0;
} else if max_delta <= 2.0 / 255.0 {
score_rounding = 0.95;
} else if max_delta <= 3.0 / 255.0 {
score_rounding = 0.9;
}
let mut zero_channels = 0u32;
let mut hot_channels = 0u32;
for ch in 0..3 {
if delta_stats.max_abs_delta[ch] < 1.0 / 255.0 {
zero_channels += 1;
}
if delta_stats.max_abs_delta[ch] > 0.1 {
hot_channels += 1;
}
}
if zero_channels == 1 && hot_channels >= 1 && max_delta > 0.05 {
score_swap = 0.9;
}
if let Some(ref opaque) = delta_stats.opaque_stats
&& let Some(ref semi) = delta_stats.semitransparent_stats
{
let opaque_max = opaque.mean_abs_delta.iter().copied().fold(0.0f64, f64::max);
let semi_mean = semi.mean_abs_delta.iter().copied().fold(0.0f64, f64::max);
if opaque_max < 0.005 && semi_mean > 0.02 && semi.pixel_count > 100 {
score_alpha = 0.9;
}
}
if let Some(corr) = delta_stats.alpha_error_correlation
&& corr > 0.8
{
score_alpha = score_alpha.max(corr);
}
let scores = [
(ErrorCategory::RoundingError, score_rounding),
(ErrorCategory::ChannelSwap, score_swap),
(ErrorCategory::AlphaCompositing, score_alpha),
];
let (best_cat, best_score) = scores
.iter()
.copied()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap();
let (dominant, confidence) = if best_score > 0.0 {
(best_cat, best_score)
} else {
(ErrorCategory::Unclassified, 0.0)
};
if dominant == ErrorCategory::RoundingError {
rounding_bias = Some(compute_rounding_bias(delta_stats));
}
ErrorClassification {
dominant,
confidence,
rounding_bias,
}
}
#[cfg(feature = "classification")]
fn compute_rounding_bias(delta_stats: &DeltaStats) -> RoundingBias {
let h = &delta_stats.signed_small_histogram;
let mut positive_fraction = [0.5f64; 3];
let mut all_balanced = true;
for ch in 0..3 {
let neg = h[ch][0] + h[ch][1] + h[ch][2]; let pos = h[ch][4] + h[ch][5] + h[ch][6]; let total_nonzero = neg + pos;
if total_nonzero == 0 {
positive_fraction[ch] = 0.5;
continue;
}
positive_fraction[ch] = pos as f64 / total_nonzero as f64;
let n = total_nonzero as f64;
let expected_std = 0.5 / n.sqrt();
let deviation = (positive_fraction[ch] - 0.5).abs();
if deviation > 3.0 * expected_std && deviation > 0.1 {
all_balanced = false;
}
}
RoundingBias {
positive_fraction,
balanced: all_balanced,
}
}
pub(crate) fn validate_pair(
source: &impl ImageSource,
distorted: &impl ImageSource,
) -> Result<(), ZensimError> {
if source.width() < 8 || source.height() < 8 {
return Err(ZensimError::ImageTooSmall);
}
if source.width() != distorted.width() || source.height() != distorted.height() {
return Err(ZensimError::DimensionMismatch);
}
Ok(())
}
fn images_byte_identical(source: &impl ImageSource, distorted: &impl ImageSource) -> bool {
use crate::source::{AlphaMode, PixelFormat};
let (w, h) = (source.width(), source.height());
if w != distorted.width() || h != distorted.height() {
return false;
}
if source.pixel_format() != distorted.pixel_format() {
return false;
}
if source.color_primaries() != distorted.color_primaries() {
return false;
}
let fmt = source.pixel_format();
let bpp = fmt.bytes_per_pixel();
let row_len = w * bpp;
let alpha_aware = fmt.has_alpha()
&& !matches!(source.alpha_mode(), AlphaMode::Opaque)
&& !matches!(distorted.alpha_mode(), AlphaMode::Opaque);
for y in 0..h {
let sr = source.row_bytes(y);
let dr = distorted.row_bytes(y);
if sr[..row_len] == dr[..row_len] {
continue; }
if !alpha_aware {
return false;
}
match fmt {
PixelFormat::Srgb8Rgba | PixelFormat::Srgb8Bgra => {
for x in 0..w {
let o = x * 4;
if sr[o + 3] == 0 && dr[o + 3] == 0 {
continue;
}
if sr[o..o + 4] != dr[o..o + 4] {
return false;
}
}
}
PixelFormat::Srgb16Rgba => {
for x in 0..w {
let o = x * 8;
let sa = u16::from_ne_bytes([sr[o + 6], sr[o + 7]]);
let da = u16::from_ne_bytes([dr[o + 6], dr[o + 7]]);
if sa == 0 && da == 0 {
continue;
}
if sr[o..o + 8] != dr[o..o + 8] {
return false;
}
}
}
PixelFormat::LinearF32Rgba => {
for x in 0..w {
let o = x * 16;
let sa = f32::from_ne_bytes([sr[o + 12], sr[o + 13], sr[o + 14], sr[o + 15]]);
let da = f32::from_ne_bytes([dr[o + 12], dr[o + 13], dr[o + 14], dr[o + 15]]);
if sa <= 0.0 && da <= 0.0 {
continue;
}
if sr[o..o + 16] != dr[o..o + 16] {
return false;
}
}
}
_ => return false,
}
}
true
}
fn compute_with_config_inner(
source: &impl ImageSource,
distorted: &impl ImageSource,
config: &ZensimConfig,
weights: &[f64],
) -> ZensimResult {
if images_byte_identical(source, distorted) {
let fpc = if config.extended_features {
FEATURES_PER_CHANNEL_EXTENDED
} else {
FEATURES_PER_CHANNEL_WITH_PEAKS
};
let num_features = config.num_scales * 3 * fpc;
return ZensimResult::new(
100.0,
0.0,
vec![0.0; num_features],
ZensimProfile::latest(),
[0.0; 3],
);
}
crate::streaming::compute_zensim_streaming(source, distorted, config, weights)
}
pub(crate) fn config_from_params(params: &ProfileParams, parallel: bool) -> ZensimConfig {
ZensimConfig {
blur_radius: params.blur_radius,
blur_passes: params.blur_passes,
blur_kernel: BlurKernel::Box {
passes: params.blur_passes,
},
downscale_filter: DownscaleFilter::default(),
compute_all_features: false,
extended_features: false,
extended_masking_strength: 4.0,
num_scales: params.num_scales,
score_mapping_a: params.score_mapping_a,
score_mapping_b: params.score_mapping_b,
allow_multithreading: parallel,
}
}
pub const FEATURES_PER_CHANNEL_BASIC: usize = 13;
pub const FEATURES_PER_CHANNEL_WITH_PEAKS: usize = 19;
pub const FEATURES_PER_CHANNEL_EXTENDED: usize = 25;
#[derive(Debug, Clone)]
pub struct FeatureView<'a> {
features: &'a [f64],
n_scales: usize,
scored_total: usize,
peaks_total: usize,
}
#[cfg(feature = "training")]
pub const CH_X: usize = 0;
#[cfg(feature = "training")]
pub const CH_Y: usize = 1;
#[cfg(feature = "training")]
pub const CH_B: usize = 2;
impl<'a> FeatureView<'a> {
pub fn new(features: &'a [f64], n_scales: usize) -> Option<Self> {
let basic_total = n_scales * 3 * FEATURES_PER_CHANNEL_BASIC;
let peaks_total = n_scales * 3 * 6;
let masked_total = n_scales * 3 * 6;
let (scored_total, peaks_total) = if features.len() == basic_total {
(basic_total, 0)
} else if features.len() == basic_total + peaks_total
|| features.len() == basic_total + peaks_total + masked_total
{
(basic_total, peaks_total)
} else {
return None;
};
Some(Self {
features,
n_scales,
scored_total,
peaks_total,
})
}
pub fn n_scales(&self) -> usize {
self.n_scales
}
pub fn has_peaks(&self) -> bool {
self.peaks_total > 0
}
pub fn has_masked(&self) -> bool {
self.features.len() > self.scored_total + self.peaks_total
}
fn scored_idx(&self, scale: usize, ch: usize, offset: usize) -> usize {
scale * 3 * FEATURES_PER_CHANNEL_BASIC + ch * FEATURES_PER_CHANNEL_BASIC + offset
}
pub fn ssim_mean(&self, scale: usize, ch: usize) -> f64 {
self.features[self.scored_idx(scale, ch, 0)]
}
pub fn ssim_4th(&self, scale: usize, ch: usize) -> f64 {
self.features[self.scored_idx(scale, ch, 1)]
}
pub fn ssim_2nd(&self, scale: usize, ch: usize) -> f64 {
self.features[self.scored_idx(scale, ch, 2)]
}
pub fn art_mean(&self, scale: usize, ch: usize) -> f64 {
self.features[self.scored_idx(scale, ch, 3)]
}
pub fn art_4th(&self, scale: usize, ch: usize) -> f64 {
self.features[self.scored_idx(scale, ch, 4)]
}
pub fn art_2nd(&self, scale: usize, ch: usize) -> f64 {
self.features[self.scored_idx(scale, ch, 5)]
}
pub fn det_mean(&self, scale: usize, ch: usize) -> f64 {
self.features[self.scored_idx(scale, ch, 6)]
}
pub fn det_4th(&self, scale: usize, ch: usize) -> f64 {
self.features[self.scored_idx(scale, ch, 7)]
}
pub fn det_2nd(&self, scale: usize, ch: usize) -> f64 {
self.features[self.scored_idx(scale, ch, 8)]
}
pub fn mse(&self, scale: usize, ch: usize) -> f64 {
self.features[self.scored_idx(scale, ch, 9)]
}
pub fn hf_energy_loss(&self, scale: usize, ch: usize) -> f64 {
self.features[self.scored_idx(scale, ch, 10)]
}
pub fn hf_mag_loss(&self, scale: usize, ch: usize) -> f64 {
self.features[self.scored_idx(scale, ch, 11)]
}
pub fn hf_energy_gain(&self, scale: usize, ch: usize) -> f64 {
self.features[self.scored_idx(scale, ch, 12)]
}
fn peak_idx(&self, scale: usize, ch: usize, offset: usize) -> Option<usize> {
if self.peaks_total == 0 {
return None;
}
Some(self.scored_total + scale * 3 * 6 + ch * 6 + offset)
}
pub fn ssim_max(&self, scale: usize, ch: usize) -> Option<f64> {
self.peak_idx(scale, ch, 0).map(|i| self.features[i])
}
pub fn art_max(&self, scale: usize, ch: usize) -> Option<f64> {
self.peak_idx(scale, ch, 1).map(|i| self.features[i])
}
pub fn det_max(&self, scale: usize, ch: usize) -> Option<f64> {
self.peak_idx(scale, ch, 2).map(|i| self.features[i])
}
pub fn ssim_l8(&self, scale: usize, ch: usize) -> Option<f64> {
self.peak_idx(scale, ch, 3).map(|i| self.features[i])
}
pub fn art_l8(&self, scale: usize, ch: usize) -> Option<f64> {
self.peak_idx(scale, ch, 4).map(|i| self.features[i])
}
pub fn det_l8(&self, scale: usize, ch: usize) -> Option<f64> {
self.peak_idx(scale, ch, 5).map(|i| self.features[i])
}
fn masked_idx(&self, scale: usize, ch: usize, offset: usize) -> Option<usize> {
if !self.has_masked() {
return None;
}
Some(self.scored_total + self.peaks_total + scale * 3 * 6 + ch * 6 + offset)
}
pub fn masked_ssim_mean(&self, scale: usize, ch: usize) -> Option<f64> {
self.masked_idx(scale, ch, 0).map(|i| self.features[i])
}
pub fn masked_ssim_4th(&self, scale: usize, ch: usize) -> Option<f64> {
self.masked_idx(scale, ch, 1).map(|i| self.features[i])
}
pub fn masked_ssim_2nd(&self, scale: usize, ch: usize) -> Option<f64> {
self.masked_idx(scale, ch, 2).map(|i| self.features[i])
}
pub fn masked_art_4th(&self, scale: usize, ch: usize) -> Option<f64> {
self.masked_idx(scale, ch, 3).map(|i| self.features[i])
}
pub fn masked_det_4th(&self, scale: usize, ch: usize) -> Option<f64> {
self.masked_idx(scale, ch, 4).map(|i| self.features[i])
}
pub fn masked_mse(&self, scale: usize, ch: usize) -> Option<f64> {
self.masked_idx(scale, ch, 5).map(|i| self.features[i])
}
pub fn scored_features(&self) -> &[f64] {
&self.features[..self.scored_total]
}
pub fn peak_features(&self) -> Option<&[f64]> {
if self.peaks_total == 0 {
None
} else {
Some(&self.features[self.scored_total..self.scored_total + self.peaks_total])
}
}
pub fn masked_features(&self) -> Option<&[f64]> {
if !self.has_masked() {
None
} else {
Some(&self.features[self.scored_total + self.peaks_total..])
}
}
}
#[cfg(any(feature = "training", test))]
pub fn compute_zensim_with_config(
source: &[[u8; 3]],
distorted: &[[u8; 3]],
width: usize,
height: usize,
config: ZensimConfig,
) -> Result<ZensimResult, ZensimError> {
if width < 8 || height < 8 {
return Err(ZensimError::ImageTooSmall);
}
if source.len() != width * height {
return Err(ZensimError::InvalidDataLength);
}
if distorted.len() != width * height {
return Err(ZensimError::InvalidDataLength);
}
if source.len() != distorted.len() {
return Err(ZensimError::DimensionMismatch);
}
if source == distorted {
let fpc = if config.extended_features {
FEATURES_PER_CHANNEL_EXTENDED
} else {
FEATURES_PER_CHANNEL_WITH_PEAKS
};
let num_features = config.num_scales * 3 * fpc;
return Ok(ZensimResult::new(
100.0,
0.0,
vec![0.0; num_features],
ZensimProfile::latest(),
[0.0; 3],
));
}
let src_img = crate::source::RgbSlice::new(source, width, height);
let dst_img = crate::source::RgbSlice::new(distorted, width, height);
let result = crate::streaming::compute_zensim_streaming(&src_img, &dst_img, &config, WEIGHTS);
Ok(result)
}
#[cfg_attr(not(feature = "training"), allow(dead_code))]
pub const FEATURES_PER_SCALE: usize = FEATURES_PER_CHANNEL_WITH_PEAKS * 3;
#[cfg(any(feature = "training", test))]
pub const WEIGHTS: &[f64; 228] = &crate::profile::WEIGHTS_PREVIEW_V0_2;
pub(crate) fn combine_scores(
scale_stats: &[ScaleStats],
weights: &[f64],
config: &ZensimConfig,
mean_offset: [f64; 3],
) -> ZensimResult {
let extended = config.extended_features;
let n_scales = scale_stats.len();
let basic_per_ch = FEATURES_PER_CHANNEL_BASIC; let basic_total = n_scales * basic_per_ch * 3;
let peak_total = n_scales * 6 * 3;
let masked_total = if extended { n_scales * 6 * 3 } else { 0 };
let total = basic_total + peak_total + masked_total;
let mut features = Vec::with_capacity(total);
let mut raw_distance = 0.0f64;
for ss in scale_stats.iter() {
for c in 0..3 {
features.push(ss.ssim[c * 2].abs());
features.push(ss.ssim[c * 2 + 1].abs());
features.push(ss.ssim_2nd[c].abs());
features.push(ss.edge[c * 4].abs());
features.push(ss.edge[c * 4 + 1].abs());
features.push(ss.edge_2nd[c * 2].abs());
features.push(ss.edge[c * 4 + 2].abs());
features.push(ss.edge[c * 4 + 3].abs());
features.push(ss.edge_2nd[c * 2 + 1].abs());
features.push(ss.mse[c]);
features.push(ss.hf_energy_loss[c]);
features.push(ss.hf_mag_loss[c]);
features.push(ss.hf_energy_gain[c]);
}
}
for ss in scale_stats.iter() {
for c in 0..3 {
features.push(ss.ssim_max[c]);
features.push(ss.art_max[c]);
features.push(ss.det_max[c]);
features.push(ss.ssim_p95[c]);
features.push(ss.art_p95[c]);
features.push(ss.det_p95[c]);
}
}
if extended {
for ss in scale_stats.iter() {
for c in 0..3 {
features.push(ss.masked_ssim[c * 3].abs());
features.push(ss.masked_ssim[c * 3 + 1].abs());
features.push(ss.masked_ssim[c * 3 + 2].abs());
features.push(ss.masked_art_4th[c].abs());
features.push(ss.masked_det_4th[c].abs());
features.push(ss.masked_mse[c]);
}
}
}
let scored_total = basic_total + peak_total;
let n_score = scored_total.min(weights.len());
for (i, &feat) in features[..n_score].iter().enumerate() {
raw_distance += feat * weights[i];
}
raw_distance /= scale_stats.len().max(1) as f64;
let score =
distance_to_score_mapped(raw_distance, config.score_mapping_a, config.score_mapping_b);
ZensimResult::new(
score,
raw_distance,
features,
ZensimProfile::PreviewV0_1,
mean_offset,
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn compute_all_matches_default() {
let w = 128;
let h = 128;
let n = w * h;
let mut src = vec![[128u8, 128, 128]; n];
let mut dst = vec![[128u8, 128, 128]; n];
for y in 0..h {
for x in 0..w {
let r = ((x * 255) / w) as u8;
let g = ((y * 255) / h) as u8;
let b = 128;
src[y * w + x] = [r, g, b];
dst[y * w + x] = [r.saturating_add(5), g, b.saturating_sub(3)];
}
}
let default_result =
compute_zensim_with_config(&src, &dst, w, h, ZensimConfig::default()).unwrap();
let all_result = compute_zensim_with_config(
&src,
&dst,
w,
h,
ZensimConfig {
compute_all_features: true,
..Default::default()
},
)
.unwrap();
assert!(
(default_result.score - all_result.score).abs() < 0.01,
"default {} vs all_features {}",
default_result.score,
all_result.score,
);
assert_eq!(all_result.features.len(), 228);
assert_eq!(default_result.features.len(), 228);
let all_nonzero = all_result
.features
.iter()
.filter(|f| f.abs() > 1e-12)
.count();
let default_nonzero = default_result
.features
.iter()
.filter(|f| f.abs() > 1e-12)
.count();
assert!(
all_nonzero >= default_nonzero,
"compute_all should have >= features: {} vs {}",
all_nonzero,
default_nonzero,
);
}
fn make_gradient_pair(w: usize, h: usize) -> (Vec<[u8; 3]>, Vec<[u8; 3]>) {
let n = w * h;
let mut src = vec![[128u8, 128, 128]; n];
let mut dst = vec![[128u8, 128, 128]; n];
for y in 0..h {
for x in 0..w {
let r = ((x * 255) / w) as u8;
let g = ((y * 255) / h) as u8;
let b = 128;
src[y * w + x] = [r, g, b];
dst[y * w + x] = [
r.saturating_add(10),
g.saturating_sub(5),
b.saturating_add(3),
];
}
}
(src, dst)
}
#[test]
fn extended_features_backward_compat() {
let (w, h) = (64, 64);
let (src, dst) = make_gradient_pair(w, h);
let basic = compute_zensim_with_config(&src, &dst, w, h, ZensimConfig::default()).unwrap();
let extended = compute_zensim_with_config(
&src,
&dst,
w,
h,
ZensimConfig {
extended_features: false,
compute_all_features: true,
..Default::default()
},
)
.unwrap();
assert_eq!(basic.features.len(), 228);
assert_eq!(extended.features.len(), 228);
assert!(
(basic.score - extended.score).abs() < 0.01,
"basic {} vs compute_all {}",
basic.score,
extended.score,
);
}
#[test]
fn extended_features_count_and_nonneg() {
let (w, h) = (64, 64);
let (src, dst) = make_gradient_pair(w, h);
let result = compute_zensim_with_config(
&src,
&dst,
w,
h,
ZensimConfig {
extended_features: true,
compute_all_features: true,
..Default::default()
},
)
.unwrap();
assert_eq!(
result.features.len(),
300,
"Expected 25 × 3 × 4 = 300 features"
);
for (i, &f) in result.features.iter().enumerate() {
assert!(f >= 0.0, "Feature {} is negative: {}", i, f);
}
}
#[test]
fn extended_features_ordering() {
let (w, h) = (64, 64);
let (src, dst) = make_gradient_pair(w, h);
let result = compute_zensim_with_config(
&src,
&dst,
w,
h,
ZensimConfig {
extended_features: true,
compute_all_features: true,
..Default::default()
},
)
.unwrap();
let scored_per_ch = FEATURES_PER_CHANNEL_BASIC; let peaks_offset = 4 * scored_per_ch * 3; let peaks_per_ch = 6;
for scale in 0..4 {
for ch in 0..3 {
let scored_base = scale * scored_per_ch * 3 + ch * scored_per_ch;
let peaks_base = peaks_offset + scale * peaks_per_ch * 3 + ch * peaks_per_ch;
let ssim_mean = result.features[scored_base]; let ssim_4th = result.features[scored_base + 1]; let ssim_max = result.features[peaks_base]; let ssim_p95 = result.features[peaks_base + 3];
assert!(
ssim_max >= ssim_4th - 1e-10,
"s{} c{}: max {:.6} < 4th {:.6}",
scale,
ch,
ssim_max,
ssim_4th,
);
assert!(
ssim_4th >= ssim_mean - 1e-10,
"s{} c{}: 4th {:.6} < mean {:.6}",
scale,
ch,
ssim_4th,
ssim_mean,
);
assert!(
ssim_p95 <= ssim_max + 1e-10,
"s{} c{}: p95 {:.6} > max {:.6}",
scale,
ch,
ssim_p95,
ssim_max,
);
}
}
}
#[test]
fn extended_features_identical_zero() {
let (w, h) = (64, 64);
let (src, _) = make_gradient_pair(w, h);
let result = compute_zensim_with_config(
&src,
&src,
w,
h,
ZensimConfig {
extended_features: true,
compute_all_features: true,
..Default::default()
},
)
.unwrap();
assert_eq!(result.score, 100.0);
assert_eq!(result.features.len(), 300);
for (i, &f) in result.features.iter().enumerate() {
assert!(
f.abs() < 1e-10,
"Feature {} not zero for identical: {}",
i,
f
);
}
}
#[test]
fn extended_masked_leq_unmasked() {
let (w, h) = (64, 64);
let (src, dst) = make_gradient_pair(w, h);
let result = compute_zensim_with_config(
&src,
&dst,
w,
h,
ZensimConfig {
extended_features: true,
compute_all_features: true,
..Default::default()
},
)
.unwrap();
let scored_per_ch = FEATURES_PER_CHANNEL_BASIC; let masked_offset = 4 * scored_per_ch * 3 + 4 * 6 * 3; let masked_per_ch = 6;
for scale in 0..4 {
for ch in 0..3 {
let scored_base = scale * scored_per_ch * 3 + ch * scored_per_ch;
let masked_base = masked_offset + scale * masked_per_ch * 3 + ch * masked_per_ch;
let ssim_mean = result.features[scored_base]; let ssim_4th = result.features[scored_base + 1]; let ssim_2nd = result.features[scored_base + 2]; let masked_ssim_mean = result.features[masked_base]; let masked_ssim_4th = result.features[masked_base + 1]; let masked_ssim_2nd = result.features[masked_base + 2];
assert!(
masked_ssim_mean <= ssim_mean + 1e-10,
"s{} c{}: masked_mean {:.6} > mean {:.6}",
scale,
ch,
masked_ssim_mean,
ssim_mean,
);
assert!(
masked_ssim_4th <= ssim_4th + 1e-10,
"s{} c{}: masked_4th {:.6} > 4th {:.6}",
scale,
ch,
masked_ssim_4th,
ssim_4th,
);
assert!(
masked_ssim_2nd <= ssim_2nd + 1e-10,
"s{} c{}: masked_2nd {:.6} > 2nd {:.6}",
scale,
ch,
masked_ssim_2nd,
ssim_2nd,
);
}
}
}
}