1use crate::fdct::fdct;
2use crate::huffman::{CodingClass, HuffmanTable};
3use crate::image_buffer::*;
4use crate::marker::Marker;
5use crate::quantization::{QuantizationTable, QuantizationTableType};
6use crate::writer::{JfifWrite, JfifWriter, ZIGZAG};
7use crate::{EncodingError, PixelDensity};
8
9use alloc::vec;
10use alloc::vec::Vec;
11
12#[cfg(feature = "std")]
13use std::io::BufWriter;
14
15#[cfg(feature = "std")]
16use std::fs::File;
17
18#[cfg(feature = "std")]
19use std::path::Path;
20
21#[derive(Copy, Clone, Debug, Eq, PartialEq)]
23pub enum JpegColorType {
24 Luma,
26
27 Ycbcr,
29
30 Cmyk,
32
33 Ycck,
35}
36
37impl JpegColorType {
38 pub(crate) fn get_num_components(self) -> usize {
39 use JpegColorType::*;
40
41 match self {
42 Luma => 1,
43 Ycbcr => 3,
44 Cmyk | Ycck => 4,
45 }
46 }
47}
48
49#[derive(Copy, Clone, Debug, Eq, PartialEq)]
54pub enum ColorType {
55 Luma,
57
58 Rgb,
60
61 Rgba,
63
64 Bgr,
66
67 Bgra,
69
70 Ycbcr,
72
73 Cmyk,
75
76 CmykAsYcck,
78
79 Ycck,
81}
82
83impl ColorType {
84 pub(crate) fn get_bytes_per_pixel(self) -> usize {
85 use ColorType::*;
86
87 match self {
88 Luma => 1,
89 Rgb | Bgr | Ycbcr => 3,
90 Rgba | Bgra | Cmyk | CmykAsYcck | Ycck => 4,
91 }
92 }
93}
94
95#[repr(u8)]
96#[derive(Copy, Clone, Debug, Eq, PartialEq)]
97#[allow(non_camel_case_types)]
102pub enum SamplingFactor {
103 F_1_1 = 1 << 4 | 1,
104 F_2_1 = 2 << 4 | 1,
105 F_1_2 = 1 << 4 | 2,
106 F_2_2 = 2 << 4 | 2,
107 F_4_1 = 4 << 4 | 1,
108 F_4_2 = 4 << 4 | 2,
109 F_1_4 = 1 << 4 | 4,
110 F_2_4 = 2 << 4 | 4,
111
112 R_4_4_4 = 0x80 | 1 << 4 | 1,
114
115 R_4_4_0 = 0x80 | 1 << 4 | 2,
117
118 R_4_4_1 = 0x80 | 1 << 4 | 4,
120
121 R_4_2_2 = 0x80 | 2 << 4 | 1,
123
124 R_4_2_0 = 0x80 | 2 << 4 | 2,
126
127 R_4_2_1 = 0x80 | 2 << 4 | 4,
129
130 R_4_1_1 = 0x80 | 4 << 4 | 1,
132
133 R_4_1_0 = 0x80 | 4 << 4 | 2,
135}
136
137impl SamplingFactor {
138 pub fn from_factors(horizontal: u8, vertical: u8) -> Option<SamplingFactor> {
140 use SamplingFactor::*;
141
142 match (horizontal, vertical) {
143 (1, 1) => Some(F_1_1),
144 (1, 2) => Some(F_1_2),
145 (1, 4) => Some(F_1_4),
146 (2, 1) => Some(F_2_1),
147 (2, 2) => Some(F_2_2),
148 (2, 4) => Some(F_2_4),
149 (4, 1) => Some(F_4_1),
150 (4, 2) => Some(F_4_2),
151 _ => None,
152 }
153 }
154
155 pub(crate) fn get_sampling_factors(self) -> (u8, u8) {
156 let value = self as u8;
157 ((value >> 4) & 0x07, value & 0xf)
158 }
159
160 pub(crate) fn supports_interleaved(self) -> bool {
161 use SamplingFactor::*;
162
163 matches!(
166 self,
167 F_1_1 | F_2_1 | F_1_2 | F_2_2 | R_4_4_4 | R_4_4_0 | R_4_2_2 | R_4_2_0
168 )
169 }
170}
171
172pub(crate) struct Component {
173 pub id: u8,
174 pub quantization_table: u8,
175 pub dc_huffman_table: u8,
176 pub ac_huffman_table: u8,
177 pub horizontal_sampling_factor: u8,
178 pub vertical_sampling_factor: u8,
179}
180
181macro_rules! add_component {
182 ($components:expr, $id:expr, $dest:expr, $h_sample:expr, $v_sample:expr) => {
183 $components.push(Component {
184 id: $id,
185 quantization_table: $dest,
186 dc_huffman_table: $dest,
187 ac_huffman_table: $dest,
188 horizontal_sampling_factor: $h_sample,
189 vertical_sampling_factor: $v_sample,
190 });
191 };
192}
193
194pub struct Encoder<W: JfifWrite> {
196 writer: JfifWriter<W>,
197 density: PixelDensity,
198 quality: u8,
199
200 components: Vec<Component>,
201 quantization_tables: [QuantizationTableType; 2],
202 huffman_tables: [(HuffmanTable, HuffmanTable); 2],
203
204 sampling_factor: SamplingFactor,
205
206 progressive_scans: Option<u8>,
207
208 restart_interval: Option<u16>,
209
210 optimize_huffman_table: bool,
211
212 app_segments: Vec<(u8, Vec<u8>)>,
213}
214
215impl<W: JfifWrite> Encoder<W> {
216 pub fn new(w: W, quality: u8) -> Encoder<W> {
222 let huffman_tables = [
223 (
224 HuffmanTable::default_luma_dc(),
225 HuffmanTable::default_luma_ac(),
226 ),
227 (
228 HuffmanTable::default_chroma_dc(),
229 HuffmanTable::default_chroma_ac(),
230 ),
231 ];
232
233 let quantization_tables = [
234 QuantizationTableType::Default,
235 QuantizationTableType::Default,
236 ];
237
238 let sampling_factor = if quality < 90 {
239 SamplingFactor::F_2_2
240 } else {
241 SamplingFactor::F_1_1
242 };
243
244 Encoder {
245 writer: JfifWriter::new(w),
246 density: PixelDensity::default(),
247 quality,
248 components: vec![],
249 quantization_tables,
250 huffman_tables,
251 sampling_factor,
252 progressive_scans: None,
253 restart_interval: None,
254 optimize_huffman_table: false,
255 app_segments: Vec::new(),
256 }
257 }
258
259 pub fn set_density(&mut self, density: PixelDensity) {
263 self.density = density;
264 }
265
266 pub fn density(&self) -> PixelDensity {
268 self.density
269 }
270
271 pub fn set_sampling_factor(&mut self, sampling: SamplingFactor) {
273 self.sampling_factor = sampling;
274 }
275
276 pub fn sampling_factor(&self) -> SamplingFactor {
278 self.sampling_factor
279 }
280
281 pub fn set_quantization_tables(
283 &mut self,
284 luma: QuantizationTableType,
285 chroma: QuantizationTableType,
286 ) {
287 self.quantization_tables = [luma, chroma];
288 }
289
290 pub fn quantization_tables(&self) -> &[QuantizationTableType; 2] {
292 &self.quantization_tables
293 }
294
295 pub fn set_progressive(&mut self, progressive: bool) {
300 self.progressive_scans = if progressive { Some(4) } else { None };
301 }
302
303 pub fn set_progressive_scans(&mut self, scans: u8) {
311 assert!(
312 (2..=64).contains(&scans),
313 "Invalid number of scans: {}",
314 scans
315 );
316 self.progressive_scans = Some(scans);
317 }
318
319 pub fn progressive_scans(&self) -> Option<u8> {
321 self.progressive_scans
322 }
323
324 pub fn set_restart_interval(&mut self, interval: u16) {
328 self.restart_interval = if interval == 0 { None } else { Some(interval) };
329 }
330
331 pub fn restart_interval(&self) -> Option<u16> {
333 self.restart_interval
334 }
335
336 pub fn set_optimized_huffman_tables(&mut self, optimize_huffman_table: bool) {
340 self.optimize_huffman_table = optimize_huffman_table;
341 }
342
343 pub fn optimized_huffman_tables(&self) -> bool {
345 self.optimize_huffman_table
346 }
347
348 pub fn add_app_segment(&mut self, segment_nr: u8, data: Vec<u8>) -> Result<(), EncodingError> {
357 if segment_nr == 0 || segment_nr > 15 {
358 Err(EncodingError::InvalidAppSegment(segment_nr))
359 } else if data.len() > 65533 {
360 Err(EncodingError::AppSegmentTooLarge(data.len()))
361 } else {
362 self.app_segments.push((segment_nr, data));
363 Ok(())
364 }
365 }
366
367 pub fn add_icc_profile(&mut self, data: &[u8]) -> Result<(), EncodingError> {
375 const MARKER: &[u8; 12] = b"ICC_PROFILE\0";
379 const MAX_CHUNK_LENGTH: usize = 65535 - 2 - 12 - 2;
380
381 let num_chunks = ceil_div(data.len(), MAX_CHUNK_LENGTH);
382
383 if num_chunks >= 255 {
385 return Err(EncodingError::IccTooLarge(data.len()));
386 }
387
388 for (i, data) in data.chunks(MAX_CHUNK_LENGTH).enumerate() {
389 let mut chunk_data = Vec::with_capacity(MAX_CHUNK_LENGTH);
390 chunk_data.extend_from_slice(MARKER);
391 chunk_data.push(i as u8 + 1);
392 chunk_data.push(num_chunks as u8);
393 chunk_data.extend_from_slice(data);
394
395 self.add_app_segment(2, chunk_data)?;
396 }
397
398 Ok(())
399 }
400
401 pub fn add_exif_metadata(&mut self, data: &[u8]) -> Result<(), EncodingError> {
409 const EXIF_HEADER: [u8; 6] = [0x45, 0x78, 0x69, 0x66, 0x00, 0x00];
412
413 let mut formatted = EXIF_HEADER.to_vec();
414 formatted.extend_from_slice(data);
415
416 self.add_app_segment(1, formatted)
417 }
418
419 pub fn encode(
423 self,
424 data: &[u8],
425 width: u16,
426 height: u16,
427 color_type: ColorType,
428 ) -> Result<(), EncodingError> {
429 let required_data_len = width as usize * height as usize * color_type.get_bytes_per_pixel();
430
431 if data.len() < required_data_len {
432 return Err(EncodingError::BadImageData {
433 length: data.len(),
434 required: required_data_len,
435 });
436 }
437
438 #[cfg(all(feature = "simd", any(target_arch = "x86", target_arch = "x86_64")))]
439 {
440 if std::is_x86_feature_detected!("avx2") {
441 use crate::avx2::*;
442
443 return match color_type {
444 ColorType::Luma => self
445 .encode_image_internal::<_, AVX2Operations>(GrayImage(data, width, height)),
446 ColorType::Rgb => self.encode_image_internal::<_, AVX2Operations>(
447 RgbImageAVX2(data, width, height),
448 ),
449 ColorType::Rgba => self.encode_image_internal::<_, AVX2Operations>(
450 RgbaImageAVX2(data, width, height),
451 ),
452 ColorType::Bgr => self.encode_image_internal::<_, AVX2Operations>(
453 BgrImageAVX2(data, width, height),
454 ),
455 ColorType::Bgra => self.encode_image_internal::<_, AVX2Operations>(
456 BgraImageAVX2(data, width, height),
457 ),
458 ColorType::Ycbcr => self.encode_image_internal::<_, AVX2Operations>(
459 YCbCrImage(data, width, height),
460 ),
461 ColorType::Cmyk => self
462 .encode_image_internal::<_, AVX2Operations>(CmykImage(data, width, height)),
463 ColorType::CmykAsYcck => self.encode_image_internal::<_, AVX2Operations>(
464 CmykAsYcckImage(data, width, height),
465 ),
466 ColorType::Ycck => self
467 .encode_image_internal::<_, AVX2Operations>(YcckImage(data, width, height)),
468 };
469 }
470 }
471
472 match color_type {
473 ColorType::Luma => self.encode_image(GrayImage(data, width, height))?,
474 ColorType::Rgb => self.encode_image(RgbImage(data, width, height))?,
475 ColorType::Rgba => self.encode_image(RgbaImage(data, width, height))?,
476 ColorType::Bgr => self.encode_image(BgrImage(data, width, height))?,
477 ColorType::Bgra => self.encode_image(BgraImage(data, width, height))?,
478 ColorType::Ycbcr => self.encode_image(YCbCrImage(data, width, height))?,
479 ColorType::Cmyk => self.encode_image(CmykImage(data, width, height))?,
480 ColorType::CmykAsYcck => self.encode_image(CmykAsYcckImage(data, width, height))?,
481 ColorType::Ycck => self.encode_image(YcckImage(data, width, height))?,
482 }
483
484 Ok(())
485 }
486
487 pub fn encode_image<I: ImageBuffer>(self, image: I) -> Result<(), EncodingError> {
489 #[cfg(all(feature = "simd", any(target_arch = "x86", target_arch = "x86_64")))]
490 {
491 if std::is_x86_feature_detected!("avx2") {
492 use crate::avx2::*;
493 return self.encode_image_internal::<_, AVX2Operations>(image);
494 }
495 }
496 self.encode_image_internal::<_, DefaultOperations>(image)
497 }
498
499 fn encode_image_internal<I: ImageBuffer, OP: Operations>(
500 mut self,
501 image: I,
502 ) -> Result<(), EncodingError> {
503 if image.width() == 0 || image.height() == 0 {
504 return Err(EncodingError::ZeroImageDimensions {
505 width: image.width(),
506 height: image.height(),
507 });
508 }
509
510 let q_tables = [
511 QuantizationTable::new_with_quality(&self.quantization_tables[0], self.quality, true),
512 QuantizationTable::new_with_quality(&self.quantization_tables[1], self.quality, false),
513 ];
514
515 let jpeg_color_type = image.get_jpeg_color_type();
516 self.init_components(jpeg_color_type);
517
518 self.writer.write_marker(Marker::SOI)?;
519
520 self.writer.write_header(&self.density)?;
521
522 if jpeg_color_type == JpegColorType::Cmyk {
523 let app_14 = b"Adobe\0\0\0\0\0\0\0";
525 self.writer
526 .write_segment(Marker::APP(14), app_14.as_ref())?;
527 } else if jpeg_color_type == JpegColorType::Ycck {
528 let app_14 = b"Adobe\0\0\0\0\0\0\x02";
530 self.writer
531 .write_segment(Marker::APP(14), app_14.as_ref())?;
532 }
533
534 for (nr, data) in &self.app_segments {
535 self.writer.write_segment(Marker::APP(*nr), data)?;
536 }
537
538 if let Some(scans) = self.progressive_scans {
539 self.encode_image_progressive::<_, OP>(image, scans, &q_tables)?;
540 } else if self.optimize_huffman_table || !self.sampling_factor.supports_interleaved() {
541 self.encode_image_sequential::<_, OP>(image, &q_tables)?;
542 } else {
543 self.encode_image_interleaved::<_, OP>(image, &q_tables)?;
544 }
545
546 self.writer.write_marker(Marker::EOI)?;
547
548 Ok(())
549 }
550
551 fn init_components(&mut self, color: JpegColorType) {
552 let (horizontal_sampling_factor, vertical_sampling_factor) =
553 self.sampling_factor.get_sampling_factors();
554
555 match color {
556 JpegColorType::Luma => {
557 add_component!(self.components, 0, 0, 1, 1);
558 }
559 JpegColorType::Ycbcr => {
560 add_component!(
561 self.components,
562 0,
563 0,
564 horizontal_sampling_factor,
565 vertical_sampling_factor
566 );
567 add_component!(self.components, 1, 1, 1, 1);
568 add_component!(self.components, 2, 1, 1, 1);
569 }
570 JpegColorType::Cmyk => {
571 add_component!(self.components, 0, 1, 1, 1);
572 add_component!(self.components, 1, 1, 1, 1);
573 add_component!(self.components, 2, 1, 1, 1);
574 add_component!(
575 self.components,
576 3,
577 0,
578 horizontal_sampling_factor,
579 vertical_sampling_factor
580 );
581 }
582 JpegColorType::Ycck => {
583 add_component!(
584 self.components,
585 0,
586 0,
587 horizontal_sampling_factor,
588 vertical_sampling_factor
589 );
590 add_component!(self.components, 1, 1, 1, 1);
591 add_component!(self.components, 2, 1, 1, 1);
592 add_component!(
593 self.components,
594 3,
595 0,
596 horizontal_sampling_factor,
597 vertical_sampling_factor
598 );
599 }
600 }
601 }
602
603 fn get_max_sampling_size(&self) -> (usize, usize) {
604 let max_h_sampling = self.components.iter().fold(1, |value, component| {
605 value.max(component.horizontal_sampling_factor)
606 });
607
608 let max_v_sampling = self.components.iter().fold(1, |value, component| {
609 value.max(component.vertical_sampling_factor)
610 });
611
612 (usize::from(max_h_sampling), usize::from(max_v_sampling))
613 }
614
615 fn write_frame_header<I: ImageBuffer>(
616 &mut self,
617 image: &I,
618 q_tables: &[QuantizationTable; 2],
619 ) -> Result<(), EncodingError> {
620 self.writer.write_frame_header(
621 image.width(),
622 image.height(),
623 &self.components,
624 self.progressive_scans.is_some(),
625 )?;
626
627 self.writer.write_quantization_segment(0, &q_tables[0])?;
628 self.writer.write_quantization_segment(1, &q_tables[1])?;
629
630 self.writer
631 .write_huffman_segment(CodingClass::Dc, 0, &self.huffman_tables[0].0)?;
632
633 self.writer
634 .write_huffman_segment(CodingClass::Ac, 0, &self.huffman_tables[0].1)?;
635
636 if image.get_jpeg_color_type().get_num_components() >= 3 {
637 self.writer
638 .write_huffman_segment(CodingClass::Dc, 1, &self.huffman_tables[1].0)?;
639
640 self.writer
641 .write_huffman_segment(CodingClass::Ac, 1, &self.huffman_tables[1].1)?;
642 }
643
644 if let Some(restart_interval) = self.restart_interval {
645 self.writer.write_dri(restart_interval)?;
646 }
647
648 Ok(())
649 }
650
651 fn init_rows(&mut self, buffer_size: usize) -> [Vec<u8>; 4] {
652 match self.components.len() {
656 1 => [
657 Vec::with_capacity(buffer_size),
658 Vec::new(),
659 Vec::new(),
660 Vec::new(),
661 ],
662 3 => [
663 Vec::with_capacity(buffer_size),
664 Vec::with_capacity(buffer_size),
665 Vec::with_capacity(buffer_size),
666 Vec::new(),
667 ],
668 4 => [
669 Vec::with_capacity(buffer_size),
670 Vec::with_capacity(buffer_size),
671 Vec::with_capacity(buffer_size),
672 Vec::with_capacity(buffer_size),
673 ],
674 len => unreachable!("Unsupported component length: {}", len),
675 }
676 }
677
678 fn encode_image_interleaved<I: ImageBuffer, OP: Operations>(
682 &mut self,
683 image: I,
684 q_tables: &[QuantizationTable; 2],
685 ) -> Result<(), EncodingError> {
686 self.write_frame_header(&image, q_tables)?;
687 self.writer
688 .write_scan_header(&self.components.iter().collect::<Vec<_>>(), None)?;
689
690 let (max_h_sampling, max_v_sampling) = self.get_max_sampling_size();
691
692 let width = image.width();
693 let height = image.height();
694
695 let num_cols = ceil_div(usize::from(width), 8 * max_h_sampling);
696 let num_rows = ceil_div(usize::from(height), 8 * max_v_sampling);
697
698 let buffer_width = num_cols * 8 * max_h_sampling;
699 let buffer_size = buffer_width * 8 * max_v_sampling;
700
701 let mut row: [Vec<_>; 4] = self.init_rows(buffer_size);
702
703 let mut prev_dc = [0i16; 4];
704
705 let restart_interval = self.restart_interval.unwrap_or(0);
706 let mut restarts = 0;
707 let mut restarts_to_go = restart_interval;
708
709 for block_y in 0..num_rows {
710 for r in &mut row {
711 r.clear();
712 }
713
714 for y in 0..(8 * max_v_sampling) {
715 let y = y + block_y * 8 * max_v_sampling;
716 let y = (y.min(height as usize - 1)) as u16;
717
718 image.fill_buffers(y, &mut row);
719
720 for _ in usize::from(width)..buffer_width {
721 for channel in &mut row {
722 if !channel.is_empty() {
723 channel.push(channel[channel.len() - 1]);
724 }
725 }
726 }
727 }
728
729 for block_x in 0..num_cols {
730 if restart_interval > 0 && restarts_to_go == 0 {
731 self.writer.finalize_bit_buffer()?;
732 self.writer
733 .write_marker(Marker::RST((restarts % 8) as u8))?;
734
735 prev_dc[0] = 0;
736 prev_dc[1] = 0;
737 prev_dc[2] = 0;
738 prev_dc[3] = 0;
739 }
740
741 for (i, component) in self.components.iter().enumerate() {
742 for v_offset in 0..component.vertical_sampling_factor as usize {
743 for h_offset in 0..component.horizontal_sampling_factor as usize {
744 let mut block = get_block(
745 &row[i],
746 block_x * 8 * max_h_sampling + (h_offset * 8),
747 v_offset * 8,
748 max_h_sampling / component.horizontal_sampling_factor as usize,
749 max_v_sampling / component.vertical_sampling_factor as usize,
750 buffer_width,
751 );
752
753 OP::fdct(&mut block);
754
755 let mut q_block = [0i16; 64];
756
757 OP::quantize_block(
758 &block,
759 &mut q_block,
760 &q_tables[component.quantization_table as usize],
761 );
762
763 self.writer.write_block(
764 &q_block,
765 prev_dc[i],
766 &self.huffman_tables[component.dc_huffman_table as usize].0,
767 &self.huffman_tables[component.ac_huffman_table as usize].1,
768 )?;
769
770 prev_dc[i] = q_block[0];
771 }
772 }
773 }
774
775 if restart_interval > 0 {
776 if restarts_to_go == 0 {
777 restarts_to_go = restart_interval;
778 restarts += 1;
779 restarts &= 7;
780 }
781 restarts_to_go -= 1;
782 }
783 }
784 }
785
786 self.writer.finalize_bit_buffer()?;
787
788 Ok(())
789 }
790
791 fn encode_image_sequential<I: ImageBuffer, OP: Operations>(
793 &mut self,
794 image: I,
795 q_tables: &[QuantizationTable; 2],
796 ) -> Result<(), EncodingError> {
797 let blocks = self.encode_blocks::<_, OP>(&image, q_tables);
798
799 if self.optimize_huffman_table {
800 self.optimize_huffman_table(&blocks);
801 }
802
803 self.write_frame_header(&image, q_tables)?;
804
805 for (i, component) in self.components.iter().enumerate() {
806 let restart_interval = self.restart_interval.unwrap_or(0);
807 let mut restarts = 0;
808 let mut restarts_to_go = restart_interval;
809
810 self.writer.write_scan_header(&[component], None)?;
811
812 let mut prev_dc = 0;
813
814 for block in &blocks[i] {
815 if restart_interval > 0 && restarts_to_go == 0 {
816 self.writer.finalize_bit_buffer()?;
817 self.writer
818 .write_marker(Marker::RST((restarts % 8) as u8))?;
819
820 prev_dc = 0;
821 }
822
823 self.writer.write_block(
824 block,
825 prev_dc,
826 &self.huffman_tables[component.dc_huffman_table as usize].0,
827 &self.huffman_tables[component.ac_huffman_table as usize].1,
828 )?;
829
830 prev_dc = block[0];
831
832 if restart_interval > 0 {
833 if restarts_to_go == 0 {
834 restarts_to_go = restart_interval;
835 restarts += 1;
836 restarts &= 7;
837 }
838 restarts_to_go -= 1;
839 }
840 }
841
842 self.writer.finalize_bit_buffer()?;
843 }
844
845 Ok(())
846 }
847
848 fn encode_image_progressive<I: ImageBuffer, OP: Operations>(
852 &mut self,
853 image: I,
854 scans: u8,
855 q_tables: &[QuantizationTable; 2],
856 ) -> Result<(), EncodingError> {
857 let blocks = self.encode_blocks::<_, OP>(&image, q_tables);
858
859 if self.optimize_huffman_table {
860 self.optimize_huffman_table(&blocks);
861 }
862
863 self.write_frame_header(&image, q_tables)?;
864
865 for (i, component) in self.components.iter().enumerate() {
868 self.writer.write_scan_header(&[component], Some((0, 0)))?;
869
870 let restart_interval = self.restart_interval.unwrap_or(0);
871 let mut restarts = 0;
872 let mut restarts_to_go = restart_interval;
873
874 let mut prev_dc = 0;
875
876 for block in &blocks[i] {
877 if restart_interval > 0 && restarts_to_go == 0 {
878 self.writer.finalize_bit_buffer()?;
879 self.writer
880 .write_marker(Marker::RST((restarts % 8) as u8))?;
881
882 prev_dc = 0;
883 }
884
885 self.writer.write_dc(
886 block[0],
887 prev_dc,
888 &self.huffman_tables[component.dc_huffman_table as usize].0,
889 )?;
890
891 prev_dc = block[0];
892
893 if restart_interval > 0 {
894 if restarts_to_go == 0 {
895 restarts_to_go = restart_interval;
896 restarts += 1;
897 restarts &= 7;
898 }
899 restarts_to_go -= 1;
900 }
901 }
902
903 self.writer.finalize_bit_buffer()?;
904 }
905
906 let scans = scans as usize - 1;
908
909 let values_per_scan = 64 / scans;
910
911 for scan in 0..scans {
912 let start = (scan * values_per_scan).max(1);
913 let end = if scan == scans - 1 {
914 64
916 } else {
917 (scan + 1) * values_per_scan
918 };
919
920 for (i, component) in self.components.iter().enumerate() {
921 let restart_interval = self.restart_interval.unwrap_or(0);
922 let mut restarts = 0;
923 let mut restarts_to_go = restart_interval;
924
925 self.writer
926 .write_scan_header(&[component], Some((start as u8, end as u8 - 1)))?;
927
928 for block in &blocks[i] {
929 if restart_interval > 0 && restarts_to_go == 0 {
930 self.writer.finalize_bit_buffer()?;
931 self.writer
932 .write_marker(Marker::RST((restarts % 8) as u8))?;
933 }
934
935 self.writer.write_ac_block(
936 block,
937 start,
938 end,
939 &self.huffman_tables[component.ac_huffman_table as usize].1,
940 )?;
941
942 if restart_interval > 0 {
943 if restarts_to_go == 0 {
944 restarts_to_go = restart_interval;
945 restarts += 1;
946 restarts &= 7;
947 }
948 restarts_to_go -= 1;
949 }
950 }
951
952 self.writer.finalize_bit_buffer()?;
953 }
954 }
955
956 Ok(())
957 }
958
959 fn encode_blocks<I: ImageBuffer, OP: Operations>(
960 &mut self,
961 image: &I,
962 q_tables: &[QuantizationTable; 2],
963 ) -> [Vec<[i16; 64]>; 4] {
964 let width = image.width();
965 let height = image.height();
966
967 let (max_h_sampling, max_v_sampling) = self.get_max_sampling_size();
968
969 let num_cols = ceil_div(usize::from(width), 8 * max_h_sampling) * max_h_sampling;
970 let num_rows = ceil_div(usize::from(height), 8 * max_v_sampling) * max_v_sampling;
971
972 debug_assert!(num_cols > 0);
973 debug_assert!(num_rows > 0);
974
975 let buffer_width = num_cols * 8;
976 let buffer_size = num_cols * num_rows * 64;
977
978 let mut row: [Vec<_>; 4] = self.init_rows(buffer_size);
979
980 for y in 0..num_rows * 8 {
981 let y = (y.min(usize::from(height) - 1)) as u16;
982
983 image.fill_buffers(y, &mut row);
984
985 for _ in usize::from(width)..num_cols * 8 {
986 for channel in &mut row {
987 if !channel.is_empty() {
988 channel.push(channel[channel.len() - 1]);
989 }
990 }
991 }
992 }
993
994 let num_cols = ceil_div(usize::from(width), 8);
995 let num_rows = ceil_div(usize::from(height), 8);
996
997 debug_assert!(num_cols > 0);
998 debug_assert!(num_rows > 0);
999
1000 let mut blocks: [Vec<_>; 4] = self.init_block_buffers(buffer_size / 64);
1001
1002 for (i, component) in self.components.iter().enumerate() {
1003 let h_scale = max_h_sampling / component.horizontal_sampling_factor as usize;
1004 let v_scale = max_v_sampling / component.vertical_sampling_factor as usize;
1005
1006 let cols = ceil_div(num_cols, h_scale);
1007 let rows = ceil_div(num_rows, v_scale);
1008
1009 debug_assert!(cols > 0);
1010 debug_assert!(rows > 0);
1011
1012 for block_y in 0..rows {
1013 for block_x in 0..cols {
1014 let mut block = get_block(
1015 &row[i],
1016 block_x * 8 * h_scale,
1017 block_y * 8 * v_scale,
1018 h_scale,
1019 v_scale,
1020 buffer_width,
1021 );
1022
1023 OP::fdct(&mut block);
1024
1025 let mut q_block = [0i16; 64];
1026
1027 OP::quantize_block(
1028 &block,
1029 &mut q_block,
1030 &q_tables[component.quantization_table as usize],
1031 );
1032
1033 blocks[i].push(q_block);
1034 }
1035 }
1036 }
1037 blocks
1038 }
1039
1040 fn init_block_buffers(&mut self, buffer_size: usize) -> [Vec<[i16; 64]>; 4] {
1041 match self.components.len() {
1045 1 => [
1046 Vec::with_capacity(buffer_size),
1047 Vec::new(),
1048 Vec::new(),
1049 Vec::new(),
1050 ],
1051 3 => [
1052 Vec::with_capacity(buffer_size),
1053 Vec::with_capacity(buffer_size),
1054 Vec::with_capacity(buffer_size),
1055 Vec::new(),
1056 ],
1057 4 => [
1058 Vec::with_capacity(buffer_size),
1059 Vec::with_capacity(buffer_size),
1060 Vec::with_capacity(buffer_size),
1061 Vec::with_capacity(buffer_size),
1062 ],
1063 len => unreachable!("Unsupported component length: {}", len),
1064 }
1065 }
1066
1067 fn optimize_huffman_table(&mut self, blocks: &[Vec<[i16; 64]>; 4]) {
1069 let max_tables = self.components.len().min(2) as u8;
1072
1073 for table in 0..max_tables {
1074 let mut dc_freq = [0u32; 257];
1075 dc_freq[256] = 1;
1076 let mut ac_freq = [0u32; 257];
1077 ac_freq[256] = 1;
1078
1079 let mut had_ac = false;
1080 let mut had_dc = false;
1081
1082 for (i, component) in self.components.iter().enumerate() {
1083 if component.dc_huffman_table == table {
1084 had_dc = true;
1085
1086 let mut prev_dc = 0;
1087
1088 debug_assert!(!blocks[i].is_empty());
1089
1090 for block in &blocks[i] {
1091 let value = block[0];
1092 let diff = value - prev_dc;
1093 let num_bits = get_num_bits(diff);
1094
1095 dc_freq[num_bits as usize] += 1;
1096
1097 prev_dc = value;
1098 }
1099 }
1100
1101 if component.ac_huffman_table == table {
1102 had_ac = true;
1103
1104 if let Some(scans) = self.progressive_scans {
1105 let scans = scans as usize - 1;
1106
1107 let values_per_scan = 64 / scans;
1108
1109 for scan in 0..scans {
1110 let start = (scan * values_per_scan).max(1);
1111 let end = if scan == scans - 1 {
1112 64
1114 } else {
1115 (scan + 1) * values_per_scan
1116 };
1117
1118 debug_assert!(!blocks[i].is_empty());
1119
1120 for block in &blocks[i] {
1121 let mut zero_run = 0;
1122
1123 for &value in &block[start..end] {
1124 if value == 0 {
1125 zero_run += 1;
1126 } else {
1127 while zero_run > 15 {
1128 ac_freq[0xF0] += 1;
1129 zero_run -= 16;
1130 }
1131 let num_bits = get_num_bits(value);
1132 let symbol = (zero_run << 4) | num_bits;
1133
1134 ac_freq[symbol as usize] += 1;
1135
1136 zero_run = 0;
1137 }
1138 }
1139
1140 if zero_run > 0 {
1141 ac_freq[0] += 1;
1142 }
1143 }
1144 }
1145 } else {
1146 for block in &blocks[i] {
1147 let mut zero_run = 0;
1148
1149 for &value in &block[1..] {
1150 if value == 0 {
1151 zero_run += 1;
1152 } else {
1153 while zero_run > 15 {
1154 ac_freq[0xF0] += 1;
1155 zero_run -= 16;
1156 }
1157 let num_bits = get_num_bits(value);
1158 let symbol = (zero_run << 4) | num_bits;
1159
1160 ac_freq[symbol as usize] += 1;
1161
1162 zero_run = 0;
1163 }
1164 }
1165
1166 if zero_run > 0 {
1167 ac_freq[0] += 1;
1168 }
1169 }
1170 }
1171 }
1172 }
1173
1174 assert!(had_dc, "Missing DC data for table {}", table);
1175 assert!(had_ac, "Missing AC data for table {}", table);
1176
1177 self.huffman_tables[table as usize] = (
1178 HuffmanTable::new_optimized(dc_freq),
1179 HuffmanTable::new_optimized(ac_freq),
1180 );
1181 }
1182 }
1183}
1184
1185#[cfg(feature = "std")]
1186impl Encoder<BufWriter<File>> {
1187 pub fn new_file<P: AsRef<Path>>(
1195 path: P,
1196 quality: u8,
1197 ) -> Result<Encoder<BufWriter<File>>, EncodingError> {
1198 let file = File::create(path)?;
1199 let buf = BufWriter::new(file);
1200 Ok(Self::new(buf, quality))
1201 }
1202}
1203
1204fn get_block(
1205 data: &[u8],
1206 start_x: usize,
1207 start_y: usize,
1208 col_stride: usize,
1209 row_stride: usize,
1210 width: usize,
1211) -> [i16; 64] {
1212 let mut block = [0i16; 64];
1213
1214 for y in 0..8 {
1215 for x in 0..8 {
1216 let ix = start_x + (x * col_stride);
1217 let iy = start_y + (y * row_stride);
1218
1219 block[y * 8 + x] = (data[iy * width + ix] as i16) - 128;
1220 }
1221 }
1222
1223 block
1224}
1225
1226fn ceil_div(value: usize, div: usize) -> usize {
1227 value / div + usize::from(value % div != 0)
1228}
1229
1230fn get_num_bits(mut value: i16) -> u8 {
1231 if value < 0 {
1232 value = -value;
1233 }
1234
1235 let mut num_bits = 0;
1236
1237 while value > 0 {
1238 num_bits += 1;
1239 value >>= 1;
1240 }
1241
1242 num_bits
1243}
1244
1245pub(crate) trait Operations {
1246 #[inline(always)]
1247 fn fdct(data: &mut [i16; 64]) {
1248 fdct(data);
1249 }
1250
1251 #[inline(always)]
1252 fn quantize_block(block: &[i16; 64], q_block: &mut [i16; 64], table: &QuantizationTable) {
1253 for i in 0..64 {
1254 let z = ZIGZAG[i] as usize & 0x3f;
1255 q_block[i] = table.quantize(block[z], z);
1256 }
1257 }
1258}
1259
1260pub(crate) struct DefaultOperations;
1261
1262impl Operations for DefaultOperations {}
1263
1264#[cfg(test)]
1265mod tests {
1266 use alloc::vec;
1267
1268 use crate::encoder::get_num_bits;
1269 use crate::writer::get_code;
1270 use crate::{Encoder, SamplingFactor};
1271
1272 #[test]
1273 fn test_get_num_bits() {
1274 let min_max = 2i16.pow(13);
1275
1276 for value in -min_max..=min_max {
1277 let num_bits1 = get_num_bits(value);
1278 let (num_bits2, _) = get_code(value);
1279
1280 assert_eq!(
1281 num_bits1, num_bits2,
1282 "Difference in num bits for value {}: {} vs {}",
1283 value, num_bits1, num_bits2
1284 );
1285 }
1286 }
1287
1288 #[test]
1289 fn sampling_factors() {
1290 assert_eq!(SamplingFactor::F_1_1.get_sampling_factors(), (1, 1));
1291 assert_eq!(SamplingFactor::F_2_1.get_sampling_factors(), (2, 1));
1292 assert_eq!(SamplingFactor::F_1_2.get_sampling_factors(), (1, 2));
1293 assert_eq!(SamplingFactor::F_2_2.get_sampling_factors(), (2, 2));
1294 assert_eq!(SamplingFactor::F_4_1.get_sampling_factors(), (4, 1));
1295 assert_eq!(SamplingFactor::F_4_2.get_sampling_factors(), (4, 2));
1296 assert_eq!(SamplingFactor::F_1_4.get_sampling_factors(), (1, 4));
1297 assert_eq!(SamplingFactor::F_2_4.get_sampling_factors(), (2, 4));
1298
1299 assert_eq!(SamplingFactor::R_4_4_4.get_sampling_factors(), (1, 1));
1300 assert_eq!(SamplingFactor::R_4_4_0.get_sampling_factors(), (1, 2));
1301 assert_eq!(SamplingFactor::R_4_4_1.get_sampling_factors(), (1, 4));
1302 assert_eq!(SamplingFactor::R_4_2_2.get_sampling_factors(), (2, 1));
1303 assert_eq!(SamplingFactor::R_4_2_0.get_sampling_factors(), (2, 2));
1304 assert_eq!(SamplingFactor::R_4_2_1.get_sampling_factors(), (2, 4));
1305 assert_eq!(SamplingFactor::R_4_1_1.get_sampling_factors(), (4, 1));
1306 assert_eq!(SamplingFactor::R_4_1_0.get_sampling_factors(), (4, 2));
1307 }
1308
1309 #[test]
1310 fn test_set_progressive() {
1311 let mut encoder = Encoder::new(vec![], 100);
1312 encoder.set_progressive(true);
1313 assert_eq!(encoder.progressive_scans(), Some(4));
1314
1315 encoder.set_progressive(false);
1316 assert_eq!(encoder.progressive_scans(), None);
1317 }
1318}