jpeg_encoder/
encoder.rs

1use crate::fdct::fdct;
2use crate::huffman::{CodingClass, HuffmanTable};
3use crate::image_buffer::*;
4use crate::marker::Marker;
5use crate::quantization::{QuantizationTable, QuantizationTableType};
6use crate::writer::{JfifWrite, JfifWriter, ZIGZAG};
7use crate::{EncodingError, PixelDensity};
8
9use alloc::vec;
10use alloc::vec::Vec;
11
12#[cfg(feature = "std")]
13use std::io::BufWriter;
14
15#[cfg(feature = "std")]
16use std::fs::File;
17
18#[cfg(feature = "std")]
19use std::path::Path;
20
21/// # Color types used in encoding
22#[derive(Copy, Clone, Debug, Eq, PartialEq)]
23pub enum JpegColorType {
24    /// One component grayscale colorspace
25    Luma,
26
27    /// Three component YCbCr colorspace
28    Ycbcr,
29
30    /// 4 Component CMYK colorspace
31    Cmyk,
32
33    /// 4 Component YCbCrK colorspace
34    Ycck,
35}
36
37impl JpegColorType {
38    pub(crate) fn get_num_components(self) -> usize {
39        use JpegColorType::*;
40
41        match self {
42            Luma => 1,
43            Ycbcr => 3,
44            Cmyk | Ycck => 4,
45        }
46    }
47}
48
49/// # Color types for input images
50///
51/// Available color input formats for [Encoder::encode]. Other types can be used
52/// by implementing an [ImageBuffer](crate::ImageBuffer).
53#[derive(Copy, Clone, Debug, Eq, PartialEq)]
54pub enum ColorType {
55    /// Grayscale with 1 byte per pixel
56    Luma,
57
58    /// RGB with 3 bytes per pixel
59    Rgb,
60
61    /// Red, Green, Blue with 4 bytes per pixel. The alpha channel will be ignored during encoding.
62    Rgba,
63
64    /// RGB with 3 bytes per pixel
65    Bgr,
66
67    /// RGBA with 4 bytes per pixel. The alpha channel will be ignored during encoding.
68    Bgra,
69
70    /// YCbCr with 3 bytes per pixel.
71    Ycbcr,
72
73    /// CMYK with 4 bytes per pixel.
74    Cmyk,
75
76    /// CMYK with 4 bytes per pixel. Encoded as YCCK (YCbCrK)
77    CmykAsYcck,
78
79    /// YCCK (YCbCrK) with 4 bytes per pixel.
80    Ycck,
81}
82
83impl ColorType {
84    pub(crate) fn get_bytes_per_pixel(self) -> usize {
85        use ColorType::*;
86
87        match self {
88            Luma => 1,
89            Rgb | Bgr | Ycbcr => 3,
90            Rgba | Bgra | Cmyk | CmykAsYcck | Ycck => 4,
91        }
92    }
93}
94
95#[repr(u8)]
96#[derive(Copy, Clone, Debug, Eq, PartialEq)]
97/// # Sampling factors for chroma subsampling
98///
99/// ## Warning
100/// Sampling factor of 4 are not supported by all decoders or applications
101#[allow(non_camel_case_types)]
102pub enum SamplingFactor {
103    F_1_1 = 1 << 4 | 1,
104    F_2_1 = 2 << 4 | 1,
105    F_1_2 = 1 << 4 | 2,
106    F_2_2 = 2 << 4 | 2,
107    F_4_1 = 4 << 4 | 1,
108    F_4_2 = 4 << 4 | 2,
109    F_1_4 = 1 << 4 | 4,
110    F_2_4 = 2 << 4 | 4,
111
112    /// Alias for F_1_1
113    R_4_4_4 = 0x80 | 1 << 4 | 1,
114
115    /// Alias for F_1_2
116    R_4_4_0 = 0x80 | 1 << 4 | 2,
117
118    /// Alias for F_1_4
119    R_4_4_1 = 0x80 | 1 << 4 | 4,
120
121    /// Alias for F_2_1
122    R_4_2_2 = 0x80 | 2 << 4 | 1,
123
124    /// Alias for F_2_2
125    R_4_2_0 = 0x80 | 2 << 4 | 2,
126
127    /// Alias for F_2_4
128    R_4_2_1 = 0x80 | 2 << 4 | 4,
129
130    /// Alias for F_4_1
131    R_4_1_1 = 0x80 | 4 << 4 | 1,
132
133    /// Alias for F_4_2
134    R_4_1_0 = 0x80 | 4 << 4 | 2,
135}
136
137impl SamplingFactor {
138    /// Get variant for supplied factors or None if not supported
139    pub fn from_factors(horizontal: u8, vertical: u8) -> Option<SamplingFactor> {
140        use SamplingFactor::*;
141
142        match (horizontal, vertical) {
143            (1, 1) => Some(F_1_1),
144            (1, 2) => Some(F_1_2),
145            (1, 4) => Some(F_1_4),
146            (2, 1) => Some(F_2_1),
147            (2, 2) => Some(F_2_2),
148            (2, 4) => Some(F_2_4),
149            (4, 1) => Some(F_4_1),
150            (4, 2) => Some(F_4_2),
151            _ => None,
152        }
153    }
154
155    pub(crate) fn get_sampling_factors(self) -> (u8, u8) {
156        let value = self as u8;
157        ((value >> 4) & 0x07, value & 0xf)
158    }
159
160    pub(crate) fn supports_interleaved(self) -> bool {
161        use SamplingFactor::*;
162
163        // Interleaved mode is only supported with h/v sampling factors of 1 or 2.
164        // Sampling factors of 4 needs sequential encoding
165        matches!(
166            self,
167            F_1_1 | F_2_1 | F_1_2 | F_2_2 | R_4_4_4 | R_4_4_0 | R_4_2_2 | R_4_2_0
168        )
169    }
170}
171
172pub(crate) struct Component {
173    pub id: u8,
174    pub quantization_table: u8,
175    pub dc_huffman_table: u8,
176    pub ac_huffman_table: u8,
177    pub horizontal_sampling_factor: u8,
178    pub vertical_sampling_factor: u8,
179}
180
181macro_rules! add_component {
182    ($components:expr, $id:expr, $dest:expr, $h_sample:expr, $v_sample:expr) => {
183        $components.push(Component {
184            id: $id,
185            quantization_table: $dest,
186            dc_huffman_table: $dest,
187            ac_huffman_table: $dest,
188            horizontal_sampling_factor: $h_sample,
189            vertical_sampling_factor: $v_sample,
190        });
191    };
192}
193
194/// # The JPEG encoder
195pub struct Encoder<W: JfifWrite> {
196    writer: JfifWriter<W>,
197    density: PixelDensity,
198    quality: u8,
199
200    components: Vec<Component>,
201    quantization_tables: [QuantizationTableType; 2],
202    huffman_tables: [(HuffmanTable, HuffmanTable); 2],
203
204    sampling_factor: SamplingFactor,
205
206    progressive_scans: Option<u8>,
207
208    restart_interval: Option<u16>,
209
210    optimize_huffman_table: bool,
211
212    app_segments: Vec<(u8, Vec<u8>)>,
213}
214
215impl<W: JfifWrite> Encoder<W> {
216    /// Create a new encoder with the given quality
217    ///
218    /// The quality must be between 1 and 100 where 100 is the highest image quality.<br>
219    /// By default, quality settings below 90 use a chroma subsampling (2x2 / 4:2:0) which can
220    /// be changed with [set_sampling_factor](Encoder::set_sampling_factor)
221    pub fn new(w: W, quality: u8) -> Encoder<W> {
222        let huffman_tables = [
223            (
224                HuffmanTable::default_luma_dc(),
225                HuffmanTable::default_luma_ac(),
226            ),
227            (
228                HuffmanTable::default_chroma_dc(),
229                HuffmanTable::default_chroma_ac(),
230            ),
231        ];
232
233        let quantization_tables = [
234            QuantizationTableType::Default,
235            QuantizationTableType::Default,
236        ];
237
238        let sampling_factor = if quality < 90 {
239            SamplingFactor::F_2_2
240        } else {
241            SamplingFactor::F_1_1
242        };
243
244        Encoder {
245            writer: JfifWriter::new(w),
246            density: PixelDensity::default(),
247            quality,
248            components: vec![],
249            quantization_tables,
250            huffman_tables,
251            sampling_factor,
252            progressive_scans: None,
253            restart_interval: None,
254            optimize_huffman_table: false,
255            app_segments: Vec::new(),
256        }
257    }
258
259    /// Set pixel density for the image
260    ///
261    /// By default, this value is None which is equal to "1 pixel per pixel".
262    pub fn set_density(&mut self, density: PixelDensity) {
263        self.density = density;
264    }
265
266    /// Return pixel density
267    pub fn density(&self) -> PixelDensity {
268        self.density
269    }
270
271    /// Set chroma subsampling factor
272    pub fn set_sampling_factor(&mut self, sampling: SamplingFactor) {
273        self.sampling_factor = sampling;
274    }
275
276    /// Get chroma subsampling factor
277    pub fn sampling_factor(&self) -> SamplingFactor {
278        self.sampling_factor
279    }
280
281    /// Set quantization tables for luma and chroma components
282    pub fn set_quantization_tables(
283        &mut self,
284        luma: QuantizationTableType,
285        chroma: QuantizationTableType,
286    ) {
287        self.quantization_tables = [luma, chroma];
288    }
289
290    /// Get configured quantization tables
291    pub fn quantization_tables(&self) -> &[QuantizationTableType; 2] {
292        &self.quantization_tables
293    }
294
295    /// Controls if progressive encoding is used.
296    ///
297    /// By default, progressive encoding uses 4 scans.<br>
298    /// Use [set_progressive_scans](Encoder::set_progressive_scans) to use a different number of scans
299    pub fn set_progressive(&mut self, progressive: bool) {
300        self.progressive_scans = if progressive { Some(4) } else { None };
301    }
302
303    /// Set number of scans per component for progressive encoding
304    ///
305    /// Number of scans must be between 2 and 64.
306    /// There is at least one scan for the DC coefficients and one for the remaining 63 AC coefficients.
307    ///
308    /// # Panics
309    /// If number of scans is not within valid range
310    pub fn set_progressive_scans(&mut self, scans: u8) {
311        assert!(
312            (2..=64).contains(&scans),
313            "Invalid number of scans: {}",
314            scans
315        );
316        self.progressive_scans = Some(scans);
317    }
318
319    /// Return number of progressive scans if progressive encoding is enabled
320    pub fn progressive_scans(&self) -> Option<u8> {
321        self.progressive_scans
322    }
323
324    /// Set restart interval
325    ///
326    /// Set numbers of MCUs between restart markers.
327    pub fn set_restart_interval(&mut self, interval: u16) {
328        self.restart_interval = if interval == 0 { None } else { Some(interval) };
329    }
330
331    /// Return the restart interval
332    pub fn restart_interval(&self) -> Option<u16> {
333        self.restart_interval
334    }
335
336    /// Set if optimized huffman table should be created
337    ///
338    /// Optimized tables result in slightly smaller file sizes but decrease encoding performance.
339    pub fn set_optimized_huffman_tables(&mut self, optimize_huffman_table: bool) {
340        self.optimize_huffman_table = optimize_huffman_table;
341    }
342
343    /// Returns if optimized huffman table should be generated
344    pub fn optimized_huffman_tables(&self) -> bool {
345        self.optimize_huffman_table
346    }
347
348    /// Appends a custom app segment to the JFIF file
349    ///
350    /// Segment numbers need to be in the range between 1 and 15<br>
351    /// The maximum allowed data length is 2^16 - 2 bytes.
352    ///
353    /// # Errors
354    ///
355    /// Returns an error if the segment number is invalid or data exceeds the allowed size
356    pub fn add_app_segment(&mut self, segment_nr: u8, data: Vec<u8>) -> Result<(), EncodingError> {
357        if segment_nr == 0 || segment_nr > 15 {
358            Err(EncodingError::InvalidAppSegment(segment_nr))
359        } else if data.len() > 65533 {
360            Err(EncodingError::AppSegmentTooLarge(data.len()))
361        } else {
362            self.app_segments.push((segment_nr, data));
363            Ok(())
364        }
365    }
366
367    /// Add an ICC profile
368    ///
369    /// The maximum allowed data length is 16,707,345 bytes.
370    ///
371    /// # Errors
372    ///
373    /// Returns an Error if the data exceeds the maximum size for the ICC profile
374    pub fn add_icc_profile(&mut self, data: &[u8]) -> Result<(), EncodingError> {
375        // Based on https://www.color.org/ICC_Minor_Revision_for_Web.pdf
376        // B.4  Embedding ICC profiles in JFIF files
377
378        const MARKER: &[u8; 12] = b"ICC_PROFILE\0";
379        const MAX_CHUNK_LENGTH: usize = 65535 - 2 - 12 - 2;
380
381        let num_chunks = ceil_div(data.len(), MAX_CHUNK_LENGTH);
382
383        // Sequence number is stored as a byte and starts with 1
384        if num_chunks >= 255 {
385            return Err(EncodingError::IccTooLarge(data.len()));
386        }
387
388        for (i, data) in data.chunks(MAX_CHUNK_LENGTH).enumerate() {
389            let mut chunk_data = Vec::with_capacity(MAX_CHUNK_LENGTH);
390            chunk_data.extend_from_slice(MARKER);
391            chunk_data.push(i as u8 + 1);
392            chunk_data.push(num_chunks as u8);
393            chunk_data.extend_from_slice(data);
394
395            self.add_app_segment(2, chunk_data)?;
396        }
397
398        Ok(())
399    }
400
401    /// Embeds Exif metadata into the image
402    ///
403    /// The maximum allowed data length is 65,528 bytes.
404    ///
405    /// # Errors
406    ///
407    /// Returns an Error if the data exceeds the maximum size for the Exif metadata
408    pub fn add_exif_metadata(&mut self, data: &[u8]) -> Result<(), EncodingError> {
409        // E x i f \0 \0
410        /// The header for an EXIF APP1 segment
411        const EXIF_HEADER: [u8; 6] = [0x45, 0x78, 0x69, 0x66, 0x00, 0x00];
412
413        let mut formatted = EXIF_HEADER.to_vec();
414        formatted.extend_from_slice(data);
415
416        self.add_app_segment(1, formatted)
417    }
418
419    /// Encode an image
420    ///
421    /// Data format and length must conform to specified width, height and color type.
422    pub fn encode(
423        self,
424        data: &[u8],
425        width: u16,
426        height: u16,
427        color_type: ColorType,
428    ) -> Result<(), EncodingError> {
429        let required_data_len = width as usize * height as usize * color_type.get_bytes_per_pixel();
430
431        if data.len() < required_data_len {
432            return Err(EncodingError::BadImageData {
433                length: data.len(),
434                required: required_data_len,
435            });
436        }
437
438        #[cfg(all(feature = "simd", any(target_arch = "x86", target_arch = "x86_64")))]
439        {
440            if std::is_x86_feature_detected!("avx2") {
441                use crate::avx2::*;
442
443                return match color_type {
444                    ColorType::Luma => self
445                        .encode_image_internal::<_, AVX2Operations>(GrayImage(data, width, height)),
446                    ColorType::Rgb => self.encode_image_internal::<_, AVX2Operations>(
447                        RgbImageAVX2(data, width, height),
448                    ),
449                    ColorType::Rgba => self.encode_image_internal::<_, AVX2Operations>(
450                        RgbaImageAVX2(data, width, height),
451                    ),
452                    ColorType::Bgr => self.encode_image_internal::<_, AVX2Operations>(
453                        BgrImageAVX2(data, width, height),
454                    ),
455                    ColorType::Bgra => self.encode_image_internal::<_, AVX2Operations>(
456                        BgraImageAVX2(data, width, height),
457                    ),
458                    ColorType::Ycbcr => self.encode_image_internal::<_, AVX2Operations>(
459                        YCbCrImage(data, width, height),
460                    ),
461                    ColorType::Cmyk => self
462                        .encode_image_internal::<_, AVX2Operations>(CmykImage(data, width, height)),
463                    ColorType::CmykAsYcck => self.encode_image_internal::<_, AVX2Operations>(
464                        CmykAsYcckImage(data, width, height),
465                    ),
466                    ColorType::Ycck => self
467                        .encode_image_internal::<_, AVX2Operations>(YcckImage(data, width, height)),
468                };
469            }
470        }
471
472        match color_type {
473            ColorType::Luma => self.encode_image(GrayImage(data, width, height))?,
474            ColorType::Rgb => self.encode_image(RgbImage(data, width, height))?,
475            ColorType::Rgba => self.encode_image(RgbaImage(data, width, height))?,
476            ColorType::Bgr => self.encode_image(BgrImage(data, width, height))?,
477            ColorType::Bgra => self.encode_image(BgraImage(data, width, height))?,
478            ColorType::Ycbcr => self.encode_image(YCbCrImage(data, width, height))?,
479            ColorType::Cmyk => self.encode_image(CmykImage(data, width, height))?,
480            ColorType::CmykAsYcck => self.encode_image(CmykAsYcckImage(data, width, height))?,
481            ColorType::Ycck => self.encode_image(YcckImage(data, width, height))?,
482        }
483
484        Ok(())
485    }
486
487    /// Encode an image
488    pub fn encode_image<I: ImageBuffer>(self, image: I) -> Result<(), EncodingError> {
489        #[cfg(all(feature = "simd", any(target_arch = "x86", target_arch = "x86_64")))]
490        {
491            if std::is_x86_feature_detected!("avx2") {
492                use crate::avx2::*;
493                return self.encode_image_internal::<_, AVX2Operations>(image);
494            }
495        }
496        self.encode_image_internal::<_, DefaultOperations>(image)
497    }
498
499    fn encode_image_internal<I: ImageBuffer, OP: Operations>(
500        mut self,
501        image: I,
502    ) -> Result<(), EncodingError> {
503        if image.width() == 0 || image.height() == 0 {
504            return Err(EncodingError::ZeroImageDimensions {
505                width: image.width(),
506                height: image.height(),
507            });
508        }
509
510        let q_tables = [
511            QuantizationTable::new_with_quality(&self.quantization_tables[0], self.quality, true),
512            QuantizationTable::new_with_quality(&self.quantization_tables[1], self.quality, false),
513        ];
514
515        let jpeg_color_type = image.get_jpeg_color_type();
516        self.init_components(jpeg_color_type);
517
518        self.writer.write_marker(Marker::SOI)?;
519
520        self.writer.write_header(&self.density)?;
521
522        if jpeg_color_type == JpegColorType::Cmyk {
523            //Set ColorTransform info to "Unknown"
524            let app_14 = b"Adobe\0\0\0\0\0\0\0";
525            self.writer
526                .write_segment(Marker::APP(14), app_14.as_ref())?;
527        } else if jpeg_color_type == JpegColorType::Ycck {
528            //Set ColorTransform info to YCCK
529            let app_14 = b"Adobe\0\0\0\0\0\0\x02";
530            self.writer
531                .write_segment(Marker::APP(14), app_14.as_ref())?;
532        }
533
534        for (nr, data) in &self.app_segments {
535            self.writer.write_segment(Marker::APP(*nr), data)?;
536        }
537
538        if let Some(scans) = self.progressive_scans {
539            self.encode_image_progressive::<_, OP>(image, scans, &q_tables)?;
540        } else if self.optimize_huffman_table || !self.sampling_factor.supports_interleaved() {
541            self.encode_image_sequential::<_, OP>(image, &q_tables)?;
542        } else {
543            self.encode_image_interleaved::<_, OP>(image, &q_tables)?;
544        }
545
546        self.writer.write_marker(Marker::EOI)?;
547
548        Ok(())
549    }
550
551    fn init_components(&mut self, color: JpegColorType) {
552        let (horizontal_sampling_factor, vertical_sampling_factor) =
553            self.sampling_factor.get_sampling_factors();
554
555        match color {
556            JpegColorType::Luma => {
557                add_component!(self.components, 0, 0, 1, 1);
558            }
559            JpegColorType::Ycbcr => {
560                add_component!(
561                    self.components,
562                    0,
563                    0,
564                    horizontal_sampling_factor,
565                    vertical_sampling_factor
566                );
567                add_component!(self.components, 1, 1, 1, 1);
568                add_component!(self.components, 2, 1, 1, 1);
569            }
570            JpegColorType::Cmyk => {
571                add_component!(self.components, 0, 1, 1, 1);
572                add_component!(self.components, 1, 1, 1, 1);
573                add_component!(self.components, 2, 1, 1, 1);
574                add_component!(
575                    self.components,
576                    3,
577                    0,
578                    horizontal_sampling_factor,
579                    vertical_sampling_factor
580                );
581            }
582            JpegColorType::Ycck => {
583                add_component!(
584                    self.components,
585                    0,
586                    0,
587                    horizontal_sampling_factor,
588                    vertical_sampling_factor
589                );
590                add_component!(self.components, 1, 1, 1, 1);
591                add_component!(self.components, 2, 1, 1, 1);
592                add_component!(
593                    self.components,
594                    3,
595                    0,
596                    horizontal_sampling_factor,
597                    vertical_sampling_factor
598                );
599            }
600        }
601    }
602
603    fn get_max_sampling_size(&self) -> (usize, usize) {
604        let max_h_sampling = self.components.iter().fold(1, |value, component| {
605            value.max(component.horizontal_sampling_factor)
606        });
607
608        let max_v_sampling = self.components.iter().fold(1, |value, component| {
609            value.max(component.vertical_sampling_factor)
610        });
611
612        (usize::from(max_h_sampling), usize::from(max_v_sampling))
613    }
614
615    fn write_frame_header<I: ImageBuffer>(
616        &mut self,
617        image: &I,
618        q_tables: &[QuantizationTable; 2],
619    ) -> Result<(), EncodingError> {
620        self.writer.write_frame_header(
621            image.width(),
622            image.height(),
623            &self.components,
624            self.progressive_scans.is_some(),
625        )?;
626
627        self.writer.write_quantization_segment(0, &q_tables[0])?;
628        self.writer.write_quantization_segment(1, &q_tables[1])?;
629
630        self.writer
631            .write_huffman_segment(CodingClass::Dc, 0, &self.huffman_tables[0].0)?;
632
633        self.writer
634            .write_huffman_segment(CodingClass::Ac, 0, &self.huffman_tables[0].1)?;
635
636        if image.get_jpeg_color_type().get_num_components() >= 3 {
637            self.writer
638                .write_huffman_segment(CodingClass::Dc, 1, &self.huffman_tables[1].0)?;
639
640            self.writer
641                .write_huffman_segment(CodingClass::Ac, 1, &self.huffman_tables[1].1)?;
642        }
643
644        if let Some(restart_interval) = self.restart_interval {
645            self.writer.write_dri(restart_interval)?;
646        }
647
648        Ok(())
649    }
650
651    fn init_rows(&mut self, buffer_size: usize) -> [Vec<u8>; 4] {
652        // To simplify the code and to give the compiler more infos to optimize stuff we always initialize 4 components
653        // Resource overhead should be minimal because an empty Vec doesn't allocate
654
655        match self.components.len() {
656            1 => [
657                Vec::with_capacity(buffer_size),
658                Vec::new(),
659                Vec::new(),
660                Vec::new(),
661            ],
662            3 => [
663                Vec::with_capacity(buffer_size),
664                Vec::with_capacity(buffer_size),
665                Vec::with_capacity(buffer_size),
666                Vec::new(),
667            ],
668            4 => [
669                Vec::with_capacity(buffer_size),
670                Vec::with_capacity(buffer_size),
671                Vec::with_capacity(buffer_size),
672                Vec::with_capacity(buffer_size),
673            ],
674            len => unreachable!("Unsupported component length: {}", len),
675        }
676    }
677
678    /// Encode all components with one scan
679    ///
680    /// This is only valid for sampling factors of 1 and 2
681    fn encode_image_interleaved<I: ImageBuffer, OP: Operations>(
682        &mut self,
683        image: I,
684        q_tables: &[QuantizationTable; 2],
685    ) -> Result<(), EncodingError> {
686        self.write_frame_header(&image, q_tables)?;
687        self.writer
688            .write_scan_header(&self.components.iter().collect::<Vec<_>>(), None)?;
689
690        let (max_h_sampling, max_v_sampling) = self.get_max_sampling_size();
691
692        let width = image.width();
693        let height = image.height();
694
695        let num_cols = ceil_div(usize::from(width), 8 * max_h_sampling);
696        let num_rows = ceil_div(usize::from(height), 8 * max_v_sampling);
697
698        let buffer_width = num_cols * 8 * max_h_sampling;
699        let buffer_size = buffer_width * 8 * max_v_sampling;
700
701        let mut row: [Vec<_>; 4] = self.init_rows(buffer_size);
702
703        let mut prev_dc = [0i16; 4];
704
705        let restart_interval = self.restart_interval.unwrap_or(0);
706        let mut restarts = 0;
707        let mut restarts_to_go = restart_interval;
708
709        for block_y in 0..num_rows {
710            for r in &mut row {
711                r.clear();
712            }
713
714            for y in 0..(8 * max_v_sampling) {
715                let y = y + block_y * 8 * max_v_sampling;
716                let y = (y.min(height as usize - 1)) as u16;
717
718                image.fill_buffers(y, &mut row);
719
720                for _ in usize::from(width)..buffer_width {
721                    for channel in &mut row {
722                        if !channel.is_empty() {
723                            channel.push(channel[channel.len() - 1]);
724                        }
725                    }
726                }
727            }
728
729            for block_x in 0..num_cols {
730                if restart_interval > 0 && restarts_to_go == 0 {
731                    self.writer.finalize_bit_buffer()?;
732                    self.writer
733                        .write_marker(Marker::RST((restarts % 8) as u8))?;
734
735                    prev_dc[0] = 0;
736                    prev_dc[1] = 0;
737                    prev_dc[2] = 0;
738                    prev_dc[3] = 0;
739                }
740
741                for (i, component) in self.components.iter().enumerate() {
742                    for v_offset in 0..component.vertical_sampling_factor as usize {
743                        for h_offset in 0..component.horizontal_sampling_factor as usize {
744                            let mut block = get_block(
745                                &row[i],
746                                block_x * 8 * max_h_sampling + (h_offset * 8),
747                                v_offset * 8,
748                                max_h_sampling / component.horizontal_sampling_factor as usize,
749                                max_v_sampling / component.vertical_sampling_factor as usize,
750                                buffer_width,
751                            );
752
753                            OP::fdct(&mut block);
754
755                            let mut q_block = [0i16; 64];
756
757                            OP::quantize_block(
758                                &block,
759                                &mut q_block,
760                                &q_tables[component.quantization_table as usize],
761                            );
762
763                            self.writer.write_block(
764                                &q_block,
765                                prev_dc[i],
766                                &self.huffman_tables[component.dc_huffman_table as usize].0,
767                                &self.huffman_tables[component.ac_huffman_table as usize].1,
768                            )?;
769
770                            prev_dc[i] = q_block[0];
771                        }
772                    }
773                }
774
775                if restart_interval > 0 {
776                    if restarts_to_go == 0 {
777                        restarts_to_go = restart_interval;
778                        restarts += 1;
779                        restarts &= 7;
780                    }
781                    restarts_to_go -= 1;
782                }
783            }
784        }
785
786        self.writer.finalize_bit_buffer()?;
787
788        Ok(())
789    }
790
791    /// Encode components with one scan per component
792    fn encode_image_sequential<I: ImageBuffer, OP: Operations>(
793        &mut self,
794        image: I,
795        q_tables: &[QuantizationTable; 2],
796    ) -> Result<(), EncodingError> {
797        let blocks = self.encode_blocks::<_, OP>(&image, q_tables);
798
799        if self.optimize_huffman_table {
800            self.optimize_huffman_table(&blocks);
801        }
802
803        self.write_frame_header(&image, q_tables)?;
804
805        for (i, component) in self.components.iter().enumerate() {
806            let restart_interval = self.restart_interval.unwrap_or(0);
807            let mut restarts = 0;
808            let mut restarts_to_go = restart_interval;
809
810            self.writer.write_scan_header(&[component], None)?;
811
812            let mut prev_dc = 0;
813
814            for block in &blocks[i] {
815                if restart_interval > 0 && restarts_to_go == 0 {
816                    self.writer.finalize_bit_buffer()?;
817                    self.writer
818                        .write_marker(Marker::RST((restarts % 8) as u8))?;
819
820                    prev_dc = 0;
821                }
822
823                self.writer.write_block(
824                    block,
825                    prev_dc,
826                    &self.huffman_tables[component.dc_huffman_table as usize].0,
827                    &self.huffman_tables[component.ac_huffman_table as usize].1,
828                )?;
829
830                prev_dc = block[0];
831
832                if restart_interval > 0 {
833                    if restarts_to_go == 0 {
834                        restarts_to_go = restart_interval;
835                        restarts += 1;
836                        restarts &= 7;
837                    }
838                    restarts_to_go -= 1;
839                }
840            }
841
842            self.writer.finalize_bit_buffer()?;
843        }
844
845        Ok(())
846    }
847
848    /// Encode image in progressive mode
849    ///
850    /// This only support spectral selection for now
851    fn encode_image_progressive<I: ImageBuffer, OP: Operations>(
852        &mut self,
853        image: I,
854        scans: u8,
855        q_tables: &[QuantizationTable; 2],
856    ) -> Result<(), EncodingError> {
857        let blocks = self.encode_blocks::<_, OP>(&image, q_tables);
858
859        if self.optimize_huffman_table {
860            self.optimize_huffman_table(&blocks);
861        }
862
863        self.write_frame_header(&image, q_tables)?;
864
865        // Phase 1: DC Scan
866        //          Only the DC coefficients can be transfer in the first component scans
867        for (i, component) in self.components.iter().enumerate() {
868            self.writer.write_scan_header(&[component], Some((0, 0)))?;
869
870            let restart_interval = self.restart_interval.unwrap_or(0);
871            let mut restarts = 0;
872            let mut restarts_to_go = restart_interval;
873
874            let mut prev_dc = 0;
875
876            for block in &blocks[i] {
877                if restart_interval > 0 && restarts_to_go == 0 {
878                    self.writer.finalize_bit_buffer()?;
879                    self.writer
880                        .write_marker(Marker::RST((restarts % 8) as u8))?;
881
882                    prev_dc = 0;
883                }
884
885                self.writer.write_dc(
886                    block[0],
887                    prev_dc,
888                    &self.huffman_tables[component.dc_huffman_table as usize].0,
889                )?;
890
891                prev_dc = block[0];
892
893                if restart_interval > 0 {
894                    if restarts_to_go == 0 {
895                        restarts_to_go = restart_interval;
896                        restarts += 1;
897                        restarts &= 7;
898                    }
899                    restarts_to_go -= 1;
900                }
901            }
902
903            self.writer.finalize_bit_buffer()?;
904        }
905
906        // Phase 2: AC scans
907        let scans = scans as usize - 1;
908
909        let values_per_scan = 64 / scans;
910
911        for scan in 0..scans {
912            let start = (scan * values_per_scan).max(1);
913            let end = if scan == scans - 1 {
914                // ensure last scan is always transfers the remaining coefficients
915                64
916            } else {
917                (scan + 1) * values_per_scan
918            };
919
920            for (i, component) in self.components.iter().enumerate() {
921                let restart_interval = self.restart_interval.unwrap_or(0);
922                let mut restarts = 0;
923                let mut restarts_to_go = restart_interval;
924
925                self.writer
926                    .write_scan_header(&[component], Some((start as u8, end as u8 - 1)))?;
927
928                for block in &blocks[i] {
929                    if restart_interval > 0 && restarts_to_go == 0 {
930                        self.writer.finalize_bit_buffer()?;
931                        self.writer
932                            .write_marker(Marker::RST((restarts % 8) as u8))?;
933                    }
934
935                    self.writer.write_ac_block(
936                        block,
937                        start,
938                        end,
939                        &self.huffman_tables[component.ac_huffman_table as usize].1,
940                    )?;
941
942                    if restart_interval > 0 {
943                        if restarts_to_go == 0 {
944                            restarts_to_go = restart_interval;
945                            restarts += 1;
946                            restarts &= 7;
947                        }
948                        restarts_to_go -= 1;
949                    }
950                }
951
952                self.writer.finalize_bit_buffer()?;
953            }
954        }
955
956        Ok(())
957    }
958
959    fn encode_blocks<I: ImageBuffer, OP: Operations>(
960        &mut self,
961        image: &I,
962        q_tables: &[QuantizationTable; 2],
963    ) -> [Vec<[i16; 64]>; 4] {
964        let width = image.width();
965        let height = image.height();
966
967        let (max_h_sampling, max_v_sampling) = self.get_max_sampling_size();
968
969        let num_cols = ceil_div(usize::from(width), 8 * max_h_sampling) * max_h_sampling;
970        let num_rows = ceil_div(usize::from(height), 8 * max_v_sampling) * max_v_sampling;
971
972        debug_assert!(num_cols > 0);
973        debug_assert!(num_rows > 0);
974
975        let buffer_width = num_cols * 8;
976        let buffer_size = num_cols * num_rows * 64;
977
978        let mut row: [Vec<_>; 4] = self.init_rows(buffer_size);
979
980        for y in 0..num_rows * 8 {
981            let y = (y.min(usize::from(height) - 1)) as u16;
982
983            image.fill_buffers(y, &mut row);
984
985            for _ in usize::from(width)..num_cols * 8 {
986                for channel in &mut row {
987                    if !channel.is_empty() {
988                        channel.push(channel[channel.len() - 1]);
989                    }
990                }
991            }
992        }
993
994        let num_cols = ceil_div(usize::from(width), 8);
995        let num_rows = ceil_div(usize::from(height), 8);
996
997        debug_assert!(num_cols > 0);
998        debug_assert!(num_rows > 0);
999
1000        let mut blocks: [Vec<_>; 4] = self.init_block_buffers(buffer_size / 64);
1001
1002        for (i, component) in self.components.iter().enumerate() {
1003            let h_scale = max_h_sampling / component.horizontal_sampling_factor as usize;
1004            let v_scale = max_v_sampling / component.vertical_sampling_factor as usize;
1005
1006            let cols = ceil_div(num_cols, h_scale);
1007            let rows = ceil_div(num_rows, v_scale);
1008
1009            debug_assert!(cols > 0);
1010            debug_assert!(rows > 0);
1011
1012            for block_y in 0..rows {
1013                for block_x in 0..cols {
1014                    let mut block = get_block(
1015                        &row[i],
1016                        block_x * 8 * h_scale,
1017                        block_y * 8 * v_scale,
1018                        h_scale,
1019                        v_scale,
1020                        buffer_width,
1021                    );
1022
1023                    OP::fdct(&mut block);
1024
1025                    let mut q_block = [0i16; 64];
1026
1027                    OP::quantize_block(
1028                        &block,
1029                        &mut q_block,
1030                        &q_tables[component.quantization_table as usize],
1031                    );
1032
1033                    blocks[i].push(q_block);
1034                }
1035            }
1036        }
1037        blocks
1038    }
1039
1040    fn init_block_buffers(&mut self, buffer_size: usize) -> [Vec<[i16; 64]>; 4] {
1041        // To simplify the code and to give the compiler more infos to optimize stuff we always initialize 4 components
1042        // Resource overhead should be minimal because an empty Vec doesn't allocate
1043
1044        match self.components.len() {
1045            1 => [
1046                Vec::with_capacity(buffer_size),
1047                Vec::new(),
1048                Vec::new(),
1049                Vec::new(),
1050            ],
1051            3 => [
1052                Vec::with_capacity(buffer_size),
1053                Vec::with_capacity(buffer_size),
1054                Vec::with_capacity(buffer_size),
1055                Vec::new(),
1056            ],
1057            4 => [
1058                Vec::with_capacity(buffer_size),
1059                Vec::with_capacity(buffer_size),
1060                Vec::with_capacity(buffer_size),
1061                Vec::with_capacity(buffer_size),
1062            ],
1063            len => unreachable!("Unsupported component length: {}", len),
1064        }
1065    }
1066
1067    // Create new huffman tables optimized for this image
1068    fn optimize_huffman_table(&mut self, blocks: &[Vec<[i16; 64]>; 4]) {
1069        // TODO: Find out if it's possible to reuse some code from the writer
1070
1071        let max_tables = self.components.len().min(2) as u8;
1072
1073        for table in 0..max_tables {
1074            let mut dc_freq = [0u32; 257];
1075            dc_freq[256] = 1;
1076            let mut ac_freq = [0u32; 257];
1077            ac_freq[256] = 1;
1078
1079            let mut had_ac = false;
1080            let mut had_dc = false;
1081
1082            for (i, component) in self.components.iter().enumerate() {
1083                if component.dc_huffman_table == table {
1084                    had_dc = true;
1085
1086                    let mut prev_dc = 0;
1087
1088                    debug_assert!(!blocks[i].is_empty());
1089
1090                    for block in &blocks[i] {
1091                        let value = block[0];
1092                        let diff = value - prev_dc;
1093                        let num_bits = get_num_bits(diff);
1094
1095                        dc_freq[num_bits as usize] += 1;
1096
1097                        prev_dc = value;
1098                    }
1099                }
1100
1101                if component.ac_huffman_table == table {
1102                    had_ac = true;
1103
1104                    if let Some(scans) = self.progressive_scans {
1105                        let scans = scans as usize - 1;
1106
1107                        let values_per_scan = 64 / scans;
1108
1109                        for scan in 0..scans {
1110                            let start = (scan * values_per_scan).max(1);
1111                            let end = if scan == scans - 1 {
1112                                // Due to rounding we might need to transfer more than values_per_scan values in the last scan
1113                                64
1114                            } else {
1115                                (scan + 1) * values_per_scan
1116                            };
1117
1118                            debug_assert!(!blocks[i].is_empty());
1119
1120                            for block in &blocks[i] {
1121                                let mut zero_run = 0;
1122
1123                                for &value in &block[start..end] {
1124                                    if value == 0 {
1125                                        zero_run += 1;
1126                                    } else {
1127                                        while zero_run > 15 {
1128                                            ac_freq[0xF0] += 1;
1129                                            zero_run -= 16;
1130                                        }
1131                                        let num_bits = get_num_bits(value);
1132                                        let symbol = (zero_run << 4) | num_bits;
1133
1134                                        ac_freq[symbol as usize] += 1;
1135
1136                                        zero_run = 0;
1137                                    }
1138                                }
1139
1140                                if zero_run > 0 {
1141                                    ac_freq[0] += 1;
1142                                }
1143                            }
1144                        }
1145                    } else {
1146                        for block in &blocks[i] {
1147                            let mut zero_run = 0;
1148
1149                            for &value in &block[1..] {
1150                                if value == 0 {
1151                                    zero_run += 1;
1152                                } else {
1153                                    while zero_run > 15 {
1154                                        ac_freq[0xF0] += 1;
1155                                        zero_run -= 16;
1156                                    }
1157                                    let num_bits = get_num_bits(value);
1158                                    let symbol = (zero_run << 4) | num_bits;
1159
1160                                    ac_freq[symbol as usize] += 1;
1161
1162                                    zero_run = 0;
1163                                }
1164                            }
1165
1166                            if zero_run > 0 {
1167                                ac_freq[0] += 1;
1168                            }
1169                        }
1170                    }
1171                }
1172            }
1173
1174            assert!(had_dc, "Missing DC data for table {}", table);
1175            assert!(had_ac, "Missing AC data for table {}", table);
1176
1177            self.huffman_tables[table as usize] = (
1178                HuffmanTable::new_optimized(dc_freq),
1179                HuffmanTable::new_optimized(ac_freq),
1180            );
1181        }
1182    }
1183}
1184
1185#[cfg(feature = "std")]
1186impl Encoder<BufWriter<File>> {
1187    /// Create a new decoder that writes into a file
1188    ///
1189    /// See [new](Encoder::new) for further information.
1190    ///
1191    /// # Errors
1192    ///
1193    /// Returns an `IoError(std::io::Error)` if the file can't be created
1194    pub fn new_file<P: AsRef<Path>>(
1195        path: P,
1196        quality: u8,
1197    ) -> Result<Encoder<BufWriter<File>>, EncodingError> {
1198        let file = File::create(path)?;
1199        let buf = BufWriter::new(file);
1200        Ok(Self::new(buf, quality))
1201    }
1202}
1203
1204fn get_block(
1205    data: &[u8],
1206    start_x: usize,
1207    start_y: usize,
1208    col_stride: usize,
1209    row_stride: usize,
1210    width: usize,
1211) -> [i16; 64] {
1212    let mut block = [0i16; 64];
1213
1214    for y in 0..8 {
1215        for x in 0..8 {
1216            let ix = start_x + (x * col_stride);
1217            let iy = start_y + (y * row_stride);
1218
1219            block[y * 8 + x] = (data[iy * width + ix] as i16) - 128;
1220        }
1221    }
1222
1223    block
1224}
1225
1226fn ceil_div(value: usize, div: usize) -> usize {
1227    value / div + usize::from(value % div != 0)
1228}
1229
1230fn get_num_bits(mut value: i16) -> u8 {
1231    if value < 0 {
1232        value = -value;
1233    }
1234
1235    let mut num_bits = 0;
1236
1237    while value > 0 {
1238        num_bits += 1;
1239        value >>= 1;
1240    }
1241
1242    num_bits
1243}
1244
1245pub(crate) trait Operations {
1246    #[inline(always)]
1247    fn fdct(data: &mut [i16; 64]) {
1248        fdct(data);
1249    }
1250
1251    #[inline(always)]
1252    fn quantize_block(block: &[i16; 64], q_block: &mut [i16; 64], table: &QuantizationTable) {
1253        for i in 0..64 {
1254            let z = ZIGZAG[i] as usize & 0x3f;
1255            q_block[i] = table.quantize(block[z], z);
1256        }
1257    }
1258}
1259
1260pub(crate) struct DefaultOperations;
1261
1262impl Operations for DefaultOperations {}
1263
1264#[cfg(test)]
1265mod tests {
1266    use alloc::vec;
1267
1268    use crate::encoder::get_num_bits;
1269    use crate::writer::get_code;
1270    use crate::{Encoder, SamplingFactor};
1271
1272    #[test]
1273    fn test_get_num_bits() {
1274        let min_max = 2i16.pow(13);
1275
1276        for value in -min_max..=min_max {
1277            let num_bits1 = get_num_bits(value);
1278            let (num_bits2, _) = get_code(value);
1279
1280            assert_eq!(
1281                num_bits1, num_bits2,
1282                "Difference in num bits for value {}: {} vs {}",
1283                value, num_bits1, num_bits2
1284            );
1285        }
1286    }
1287
1288    #[test]
1289    fn sampling_factors() {
1290        assert_eq!(SamplingFactor::F_1_1.get_sampling_factors(), (1, 1));
1291        assert_eq!(SamplingFactor::F_2_1.get_sampling_factors(), (2, 1));
1292        assert_eq!(SamplingFactor::F_1_2.get_sampling_factors(), (1, 2));
1293        assert_eq!(SamplingFactor::F_2_2.get_sampling_factors(), (2, 2));
1294        assert_eq!(SamplingFactor::F_4_1.get_sampling_factors(), (4, 1));
1295        assert_eq!(SamplingFactor::F_4_2.get_sampling_factors(), (4, 2));
1296        assert_eq!(SamplingFactor::F_1_4.get_sampling_factors(), (1, 4));
1297        assert_eq!(SamplingFactor::F_2_4.get_sampling_factors(), (2, 4));
1298
1299        assert_eq!(SamplingFactor::R_4_4_4.get_sampling_factors(), (1, 1));
1300        assert_eq!(SamplingFactor::R_4_4_0.get_sampling_factors(), (1, 2));
1301        assert_eq!(SamplingFactor::R_4_4_1.get_sampling_factors(), (1, 4));
1302        assert_eq!(SamplingFactor::R_4_2_2.get_sampling_factors(), (2, 1));
1303        assert_eq!(SamplingFactor::R_4_2_0.get_sampling_factors(), (2, 2));
1304        assert_eq!(SamplingFactor::R_4_2_1.get_sampling_factors(), (2, 4));
1305        assert_eq!(SamplingFactor::R_4_1_1.get_sampling_factors(), (4, 1));
1306        assert_eq!(SamplingFactor::R_4_1_0.get_sampling_factors(), (4, 2));
1307    }
1308
1309    #[test]
1310    fn test_set_progressive() {
1311        let mut encoder = Encoder::new(vec![], 100);
1312        encoder.set_progressive(true);
1313        assert_eq!(encoder.progressive_scans(), Some(4));
1314
1315        encoder.set_progressive(false);
1316        assert_eq!(encoder.progressive_scans(), None);
1317    }
1318}