zensim 0.2.4

Fast psychovisual image similarity metric
Documentation
//! Adapter for zenpixels pixel types.
//!
//! Provides [`ZenpixelsSource`], a validated wrapper that implements
//! [`ImageSource`] for any zenpixels [`PixelSlice`] or [`PixelBuffer`].
//!
//! # Supported formats
//!
//! | zenpixels format | transfer | result |
//! |------------------|----------|--------|
//! | Rgb8 | sRGB/BT.709 | `Srgb8Rgb`, Opaque |
//! | Rgba8 | sRGB/BT.709 | `Srgb8Rgba`, from alpha mode |
//! | Bgra8 | sRGB/BT.709 | `Srgb8Bgra`, from alpha mode |
//! | Rgbx8 | sRGB/BT.709 | `Srgb8Rgba`, **Opaque** (automatic) |
//! | Bgrx8 | sRGB/BT.709 | `Srgb8Bgra`, **Opaque** (automatic) |
//! | Rgba16 | sRGB/BT.709 | `Srgb16Rgba`, from alpha mode |
//! | RgbaF32 | Linear | `LinearF32Rgba`, from alpha mode |
//!
//! Premultiplied alpha is un-premultiplied automatically (allocates a copy).
//! RGBX/BGRX padding bytes are ignored automatically (no user action needed).
//!
//! # Rejected formats
//!
//! - HDR transfers (PQ, HLG) — zensim is SDR-only
//! - Unknown transfer — can't convert losslessly
//! - Grayscale — not supported
//! - Narrow signal range — requires expansion first

use std::borrow::Cow;

use zenpixels::{AlphaMode as ZpAlpha, PixelBuffer, PixelDescriptor, PixelSlice, TransferFunction};

use crate::error::UnsupportedFormat;
use crate::source::{AlphaMode, ColorPrimaries, ImageSource, PixelFormat};

/// Validated adapter wrapping zenpixels pixel data for use with zensim.
///
/// Created via [`try_from_slice`](Self::try_from_slice) or
/// [`try_from_buffer`](Self::try_from_buffer). Validates the pixel descriptor
/// and maps it to the correct zensim format, alpha mode, and color primaries.
///
/// For premultiplied alpha, the data is un-premultiplied into an internal
/// buffer (one-time allocation). For all other formats, the original data
/// is borrowed with zero copies.
pub struct ZenpixelsSource<'a> {
    data: Cow<'a, [u8]>,
    width: usize,
    height: usize,
    stride: usize,
    pixel_format: PixelFormat,
    alpha_mode: AlphaMode,
    color_primaries: ColorPrimaries,
}

impl<'a> ZenpixelsSource<'a> {
    /// Create from a borrowed [`PixelSlice`].
    ///
    /// Validates the pixel descriptor and rejects unsupported formats.
    /// For premultiplied alpha, allocates and un-premultiplies the data.
    pub fn try_from_slice(slice: &'a PixelSlice<'a>) -> Result<Self, UnsupportedFormat> {
        let desc = slice.descriptor();
        let (pixel_format, alpha_mode, color_primaries) = map_descriptor(&desc)?;

        let width = slice.width() as usize;
        let height = slice.rows() as usize;
        let stride = slice.stride();
        let raw = slice.as_strided_bytes();

        let data = if matches!(desc.alpha, Some(ZpAlpha::Premultiplied)) {
            Cow::Owned(unpremultiply(raw, width, height, stride, &desc))
        } else {
            Cow::Borrowed(raw)
        };

        Ok(Self {
            data,
            width,
            height,
            stride,
            pixel_format,
            alpha_mode,
            color_primaries,
        })
    }

    /// Create from an owned [`PixelBuffer`].
    ///
    /// Same validation as [`try_from_slice`](Self::try_from_slice).
    /// The buffer's backing memory is borrowed for the lifetime of the source.
    pub fn try_from_buffer(buf: &'a PixelBuffer) -> Result<Self, UnsupportedFormat> {
        let desc = buf.descriptor();
        let (pixel_format, alpha_mode, color_primaries) = map_descriptor(&desc)?;

        let width = buf.width() as usize;
        let height = buf.height() as usize;
        let stride = buf.stride();

        // PixelBuffer may have internal offset; get contiguous or strided bytes.
        let raw = match buf.as_contiguous_bytes() {
            Some(bytes) => bytes,
            None => {
                return Err(UnsupportedFormat("non-contiguous PixelBuffer layout"));
            }
        };

        let data = if matches!(desc.alpha, Some(ZpAlpha::Premultiplied)) {
            Cow::Owned(unpremultiply(raw, width, height, stride, &desc))
        } else {
            Cow::Borrowed(raw)
        };

        Ok(Self {
            data,
            width,
            height,
            stride,
            pixel_format,
            alpha_mode,
            color_primaries,
        })
    }
}

