Skip to main content

jxl_encoder/vardct/
splines.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! Spline encoding for JPEG XL.
6//!
7//! Splines are parametric Gaussian-blurred curves overlaid additively onto
8//! decoded images. They efficiently encode thin features (power lines,
9//! horizons, etc.) that VarDCT handles poorly. The encoder quantizes
10//! splines, subtracts them from XYB, and encodes the residual via VarDCT.
11//! The decoder adds splines back after VarDCT reconstruction.
12
13use core::f32::consts::{FRAC_1_SQRT_2, PI, SQRT_2};
14
15use super::common::pack_signed;
16use crate::bit_writer::BitWriter;
17use crate::entropy_coding::encode::{
18    build_entropy_code_ans_with_options, write_entropy_code_ans, write_tokens_ans,
19};
20use crate::entropy_coding::token::Token;
21use crate::error::Result;
22
23// ── Public types ────────────────────────────────────────────────────────────
24
25/// A control point on a spline curve.
26#[derive(Clone, Copy, Debug, Default)]
27pub struct SplinePoint {
28    /// X coordinate in image space.
29    pub x: f32,
30    /// Y coordinate in image space.
31    pub y: f32,
32}
33
34impl SplinePoint {
35    /// Create a new point.
36    pub fn new(x: f32, y: f32) -> Self {
37        Self { x, y }
38    }
39
40    fn abs(&self) -> f32 {
41        self.x.hypot(self.y)
42    }
43}
44
45impl core::ops::Add for SplinePoint {
46    type Output = Self;
47    fn add(self, rhs: Self) -> Self {
48        Self {
49            x: self.x + rhs.x,
50            y: self.y + rhs.y,
51        }
52    }
53}
54
55impl core::ops::Sub for SplinePoint {
56    type Output = Self;
57    fn sub(self, rhs: Self) -> Self {
58        Self {
59            x: self.x - rhs.x,
60            y: self.y - rhs.y,
61        }
62    }
63}
64
65impl core::ops::Mul<f32> for SplinePoint {
66    type Output = Self;
67    fn mul(self, rhs: f32) -> Self {
68        Self {
69            x: self.x * rhs,
70            y: self.y * rhs,
71        }
72    }
73}
74
75impl core::ops::Div<f32> for SplinePoint {
76    type Output = Self;
77    fn div(self, rhs: f32) -> Self {
78        let inv = 1.0 / rhs;
79        Self {
80            x: self.x * inv,
81            y: self.y * inv,
82        }
83    }
84}
85
86/// A spline with control points, color DCT coefficients, and sigma DCT.
87///
88/// Control points define the curve path. The 32-element DCT arrays define
89/// how color intensity and Gaussian width vary along the curve.
90#[derive(Clone, Debug)]
91pub struct Spline {
92    /// Control points of the spline (at least 1).
93    pub control_points: Vec<SplinePoint>,
94    /// Color DCT coefficients: `[channel][coeff]` for X, Y, B channels.
95    pub color_dct: [[f32; 32]; 3],
96    /// Sigma (Gaussian width) DCT coefficients.
97    pub sigma_dct: [f32; 32],
98}
99
100// ── Internal types ──────────────────────────────────────────────────────────
101
102/// Quantized spline (delta-of-deltas control points, integer DCT coefficients).
103struct QuantizedSpline {
104    /// Double-delta-encoded control points (excluding the starting point).
105    control_points: Vec<(i64, i64)>,
106    /// Quantized color DCT: `[channel][coeff]`.
107    color_dct: [[i32; 32]; 3],
108    /// Quantized sigma DCT.
109    sigma_dct: [i32; 32],
110}
111
112/// A single rendered segment of a spline (one sample point along the curve).
113#[derive(Clone, Copy, Debug, Default)]
114struct SplineSegment {
115    center_x: f32,
116    center_y: f32,
117    maximum_distance: f32,
118    inv_sigma: f32,
119    sigma_over_4_times_intensity: f32,
120    color: [f32; 3],
121}
122
123/// Fully prepared spline data ready for subtraction/addition and encoding.
124pub(crate) struct SplinesData {
125    /// Quantization adjustment parameter.
126    quantization_adjustment: i32,
127    /// Original splines (for encoding).
128    splines: Vec<Spline>,
129    /// Quantized splines (for bitstream encoding).
130    quantized: Vec<QuantizedSpline>,
131    /// Rendered segments for pixel operations.
132    segments: Vec<SplineSegment>,
133    /// Indices into `segments` sorted by y coordinate.
134    segment_indices: Vec<usize>,
135    /// Prefix-sum index: `segment_y_start[y]` is the start index in
136    /// `segment_indices` for row y. Length = image_height + 1.
137    segment_y_start: Vec<usize>,
138}
139
140// ── Constants ───────────────────────────────────────────────────────────────
141
142/// Channel weights for quantization: [X, Y, B, sigma].
143const CHANNEL_WEIGHT: [f32; 4] = [0.0042, 0.075, 0.07, 0.3333];
144
145/// Number of entropy contexts for spline encoding.
146const NUM_SPLINE_CONTEXTS: usize = 6;
147
148/// Target rendering distance between sample points along the curve.
149const DESIRED_RENDERING_DISTANCE: f32 = 1.0;
150
151/// 1 / (2 * sqrt(2)), used in Gaussian splatting.
152const ONE_OVER_2S2: f32 = 0.353_553_38;
153
154/// Exponent for maximum_distance computation (fast mode, matches jxl-rs default).
155const DISTANCE_EXP: f32 = 3.0;
156
157/// Number of sub-points per Catmull-Rom segment.
158const NUM_POINTS_PER_SEGMENT: usize = 16;
159
160// ── Fast math ───────────────────────────────────────────────────────────────
161
162/// Fast error function approximation (max error ~6e-4).
163/// Ported from jxl-rs `fast_math.rs`.
164#[allow(clippy::excessive_precision)]
165#[inline]
166fn fast_erf(x: f32) -> f32 {
167    let absx = x.abs();
168    let d1 = absx * 7.77394369e-02 + 2.05260015e-04;
169    let d2 = d1 * absx + 2.32120216e-01;
170    let d3 = d2 * absx + 2.77820801e-01;
171    let d4 = d3 * absx + 1.0;
172    let d5 = d4 * d4;
173    let inv = 1.0 / d5;
174    (-inv * inv + 1.0).copysign(x)
175}
176
177/// Fast cosine approximation (max error ~1e-4).
178/// Ported from jxl-rs `fast_math.rs`.
179#[allow(clippy::excessive_precision)]
180#[inline]
181fn fast_cos(x: f32) -> f32 {
182    let pi2 = PI * 2.0;
183    let pi2_inv = 0.5 / PI;
184    let npi2 = (x * pi2_inv).floor() * pi2;
185    let xmodpi2 = x - npi2;
186    let x_pi = xmodpi2.min(pi2 - xmodpi2);
187    let above_pihalf = x_pi >= PI / 2.0;
188    let x_pihalf = if above_pihalf { PI - x_pi } else { x_pi };
189    let xs = x_pihalf * 0.25;
190    let x2 = xs * xs;
191    let x4 = x2 * x2;
192    let cosx_prescaling = x4 * 0.06960438 + (x2 * -0.84087373 + 1.68179268);
193    let cosx_scale1 = cosx_prescaling * cosx_prescaling - SQRT_2;
194    let cosx_scale2 = cosx_scale1 * cosx_scale1 - 1.0;
195    if above_pihalf {
196        -cosx_scale2
197    } else {
198        cosx_scale2
199    }
200}
201
202// ── Continuous IDCT ─────────────────────────────────────────────────────────
203
204/// Precomputed cosines for continuous IDCT at a given t value.
205/// Computed once per sample point and reused for all 4 DCT evaluations.
206struct PrecomputedCosines([f32; 32]);
207
208impl PrecomputedCosines {
209    #[inline]
210    fn new(t: f32) -> Self {
211        let tandhalf = t + 0.5;
212        Self(core::array::from_fn(|i| {
213            fast_cos(PI / 32.0 * i as f32 * tandhalf)
214        }))
215    }
216}
217
218/// Evaluate continuous IDCT with precomputed cosines.
219#[inline]
220fn continuous_idct(dct: &[f32; 32], precomputed: &PrecomputedCosines) -> f32 {
221    dct.iter()
222        .zip(precomputed.0.iter())
223        .map(|(&c, &cos)| c * cos)
224        .sum::<f32>()
225        * SQRT_2
226}
227
228// ── Catmull-Rom interpolation ───────────────────────────────────────────────
229
230/// Centripetal Catmull-Rom spline interpolation.
231/// Ported from libjxl `splines.cc:294-336` / jxl-rs `spline.rs`.
232fn draw_centripetal_catmull_rom(points: &[SplinePoint]) -> Vec<SplinePoint> {
233    if points.is_empty() {
234        return vec![];
235    }
236    if points.len() == 1 {
237        return vec![points[0]];
238    }
239
240    // Extend endpoints by reflection.
241    let first_extra = points[0] + (points[0] - points[1]);
242    let last_extra =
243        points[points.len() - 1] + (points[points.len() - 1] - points[points.len() - 2]);
244
245    let extended: Vec<SplinePoint> = core::iter::once(first_extra)
246        .chain(points.iter().copied())
247        .chain(core::iter::once(last_extra))
248        .collect();
249
250    // Compute centripetal distances between consecutive extended points.
251    let mut dists = Vec::with_capacity(extended.len());
252    for i in 0..extended.len() - 1 {
253        dists.push((extended[i + 1] - extended[i]).abs().sqrt());
254    }
255    // dists[i] = sqrt(|extended[i+1] - extended[i]|), length = extended.len() - 1
256
257    let num_windows = extended.len() - 3; // = points.len() - 1
258    let mut result = Vec::with_capacity(num_windows * NUM_POINTS_PER_SEGMENT + 1);
259
260    for w in 0..num_windows {
261        // Window: extended[w], extended[w+1], extended[w+2], extended[w+3]
262        // Distances: dists[w], dists[w+1], dists[w+2]
263        let p = [
264            extended[w],
265            extended[w + 1],
266            extended[w + 2],
267            extended[w + 3],
268        ];
269        let d = [dists[w], dists[w + 1], dists[w + 2]];
270
271        let mut t = [0.0f32; 4];
272        t[1] = t[0] + d[0];
273        t[2] = t[1] + d[1];
274        t[3] = t[2] + d[2];
275
276        // First point of this segment
277        result.push(p[1]);
278
279        for i in 1..NUM_POINTS_PER_SEGMENT {
280            let tt = d[0] + (i as f32 / NUM_POINTS_PER_SEGMENT as f32) * d[1];
281
282            // Three-level interpolation
283            let mut a = [SplinePoint::default(); 3];
284            for k in 0..3 {
285                a[k] = p[k] + (p[k + 1] - p[k]) * ((tt - t[k]) / d[k]);
286            }
287            let mut b = [SplinePoint::default(); 2];
288            for k in 0..2 {
289                b[k] = a[k] + (a[k + 1] - a[k]) * ((tt - t[k]) / (d[k] + d[k + 1]));
290            }
291            let point = b[0] + (b[1] - b[0]) * ((tt - t[1]) / d[1]);
292            result.push(point);
293        }
294    }
295    // Add the final point
296    result.push(points[points.len() - 1]);
297    result
298}
299
300// ── Equal-distance resampling ───────────────────────────────────────────────
301
302/// Walk curve at uniform intervals, collecting (point, multiplier) pairs.
303/// Ported from libjxl `splines.cc:344-375` / jxl-rs `spline.rs`.
304fn for_each_equally_spaced_point(
305    points: &[SplinePoint],
306    desired_distance: f32,
307) -> Vec<(SplinePoint, f32)> {
308    if points.is_empty() {
309        return vec![];
310    }
311    let mut result = Vec::new();
312    result.push((points[0], desired_distance));
313    if points.len() == 1 {
314        return result;
315    }
316
317    let mut accumulated_distance = 0.0f32;
318    for index in 0..points.len() - 1 {
319        let mut current = points[index];
320        let next = points[index + 1];
321        let segment = next - current;
322        let segment_length = segment.abs();
323        if segment_length < 1e-10 {
324            continue;
325        }
326        let unit_step = segment / segment_length;
327        if accumulated_distance + segment_length >= desired_distance {
328            current = current + unit_step * (desired_distance - accumulated_distance);
329            result.push((current, desired_distance));
330            accumulated_distance -= desired_distance;
331        }
332        accumulated_distance += segment_length;
333        while accumulated_distance >= desired_distance {
334            current = current + unit_step * desired_distance;
335            result.push((current, desired_distance));
336            accumulated_distance -= desired_distance;
337        }
338    }
339    result.push((points[points.len() - 1], accumulated_distance));
340    result
341}
342
343// ── Quantization ────────────────────────────────────────────────────────────
344
345/// Compute inverse adjusted quantization factor.
346fn inv_adjusted_quant(adjustment: i32) -> f32 {
347    if adjustment >= 0 {
348        1.0 / (1.0 + 0.125 * adjustment as f32)
349    } else {
350        1.0 - 0.125 * adjustment as f32
351    }
352}
353
354/// Compute adjusted quantization factor (inverse of inv_adjusted_quant).
355fn adjusted_quant(adjustment: i32) -> f32 {
356    if adjustment >= 0 {
357        1.0 + 0.125 * adjustment as f32
358    } else {
359        1.0 / (1.0 - 0.125 * adjustment as f32)
360    }
361}
362
363impl QuantizedSpline {
364    /// Quantize a spline. Ported from libjxl `QuantizedSpline::Create()`.
365    ///
366    /// Process order: Y (channel 1) first for CfL decorrelation, then X (0), B (2).
367    fn from_spline(
368        spline: &Spline,
369        quantization_adjustment: i32,
370        y_to_x: f32,
371        y_to_b: f32,
372    ) -> Self {
373        let quant = adjusted_quant(quantization_adjustment);
374
375        // Quantize control points: delta-of-deltas encoding.
376        // Starting point is encoded separately; here we encode the second-order
377        // differences of the remaining points.
378        let mut control_points = Vec::new();
379        if spline.control_points.len() > 1 {
380            let pts = &spline.control_points;
381            let mut prev_delta_x = 0i64;
382            let mut prev_delta_y = 0i64;
383            let mut prev_x = pts[0].x.round() as i64;
384            let mut prev_y = pts[0].y.round() as i64;
385
386            for p in pts.iter().skip(1) {
387                let cur_x = p.x.round() as i64;
388                let cur_y = p.y.round() as i64;
389                let delta_x = cur_x - prev_x;
390                let delta_y = cur_y - prev_y;
391                let dd_x = delta_x - prev_delta_x;
392                let dd_y = delta_y - prev_delta_y;
393                control_points.push((dd_x, dd_y));
394                prev_delta_x = delta_x;
395                prev_delta_y = delta_y;
396                prev_x = cur_x;
397                prev_y = cur_y;
398            }
399        }
400
401        // Quantize Y channel first (channel 1) for CfL reference.
402        let mut quantized_color = [[0i32; 32]; 3];
403        for (i, qc) in quantized_color[1].iter_mut().enumerate() {
404            let dct_factor = if i == 0 { SQRT_2 } else { 1.0 };
405            *qc = (spline.color_dct[1][i] * dct_factor * quant / CHANNEL_WEIGHT[1]).round() as i32;
406        }
407
408        // Dequantize Y for CfL decorrelation reference.
409        let inv_quant = inv_adjusted_quant(quantization_adjustment);
410        let mut restored_y = [0.0f32; 32];
411        for (i, ry) in restored_y.iter_mut().enumerate() {
412            let inv_dct_factor = if i == 0 { FRAC_1_SQRT_2 } else { 1.0 };
413            *ry = quantized_color[1][i] as f32 * inv_dct_factor * CHANNEL_WEIGHT[1] * inv_quant;
414        }
415
416        // Quantize X (channel 0) and B (channel 2) with CfL decorrelation.
417        for c in [0, 2] {
418            let cfl_factor = if c == 0 { y_to_x } else { y_to_b };
419            for (i, qc) in quantized_color[c].iter_mut().enumerate() {
420                let dct_factor = if i == 0 { SQRT_2 } else { 1.0 };
421                let decorrelated = spline.color_dct[c][i] - cfl_factor * restored_y[i];
422                *qc = (decorrelated * dct_factor * quant / CHANNEL_WEIGHT[c]).round() as i32;
423            }
424        }
425
426        // Quantize sigma DCT.
427        let mut quantized_sigma = [0i32; 32];
428        for (i, qs) in quantized_sigma.iter_mut().enumerate() {
429            let dct_factor = if i == 0 { SQRT_2 } else { 1.0 };
430            *qs = (spline.sigma_dct[i] * dct_factor * quant / CHANNEL_WEIGHT[3]).round() as i32;
431        }
432
433        Self {
434            control_points,
435            color_dct: quantized_color,
436            sigma_dct: quantized_sigma,
437        }
438    }
439
440    /// Dequantize back to floating-point spline (for rendering).
441    /// This matches what the decoder will reconstruct.
442    fn dequantize(
443        &self,
444        starting_point: SplinePoint,
445        quantization_adjustment: i32,
446        y_to_x: f32,
447        y_to_b: f32,
448    ) -> DequantizedSpline {
449        let inv_quant = inv_adjusted_quant(quantization_adjustment);
450
451        // Reconstruct control points from delta-of-deltas.
452        let mut control_points = Vec::with_capacity(self.control_points.len() + 1);
453        let sp_x = starting_point.x.round() as i64;
454        let sp_y = starting_point.y.round() as i64;
455        control_points.push(SplinePoint::new(sp_x as f32, sp_y as f32));
456
457        let mut cur_x = sp_x;
458        let mut cur_y = sp_y;
459        let mut delta_x = 0i64;
460        let mut delta_y = 0i64;
461        for &(dd_x, dd_y) in &self.control_points {
462            delta_x += dd_x;
463            delta_y += dd_y;
464            cur_x += delta_x;
465            cur_y += delta_y;
466            control_points.push(SplinePoint::new(cur_x as f32, cur_y as f32));
467        }
468
469        // Dequantize color DCTs.
470        let mut color_dct = [[0.0f32; 32]; 3];
471        for (c, (out_ch, in_ch)) in color_dct.iter_mut().zip(self.color_dct.iter()).enumerate() {
472            for (i, (out, &inp)) in out_ch.iter_mut().zip(in_ch.iter()).enumerate() {
473                let inv_dct_factor = if i == 0 { FRAC_1_SQRT_2 } else { 1.0 };
474                *out = inp as f32 * inv_dct_factor * CHANNEL_WEIGHT[c] * inv_quant;
475            }
476        }
477        // Apply CfL: add Y contribution to X and B.
478        // Index-based loop required: simultaneous mutable access to channels 0/2
479        // while reading channel 1 of the same array.
480        #[allow(clippy::needless_range_loop)]
481        for i in 0..32 {
482            color_dct[0][i] += y_to_x * color_dct[1][i];
483            color_dct[2][i] += y_to_b * color_dct[1][i];
484        }
485
486        // Dequantize sigma DCT.
487        let mut sigma_dct = [0.0f32; 32];
488        for (i, (out, &inp)) in sigma_dct.iter_mut().zip(self.sigma_dct.iter()).enumerate() {
489            let inv_dct_factor = if i == 0 { FRAC_1_SQRT_2 } else { 1.0 };
490            *out = inp as f32 * inv_dct_factor * CHANNEL_WEIGHT[3] * inv_quant;
491        }
492
493        DequantizedSpline {
494            control_points,
495            color_dct,
496            sigma_dct,
497        }
498    }
499}
500
501/// Intermediate dequantized spline used for rendering.
502struct DequantizedSpline {
503    control_points: Vec<SplinePoint>,
504    color_dct: [[f32; 32]; 3],
505    sigma_dct: [f32; 32],
506}
507
508// ── Segment generation ──────────────────────────────────────────────────────
509
510/// Create a segment from a sample point along the spline.
511fn make_segment(
512    center: &SplinePoint,
513    intensity: f32,
514    color: [f32; 3],
515    sigma: f32,
516) -> Option<SplineSegment> {
517    if sigma.is_infinite() || sigma == 0.0 || (1.0 / sigma).is_infinite() || intensity.is_infinite()
518    {
519        return None;
520    }
521    let max_color = [0.01, color[0].abs(), color[1].abs(), color[2].abs()]
522        .iter()
523        .copied()
524        .map(|c| (c * intensity).abs())
525        .max_by(|a, b| a.total_cmp(b))
526        .unwrap();
527    let max_distance =
528        (-2.0 * sigma * sigma * (0.1f32.ln() * DISTANCE_EXP - max_color.ln())).sqrt();
529    if max_distance.is_nan() || max_distance <= 0.0 {
530        return None;
531    }
532    Some(SplineSegment {
533        center_x: center.x,
534        center_y: center.y,
535        color,
536        inv_sigma: 1.0 / sigma,
537        sigma_over_4_times_intensity: 0.25 * sigma * intensity,
538        maximum_distance: max_distance,
539    })
540}
541
542/// Generate segments from a dequantized spline.
543fn generate_segments(spline: &DequantizedSpline) -> Vec<SplineSegment> {
544    let intermediate = draw_centripetal_catmull_rom(&spline.control_points);
545    let points_to_draw = for_each_equally_spaced_point(&intermediate, DESIRED_RENDERING_DISTANCE);
546    if points_to_draw.len() < 2 {
547        return vec![];
548    }
549
550    let length = (points_to_draw.len() as isize - 2) as f32 * DESIRED_RENDERING_DISTANCE
551        + points_to_draw[points_to_draw.len() - 1].1;
552    if length <= 0.0 {
553        return vec![];
554    }
555
556    let inv_length = 1.0 / length;
557    let mut segments = Vec::new();
558
559    for (point_index, (point, multiplier)) in points_to_draw.iter().enumerate() {
560        let progress = (point_index as f32 * DESIRED_RENDERING_DISTANCE * inv_length).min(1.0);
561        let t = 31.0 * progress;
562
563        let precomputed = PrecomputedCosines::new(t);
564        let mut color = [0.0f32; 3];
565        for (c, coeffs) in spline.color_dct.iter().enumerate() {
566            color[c] = continuous_idct(coeffs, &precomputed);
567        }
568        let sigma = continuous_idct(&spline.sigma_dct, &precomputed);
569
570        if let Some(seg) = make_segment(point, *multiplier, color, sigma) {
571            segments.push(seg);
572        }
573    }
574    segments
575}
576
577// ── Gaussian splatting (add/subtract) ───────────────────────────────────────
578
579/// Apply a segment to a single pixel.
580#[inline]
581fn apply_segment_at(
582    planes: &mut [Vec<f32>; 3],
583    stride: usize,
584    x: usize,
585    y: usize,
586    segment: &SplineSegment,
587    add: bool,
588) {
589    let dx = x as f32 - segment.center_x;
590    let dy = y as f32 - segment.center_y;
591    let distance = (dx * dx + dy * dy).sqrt();
592    let one_dim = fast_erf((distance * 0.5 + ONE_OVER_2S2) * segment.inv_sigma)
593        - fast_erf((distance * 0.5 - ONE_OVER_2S2) * segment.inv_sigma);
594    let local_intensity = segment.sigma_over_4_times_intensity * one_dim * one_dim;
595
596    let idx = y * stride + x;
597    let sign = if add { 1.0 } else { -1.0 };
598    for (plane, &color) in planes.iter_mut().zip(segment.color.iter()) {
599        plane[idx] += sign * color * local_intensity;
600    }
601}
602
603/// Apply all spline segments to XYB planes (add or subtract).
604fn apply_splines(
605    planes: &mut [Vec<f32>; 3],
606    stride: usize,
607    width: usize,
608    height: usize,
609    data: &SplinesData,
610    add: bool,
611) {
612    for y in 0..height {
613        let first = data.segment_y_start[y];
614        let last = data.segment_y_start[y + 1];
615        for seg_idx_pos in first..last {
616            let segment = &data.segments[data.segment_indices[seg_idx_pos]];
617            let x0 = (segment.center_x - segment.maximum_distance)
618                .round()
619                .max(0.0) as usize;
620            let x1 = width.min((segment.center_x + segment.maximum_distance).round() as usize + 1);
621            for x in x0..x1 {
622                apply_segment_at(planes, stride, x, y, segment, add);
623            }
624        }
625    }
626}
627
628/// Subtract splines from XYB planes (encoder side: before VarDCT).
629pub(crate) fn subtract_splines(
630    planes: &mut [Vec<f32>; 3],
631    stride: usize,
632    width: usize,
633    height: usize,
634    data: &SplinesData,
635) {
636    apply_splines(planes, stride, width, height, data, false);
637}
638
639/// Add splines to XYB planes (reconstruction: after VarDCT decode, for butteraugli).
640#[allow(dead_code)]
641pub(crate) fn add_splines(
642    planes: &mut [Vec<f32>; 3],
643    stride: usize,
644    width: usize,
645    height: usize,
646    data: &SplinesData,
647) {
648    apply_splines(planes, stride, width, height, data, true);
649}
650
651// ── SplinesData construction ────────────────────────────────────────────────
652
653impl SplinesData {
654    /// Build SplinesData from user-provided splines.
655    ///
656    /// Quantizes, dequantizes (for pixel-accurate rendering), generates
657    /// segments, and builds the y-sorted lookup structure.
658    pub(crate) fn from_splines(
659        splines: Vec<Spline>,
660        quantization_adjustment: i32,
661        y_to_x: f32,
662        y_to_b: f32,
663        _image_width: usize,
664        image_height: usize,
665    ) -> Self {
666        let mut quantized = Vec::with_capacity(splines.len());
667        let mut all_segments: Vec<SplineSegment> = Vec::new();
668        let mut segments_by_y: Vec<(usize, usize)> = Vec::new(); // (y, segment_index)
669
670        for spline in &splines {
671            let qs = QuantizedSpline::from_spline(spline, quantization_adjustment, y_to_x, y_to_b);
672
673            // Dequantize for rendering (matches decoder reconstruction).
674            let starting_point = spline.control_points[0];
675            let dqs = qs.dequantize(starting_point, quantization_adjustment, y_to_x, y_to_b);
676
677            // Generate segments from the dequantized spline.
678            let segs = generate_segments(&dqs);
679            let base_idx = all_segments.len();
680            for (i, seg) in segs.iter().enumerate() {
681                let seg_idx = base_idx + i;
682                let y0 = 0i64.max((seg.center_y - seg.maximum_distance).round() as i64);
683                let y1 = (image_height as i64)
684                    .min((seg.center_y + seg.maximum_distance).round() as i64 + 1);
685                for y in y0..y1 {
686                    segments_by_y.push((y as usize, seg_idx));
687                }
688            }
689            all_segments.extend(segs);
690
691            quantized.push(qs);
692        }
693
694        // Sort by y for efficient row-based rendering.
695        segments_by_y.sort_by_key(|&(y, _)| y);
696
697        let mut segment_indices = Vec::with_capacity(segments_by_y.len());
698        let mut segment_y_start = vec![0usize; image_height + 1];
699
700        for &(y, idx) in &segments_by_y {
701            segment_indices.push(idx);
702            if y < image_height {
703                segment_y_start[y + 1] += 1;
704            }
705        }
706        // Prefix-sum.
707        for y in 0..image_height {
708            segment_y_start[y + 1] += segment_y_start[y];
709        }
710
711        Self {
712            quantization_adjustment,
713            splines,
714            quantized,
715            segments: all_segments,
716            segment_indices,
717            segment_y_start,
718        }
719    }
720}
721
722// ── Bitstream encoding ──────────────────────────────────────────────────────
723
724/// Encode splines section into LfGlobal.
725///
726/// Token stream layout (6 contexts):
727/// - ctx 2: num_splines - 1
728/// - ctx 1: starting positions (first absolute, rest delta-coded via pack_signed)
729/// - ctx 0: quantization_adjustment (pack_signed)
730/// - Per spline:
731///   - ctx 3: num_control_points
732///   - ctx 4: control point double-deltas (pack_signed)
733///   - ctx 5: DCT coefficients (3×32 color + 32 sigma, pack_signed)
734pub(crate) fn encode_splines_section(data: &SplinesData, writer: &mut BitWriter) -> Result<()> {
735    let mut tokens = Vec::new();
736
737    let num_splines = data.splines.len();
738    // num_splines - 1
739    tokens.push(Token::new(2, (num_splines - 1) as u32));
740
741    // Starting positions: first is unsigned absolute, rest are signed deltas.
742    let mut last_x = 0i64;
743    let mut last_y = 0i64;
744    for (i, spline) in data.splines.iter().enumerate() {
745        let sp = spline.control_points[0];
746        let x = sp.x.round() as i64;
747        let y = sp.y.round() as i64;
748        if i == 0 {
749            tokens.push(Token::new(1, x as u32));
750            tokens.push(Token::new(1, y as u32));
751        } else {
752            let dx = x - last_x;
753            let dy = y - last_y;
754            tokens.push(Token::new(1, pack_signed(dx as i32)));
755            tokens.push(Token::new(1, pack_signed(dy as i32)));
756        }
757        last_x = x;
758        last_y = y;
759    }
760
761    // Quantization adjustment.
762    tokens.push(Token::new(0, pack_signed(data.quantization_adjustment)));
763
764    // Per-spline data.
765    for qs in &data.quantized {
766        // num_control_points (double-deltas, not including starting point)
767        tokens.push(Token::new(3, qs.control_points.len() as u32));
768
769        // Control point double-deltas.
770        for &(dd_x, dd_y) in &qs.control_points {
771            tokens.push(Token::new(4, pack_signed(dd_x as i32)));
772            tokens.push(Token::new(4, pack_signed(dd_y as i32)));
773        }
774
775        // Color DCT coefficients (3 channels × 32).
776        for channel in &qs.color_dct {
777            for &coeff in channel {
778                tokens.push(Token::new(5, pack_signed(coeff)));
779            }
780        }
781
782        // Sigma DCT coefficients (32).
783        for &coeff in &qs.sigma_dct {
784            tokens.push(Token::new(5, pack_signed(coeff)));
785        }
786    }
787
788    // Write LZ77 disabled flag.
789    writer.write(1, 0)?; // lz77_enabled = false
790
791    // Build and write ANS entropy code, then tokens.
792    let code =
793        build_entropy_code_ans_with_options(&tokens, NUM_SPLINE_CONTEXTS, false, true, None, None);
794    write_entropy_code_ans(&code, writer)?;
795    write_tokens_ans(&tokens, &code, None, writer)?;
796
797    Ok(())
798}
799
800// ── Tests ───────────────────────────────────────────────────────────────────
801
802#[cfg(test)]
803mod tests {
804    use super::*;
805
806    #[test]
807    fn test_fast_erf_accuracy() {
808        // Golden data from Wikipedia error function table.
809        let golden = [
810            (0.0, 0.0),
811            (0.1, 0.112_462_92),
812            (0.2, 0.222_702_6),
813            (0.5, 0.520_499_9),
814            (1.0, 0.842_700_8),
815            (1.5, 0.966_105_16),
816            (2.0, 0.995_322_3),
817            (2.5, 0.999_593),
818            (3.0, 0.999_977_9),
819        ];
820        for (x, expected) in golden {
821            let got = fast_erf(x);
822            assert!(
823                (got - expected).abs() < 6e-4,
824                "fast_erf({x}) = {got}, expected {expected}"
825            );
826            let got_neg = fast_erf(-x);
827            assert!(
828                (got_neg - (-expected)).abs() < 6e-4,
829                "fast_erf(-{x}) = {got_neg}, expected {}",
830                -expected
831            );
832        }
833    }
834
835    #[test]
836    fn test_fast_cos_accuracy() {
837        for i in 0..100 {
838            let x = i as f32 / 100.0 * (5.0 * PI) - (2.5 * PI);
839            let got = fast_cos(x);
840            let expected = x.cos();
841            assert!(
842                (got - expected).abs() < 1e-4,
843                "fast_cos({x}) = {got}, expected {expected}"
844            );
845        }
846    }
847
848    #[test]
849    fn test_continuous_idct_values() {
850        // Simple test: DC-only signal should be constant along the spline.
851        let mut dct = [0.0f32; 32];
852        dct[0] = 1.0;
853        for t_idx in 0..32 {
854            let t = t_idx as f32;
855            let pc = PrecomputedCosines::new(t);
856            let val = continuous_idct(&dct, &pc);
857            // DC coefficient * SQRT_2 * cos(0) = 1.0 * SQRT_2 * 1.0 = SQRT_2
858            // But dct[0]*cos(0*(t+0.5)*pi/32) = 1.0*1.0 = 1.0, times SQRT_2 = SQRT_2
859            assert!(
860                (val - SQRT_2).abs() < 0.01,
861                "DC-only IDCT at t={t} = {val}, expected ~{SQRT_2}"
862            );
863        }
864    }
865
866    #[test]
867    fn test_catmull_rom_basic() {
868        // Two control points should produce a straight line with interpolation.
869        let points = vec![SplinePoint::new(0.0, 0.0), SplinePoint::new(10.0, 0.0)];
870        let interpolated = draw_centripetal_catmull_rom(&points);
871        assert!(interpolated.len() > 2, "should produce intermediate points");
872        // First and last should match input.
873        assert!((interpolated[0].x - 0.0).abs() < 0.01);
874        assert!((interpolated[0].y - 0.0).abs() < 0.01);
875        let last = interpolated[interpolated.len() - 1];
876        assert!((last.x - 10.0).abs() < 0.01);
877        assert!((last.y - 0.0).abs() < 0.01);
878    }
879
880    #[test]
881    fn test_quantize_roundtrip() {
882        // Create a simple spline with small DCT values, quantize, dequantize.
883        let spline = Spline {
884            control_points: vec![SplinePoint::new(10.0, 10.0), SplinePoint::new(50.0, 50.0)],
885            color_dct: {
886                let mut dct = [[0.0f32; 32]; 3];
887                dct[1][0] = 0.5; // Y DC
888                dct[0][0] = 0.1; // X DC
889                dct[2][0] = 0.2; // B DC
890                dct
891            },
892            sigma_dct: {
893                let mut s = [0.0f32; 32];
894                s[0] = 2.0;
895                s
896            },
897        };
898
899        let adj = 0;
900        let y_to_x = 0.0;
901        let y_to_b = 1.13;
902
903        let qs = QuantizedSpline::from_spline(&spline, adj, y_to_x, y_to_b);
904        let dqs = qs.dequantize(spline.control_points[0], adj, y_to_x, y_to_b);
905
906        // Control points should roundtrip exactly (integer-rounded).
907        assert_eq!(dqs.control_points.len(), 2);
908        assert!((dqs.control_points[0].x - 10.0).abs() < 1.0);
909        assert!((dqs.control_points[1].x - 50.0).abs() < 1.0);
910
911        // Sigma should be close (within quantization error).
912        assert!(
913            (dqs.sigma_dct[0] - spline.sigma_dct[0]).abs() < 0.5,
914            "sigma DC roundtrip: got {}, expected {}",
915            dqs.sigma_dct[0],
916            spline.sigma_dct[0]
917        );
918    }
919
920    #[test]
921    fn test_double_delta_encoding() {
922        // Verify that delta-of-deltas encoding is correct.
923        let spline = Spline {
924            control_points: vec![
925                SplinePoint::new(0.0, 0.0),
926                SplinePoint::new(10.0, 0.0),
927                SplinePoint::new(20.0, 5.0),
928                SplinePoint::new(30.0, 15.0),
929            ],
930            color_dct: [[0.0; 32]; 3],
931            sigma_dct: {
932                let mut s = [0.0; 32];
933                s[0] = 1.0;
934                s
935            },
936        };
937
938        let qs = QuantizedSpline::from_spline(&spline, 0, 0.0, 0.0);
939
940        // Deltas: (10,0), (10,5), (10,10)
941        // Double-deltas: (10,0), (0,5), (0,5)
942        assert_eq!(qs.control_points.len(), 3); // 4 points - 1 starting point = 3
943        assert_eq!(qs.control_points[0], (10, 0));
944        assert_eq!(qs.control_points[1], (0, 5));
945        assert_eq!(qs.control_points[2], (0, 5));
946    }
947
948    #[test]
949    fn test_splines_data_construction() {
950        let spline = Spline {
951            control_points: vec![SplinePoint::new(10.0, 10.0), SplinePoint::new(50.0, 50.0)],
952            color_dct: {
953                let mut dct = [[0.0f32; 32]; 3];
954                dct[1][0] = 0.5;
955                dct
956            },
957            sigma_dct: {
958                let mut s = [0.0f32; 32];
959                s[0] = 3.0;
960                s
961            },
962        };
963
964        let data = SplinesData::from_splines(vec![spline], 0, 0.0, 1.13, 64, 64);
965
966        assert_eq!(data.splines.len(), 1);
967        assert_eq!(data.quantized.len(), 1);
968        assert!(!data.segments.is_empty(), "should have rendered segments");
969        assert_eq!(
970            data.segment_y_start.len(),
971            65,
972            "y_start should have height+1 entries"
973        );
974    }
975}