zenavif 0.1.6

Pure Rust AVIF image codec powered by rav1d and zenravif
Documentation
//! Alpha channel handling, premultiply conversion, and bit depth scaling

use crate::error::{Error, Result};
use crate::image::ColorRange;
use rgb::prelude::*;
use rgb::{Rgb, Rgba};
use whereat::at;
use zenpixels::{PixelBuffer, PixelDescriptor};

/// Scale a limited-range Y value to full range (8-bit)
#[inline]
fn limited_to_full_8(y: u8) -> u8 {
    // Limited range: Y ∈ [16, 235]
    // Full range: Y ∈ [0, 255]
    // Use i32 to avoid i16 overflow: (235-16)*255 = 55845 > i16::MAX
    let y = y as i32;
    ((y - 16).max(0) * 255 / 219).min(255) as u8
}

/// Scale a limited-range Y value to full range (16-bit, given bit depth)
#[inline]
fn limited_to_full_16(y: u16, bit_depth: u8) -> u16 {
    let max_val = (1u32 << bit_depth) - 1;
    let y_min = 16u32 << (bit_depth - 8);
    let y_range = 219u32 << (bit_depth - 8);
    let y32 = y as u32;
    ((y32.saturating_sub(y_min)) * max_val / y_range).min(max_val) as u16
}

/// Scale a value from native bit depth to full u16 range using LSB replication.
///
/// For 10-bit: `(v << 6) | (v >> 4)` maps 0→0, 1023→65535
/// For 12-bit: `(v << 4) | (v >> 8)` maps 0→0, 4095→65535
/// For 16-bit: no-op
#[inline]
fn scale_to_u16(v: u16, bit_depth: u8) -> u16 {
    let shift = 16 - bit_depth;
    if shift == 0 {
        return v;
    }
    // LSB replication: fill lower bits with copies of upper bits
    // This gives exact mapping: 0→0, max→65535
    (v << shift) | (v >> (bit_depth - shift))
}

/// Downscale a 16-bit PixelBuffer to 8-bit by taking the high byte of each channel.
///
/// Converts Rgb16 → Rgb8 and Rgba16 → Rgba8 in-place (reallocates to a new buffer).
/// Values are assumed to be in full u16 range (0–65535) after `scale_pixels_to_u16`.
pub fn downscale_to_8bit(image: PixelBuffer) -> PixelBuffer {
    let desc = image.descriptor();
    let w = image.width();
    let h = image.height();
    if desc.layout_compatible(PixelDescriptor::RGB16) {
        let src = image.try_as_imgref::<Rgb<u16>>().unwrap();
        let out: Vec<Rgb<u8>> = src
            .pixels()
            .map(|px| Rgb {
                r: (px.r >> 8) as u8,
                g: (px.g >> 8) as u8,
                b: (px.b >> 8) as u8,
            })
            .collect();
        PixelBuffer::from_pixels(out, w, h)
            .expect("allocation should succeed for same dimensions")
            .into()
    } else if desc.layout_compatible(PixelDescriptor::RGBA16) {
        let src = image.try_as_imgref::<Rgba<u16>>().unwrap();
        let out: Vec<Rgba<u8>> = src
            .pixels()
            .map(|px| Rgba {
                r: (px.r >> 8) as u8,
                g: (px.g >> 8) as u8,
                b: (px.b >> 8) as u8,
                a: (px.a >> 8) as u8,
            })
            .collect();
        PixelBuffer::from_pixels(out, w, h)
            .expect("allocation should succeed for same dimensions")
            .into()
    } else {
        image
    }
}

/// Scale all channels in a 16-bit PixelBuffer from native bit depth to full u16 range.
///
/// This converts e.g. 10-bit values (0–1023) to full 16-bit (0–65535) using
/// LSB replication for exact endpoint mapping.
pub fn scale_pixels_to_u16(image: &mut PixelBuffer, bit_depth: u8) {
    if bit_depth >= 16 {
        return;
    }
    let desc = image.descriptor();
    if desc.layout_compatible(PixelDescriptor::RGB16) {
        let mut img = image.try_as_imgref_mut::<Rgb<u16>>().unwrap();
        for px in img.buf_mut().iter_mut() {
            *px = Rgb {
                r: scale_to_u16(px.r, bit_depth),
                g: scale_to_u16(px.g, bit_depth),
                b: scale_to_u16(px.b, bit_depth),
            };
        }
    } else if desc.layout_compatible(PixelDescriptor::RGBA16) {
        let mut img = image.try_as_imgref_mut::<Rgba<u16>>().unwrap();
        for px in img.buf_mut().iter_mut() {
            *px = Rgba {
                r: scale_to_u16(px.r, bit_depth),
                g: scale_to_u16(px.g, bit_depth),
                b: scale_to_u16(px.b, bit_depth),
                a: scale_to_u16(px.a, bit_depth),
            };
        }
    }
}

/// Scale a full u16 value (0–65535) down to native bit depth range.
///
/// For 10-bit: `v >> 6` maps 0→0, 65535→1023
/// For 12-bit: `v >> 4` maps 0→0, 65535→4095
///
/// Uses truncation (top-bit extraction), which is the exact inverse of
/// LSB replication in `scale_to_u16`. This gives lossless roundtrip for
/// values produced by LSB replication, symmetric bias for arbitrary
/// inputs, and lower max error than half-up rounding (63 vs 95 for 10-bit).
#[cfg(feature = "encode")]
#[inline]
pub fn scale_from_u16(v: u16, bit_depth: u8) -> u16 {
    let shift = 16 - bit_depth;
    if shift == 0 {
        return v;
    }
    v >> shift
}