impl ImageSource for ZenpixelsSource<'_> {
    #[inline]
    fn width(&self) -> usize {
        self.width
    }
    #[inline]
    fn height(&self) -> usize {
        self.height
    }
    #[inline]
    fn pixel_format(&self) -> PixelFormat {
        self.pixel_format
    }
    #[inline]
    fn alpha_mode(&self) -> AlphaMode {
        self.alpha_mode
    }
    #[inline]
    fn color_primaries(&self) -> ColorPrimaries {
        self.color_primaries
    }
    #[inline]
    fn row_bytes(&self, y: usize) -> &[u8] {
        let bpp = self.pixel_format.bytes_per_pixel();
        let start = y * self.stride;
        &self.data[start..start + self.width * bpp]
    }
}

/// Map a zenpixels [`PixelDescriptor`] to zensim types.
fn map_descriptor(
    desc: &PixelDescriptor,
) -> Result<(PixelFormat, AlphaMode, ColorPrimaries), UnsupportedFormat> {
    use zenpixels::{ColorPrimaries as ZpPrimaries, PixelFormat as ZpFormat, SignalRange};

    // Reject unsupported transfers
    match desc.transfer {
        TransferFunction::Srgb | TransferFunction::Bt709 => {}
        TransferFunction::Linear => {
            if !matches!(desc.format, ZpFormat::RgbaF32) {
                return Err(UnsupportedFormat("linear transfer requires RgbaF32 format"));
            }
        }
        TransferFunction::Pq | TransferFunction::Hlg => {
            return Err(UnsupportedFormat(
                "HDR transfers (PQ, HLG) are not supported — zensim is SDR-only",
            ));
        }
        _ => {
            return Err(UnsupportedFormat(
                "unknown transfer function — cannot convert losslessly",
            ));
        }
    }

    // Reject narrow signal range
    if desc.signal_range != SignalRange::Full {
        return Err(UnsupportedFormat(
            "narrow/limited signal range not supported — expand to full range first",
        ));
    }

    // Map pixel format
    let pixel_format = match desc.format {
        ZpFormat::Rgb8 => PixelFormat::Srgb8Rgb,
        ZpFormat::Rgba8 | ZpFormat::Rgbx8 => PixelFormat::Srgb8Rgba,
        ZpFormat::Bgra8 | ZpFormat::Bgrx8 => PixelFormat::Srgb8Bgra,
        ZpFormat::Rgba16 => PixelFormat::Srgb16Rgba,
        ZpFormat::RgbaF32 => PixelFormat::LinearF32Rgba,
        ZpFormat::Gray8
        | ZpFormat::Gray16
        | ZpFormat::GrayF32
        | ZpFormat::GrayA8
        | ZpFormat::GrayA16
        | ZpFormat::GrayAF32 => {
            return Err(UnsupportedFormat("grayscale formats not supported"));
        }
        _ => {
            return Err(UnsupportedFormat("unsupported pixel format"));
        }
    };

    // Map alpha mode
    let alpha_mode = match desc.alpha {
        None => AlphaMode::Opaque,
        Some(ZpAlpha::Undefined) => AlphaMode::Opaque,
        Some(ZpAlpha::Opaque) => AlphaMode::Opaque,
        Some(ZpAlpha::Straight) => AlphaMode::Straight,
        Some(ZpAlpha::Premultiplied) => AlphaMode::Straight, // un-premultiplied by adapter
        _ => AlphaMode::Straight,                            // future variants → treat as straight
    };

    // Map color primaries
    let color_primaries = match desc.primaries {
        ZpPrimaries::Bt709 => ColorPrimaries::Srgb,
        ZpPrimaries::DisplayP3 => ColorPrimaries::DisplayP3,
        ZpPrimaries::Bt2020 => ColorPrimaries::Bt2020,
        _ => ColorPrimaries::Srgb,
    };

    Ok((pixel_format, alpha_mode, color_primaries))
}

