Skip to main content

jxl_encoder_simd/
lib.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! SIMD-accelerated primitives for jxl_encoder.
6//!
7//! This crate wraps platform-specific SIMD intrinsics behind safe public functions.
8//! The main encoder crate (`jxl_encoder`) maintains `#![forbid(unsafe_code)]` and
9//! calls into these safe wrappers.
10//!
11//! Uses [archmage](https://docs.rs/archmage) for token-based SIMD dispatch
12//! and [magetypes](https://docs.rs/magetypes) for cross-platform vector types.
13//!
14//! # Direct variant access
15//!
16//! Each kernel is available in three forms:
17//! - A dispatching function (e.g. `dct_8x8`) that picks the best at runtime
18//! - Concrete `_avx2(token, ...)` / `_neon(token, ...)` / `_scalar(...)` variants
19//!
20//! For hot loops, callers should summon a token once, then call the concrete
21//! variant directly from an `#[arcane]` function so LLVM can inline across the
22//! target-feature boundary.
23
24#![cfg_attr(not(feature = "unsafe-performance"), forbid(unsafe_code))]
25#![cfg_attr(feature = "unsafe-performance", deny(unsafe_code))]
26// Numerical SIMD/DSP code: range loops and many-parameter kernels are natural.
27#![allow(clippy::needless_range_loop, clippy::too_many_arguments)]
28#![no_std]
29extern crate alloc;
30
31/// Return an uninitialized `[f32; N]` scratch buffer (unsafe-performance path).
32///
33/// # Safety
34/// Caller must write every element before reading it. All call sites are DCT/IDCT
35/// scratch arrays that are immediately filled by copy_from_slice, transpose, or
36/// gather_col before any read occurs.
37#[cfg(feature = "unsafe-performance")]
38#[allow(unsafe_code, clippy::uninit_assumed_init)]
39#[inline(always)]
40pub(crate) fn scratch_buf<const N: usize>() -> [f32; N] {
41    // SAFETY: All call sites write every element via copy_from_slice, transpose,
42    // or gather_col before any read. f32 has no trap representations on IEEE 754.
43    unsafe { core::mem::MaybeUninit::<[f32; N]>::uninit().assume_init() }
44}
45
46/// Return a zero-initialized `[f32; N]` scratch buffer (safe default path).
47#[cfg(not(feature = "unsafe-performance"))]
48#[inline(always)]
49pub(crate) fn scratch_buf<const N: usize>() -> [f32; N] {
50    [0.0f32; N]
51}
52
53/// Allocate a `Vec<f32>` of length `n` without zeroing (unsafe-performance path).
54///
55/// # Safety
56/// Caller must write every element before reading it. Intended for output buffers
57/// that are immediately overwritten by IDCT, EPF, gaborish, or similar operations.
58#[cfg(feature = "unsafe-performance")]
59#[allow(unsafe_code, clippy::uninit_vec)]
60#[inline]
61pub fn vec_f32_dirty(n: usize) -> alloc::vec::Vec<f32> {
62    let mut v = alloc::vec::Vec::with_capacity(n);
63    // SAFETY: f32 has no trap representations on IEEE 754. Caller must write all
64    // elements before reading. Length is within the allocated capacity.
65    unsafe { v.set_len(n) };
66    v
67}
68
69/// Allocate a zero-initialized `Vec<f32>` of length `n` (safe default path).
70#[cfg(not(feature = "unsafe-performance"))]
71#[inline]
72pub fn vec_f32_dirty(n: usize) -> alloc::vec::Vec<f32> {
73    alloc::vec![0.0f32; n]
74}
75
76/// Slice from offset without bounds check (unsafe-performance path).
77///
78/// # Safety
79/// Caller must ensure `offset <= s.len()`.
80#[cfg(all(feature = "unsafe-performance", target_arch = "x86_64"))]
81#[inline(always)]
82#[allow(unsafe_code)]
83pub(crate) fn slice_from(s: &[f32], offset: usize) -> &[f32] {
84    debug_assert!(offset <= s.len());
85    // SAFETY: Caller guarantees offset <= s.len(); debug_assert checks in debug builds.
86    unsafe { s.get_unchecked(offset..) }
87}
88
89/// Slice from offset with bounds check (safe default path).
90#[cfg(all(not(feature = "unsafe-performance"), target_arch = "x86_64"))]
91#[inline(always)]
92pub(crate) fn slice_from(s: &[f32], offset: usize) -> &[f32] {
93    &s[offset..]
94}
95
96/// Load 8 floats at offset — no bounds checks (unsafe-performance path).
97///
98/// Bypasses both slice-from bounds check AND `f32x8::from_slice`'s internal
99/// `[..8]` bounds check by using `_mm256_loadu_ps` directly.
100///
101/// # Safety
102/// Caller must ensure `offset + 8 <= s.len()`.
103#[cfg(all(feature = "unsafe-performance", target_arch = "x86_64"))]
104#[inline(always)]
105#[allow(unsafe_code)]
106pub(crate) fn load_f32x8(
107    token: archmage::X64V3Token,
108    s: &[f32],
109    offset: usize,
110) -> magetypes::simd::f32x8 {
111    use magetypes::simd::f32x8;
112    debug_assert!(
113        offset + 8 <= s.len(),
114        "load_f32x8: offset={offset}, len={}",
115        s.len()
116    );
117    // SAFETY: Caller guarantees offset + 8 <= s.len(); debug_assert checks in debug builds.
118    unsafe {
119        let ptr = s.as_ptr().add(offset);
120        f32x8::from_m256(token, core::arch::x86_64::_mm256_loadu_ps(ptr))
121    }
122}
123
124/// Load 8 floats at offset — with bounds checks (safe default path).
125#[cfg(all(not(feature = "unsafe-performance"), target_arch = "x86_64"))]
126#[inline(always)]
127pub(crate) fn load_f32x8(
128    token: archmage::X64V3Token,
129    s: &[f32],
130    offset: usize,
131) -> magetypes::simd::f32x8 {
132    use magetypes::simd::f32x8;
133    f32x8::from_slice(token, &s[offset..])
134}
135
136/// Store 8 floats at offset — no bounds checks (unsafe-performance path).
137///
138/// Bypasses slice bounds check and `try_into().unwrap()` by using
139/// `_mm256_storeu_ps` directly.
140///
141/// # Safety
142/// Caller must ensure `offset + 8 <= s.len()`.
143#[cfg(all(feature = "unsafe-performance", target_arch = "x86_64"))]
144#[inline(always)]
145#[allow(unsafe_code)]
146pub(crate) fn store_f32x8(s: &mut [f32], offset: usize, v: magetypes::simd::f32x8) {
147    debug_assert!(
148        offset + 8 <= s.len(),
149        "store_f32x8: offset={offset}, len={}",
150        s.len()
151    );
152    // SAFETY: Caller guarantees offset + 8 <= s.len(); debug_assert checks in debug builds.
153    unsafe {
154        let ptr = s.as_mut_ptr().add(offset);
155        core::arch::x86_64::_mm256_storeu_ps(ptr, v.raw());
156    }
157}
158
159/// Store 8 floats at offset — with bounds checks (safe default path).
160#[cfg(all(not(feature = "unsafe-performance"), target_arch = "x86_64"))]
161#[inline(always)]
162pub(crate) fn store_f32x8(s: &mut [f32], offset: usize, v: magetypes::simd::f32x8) {
163    let out: &mut [f32; 8] = (&mut s[offset..offset + 8]).try_into().unwrap();
164    v.store(out);
165}
166
167/// Load column `j` from 8 consecutive rows starting at `base_row` with given stride.
168///
169/// Unsafe-performance path: uses unchecked indexing (validated by debug_assert).
170/// Safe path: uses bounds-checked indexing.
171#[cfg(target_arch = "x86_64")]
172#[inline(always)]
173#[cfg_attr(feature = "unsafe-performance", allow(unsafe_code))]
174pub(crate) fn gather_col_strided(
175    token: archmage::X64V3Token,
176    data: &[f32],
177    base_row: usize,
178    j: usize,
179    stride: usize,
180) -> magetypes::simd::f32x8 {
181    #[cfg(feature = "unsafe-performance")]
182    {
183        debug_assert!(
184            (base_row + 7) * stride + j < data.len(),
185            "gather_col_strided OOB: base_row={base_row}, j={j}, stride={stride}, len={}",
186            data.len()
187        );
188        // SAFETY: Caller guarantees (base_row + 7) * stride + j < data.len().
189        // All lower indices are within bounds since base_row + r <= base_row + 7.
190        unsafe {
191            let arr = [
192                *data.get_unchecked(base_row * stride + j),
193                *data.get_unchecked((base_row + 1) * stride + j),
194                *data.get_unchecked((base_row + 2) * stride + j),
195                *data.get_unchecked((base_row + 3) * stride + j),
196                *data.get_unchecked((base_row + 4) * stride + j),
197                *data.get_unchecked((base_row + 5) * stride + j),
198                *data.get_unchecked((base_row + 6) * stride + j),
199                *data.get_unchecked((base_row + 7) * stride + j),
200            ];
201            magetypes::simd::f32x8::from_array(token, arr)
202        }
203    }
204    #[cfg(not(feature = "unsafe-performance"))]
205    magetypes::simd::f32x8::from_array(
206        token,
207        [
208            data[base_row * stride + j],
209            data[(base_row + 1) * stride + j],
210            data[(base_row + 2) * stride + j],
211            data[(base_row + 3) * stride + j],
212            data[(base_row + 4) * stride + j],
213            data[(base_row + 5) * stride + j],
214            data[(base_row + 6) * stride + j],
215            data[(base_row + 7) * stride + j],
216        ],
217    )
218}
219
220/// Store f32x8 lanes back to column `j` of 8 consecutive rows with given stride.
221///
222/// Unsafe-performance path: uses unchecked indexing (validated by debug_assert).
223/// Safe path: uses bounds-checked indexing.
224#[cfg(target_arch = "x86_64")]
225#[inline(always)]
226#[cfg_attr(feature = "unsafe-performance", allow(unsafe_code))]
227pub(crate) fn scatter_col_strided(
228    v: magetypes::simd::f32x8,
229    data: &mut [f32],
230    base_row: usize,
231    j: usize,
232    stride: usize,
233) {
234    let mut lane = [0.0f32; 8];
235    v.store(&mut lane);
236    #[cfg(feature = "unsafe-performance")]
237    {
238        debug_assert!(
239            (base_row + 7) * stride + j < data.len(),
240            "scatter_col_strided OOB: base_row={base_row}, j={j}, stride={stride}, len={}",
241            data.len()
242        );
243        // SAFETY: Caller guarantees (base_row + 7) * stride + j < data.len().
244        unsafe {
245            for (r, &val) in lane.iter().enumerate() {
246                *data.get_unchecked_mut((base_row + r) * stride + j) = val;
247            }
248        }
249    }
250    #[cfg(not(feature = "unsafe-performance"))]
251    for (r, &val) in lane.iter().enumerate() {
252        data[(base_row + r) * stride + j] = val;
253    }
254}
255
256mod adaptive_quant;
257mod block_l2;
258mod cfl;
259mod dct16;
260mod dct32;
261mod dct4;
262mod dct64;
263mod dct8;
264mod dequant;
265mod entropy;
266mod epf;
267mod fused_dct8;
268mod gab;
269mod gaborish5x5;
270mod idct16;
271mod idct32;
272mod idct64;
273mod mask1x1;
274mod noise;
275mod pixel_loss;
276mod quantize;
277mod transpose;
278mod xyb;
279
280// Re-export archmage token types so callers don't need a direct archmage dependency
281#[cfg(target_arch = "aarch64")]
282pub use archmage::NeonToken;
283pub use archmage::SimdToken;
284#[cfg(target_arch = "wasm32")]
285pub use archmage::Wasm128Token;
286#[cfg(target_arch = "x86_64")]
287pub use archmage::X64V3Token;
288
289// --- Dispatching functions (runtime auto-select) ---
290
291pub use adaptive_quant::{compute_pre_erosion, per_block_modulations};
292pub use block_l2::compute_block_l2_errors;
293pub use cfl::find_best_multiplier as cfl_find_best_multiplier;
294pub use cfl::find_best_multiplier_newton as cfl_find_best_multiplier_newton;
295pub use cfl::{NEWTON_EPS_DEFAULT, NEWTON_MAX_ITERS_DEFAULT};
296pub use dct4::{
297    dct_4x4_full, dct_4x8_full, dct_8x4_full, idct_4x4_full, idct_4x8_full, idct_8x4_full,
298};
299pub use dct8::{dct_8x8, idct_8x8};
300pub use dct16::{dct_8x16, dct_16x8, dct_16x16};
301pub use dct32::{dct_16x32, dct_32x16, dct_32x32};
302pub use dct64::{dct_32x64, dct_64x32, dct_64x64};
303pub use dequant::dequant_block_dct8;
304pub use entropy::{
305    EntropyCoeffResult, entropy_estimate_coeffs, fast_log2f, fast_pow2f, fast_powf,
306    shannon_entropy_bits,
307};
308pub use epf::{epf_step1, epf_step2, pad_plane};
309pub use fused_dct8::fused_dct8_entropy;
310pub use gab::gab_smooth_channel;
311pub use gaborish5x5::gaborish_5x5_channel;
312pub use idct16::{idct_8x16, idct_16x8, idct_16x16};
313pub use idct32::{idct_16x32, idct_32x16, idct_32x32};
314pub use idct64::{idct_32x64, idct_64x32, idct_64x64};
315pub use mask1x1::compute_mask1x1;
316pub use noise::denoise_channel;
317pub use pixel_loss::pixel_domain_loss;
318pub use quantize::{quantize_block_dct8, quantize_block_large};
319pub use transpose::transpose_8x8;
320pub use xyb::{linear_rgb_to_xyb_batch, xyb_to_linear_rgb_batch, xyb_to_linear_rgb_planar};
321
322// --- Scalar variants (no token needed) ---
323
324pub use adaptive_quant::{compute_pre_erosion_scalar, per_block_modulations_scalar};
325pub use block_l2::compute_block_l2_errors_scalar;
326pub use cfl::find_best_multiplier_newton_scalar as cfl_find_best_multiplier_newton_scalar;
327pub use cfl::find_best_multiplier_scalar as cfl_find_best_multiplier_scalar;
328pub use dct4::{
329    dct_4x4_full_scalar, dct_4x8_full_scalar, dct_8x4_full_scalar, idct_4x4_full_scalar,
330    idct_4x8_full_scalar, idct_8x4_full_scalar,
331};
332pub use dct8::{dct_8x8_scalar, idct_8x8_scalar};
333pub use dct16::{dct_8x16_scalar, dct_16x8_scalar, dct_16x16_scalar};
334pub use dct32::{dct_16x32_scalar, dct_32x16_scalar, dct_32x32_scalar};
335pub use dct64::{dct_32x64_scalar, dct_64x32_scalar, dct_64x64_scalar};
336pub use dequant::dequant_dct8_scalar;
337pub use entropy::{entropy_coeffs_scalar, shannon_entropy_scalar};
338pub use epf::{epf_step1_scalar, epf_step2_scalar};
339pub use fused_dct8::fused_dct8_entropy_fallback;
340pub use gab::gab_smooth_scalar;
341pub use gaborish5x5::gaborish_5x5_scalar;
342pub use idct16::{idct_8x16_scalar, idct_16x8_scalar, idct_16x16_scalar};
343pub use idct32::{idct_16x32_scalar, idct_32x16_scalar, idct_32x32_scalar};
344pub use idct64::{idct_32x64_scalar, idct_64x32_scalar, idct_64x64_scalar};
345pub use mask1x1::compute_mask1x1_scalar;
346pub use noise::denoise_channel_scalar;
347pub use pixel_loss::pixel_domain_loss_scalar;
348pub use quantize::{quantize_dct8_scalar, quantize_large_scalar};
349// transpose has no separate scalar — the dispatching fn IS the scalar fallback
350pub use xyb::{forward_xyb_scalar, inverse_xyb_planar_scalar, inverse_xyb_scalar};
351
352// --- AVX2 variants (require X64V3Token) ---
353
354#[cfg(target_arch = "x86_64")]
355pub use adaptive_quant::{compute_pre_erosion_avx2, per_block_modulations_avx2};
356#[cfg(target_arch = "x86_64")]
357pub use block_l2::compute_block_l2_errors_avx2;
358#[cfg(target_arch = "x86_64")]
359pub use cfl::find_best_multiplier_avx2 as cfl_find_best_multiplier_avx2;
360#[cfg(target_arch = "x86_64")]
361pub use dct4::{
362    dct_4x4_full_avx2, dct_4x8_full_avx2, dct_8x4_full_avx2, idct_4x4_full_avx2,
363    idct_4x8_full_avx2, idct_8x4_full_avx2,
364};
365#[cfg(target_arch = "x86_64")]
366pub use dct8::{dct_8x8_avx2, idct_8x8_avx2};
367#[cfg(target_arch = "x86_64")]
368pub use dct16::{dct_8x16_avx2, dct_16x8_avx2, dct_16x16_avx2};
369#[cfg(target_arch = "x86_64")]
370pub use dct32::{dct_16x32_avx2, dct_32x16_avx2, dct_32x32_avx2};
371#[cfg(target_arch = "x86_64")]
372pub use dct64::{dct_32x64_avx2, dct_64x32_avx2, dct_64x64_avx2};
373#[cfg(target_arch = "x86_64")]
374pub use dequant::dequant_dct8_avx2;
375#[cfg(target_arch = "x86_64")]
376pub use entropy::{entropy_coeffs_avx2, shannon_entropy_avx2};
377#[cfg(target_arch = "x86_64")]
378pub use epf::{epf_step1_avx2, epf_step2_avx2};
379#[cfg(target_arch = "x86_64")]
380pub use fused_dct8::fused_dct8_entropy_avx2;
381#[cfg(target_arch = "x86_64")]
382pub use gab::gab_smooth_avx2;
383#[cfg(target_arch = "x86_64")]
384pub use gaborish5x5::gaborish_5x5_avx2;
385#[cfg(target_arch = "x86_64")]
386pub use idct16::{idct_8x16_avx2, idct_16x8_avx2, idct_16x16_avx2};
387#[cfg(target_arch = "x86_64")]
388pub use idct32::{idct_16x32_avx2, idct_32x16_avx2, idct_32x32_avx2};
389#[cfg(target_arch = "x86_64")]
390pub use idct64::{idct_32x64_avx2, idct_64x32_avx2, idct_64x64_avx2};
391#[cfg(target_arch = "x86_64")]
392pub use mask1x1::compute_mask1x1_avx2;
393#[cfg(target_arch = "x86_64")]
394pub use noise::denoise_channel_avx2;
395#[cfg(target_arch = "x86_64")]
396pub use pixel_loss::pixel_domain_loss_avx2;
397#[cfg(target_arch = "x86_64")]
398pub use quantize::{quantize_dct8_avx2, quantize_large_avx2};
399#[cfg(target_arch = "x86_64")]
400pub use transpose::transpose_8x8_avx2;
401#[cfg(target_arch = "x86_64")]
402pub use xyb::{forward_xyb_avx2, inverse_xyb_avx2, inverse_xyb_planar_avx2};
403
404// --- NEON variants (require NeonToken) ---
405
406#[cfg(target_arch = "aarch64")]
407pub use adaptive_quant::{compute_pre_erosion_neon, per_block_modulations_neon};
408#[cfg(target_arch = "aarch64")]
409pub use block_l2::compute_block_l2_errors_neon;
410#[cfg(target_arch = "aarch64")]
411pub use cfl::find_best_multiplier_neon as cfl_find_best_multiplier_neon;
412#[cfg(target_arch = "aarch64")]
413pub use dct8::{dct_8x8_neon, idct_8x8_neon};
414#[cfg(target_arch = "aarch64")]
415pub use dct16::{dct_8x16_neon, dct_16x8_neon, dct_16x16_neon};
416#[cfg(target_arch = "aarch64")]
417pub use dequant::dequant_dct8_neon;
418#[cfg(target_arch = "aarch64")]
419pub use entropy::{entropy_coeffs_neon, shannon_entropy_neon};
420#[cfg(target_arch = "aarch64")]
421pub use epf::{epf_step1_neon, epf_step2_neon};
422#[cfg(target_arch = "aarch64")]
423pub use gab::gab_smooth_neon;
424#[cfg(target_arch = "aarch64")]
425pub use gaborish5x5::gaborish_5x5_neon;
426#[cfg(target_arch = "aarch64")]
427pub use idct16::{idct_8x16_neon, idct_16x8_neon, idct_16x16_neon};
428#[cfg(target_arch = "aarch64")]
429pub use mask1x1::compute_mask1x1_neon;
430#[cfg(target_arch = "aarch64")]
431pub use noise::denoise_channel_neon;
432#[cfg(target_arch = "aarch64")]
433pub use pixel_loss::pixel_domain_loss_neon;
434#[cfg(target_arch = "aarch64")]
435pub use quantize::{quantize_dct8_neon, quantize_large_neon};
436#[cfg(target_arch = "aarch64")]
437pub use transpose::transpose_8x8_neon;
438#[cfg(target_arch = "aarch64")]
439pub use xyb::{forward_xyb_neon, inverse_xyb_neon, inverse_xyb_planar_neon};
440
441// --- WASM SIMD128 variants (require Wasm128Token) ---
442
443#[cfg(target_arch = "wasm32")]
444pub use adaptive_quant::{compute_pre_erosion_wasm128, per_block_modulations_wasm128};
445#[cfg(target_arch = "wasm32")]
446pub use block_l2::compute_block_l2_errors_wasm128;
447#[cfg(target_arch = "wasm32")]
448pub use cfl::find_best_multiplier_wasm128 as cfl_find_best_multiplier_wasm128;
449#[cfg(target_arch = "wasm32")]
450pub use dct8::{dct_8x8_wasm128, idct_8x8_wasm128};
451#[cfg(target_arch = "wasm32")]
452pub use dct16::{dct_8x16_wasm128, dct_16x8_wasm128, dct_16x16_wasm128};
453#[cfg(target_arch = "wasm32")]
454pub use dequant::dequant_dct8_wasm128;
455#[cfg(target_arch = "wasm32")]
456pub use entropy::{entropy_coeffs_wasm128, shannon_entropy_wasm128};
457#[cfg(target_arch = "wasm32")]
458pub use epf::{epf_step1_wasm128, epf_step2_wasm128};
459#[cfg(target_arch = "wasm32")]
460pub use idct16::{idct_8x16_wasm128, idct_16x8_wasm128, idct_16x16_wasm128};
461#[cfg(target_arch = "wasm32")]
462pub use mask1x1::compute_mask1x1_wasm128;
463#[cfg(target_arch = "wasm32")]
464pub use noise::denoise_channel_wasm128;
465#[cfg(target_arch = "wasm32")]
466pub use pixel_loss::pixel_domain_loss_wasm128;
467#[cfg(target_arch = "wasm32")]
468pub use quantize::{quantize_dct8_wasm128, quantize_large_wasm128};
469#[cfg(target_arch = "wasm32")]
470pub use xyb::{forward_xyb_wasm128, inverse_xyb_planar_wasm128, inverse_xyb_wasm128};