1#![forbid(unsafe_code)]
2use thiserror::Error;
58
59#[derive(Debug, Error)]
63pub enum DctError {
64 #[error("not a JPEG file")]
66 NotJpeg,
67
68 #[error("truncated JPEG data")]
70 Truncated,
71
72 #[error("corrupt or malformed JPEG entropy stream")]
75 CorruptEntropy,
76
77 #[error("unsupported JPEG variant: {0}")]
80 Unsupported(String),
81
82 #[error("missing required JPEG structure: {0}")]
85 Missing(String),
86
87 #[error("coefficient data is incompatible with this JPEG: {0}")]
91 Incompatible(String),
92}
93
94#[derive(Debug, Clone)]
98pub struct ComponentInfo {
99 pub id: u8,
101 pub h_samp: u8,
103 pub v_samp: u8,
105 pub block_count: usize,
107}
108
109#[derive(Debug, Clone)]
114pub struct JpegInfo {
115 pub width: u16,
117 pub height: u16,
119 pub components: Vec<ComponentInfo>,
121}
122
123#[derive(Debug, Clone)]
134pub struct ComponentCoefficients {
135 pub id: u8,
137 pub blocks: Vec<[i16; 64]>,
140}
141
142#[derive(Debug, Clone)]
146pub struct JpegCoefficients {
147 pub components: Vec<ComponentCoefficients>,
150}
151
152#[must_use = "returns the decoded coefficients or an error; ignoring it discards the result"]
164pub fn read_coefficients(jpeg: &[u8]) -> Result<JpegCoefficients, DctError> {
165 let mut parser = JpegParser::new(jpeg)?;
166 parser.parse()?;
167 parser.decode_coefficients()
168}
169
170#[must_use = "returns the re-encoded JPEG bytes or an error; ignoring it discards the result"]
195pub fn write_coefficients(jpeg: &[u8], coeffs: &JpegCoefficients) -> Result<Vec<u8>, DctError> {
196 let mut parser = JpegParser::new(jpeg)?;
197 parser.parse()?;
198 parser.encode_coefficients(jpeg, coeffs)
199}
200
201#[must_use = "returns block counts or an error; ignoring it discards the result"]
211pub fn block_count(jpeg: &[u8]) -> Result<Vec<usize>, DctError> {
212 let mut parser = JpegParser::new(jpeg)?;
213 parser.parse()?;
214 parser.block_counts()
215}
216
217#[must_use = "returns image metadata or an error; ignoring it discards the result"]
226pub fn inspect(jpeg: &[u8]) -> Result<JpegInfo, DctError> {
227 let mut parser = JpegParser::new(jpeg)?;
228 parser.parse()?;
229 let counts = parser.block_counts()?;
230 Ok(JpegInfo {
231 width: parser.image_width,
232 height: parser.image_height,
233 components: parser
234 .frame_components
235 .iter()
236 .enumerate()
237 .map(|(i, fc)| ComponentInfo {
238 id: fc.id,
239 h_samp: fc.h_samp,
240 v_samp: fc.v_samp,
241 block_count: counts[i],
242 })
243 .collect(),
244 })
245}
246
247#[must_use = "returns the eligible AC coefficient count or an error; ignoring it discards the result"]
258pub fn eligible_ac_count(jpeg: &[u8]) -> Result<usize, DctError> {
259 Ok(read_coefficients(jpeg)?.eligible_ac_count())
260}
261
262impl JpegCoefficients {
263 #[must_use]
280 pub fn eligible_ac_count(&self) -> usize {
281 self.components
282 .iter()
283 .flat_map(|c| c.blocks.iter())
284 .flat_map(|b| b[1..].iter())
285 .filter(|&&v| v.abs() >= 2)
286 .count()
287 }
288}
289
290#[rustfmt::skip]
295const ZIGZAG: [u8; 64] = [
296 0, 1, 8, 16, 9, 2, 3, 10,
297 17, 24, 32, 25, 18, 11, 4, 5,
298 12, 19, 26, 33, 40, 48, 41, 34,
299 27, 20, 13, 6, 7, 14, 21, 28,
300 35, 42, 49, 56, 57, 50, 43, 36,
301 29, 22, 15, 23, 30, 37, 44, 51,
302 58, 59, 52, 45, 38, 31, 39, 46,
303 53, 60, 61, 54, 47, 55, 62, 63,
304];
305
306const MAX_MCU_COUNT: usize = 1_048_576; #[inline]
315fn category(value: i16) -> u8 {
316 if value == 0 {
317 return 0;
318 }
319 let abs = value.unsigned_abs();
320 let cat = (16u32 - abs.leading_zeros()) as u8;
321 cat.min(15)
322}
323
324#[inline]
327fn encode_value(value: i16) -> (u8, u16, u8) {
328 let cat = category(value);
329 if cat == 0 {
330 return (0, 0, 0);
331 }
332 let bits = if value > 0 {
333 value as u16
334 } else {
335 let v = (1i16 << cat) - 1 + value;
337 v as u16
338 };
339 (cat, bits, cat)
340}
341
342#[derive(Clone)]
354struct HuffTable {
355 lut: Vec<u16>,
357 encode: [(u16, u8); 256],
359}
360
361impl std::fmt::Debug for HuffTable {
362 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363 let entries = self.encode.iter().filter(|e| e.1 > 0).count();
364 f.debug_struct("HuffTable")
365 .field("encode_entries", &entries)
366 .finish()
367 }
368}
369
370impl HuffTable {
371 fn from_jpeg(counts: &[u8; 16], symbols: &[u8]) -> Result<Self, DctError> {
376 let mut encode = [(0u16, 0u8); 256];
377 let mut lut = vec![0u16; 65536];
378 let mut code: u16 = 0;
379 let mut sym_idx = 0usize;
380
381 for len in 1u8..=16u8 {
382 let count = counts[(len - 1) as usize] as usize;
383 for _ in 0..count {
384 if sym_idx >= symbols.len() {
385 return Err(DctError::CorruptEntropy);
386 }
387 if (code as u32) >= (1u32 << len) {
391 return Err(DctError::CorruptEntropy);
392 }
393 let sym = symbols[sym_idx];
394 sym_idx += 1;
395 encode[sym as usize] = (code, len);
396
397 let spread = 1usize << (16 - len);
401 let base = (code as usize) << (16 - len);
402 let entry = ((sym as u16) << 8) | (len as u16);
403 lut[base..base + spread].fill(entry);
404
405 code += 1;
406 }
407 code <<= 1;
408 }
409
410 Ok(HuffTable { lut, encode })
411 }
412}
413
414struct BitReader<'a> {
417 data: &'a [u8],
418 pos: usize,
419 buf: u64,
420 bits: u8,
421}
422
423impl<'a> BitReader<'a> {
424 fn new(data: &'a [u8]) -> Self {
425 BitReader { data, pos: 0, buf: 0, bits: 0 }
426 }
427
428 fn refill(&mut self) {
431 while self.bits <= 56 {
432 if self.pos >= self.data.len() {
433 break;
434 }
435 let byte = self.data[self.pos];
436 if byte == 0xFF {
437 if self.pos + 1 >= self.data.len() {
438 break;
439 }
440 let next = self.data[self.pos + 1];
441 if next == 0x00 {
442 self.pos += 2;
444 self.buf = (self.buf << 8) | 0xFF;
445 self.bits += 8;
446 } else {
447 break;
449 }
450 } else {
451 self.pos += 1;
452 self.buf = (self.buf << 8) | (byte as u64);
453 self.bits += 8;
454 }
455 }
456 }
457
458 fn peek(&mut self, n: u8) -> Result<u16, DctError> {
460 if self.bits < n {
461 self.refill();
462 }
463 if self.bits < n {
464 return Err(DctError::Truncated);
465 }
466 Ok(((self.buf >> (self.bits - n)) & ((1u64 << n) - 1)) as u16)
467 }
468
469 fn consume(&mut self, n: u8) {
471 debug_assert!(self.bits >= n);
472 self.bits -= n;
473 self.buf &= (1u64 << self.bits) - 1;
474 }
475
476 fn read_bits(&mut self, n: u8) -> Result<u16, DctError> {
478 if n == 0 {
479 return Ok(0);
480 }
481 let v = self.peek(n)?;
482 self.consume(n);
483 Ok(v)
484 }
485
486 fn decode_huffman(&mut self, table: &HuffTable) -> Result<u8, DctError> {
492 if self.bits < 16 {
493 self.refill();
494 }
495 let key = if self.bits >= 16 {
497 ((self.buf >> (self.bits - 16)) & 0xFFFF) as u16
498 } else {
499 ((self.buf << (16 - self.bits)) & 0xFFFF) as u16
503 };
504
505 let entry = table.lut[key as usize];
506 let len = (entry & 0xFF) as u8;
507 let sym = (entry >> 8) as u8;
508
509 if len == 0 {
510 return Err(DctError::CorruptEntropy);
511 }
512 if self.bits < len {
513 return Err(DctError::Truncated);
514 }
515 self.consume(len);
516 Ok(sym)
517 }
518
519 fn sync_restart(&mut self) -> bool {
522 self.bits = 0;
524 self.buf = 0;
525 if self.pos + 1 < self.data.len()
527 && self.data[self.pos] == 0xFF
528 && (0xD0..=0xD7).contains(&self.data[self.pos + 1])
529 {
530 self.pos += 2;
531 return true;
532 }
533 false
534 }
535}
536
537struct BitWriter {
540 out: Vec<u8>,
541 buf: u64,
542 bits: u8,
543}
544
545impl BitWriter {
546 fn with_capacity(cap: usize) -> Self {
547 BitWriter { out: Vec::with_capacity(cap), buf: 0, bits: 0 }
548 }
549
550 fn write_bits(&mut self, value: u16, n: u8) {
552 if n == 0 {
553 return;
554 }
555 self.buf = (self.buf << n) | (value as u64);
556 self.bits += n;
557 while self.bits >= 8 {
558 self.bits -= 8;
559 let byte = ((self.buf >> self.bits) & 0xFF) as u8;
560 self.out.push(byte);
561 if byte == 0xFF {
562 self.out.push(0x00); }
564 self.buf &= (1u64 << self.bits) - 1;
565 }
566 }
567
568 fn flush(&mut self) {
570 if self.bits > 0 {
571 let pad = 8 - self.bits;
572 let byte = (((self.buf << pad) | ((1u64 << pad) - 1)) & 0xFF) as u8;
573 self.out.push(byte);
574 if byte == 0xFF {
575 self.out.push(0x00);
576 }
577 self.bits = 0;
578 self.buf = 0;
579 }
580 }
581
582 fn write_restart_marker(&mut self, n: u8) {
585 self.flush();
586 self.out.push(0xFF);
587 self.out.push(0xD0 | (n & 0x07));
588 }
589}
590
591#[derive(Debug, Clone)]
595struct FrameComponent {
596 id: u8,
597 h_samp: u8,
598 v_samp: u8,
599 #[allow(dead_code)]
600 qt_id: u8,
601}
602
603#[derive(Debug, Clone)]
605struct ScanComponent {
606 comp_idx: usize, dc_table: usize,
608 ac_table: usize,
609}
610
611struct JpegParser<'a> {
613 data: &'a [u8],
614 pos: usize,
615
616 entropy_start: usize,
618 entropy_len: usize,
620
621 frame_components: Vec<FrameComponent>,
622 scan_components: Vec<ScanComponent>,
623 dc_tables: [Option<HuffTable>; 4],
624 ac_tables: [Option<HuffTable>; 4],
625 restart_interval: u16,
626 image_width: u16,
627 image_height: u16,
628}
629
630impl<'a> JpegParser<'a> {
631 fn new(data: &'a [u8]) -> Result<Self, DctError> {
632 if data.len() < 2 || data[0] != 0xFF || data[1] != 0xD8 {
633 return Err(DctError::NotJpeg);
634 }
635 Ok(JpegParser {
636 data,
637 pos: 2,
638 entropy_start: 0,
639 entropy_len: 0,
640 frame_components: Vec::new(),
641 scan_components: Vec::new(),
642 dc_tables: [None, None, None, None],
643 ac_tables: [None, None, None, None],
644 restart_interval: 0,
645 image_width: 0,
646 image_height: 0,
647 })
648 }
649
650 fn read_u16(&mut self) -> Result<u16, DctError> {
652 if self.pos + 1 >= self.data.len() {
653 return Err(DctError::Truncated);
654 }
655 let v = u16::from_be_bytes([self.data[self.pos], self.data[self.pos + 1]]);
656 self.pos += 2;
657 Ok(v)
658 }
659
660 fn parse(&mut self) -> Result<(), DctError> {
663 loop {
664 if self.pos >= self.data.len() {
666 return Err(DctError::Missing("SOS marker".into()));
667 }
668 if self.data[self.pos] != 0xFF {
669 return Err(DctError::CorruptEntropy);
670 }
671 while self.pos < self.data.len() && self.data[self.pos] == 0xFF {
673 self.pos += 1;
674 }
675 if self.pos >= self.data.len() {
676 return Err(DctError::Truncated);
677 }
678 let marker = self.data[self.pos];
679 self.pos += 1;
680
681 match marker {
682 0xD8 => {} 0xD9 => return Err(DctError::Missing("SOS before EOI".into())),
684
685 0xC0 | 0xC1 => self.parse_sof()?,
687
688 0xC2 => return Err(DctError::Unsupported("progressive JPEG (SOF2)".into())),
690 0xC3 => return Err(DctError::Unsupported("lossless JPEG (SOF3)".into())),
691 0xC9 => return Err(DctError::Unsupported("arithmetic coding (SOF9)".into())),
692 0xCA => return Err(DctError::Unsupported("progressive arithmetic (SOF10)".into())),
693 0xCB => return Err(DctError::Unsupported("lossless arithmetic (SOF11)".into())),
694
695 0xC4 => self.parse_dht()?,
696 0xDD => self.parse_dri()?,
697
698 0xDA => {
699 self.parse_sos_header()?;
701 self.entropy_start = self.pos;
702 self.entropy_len = self.find_entropy_end();
703 return Ok(());
704 }
705
706 _ => {
708 let len = self.read_u16()? as usize;
709 if len < 2 {
710 return Err(DctError::CorruptEntropy);
711 }
712 let skip = len - 2;
713 if self.pos + skip > self.data.len() {
714 return Err(DctError::Truncated);
715 }
716 self.pos += skip;
717 }
718 }
719 }
720 }
721
722 fn parse_sof(&mut self) -> Result<(), DctError> {
723 let len = self.read_u16()? as usize;
724 if len < 8 {
725 return Err(DctError::CorruptEntropy);
726 }
727 let end = self.pos + len - 2;
728 if end > self.data.len() {
729 return Err(DctError::Truncated);
730 }
731 let _precision = self.data[self.pos];
732 self.pos += 1;
733 self.image_height = u16::from_be_bytes([self.data[self.pos], self.data[self.pos + 1]]);
734 self.pos += 2;
735 self.image_width = u16::from_be_bytes([self.data[self.pos], self.data[self.pos + 1]]);
736 self.pos += 2;
737
738 if self.image_width == 0 || self.image_height == 0 {
739 return Err(DctError::Unsupported("zero image dimension".into()));
740 }
741
742 let ncomp = self.data[self.pos] as usize;
743 self.pos += 1;
744
745 if ncomp == 0 || ncomp > 4 {
746 return Err(DctError::Unsupported(format!("{ncomp} components")));
747 }
748 if self.pos + ncomp * 3 > end {
749 return Err(DctError::Truncated);
750 }
751
752 self.frame_components.clear();
753 for _ in 0..ncomp {
754 let id = self.data[self.pos];
755 let samp = self.data[self.pos + 1];
756 let qt_id = self.data[self.pos + 2];
757 self.pos += 3;
758 let h_samp = samp >> 4;
759 let v_samp = samp & 0x0F;
760 if h_samp == 0 || v_samp == 0 {
761 return Err(DctError::CorruptEntropy);
762 }
763 self.frame_components.push(FrameComponent {
764 id,
765 h_samp,
766 v_samp,
767 qt_id,
768 });
769 }
770 self.pos = end;
771 Ok(())
772 }
773
774 fn parse_dht(&mut self) -> Result<(), DctError> {
775 let len = self.read_u16()? as usize;
776 if len < 2 {
777 return Err(DctError::CorruptEntropy);
778 }
779 let end = self.pos + len - 2;
780 if end > self.data.len() {
781 return Err(DctError::Truncated);
782 }
783
784 while self.pos < end {
785 if self.pos >= self.data.len() {
786 return Err(DctError::Truncated);
787 }
788 let tc_th = self.data[self.pos];
789 self.pos += 1;
790 let tc = (tc_th >> 4) & 0x0F; let th = (tc_th & 0x0F) as usize; if tc > 1 {
794 return Err(DctError::CorruptEntropy);
795 }
796 if th > 3 {
797 return Err(DctError::CorruptEntropy);
798 }
799
800 if self.pos + 16 > end {
801 return Err(DctError::Truncated);
802 }
803 let mut counts = [0u8; 16];
804 counts.copy_from_slice(&self.data[self.pos..self.pos + 16]);
805 self.pos += 16;
806
807 let total: usize = counts.iter().map(|&c| c as usize).sum();
808 if total > 256 {
810 return Err(DctError::CorruptEntropy);
811 }
812 if self.pos + total > end {
813 return Err(DctError::Truncated);
814 }
815 let symbols = &self.data[self.pos..self.pos + total];
816 self.pos += total;
817
818 let table = HuffTable::from_jpeg(&counts, symbols)?;
819 if tc == 0 {
820 self.dc_tables[th] = Some(table);
821 } else {
822 self.ac_tables[th] = Some(table);
823 }
824 }
825
826 self.pos = end;
827 Ok(())
828 }
829
830 fn parse_dri(&mut self) -> Result<(), DctError> {
831 let len = self.read_u16()?;
832 if len != 4 {
833 return Err(DctError::CorruptEntropy);
834 }
835 self.restart_interval = self.read_u16()?;
836 Ok(())
837 }
838
839 fn parse_sos_header(&mut self) -> Result<(), DctError> {
840 let len = self.read_u16()? as usize;
841 if len < 3 {
842 return Err(DctError::CorruptEntropy);
843 }
844 let end = self.pos + len - 2;
845 if end > self.data.len() {
846 return Err(DctError::Truncated);
847 }
848
849 let ns = self.data[self.pos] as usize;
850 self.pos += 1;
851
852 if ns == 0 || ns > self.frame_components.len() {
853 return Err(DctError::CorruptEntropy);
854 }
855 if self.pos + ns * 2 > end {
856 return Err(DctError::Truncated);
857 }
858
859 self.scan_components.clear();
860 for _ in 0..ns {
861 let comp_id = self.data[self.pos];
862 let td_ta = self.data[self.pos + 1];
863 self.pos += 2;
864
865 let dc_table = (td_ta >> 4) as usize;
866 let ac_table = (td_ta & 0x0F) as usize;
867
868 if dc_table > 3 || ac_table > 3 {
869 return Err(DctError::CorruptEntropy);
870 }
871
872 let comp_idx = self
873 .frame_components
874 .iter()
875 .position(|fc| fc.id == comp_id)
876 .ok_or_else(|| {
877 DctError::Missing(format!("component id {comp_id} in frame"))
878 })?;
879
880 self.scan_components.push(ScanComponent { comp_idx, dc_table, ac_table });
881 }
882
883 self.pos = end;
885 Ok(())
886 }
887
888 fn find_entropy_end(&self) -> usize {
891 let mut i = self.entropy_start;
892 while i < self.data.len() {
893 if self.data[i] == 0xFF && i + 1 < self.data.len() {
894 let next = self.data[i + 1];
895 if next == 0x00 {
896 i += 2;
898 continue;
899 }
900 if (0xD0..=0xD7).contains(&next) {
901 i += 2;
903 continue;
904 }
905 return i - self.entropy_start;
907 }
908 i += 1;
909 }
910 self.data.len() - self.entropy_start
911 }
912
913 fn max_h_samp(&self) -> u8 {
916 self.frame_components.iter().map(|c| c.h_samp).max().unwrap_or(1)
917 }
918
919 fn max_v_samp(&self) -> u8 {
920 self.frame_components.iter().map(|c| c.v_samp).max().unwrap_or(1)
921 }
922
923 fn mcu_cols(&self) -> usize {
924 let max_h = self.max_h_samp() as usize;
925 (self.image_width as usize + max_h * 8 - 1) / (max_h * 8)
926 }
927
928 fn mcu_rows(&self) -> usize {
929 let max_v = self.max_v_samp() as usize;
930 (self.image_height as usize + max_v * 8 - 1) / (max_v * 8)
931 }
932
933 fn mcu_count(&self) -> Result<usize, DctError> {
934 self.mcu_cols()
935 .checked_mul(self.mcu_rows())
936 .ok_or_else(|| DctError::Unsupported("image dimensions overflow usize".into()))
937 }
938
939 fn du_per_mcu(&self) -> Vec<usize> {
941 self.scan_components
942 .iter()
943 .map(|sc| {
944 let fc = &self.frame_components[sc.comp_idx];
945 (fc.h_samp as usize) * (fc.v_samp as usize)
946 })
947 .collect()
948 }
949
950 fn block_counts(&self) -> Result<Vec<usize>, DctError> {
952 let n_mcu = self.mcu_count()?;
953 let du = self.du_per_mcu();
954 let mut counts = vec![0usize; self.frame_components.len()];
955 for (sc_idx, sc) in self.scan_components.iter().enumerate() {
956 counts[sc.comp_idx] = n_mcu * du[sc_idx];
957 }
958 Ok(counts)
959 }
960
961 fn decode_coefficients(&self) -> Result<JpegCoefficients, DctError> {
964 let entropy = &self.data[self.entropy_start..self.entropy_start + self.entropy_len];
965 let n_mcu = self.mcu_count()?;
966
967 if n_mcu > MAX_MCU_COUNT {
968 return Err(DctError::Unsupported(format!(
969 "image too large ({n_mcu} MCUs; max {MAX_MCU_COUNT})"
970 )));
971 }
972
973 let du = self.du_per_mcu();
974
975 let counts = self.block_counts()?;
977 let mut comp_blocks: Vec<Vec<[i16; 64]>> =
978 counts.iter().map(|&c| vec![[0i16; 64]; c]).collect();
979 let mut comp_block_idx: Vec<usize> = vec![0; self.frame_components.len()];
980
981 let mut dc_pred: Vec<i16> = vec![0; self.scan_components.len()];
982 let mut reader = BitReader::new(entropy);
983
984 let restart_interval = self.restart_interval as usize;
985
986 for mcu_idx in 0..n_mcu {
987 if restart_interval > 0 && mcu_idx > 0 && mcu_idx % restart_interval == 0 {
989 reader.sync_restart();
990 for p in dc_pred.iter_mut() {
991 *p = 0;
992 }
993 }
994
995 for (sc_idx, sc) in self.scan_components.iter().enumerate() {
996 let dc_table = self.dc_tables[sc.dc_table]
997 .as_ref()
998 .ok_or_else(|| DctError::Missing(format!("DC table {}", sc.dc_table)))?;
999 let ac_table = self.ac_tables[sc.ac_table]
1000 .as_ref()
1001 .ok_or_else(|| DctError::Missing(format!("AC table {}", sc.ac_table)))?;
1002
1003 for _du_i in 0..du[sc_idx] {
1004 let mut block = [0i16; 64];
1005
1006 let dc_cat = reader.decode_huffman(dc_table)?;
1008 let dc_cat = dc_cat.min(15);
1009 let dc_bits = reader.read_bits(dc_cat)?;
1010 let dc_diff = decode_magnitude(dc_cat, dc_bits);
1011 dc_pred[sc_idx] = dc_pred[sc_idx].saturating_add(dc_diff);
1012 block[ZIGZAG[0] as usize] = dc_pred[sc_idx];
1013
1014 let mut k = 1usize;
1016 while k < 64 {
1017 let rs = reader.decode_huffman(ac_table)?;
1018 if rs == 0x00 {
1019 break;
1021 }
1022 if rs == 0xF0 {
1023 k += 16;
1025 continue;
1026 }
1027 let run = (rs >> 4) as usize;
1028 let cat = (rs & 0x0F).min(15);
1029 k += run;
1030 if k >= 64 {
1031 break;
1032 }
1033 let bits = reader.read_bits(cat)?;
1034 let val = decode_magnitude(cat, bits);
1035 block[ZIGZAG[k] as usize] = val;
1036 k += 1;
1037 }
1038
1039 let block_idx = comp_block_idx[sc.comp_idx];
1040 if block_idx >= comp_blocks[sc.comp_idx].len() {
1041 return Err(DctError::CorruptEntropy);
1042 }
1043 comp_blocks[sc.comp_idx][block_idx] = block;
1044 comp_block_idx[sc.comp_idx] += 1;
1045 }
1046 }
1047 }
1048
1049 let components = self
1050 .frame_components
1051 .iter()
1052 .zip(comp_blocks)
1053 .map(|(fc, blocks)| ComponentCoefficients { id: fc.id, blocks })
1054 .collect();
1055
1056 Ok(JpegCoefficients { components })
1057 }
1058
1059 fn encode_coefficients(
1062 &self,
1063 original: &[u8],
1064 coeffs: &JpegCoefficients,
1065 ) -> Result<Vec<u8>, DctError> {
1066 if coeffs.components.len() != self.frame_components.len() {
1068 return Err(DctError::Incompatible(format!(
1069 "expected {} components, got {}",
1070 self.frame_components.len(),
1071 coeffs.components.len()
1072 )));
1073 }
1074 let counts = self.block_counts()?;
1075 for (i, (cc, &expected)) in coeffs.components.iter().zip(counts.iter()).enumerate() {
1076 if cc.id != self.frame_components[i].id {
1077 return Err(DctError::Incompatible(format!(
1078 "component {i}: expected id {}, got {}",
1079 self.frame_components[i].id, cc.id
1080 )));
1081 }
1082 if cc.blocks.len() != expected {
1083 return Err(DctError::Incompatible(format!(
1084 "component {i}: expected {expected} blocks, got {}",
1085 cc.blocks.len()
1086 )));
1087 }
1088 }
1089
1090 let n_mcu = self.mcu_count()?;
1091 let du = self.du_per_mcu();
1092
1093 let mut writer = BitWriter::with_capacity(self.entropy_len);
1094 let mut dc_pred: Vec<i16> = vec![0; self.scan_components.len()];
1095 let mut comp_block_idx: Vec<usize> = vec![0; self.frame_components.len()];
1096 let restart_interval = self.restart_interval as usize;
1097 let mut rst_count: u8 = 0;
1098
1099 for mcu_idx in 0..n_mcu {
1100 if restart_interval > 0 && mcu_idx > 0 && mcu_idx % restart_interval == 0 {
1101 writer.write_restart_marker(rst_count);
1102 rst_count = rst_count.wrapping_add(1) & 0x07;
1103 for p in dc_pred.iter_mut() {
1104 *p = 0;
1105 }
1106 }
1107
1108 for (sc_idx, sc) in self.scan_components.iter().enumerate() {
1109 let dc_table = self.dc_tables[sc.dc_table]
1110 .as_ref()
1111 .ok_or_else(|| DctError::Missing(format!("DC table {}", sc.dc_table)))?;
1112 let ac_table = self.ac_tables[sc.ac_table]
1113 .as_ref()
1114 .ok_or_else(|| DctError::Missing(format!("AC table {}", sc.ac_table)))?;
1115
1116 for _du_i in 0..du[sc_idx] {
1117 let block = &coeffs.components[sc.comp_idx].blocks
1118 [comp_block_idx[sc.comp_idx]];
1119 comp_block_idx[sc.comp_idx] += 1;
1120
1121 let dc_val = block[ZIGZAG[0] as usize];
1123 let dc_diff = dc_val.saturating_sub(dc_pred[sc_idx]);
1124 dc_pred[sc_idx] = dc_val;
1125 let (dc_cat, dc_bits, dc_n) = encode_value(dc_diff);
1126 let (dc_code, dc_code_len) = {
1127 let e = dc_table.encode[dc_cat as usize];
1128 if e.1 == 0 { return Err(DctError::CorruptEntropy); }
1129 e
1130 };
1131 writer.write_bits(dc_code, dc_code_len);
1132 writer.write_bits(dc_bits, dc_n);
1133
1134 let last_nonzero_zz = (1..64)
1137 .rev()
1138 .find(|&i| block[ZIGZAG[i] as usize] != 0);
1139
1140 let mut k = 1usize;
1141 let mut zero_run = 0usize;
1142
1143 if let Some(last_pos) = last_nonzero_zz {
1144 while k <= last_pos {
1145 let val = block[ZIGZAG[k] as usize];
1146 if val == 0 {
1147 zero_run += 1;
1148 if zero_run == 16 {
1149 let (zrl_code, zrl_len) = {
1151 let e = ac_table.encode[0xF0];
1152 if e.1 == 0 { return Err(DctError::CorruptEntropy); }
1153 e
1154 };
1155 writer.write_bits(zrl_code, zrl_len);
1156 zero_run = 0;
1157 }
1158 } else {
1159 let (cat, bits, n) = encode_value(val);
1160 let rs = ((zero_run as u8) << 4) | cat;
1161 let (ac_code, ac_len) = {
1162 let e = ac_table.encode[rs as usize];
1163 if e.1 == 0 { return Err(DctError::CorruptEntropy); }
1164 e
1165 };
1166 writer.write_bits(ac_code, ac_len);
1167 writer.write_bits(bits, n);
1168 zero_run = 0;
1169 }
1170 k += 1;
1171 }
1172 }
1173 let needs_eob = last_nonzero_zz.map_or(true, |p| p < 63);
1177 if needs_eob {
1178 let (eob_code, eob_len) = {
1179 let e = ac_table.encode[0x00];
1180 if e.1 == 0 { return Err(DctError::CorruptEntropy); }
1181 e
1182 };
1183 writer.write_bits(eob_code, eob_len);
1184 }
1185 }
1186 }
1187 }
1188
1189 writer.flush();
1190
1191 let after_entropy = self.entropy_start + self.entropy_len;
1194 let mut out = Vec::with_capacity(original.len());
1195 out.extend_from_slice(&original[..self.entropy_start]);
1196 out.extend_from_slice(&writer.out);
1197 out.extend_from_slice(&original[after_entropy..]);
1198 Ok(out)
1199 }
1200}
1201
1202fn decode_magnitude(cat: u8, bits: u16) -> i16 {
1206 if cat == 0 {
1207 return 0;
1208 }
1209 if bits >= (1u16 << (cat - 1)) {
1211 bits as i16
1212 } else {
1213 bits as i16 - (1i16 << cat) + 1
1214 }
1215}
1216
1217#[cfg(test)]
1220mod tests {
1221 use super::*;
1222
1223 fn make_jpeg_gray(width: u32, height: u32) -> Vec<u8> {
1226 use image::{GrayImage, ImageEncoder, codecs::jpeg::JpegEncoder};
1227 let img = GrayImage::from_fn(width, height, |x, y| {
1228 image::Luma([(((x * 7 + y * 13) % 200) + 28) as u8])
1229 });
1230 let mut buf = Vec::new();
1231 let enc = JpegEncoder::new_with_quality(&mut buf, 90);
1232 enc.write_image(img.as_raw(), width, height, image::ExtendedColorType::L8)
1233 .unwrap();
1234 buf
1235 }
1236
1237 fn make_jpeg_rgb(width: u32, height: u32) -> Vec<u8> {
1238 use image::{ImageEncoder, RgbImage, codecs::jpeg::JpegEncoder};
1239 let img = RgbImage::from_fn(width, height, |x, y| {
1240 image::Rgb([
1241 ((x * 11 + y * 3) % 200 + 28) as u8,
1242 ((x * 5 + y * 17) % 200 + 28) as u8,
1243 ((x * 3 + y * 7) % 200 + 28) as u8,
1244 ])
1245 });
1246 let mut buf = Vec::new();
1247 let enc = JpegEncoder::new_with_quality(&mut buf, 85);
1248 enc.write_image(img.as_raw(), width, height, image::ExtendedColorType::Rgb8)
1249 .unwrap();
1250 buf
1251 }
1252
1253 #[test]
1256 fn not_jpeg_returns_error() {
1257 let result = read_coefficients(b"PNG\x00garbage");
1258 assert!(matches!(result, Err(DctError::NotJpeg)));
1259 }
1260
1261 #[test]
1262 fn empty_input_returns_error() {
1263 assert!(matches!(read_coefficients(b""), Err(DctError::NotJpeg)));
1264 }
1265
1266 #[test]
1267 fn truncated_returns_error() {
1268 assert!(matches!(
1270 read_coefficients(b"\xFF\xD8\xFF"),
1271 Err(DctError::Truncated | DctError::Missing(_))
1272 ));
1273 }
1274
1275 #[test]
1276 fn progressive_jpeg_returns_unsupported() {
1277 let mut data = vec![0xFF, 0xD8]; data.extend_from_slice(&[0xFF, 0xE0, 0x00, 0x10]);
1281 data.extend_from_slice(&[0x4A, 0x46, 0x49, 0x46, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00]);
1282 data.extend_from_slice(&[0xFF, 0xC2, 0x00, 0x0B]);
1284 data.extend_from_slice(&[0x08, 0x00, 0x10, 0x00, 0x10, 0x01, 0x01, 0x11, 0x00]);
1285 let result = read_coefficients(&data);
1286 assert!(matches!(result, Err(DctError::Unsupported(_))));
1287 }
1288
1289 #[test]
1290 fn incompatible_block_count_returns_error() {
1291 let jpeg = make_jpeg_gray(16, 16);
1292 let mut coeffs = read_coefficients(&jpeg).unwrap();
1293 coeffs.components[0].blocks.pop();
1295 let result = write_coefficients(&jpeg, &coeffs);
1296 assert!(matches!(result, Err(DctError::Incompatible(_))));
1297 }
1298
1299 #[test]
1302 fn roundtrip_identity_gray() {
1303 let jpeg = make_jpeg_gray(32, 32);
1304 let coeffs = read_coefficients(&jpeg).unwrap();
1305 let reencoded = write_coefficients(&jpeg, &coeffs).unwrap();
1306 assert_eq!(jpeg, reencoded, "roundtrip changed the JPEG bytes");
1308 }
1309
1310 #[test]
1311 fn roundtrip_identity_rgb() {
1312 let jpeg = make_jpeg_rgb(32, 32);
1313 let coeffs = read_coefficients(&jpeg).unwrap();
1314 let reencoded = write_coefficients(&jpeg, &coeffs).unwrap();
1315 assert_eq!(jpeg, reencoded, "roundtrip changed the JPEG bytes");
1316 }
1317
1318 #[test]
1319 fn roundtrip_identity_non_square() {
1320 let jpeg = make_jpeg_rgb(48, 16);
1321 let coeffs = read_coefficients(&jpeg).unwrap();
1322 let reencoded = write_coefficients(&jpeg, &coeffs).unwrap();
1323 assert_eq!(jpeg, reencoded);
1324 }
1325
1326 #[test]
1329 fn lsb_modification_survives_roundtrip() {
1330 let jpeg = make_jpeg_gray(32, 32);
1331 let mut coeffs = read_coefficients(&jpeg).unwrap();
1332
1333 let mut modified_count = 0usize;
1334 for block in &mut coeffs.components[0].blocks {
1335 for coeff in block[1..].iter_mut() {
1336 if coeff.abs() >= 2 {
1337 *coeff ^= 1;
1338 modified_count += 1;
1339 }
1340 }
1341 }
1342 assert!(modified_count > 0, "test image had no eligible coefficients");
1343
1344 let modified_jpeg = write_coefficients(&jpeg, &coeffs).unwrap();
1345
1346 let coeffs2 = read_coefficients(&modified_jpeg).unwrap();
1348 assert_eq!(coeffs.components[0].blocks, coeffs2.components[0].blocks);
1349 }
1350
1351 #[test]
1354 fn block_count_gray_16x16() {
1355 let jpeg = make_jpeg_gray(16, 16);
1356 let counts = block_count(&jpeg).unwrap();
1357 assert_eq!(counts, vec![4]);
1359 }
1360
1361 #[test]
1362 fn block_count_rgb_32x32() {
1363 let jpeg = make_jpeg_rgb(32, 32);
1364 let counts = block_count(&jpeg).unwrap();
1365 assert_eq!(counts.len(), 3);
1369 let total: usize = counts.iter().sum();
1370 assert!(total > 0);
1371 }
1372
1373 #[test]
1376 fn category_values() {
1377 assert_eq!(category(0), 0);
1378 assert_eq!(category(1), 1);
1379 assert_eq!(category(-1), 1);
1380 assert_eq!(category(2), 2);
1381 assert_eq!(category(3), 2);
1382 assert_eq!(category(4), 3);
1383 assert_eq!(category(127), 7);
1384 assert_eq!(category(-128), 8);
1385 assert_eq!(category(1023), 10);
1386 assert_eq!(category(i16::MAX), 15); }
1388
1389 #[test]
1392 fn output_is_valid_jpeg() {
1393 let jpeg = make_jpeg_rgb(24, 24);
1394 let mut coeffs = read_coefficients(&jpeg).unwrap();
1395 if let Some(block) = coeffs.components[0].blocks.first_mut() {
1397 block[1] |= 1;
1398 }
1399 let out = write_coefficients(&jpeg, &coeffs).unwrap();
1400 assert_eq!(&out[..2], &[0xFF, 0xD8], "missing SOI");
1402 assert_eq!(&out[out.len() - 2..], &[0xFF, 0xD9], "missing EOI");
1403 }
1404
1405 #[test]
1408 fn inspect_gray_returns_correct_dimensions() {
1409 let jpeg = make_jpeg_gray(32, 16);
1410 let info = inspect(&jpeg).unwrap();
1411 assert_eq!(info.width, 32);
1412 assert_eq!(info.height, 16);
1413 assert_eq!(info.components.len(), 1);
1414 assert_eq!(info.components[0].block_count, 8); }
1416
1417 #[test]
1418 fn inspect_rgb_returns_three_components() {
1419 let jpeg = make_jpeg_rgb(32, 32);
1420 let info = inspect(&jpeg).unwrap();
1421 assert_eq!(info.width, 32);
1422 assert_eq!(info.height, 32);
1423 assert_eq!(info.components.len(), 3);
1424 let total: usize = info.components.iter().map(|c| c.block_count).sum();
1426 assert!(total > 0);
1427 }
1428
1429 #[test]
1430 fn inspect_matches_block_count() {
1431 let jpeg = make_jpeg_rgb(48, 32);
1432 let info = inspect(&jpeg).unwrap();
1433 let counts = block_count(&jpeg).unwrap();
1434 let info_counts: Vec<usize> = info.components.iter().map(|c| c.block_count).collect();
1435 assert_eq!(info_counts, counts);
1436 }
1437
1438 #[test]
1441 fn eligible_ac_count_is_positive() {
1442 let jpeg = make_jpeg_rgb(32, 32);
1443 let n = eligible_ac_count(&jpeg).unwrap();
1444 assert!(n > 0, "natural image should have eligible AC coefficients");
1445 }
1446
1447 #[test]
1448 fn eligible_ac_count_method_matches_free_fn() {
1449 let jpeg = make_jpeg_gray(32, 32);
1450 let coeffs = read_coefficients(&jpeg).unwrap();
1451 let via_method = coeffs.eligible_ac_count();
1452 let via_fn = eligible_ac_count(&jpeg).unwrap();
1453 assert_eq!(via_method, via_fn);
1454 }
1455
1456 #[test]
1457 fn eligible_ac_count_leq_total_ac_count() {
1458 let jpeg = make_jpeg_rgb(32, 32);
1459 let coeffs = read_coefficients(&jpeg).unwrap();
1460 let eligible = coeffs.eligible_ac_count();
1461 let total_ac: usize = coeffs.components.iter()
1462 .flat_map(|c| c.blocks.iter())
1463 .map(|_| 63) .sum();
1465 assert!(eligible <= total_ac);
1466 }
1467
1468 #[test]
1471 fn lut_decode_matches_modification_roundtrip() {
1472 let jpeg = make_jpeg_rgb(64, 64);
1475 let mut coeffs = read_coefficients(&jpeg).unwrap();
1476 let mut flipped = 0usize;
1477 for comp in &mut coeffs.components {
1478 for block in &mut comp.blocks {
1479 for coeff in block[1..].iter_mut() {
1480 if coeff.abs() >= 2 {
1481 *coeff ^= 1;
1482 flipped += 1;
1483 }
1484 }
1485 }
1486 }
1487 assert!(flipped > 0);
1488 let modified = write_coefficients(&jpeg, &coeffs).unwrap();
1489 let coeffs2 = read_coefficients(&modified).unwrap();
1490 assert_eq!(coeffs.components.len(), coeffs2.components.len());
1491 for (c1, c2) in coeffs.components.iter().zip(coeffs2.components.iter()) {
1492 assert_eq!(c1.blocks, c2.blocks);
1493 }
1494 }
1495}