/// Un-premultiply pixel data, returning a new buffer with straight alpha.
///
/// Divides each RGB channel by alpha in the native domain (sRGB for u8/u16,
/// linear for f32). Fully transparent pixels (A=0) get RGB zeroed.
fn unpremultiply(
    data: &[u8],
    width: usize,
    height: usize,
    stride: usize,
    desc: &PixelDescriptor,
) -> Vec<u8> {
    use zenpixels::PixelFormat as ZpFormat;

    let mut out = data.to_vec();

    match desc.format {
        ZpFormat::Rgba8 | ZpFormat::Rgbx8 | ZpFormat::Bgra8 | ZpFormat::Bgrx8 => {
            for y in 0..height {
                let row_start = y * stride;
                for x in 0..width {
                    let off = row_start + x * 4;
                    let a = out[off + 3];
                    if a == 0 {
                        out[off] = 0;
                        out[off + 1] = 0;
                        out[off + 2] = 0;
                    } else if a < 255 {
                        let inv = 255.0 / a as f32;
                        out[off] = (out[off] as f32 * inv).min(255.0) as u8;
                        out[off + 1] = (out[off + 1] as f32 * inv).min(255.0) as u8;
                        out[off + 2] = (out[off + 2] as f32 * inv).min(255.0) as u8;
                    }
                }
            }
        }
        ZpFormat::Rgba16 => {
            for y in 0..height {
                let row_start = y * stride;
                for x in 0..width {
                    let off = row_start + x * 8;
                    let a = u16::from_ne_bytes([out[off + 6], out[off + 7]]);
                    if a == 0 {
                        out[off..off + 6].fill(0);
                    } else if a < 65535 {
                        let inv = 65535.0 / a as f32;
                        for c in 0..3 {
                            let co = off + c * 2;
                            let v = u16::from_ne_bytes([out[co], out[co + 1]]);
                            let unpremul = (v as f32 * inv).min(65535.0) as u16;
                            out[co..co + 2].copy_from_slice(&unpremul.to_ne_bytes());
                        }
                    }
                }
            }
        }
        ZpFormat::RgbaF32 => {
            for y in 0..height {
                let row_start = y * stride;
                for x in 0..width {
                    let off = row_start + x * 16;
                    let a = f32::from_ne_bytes([
                        out[off + 12],
                        out[off + 13],
                        out[off + 14],
                        out[off + 15],
                    ]);
                    if a <= 0.0 {
                        out[off..off + 12].fill(0);
                    } else if a < 1.0 {
                        let inv = 1.0 / a;
                        for c in 0..3 {
                            let co = off + c * 4;
                            let v = f32::from_ne_bytes([
                                out[co],
                                out[co + 1],
                                out[co + 2],
                                out[co + 3],
                            ]);
                            out[co..co + 4].copy_from_slice(&(v * inv).to_ne_bytes());
                        }
                    }
                }
            }
        }
        _ => {}
    }

    out
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn rgbx_maps_to_opaque() {
        let desc = PixelDescriptor::RGBX8_SRGB;
        let (fmt, alpha, _) = map_descriptor(&desc).unwrap();
        assert_eq!(fmt, PixelFormat::Srgb8Rgba);
        assert_eq!(alpha, AlphaMode::Opaque);
    }

    #[test]
    fn bgrx_maps_to_opaque() {
        let desc = PixelDescriptor::BGRX8_SRGB;
        let (fmt, alpha, _) = map_descriptor(&desc).unwrap();
        assert_eq!(fmt, PixelFormat::Srgb8Bgra);
        assert_eq!(alpha, AlphaMode::Opaque);
    }

    #[test]
    fn rgba_straight_maps_correctly() {
        let desc = PixelDescriptor::RGBA8_SRGB;
        let (fmt, alpha, _) = map_descriptor(&desc).unwrap();
        assert_eq!(fmt, PixelFormat::Srgb8Rgba);
        assert_eq!(alpha, AlphaMode::Straight);
    }

    #[test]
    fn hdr_rejected() {
        let desc = PixelDescriptor::RGBA8_SRGB.with_transfer(TransferFunction::Pq);
        assert!(map_descriptor(&desc).is_err());
    }

    #[test]
    fn unknown_transfer_rejected() {
        let desc = PixelDescriptor::RGBA8_SRGB.with_transfer(TransferFunction::Unknown);
        assert!(map_descriptor(&desc).is_err());
    }

    #[test]
    fn grayscale_rejected() {
        let desc = PixelDescriptor::GRAY8_SRGB;
        assert!(map_descriptor(&desc).is_err());
    }

    #[test]
    fn p3_primaries_mapped() {
        let desc = PixelDescriptor::RGBA8_SRGB.with_primaries(zenpixels::ColorPrimaries::DisplayP3);
        let (_, _, primaries) = map_descriptor(&desc).unwrap();
        assert_eq!(primaries, ColorPrimaries::DisplayP3);
    }

    #[test]
    fn unpremultiply_u8_roundtrip() {
        // Premultiplied: R=100 means straight R = 100 * 255/200 ≈ 127
        let data = vec![100u8, 50, 25, 200];
        let result = unpremultiply(&data, 1, 1, 4, &PixelDescriptor::RGBA8_SRGB);
        assert_eq!(result[3], 200); // alpha unchanged
        assert_eq!(result[0], 127); // 100 * 255/200 = 127.5 → 127
        assert_eq!(result[1], 63); // 50 * 255/200 = 63.75 → 63
    }

    #[test]
    fn unpremultiply_a0_clears_rgb() {
        let data = vec![255u8, 128, 64, 0];
        let result = unpremultiply(&data, 1, 1, 4, &PixelDescriptor::RGBA8_SRGB);
        assert_eq!(result, [0, 0, 0, 0]);
    }
}