use zune_core::bit_depth::{BitDepth, BitType};
use zune_core::colorspace::ColorSpace;
use zune_core::log::warn;
use zune_image::errors::ImageErrors;
use zune_image::image::Image;
use zune_image::metadata::AlphaState;
use zune_image::traits::OperationsTrait;
use crate::mathops::{compute_mod_u32, fastdiv_u32};
mod std_simd;
#[derive(Copy, Clone)]
pub struct PremultiplyAlpha {
to: AlphaState
}
impl PremultiplyAlpha {
#[must_use]
pub fn new(to: AlphaState) -> PremultiplyAlpha {
PremultiplyAlpha { to }
}
}
impl OperationsTrait for PremultiplyAlpha {
fn name(&self) -> &'static str {
"pre-multiply alpha"
}
fn execute_impl(&self, image: &mut Image) -> Result<(), ImageErrors> {
if !image.colorspace().has_alpha() {
warn!("Image colorspace indicates no alpha channel, this operation is a no-op");
return Ok(());
}
let colorspaces = image.colorspace();
let alpha_state = image.metadata().alpha();
if alpha_state == self.to {
warn!("Alpha is already in required mode, exiting");
return Ok(());
}
let bit_type = image.depth();
for image_frame in image.frames_mut() {
let (color_channels, alpha) = {
if colorspaces == ColorSpace::ARGB {
let im = image_frame.channels_vec();
let (alpha, channels) = im.split_at_mut(1);
(channels, alpha)
} else {
image_frame
.channels_mut(colorspaces, false)
.split_at_mut(colorspaces.num_components() - 1)
}
};
assert_eq!(alpha.len(), 1);
let u8_table = create_unpremul_table_u8();
let mut u16_table = vec![];
if bit_type == BitDepth::Sixteen {
u16_table = create_unpremul_table_u16();
}
for channel in color_channels {
match (alpha_state, self.to) {
(AlphaState::NonPreMultiplied, AlphaState::PreMultiplied) => match bit_type {
BitDepth::Eight => {
premultiply_u8(
channel.reinterpret_as_mut()?,
alpha[0].reinterpret_as()?
);
}
BitDepth::Sixteen => {
premultiply_u16(
channel.reinterpret_as_mut()?,
alpha[0].reinterpret_as()?
);
}
BitDepth::Float32 => premultiply_f32(
channel.reinterpret_as_mut()?,
alpha[0].reinterpret_as()?
),
d => {
return Err(ImageErrors::ImageOperationNotImplemented(
self.name(),
d.bit_type()
))
}
},
(AlphaState::PreMultiplied, AlphaState::NonPreMultiplied) => match bit_type {
BitDepth::Eight => {
unpremultiply_u8(
channel.reinterpret_as_mut()?,
alpha[0].reinterpret_as()?,
&u8_table
);
}
BitDepth::Sixteen => {
unpremultiply_u16(
channel.reinterpret_as_mut()?,
alpha[0].reinterpret_as()?,
&u16_table
);
}
BitDepth::Float32 => unpremultiply_f32(
channel.reinterpret_as_mut()?,
alpha[0].reinterpret_as()?
),
d => {
return Err(ImageErrors::ImageOperationNotImplemented(
self.name(),
d.bit_type()
))
}
},
(_, _) => return Err(ImageErrors::GenericStr("Could not pre-multiply alpha"))
}
}
}
image.metadata_mut().set_alpha(self.to);
Ok(())
}
fn supported_types(&self) -> &'static [BitType] {
&[BitType::F32, BitType::U16, BitType::U8]
}
}
#[allow(clippy::needless_range_loop)]
#[must_use]
pub fn create_unpremul_table_u8() -> [u128; 256] {
let mut array = [0; 256];
for i in 1..256 {
array[i] = compute_mod_u32(i as u64);
}
array
}
#[must_use]
#[allow(clippy::needless_range_loop)]
pub fn create_unpremul_table_u16() -> Vec<u128> {
let mut array = vec![0; 65536];
for i in 1..65536 {
array[i] = compute_mod_u32(i as u64);
}
array
}
#[allow(clippy::cast_possible_truncation)]
pub fn premultiply_u8(input: &mut [u8], alpha: &[u8]) {
const MAX_VALUE: u16 = 255;
input.iter_mut().zip(alpha).for_each(|(color, al)| {
let temp = (u16::from(*al) * u16::from(*color)) + 0x80;
*color = ((temp + (temp >> 8)) / MAX_VALUE) as u8;
});
}
#[allow(clippy::cast_possible_truncation)]
pub fn premultiply_u16(input: &mut [u16], alpha: &[u16]) {
const MAX_VALUE: u32 = 65535;
input.iter_mut().zip(alpha).for_each(|(color, al)| {
let temp = (u32::from(*al) * u32::from(*color)) + ((MAX_VALUE + 1) / 2);
*color = ((temp + (temp >> 16)) / MAX_VALUE) as u16;
});
}
pub fn unpremultiply_u8(input: &mut [u8], alpha: &[u8], premul_table: &[u128; 256]) {
const MAX_VALUE: u32 = 255;
input.iter_mut().zip(alpha).for_each(|(color, al)| {
let associated_alpha = premul_table[usize::from(*al)];
*color = u8::try_from(fastdiv_u32(
u32::from(*color) * MAX_VALUE + (u32::from(*al) / 2),
associated_alpha
))
.unwrap_or(u8::MAX);
});
}
pub fn unpremultiply_u16(input: &mut [u16], alpha: &[u16], premul_table: &[u128]) {
const MAX_VALUE: u32 = 65535;
debug_assert!(premul_table.len() > 65535);
if premul_table.len() < 65536 {
return;
}
input.iter_mut().zip(alpha).for_each(|(color, al)| {
let associated_alpha = premul_table[usize::from(*al)];
*color = u16::try_from(fastdiv_u32(
u32::from(*color) * MAX_VALUE + (u32::from(*al) / 2),
associated_alpha
))
.unwrap_or(u16::MAX);
});
}
pub fn premultiply_f32(input: &mut [f32], alpha: &[f32]) {
input.iter_mut().zip(alpha).for_each(|(color, al)| {
*color *= al;
});
}
fn unpremultiply_f32_scalar(input: &mut [f32], alpha: &[f32]) {
input.iter_mut().zip(alpha).for_each(|(color, al)| {
if *al == 0.0 {
*color = 0.0;
} else {
*color /= *al;
}
});
}
pub fn unpremultiply_f32(input: &mut [f32], alpha: &[f32]) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "portable-simd")]
{
use crate::premul_alpha::std_simd::unpremultiply_std_simd;
unpremultiply_std_simd(input, alpha);
}
}
unpremultiply_f32_scalar(input, alpha);
}