/// Add 8-bit alpha channel to an image from Y plane data
pub fn add_alpha8<'a>(
    buf: &mut PixelBuffer,
    alpha_rows: impl Iterator<Item = &'a [u8]>,
    width: usize,
    height: usize,
    alpha_range: ColorRange,
    premultiplied: bool,
) -> Result<()> {
    let mut img = buf.try_as_imgref_mut::<Rgba<u8>>().ok_or_else(|| {
        at!(Error::Unsupported(
            "cannot add 8-bit alpha to this image type",
        ))
    })?;

    if img.width() != width || img.height() != height {
        return Err(at!(Error::Unsupported("alpha size mismatch")));
    }

    for (alpha_row, img_row) in alpha_rows.zip(img.rows_mut()) {
        if alpha_row.len() < img_row.len() {
            return Err(at!(Error::Unsupported("alpha width mismatch")));
        }
        for (&y, px) in alpha_row.iter().zip(img_row.iter_mut()) {
            px.a = match alpha_range {
                ColorRange::Full => y,
                ColorRange::Limited => limited_to_full_8(y),
            };
        }
        if premultiplied {
            unpremultiply8(img_row);
        }
    }

    Ok(())
}

/// Add 16-bit alpha channel to an image from Y plane data.
///
/// Alpha values from the plane are in native bit depth range (e.g. 0–1023 for
/// 10-bit). They are range-converted (limited→full if needed) and then scaled
/// to full u16 (0–65535) to match the already-scaled RGB channels.
pub fn add_alpha16<'a>(
    buf: &mut PixelBuffer,
    alpha_rows: impl Iterator<Item = &'a [u16]>,
    width: usize,
    height: usize,
    alpha_range: ColorRange,
    bit_depth: u8,
    premultiplied: bool,
) -> Result<()> {
    let mut img = buf.try_as_imgref_mut::<Rgba<u16>>().ok_or_else(|| {
        at!(Error::Unsupported(
            "cannot add 16-bit alpha to this image type",
        ))
    })?;

    if img.width() != width || img.height() != height {
        return Err(at!(Error::Unsupported("alpha size mismatch")));
    }

    for (alpha_row, img_row) in alpha_rows.zip(img.rows_mut()) {
        if alpha_row.len() < img_row.len() {
            return Err(at!(Error::Unsupported("alpha width mismatch")));
        }
        for (&y, px) in alpha_row.iter().zip(img_row.iter_mut()) {
            let a = match alpha_range {
                ColorRange::Full => y,
                ColorRange::Limited => limited_to_full_16(y, bit_depth),
            };
            // Scale from native bit depth to full u16
            px.a = scale_to_u16(a, bit_depth);
        }
        if premultiplied {
            unpremultiply16(img_row);
        }
    }

    Ok(())
}

/// Convert premultiplied alpha to straight alpha for 8-bit RGBA
#[inline(never)]
pub fn unpremultiply8(img_row: &mut [Rgba<u8>]) {
    for px in img_row.iter_mut() {
        if px.a != 255 && px.a != 0 {
            *px.rgb_mut() = px
                .rgb()
                .map(|c| ((c as u16 * 255 + px.a as u16 / 2) / px.a as u16).min(255) as u8);
        }
    }
}

/// Convert premultiplied alpha to straight alpha for 16-bit RGBA
#[inline(never)]
pub fn unpremultiply16(img_row: &mut [Rgba<u16>]) {
    for px in img_row.iter_mut() {
        if px.a != 0xFFFF && px.a != 0 {
            *px.rgb_mut() = px
                .rgb()
                .map(|c| (c as u32 * 0xFFFF / px.a as u32).min(0xFFFF) as u16);
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn limited_to_full_8_no_overflow() {
        // Regression: i16 arithmetic overflowed for y > 144
        // (y-16)*255 = (235-16)*255 = 55845 > i16::MAX (32767)
        assert_eq!(limited_to_full_8(16), 0);
        assert_eq!(limited_to_full_8(235), 255);
        // y=145: (145-16)*255 = 32895 > 32767 — would overflow with i16
        assert_eq!(limited_to_full_8(145), 150);
        // y=200: (200-16)*255 = 46920 — definitely overflows i16
        assert_eq!(limited_to_full_8(200), 214);
        // Below range clamps to 0
        assert_eq!(limited_to_full_8(0), 0);
        assert_eq!(limited_to_full_8(15), 0);
        // Above range clamps to 255
        assert_eq!(limited_to_full_8(255), 255);
    }

    #[test]
    fn limited_to_full_8_all_values_in_range() {
        // Ensure no panic or overflow for any u8 input
        for y in 0..=255u8 {
            let result = limited_to_full_8(y);
            // u8 is always <= 255, but verify the function doesn't panic
            let _ = result;
        }
    }

    #[test]
    fn limited_to_full_16_endpoints() {
        // 10-bit
        assert_eq!(limited_to_full_16(64, 10), 0); // 16<<2 = 64
        assert_eq!(limited_to_full_16(940, 10), 1023); // 235<<2 = 940
        // 12-bit
        assert_eq!(limited_to_full_16(256, 12), 0); // 16<<4 = 256
        assert_eq!(limited_to_full_16(3760, 12), 4095); // 235<<4 = 3760
    }

    #[test]
    fn scale_to_u16_endpoints() {
        // 10-bit
        assert_eq!(scale_to_u16(0, 10), 0);
        assert_eq!(scale_to_u16(1023, 10), 65535);
        // 12-bit
        assert_eq!(scale_to_u16(0, 12), 0);
        assert_eq!(scale_to_u16(4095, 12), 65535);
        // 16-bit no-op
        assert_eq!(scale_to_u16(12345, 16), 12345);
    }
}