1#![forbid(unsafe_code)]
2#[derive(Debug)]
61pub enum DctError {
62 NotJpeg,
64
65 Truncated,
67
68 CorruptEntropy,
71
72 Unsupported(String),
75
76 Missing(String),
79
80 Incompatible(String),
84}
85
86impl core::fmt::Display for DctError {
87 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
88 match self {
89 DctError::NotJpeg => f.write_str("not a JPEG file"),
90 DctError::Truncated => f.write_str("truncated JPEG data"),
91 DctError::CorruptEntropy => f.write_str("corrupt or malformed JPEG entropy stream"),
92 DctError::Unsupported(s) => write!(f, "unsupported JPEG variant: {}", s),
93 DctError::Missing(s) => write!(f, "missing required JPEG structure: {}", s),
94 DctError::Incompatible(s) => {
95 write!(f, "coefficient data is incompatible with this JPEG: {}", s)
96 }
97 }
98 }
99}
100
101impl std::error::Error for DctError {}
102
103#[derive(Debug, Clone)]
107pub struct ComponentInfo {
108 pub id: u8,
110 pub h_samp: u8,
112 pub v_samp: u8,
114 pub block_count: usize,
116}
117
118#[derive(Debug, Clone)]
123pub struct JpegInfo {
124 pub width: u16,
126 pub height: u16,
128 pub components: Vec<ComponentInfo>,
130}
131
132#[derive(Debug, Clone)]
143pub struct ComponentCoefficients {
144 pub id: u8,
146 pub blocks: Vec<[i16; 64]>,
149}
150
151#[derive(Debug, Clone)]
155pub struct JpegCoefficients {
156 pub components: Vec<ComponentCoefficients>,
159}
160
161#[must_use = "returns the decoded coefficients or an error; ignoring it discards the result"]
173pub fn read_coefficients(jpeg: &[u8]) -> Result<JpegCoefficients, DctError> {
174 let mut parser = JpegParser::new(jpeg)?;
175 parser.parse()?;
176 parser.decode_coefficients()
177}
178
179#[must_use = "returns the re-encoded JPEG bytes or an error; ignoring it discards the result"]
204pub fn write_coefficients(jpeg: &[u8], coeffs: &JpegCoefficients) -> Result<Vec<u8>, DctError> {
205 let mut parser = JpegParser::new(jpeg)?;
206 parser.parse()?;
207 parser.encode_coefficients(jpeg, coeffs)
208}
209
210#[must_use = "returns block counts or an error; ignoring it discards the result"]
220pub fn block_count(jpeg: &[u8]) -> Result<Vec<usize>, DctError> {
221 let mut parser = JpegParser::new(jpeg)?;
222 parser.parse()?;
223 parser.block_counts()
224}
225
226#[must_use = "returns image metadata or an error; ignoring it discards the result"]
235pub fn inspect(jpeg: &[u8]) -> Result<JpegInfo, DctError> {
236 let mut parser = JpegParser::new(jpeg)?;
237 parser.parse()?;
238 let counts = parser.block_counts()?;
239 Ok(JpegInfo {
240 width: parser.image_width,
241 height: parser.image_height,
242 components: parser
243 .frame_components
244 .iter()
245 .enumerate()
246 .map(|(i, fc)| ComponentInfo {
247 id: fc.id,
248 h_samp: fc.h_samp,
249 v_samp: fc.v_samp,
250 block_count: counts[i],
251 })
252 .collect(),
253 })
254}
255
256#[must_use = "returns the eligible AC coefficient count or an error; ignoring it discards the result"]
267pub fn eligible_ac_count(jpeg: &[u8]) -> Result<usize, DctError> {
268 Ok(read_coefficients(jpeg)?.eligible_ac_count())
269}
270
271impl JpegCoefficients {
272 #[must_use]
289 pub fn eligible_ac_count(&self) -> usize {
290 self.components
291 .iter()
292 .flat_map(|c| c.blocks.iter())
293 .flat_map(|b| b[1..].iter())
294 .filter(|&&v| v.abs() >= 2)
295 .count()
296 }
297}
298
299#[rustfmt::skip]
304const ZIGZAG: [u8; 64] = [
305 0, 1, 8, 16, 9, 2, 3, 10,
306 17, 24, 32, 25, 18, 11, 4, 5,
307 12, 19, 26, 33, 40, 48, 41, 34,
308 27, 20, 13, 6, 7, 14, 21, 28,
309 35, 42, 49, 56, 57, 50, 43, 36,
310 29, 22, 15, 23, 30, 37, 44, 51,
311 58, 59, 52, 45, 38, 31, 39, 46,
312 53, 60, 61, 54, 47, 55, 62, 63,
313];
314
315const MAX_MCU_COUNT: usize = 1_048_576; #[inline]
324fn category(value: i16) -> u8 {
325 if value == 0 {
326 return 0;
327 }
328 let abs = value.unsigned_abs();
329 let cat = (16u32 - abs.leading_zeros()) as u8;
330 cat.min(15)
331}
332
333#[inline]
336fn encode_value(value: i16) -> (u8, u16, u8) {
337 let cat = category(value);
338 if cat == 0 {
339 return (0, 0, 0);
340 }
341 let bits = if value > 0 {
342 value as u16
343 } else {
344 let v = (1i16 << cat) - 1 + value;
346 v as u16
347 };
348 (cat, bits, cat)
349}
350
351#[derive(Clone)]
363struct HuffTable {
364 lut: Vec<u16>,
366 encode: [(u16, u8); 256],
368}
369
370impl std::fmt::Debug for HuffTable {
371 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
372 let entries = self.encode.iter().filter(|e| e.1 > 0).count();
373 f.debug_struct("HuffTable")
374 .field("encode_entries", &entries)
375 .finish()
376 }
377}
378
379impl HuffTable {
380 fn from_jpeg(counts: &[u8; 16], symbols: &[u8]) -> Result<Self, DctError> {
385 let mut encode = [(0u16, 0u8); 256];
386 let mut lut = vec![0u16; 65536];
387 let mut code: u16 = 0;
388 let mut sym_idx = 0usize;
389
390 for len in 1u8..=16u8 {
391 let count = counts[(len - 1) as usize] as usize;
392 for _ in 0..count {
393 if sym_idx >= symbols.len() {
394 return Err(DctError::CorruptEntropy);
395 }
396 if (code as u32) >= (1u32 << len) {
400 return Err(DctError::CorruptEntropy);
401 }
402 let sym = symbols[sym_idx];
403 sym_idx += 1;
404 encode[sym as usize] = (code, len);
405
406 let spread = 1usize << (16 - len);
410 let base = (code as usize) << (16 - len);
411 let entry = ((sym as u16) << 8) | (len as u16);
412 lut[base..base + spread].fill(entry);
413
414 code += 1;
415 }
416 code <<= 1;
417 }
418
419 Ok(HuffTable { lut, encode })
420 }
421}
422
423struct BitReader<'a> {
426 data: &'a [u8],
427 pos: usize,
428 buf: u64,
429 bits: u8,
430}
431
432impl<'a> BitReader<'a> {
433 fn new(data: &'a [u8]) -> Self {
434 BitReader {
435 data,
436 pos: 0,
437 buf: 0,
438 bits: 0,
439 }
440 }
441
442 fn refill(&mut self) {
445 while self.bits <= 56 {
446 if self.pos >= self.data.len() {
447 break;
448 }
449 let byte = self.data[self.pos];
450 if byte == 0xFF {
451 if self.pos + 1 >= self.data.len() {
452 break;
453 }
454 let next = self.data[self.pos + 1];
455 if next == 0x00 {
456 self.pos += 2;
458 self.buf = (self.buf << 8) | 0xFF;
459 self.bits += 8;
460 } else {
461 break;
463 }
464 } else {
465 self.pos += 1;
466 self.buf = (self.buf << 8) | (byte as u64);
467 self.bits += 8;
468 }
469 }
470 }
471
472 fn peek(&mut self, n: u8) -> Result<u16, DctError> {
474 if self.bits < n {
475 self.refill();
476 }
477 if self.bits < n {
478 return Err(DctError::Truncated);
479 }
480 Ok(((self.buf >> (self.bits - n)) & ((1u64 << n) - 1)) as u16)
481 }
482
483 fn consume(&mut self, n: u8) {
485 debug_assert!(self.bits >= n);
486 self.bits -= n;
487 self.buf &= (1u64 << self.bits) - 1;
488 }
489
490 fn read_bits(&mut self, n: u8) -> Result<u16, DctError> {
492 if n == 0 {
493 return Ok(0);
494 }
495 let v = self.peek(n)?;
496 self.consume(n);
497 Ok(v)
498 }
499
500 fn decode_huffman(&mut self, table: &HuffTable) -> Result<u8, DctError> {
506 if self.bits < 16 {
507 self.refill();
508 }
509 let key = if self.bits >= 16 {
511 ((self.buf >> (self.bits - 16)) & 0xFFFF) as u16
512 } else {
513 ((self.buf << (16 - self.bits)) & 0xFFFF) as u16
517 };
518
519 let entry = table.lut[key as usize];
520 let len = (entry & 0xFF) as u8;
521 let sym = (entry >> 8) as u8;
522
523 if len == 0 {
524 return Err(DctError::CorruptEntropy);
525 }
526 if self.bits < len {
527 return Err(DctError::Truncated);
528 }
529 self.consume(len);
530 Ok(sym)
531 }
532
533 fn sync_restart(&mut self) -> bool {
536 self.bits = 0;
538 self.buf = 0;
539 if self.pos + 1 < self.data.len()
541 && self.data[self.pos] == 0xFF
542 && (0xD0..=0xD7).contains(&self.data[self.pos + 1])
543 {
544 self.pos += 2;
545 return true;
546 }
547 false
548 }
549}
550
551struct BitWriter {
554 out: Vec<u8>,
555 buf: u64,
556 bits: u8,
557}
558
559impl BitWriter {
560 fn with_capacity(cap: usize) -> Self {
561 BitWriter {
562 out: Vec::with_capacity(cap),
563 buf: 0,
564 bits: 0,
565 }
566 }
567
568 fn write_bits(&mut self, value: u16, n: u8) {
570 if n == 0 {
571 return;
572 }
573 self.buf = (self.buf << n) | (value as u64);
574 self.bits += n;
575 while self.bits >= 8 {
576 self.bits -= 8;
577 let byte = ((self.buf >> self.bits) & 0xFF) as u8;
578 self.out.push(byte);
579 if byte == 0xFF {
580 self.out.push(0x00); }
582 self.buf &= (1u64 << self.bits) - 1;
583 }
584 }
585
586 fn flush(&mut self) {
588 if self.bits > 0 {
589 let pad = 8 - self.bits;
590 let byte = (((self.buf << pad) | ((1u64 << pad) - 1)) & 0xFF) as u8;
591 self.out.push(byte);
592 if byte == 0xFF {
593 self.out.push(0x00);
594 }
595 self.bits = 0;
596 self.buf = 0;
597 }
598 }
599
600 fn write_restart_marker(&mut self, n: u8) {
603 self.flush();
604 self.out.push(0xFF);
605 self.out.push(0xD0 | (n & 0x07));
606 }
607}
608
609#[derive(Debug, Clone)]
613struct FrameComponent {
614 id: u8,
615 h_samp: u8,
616 v_samp: u8,
617 #[allow(dead_code)]
618 qt_id: u8,
619}
620
621#[derive(Debug, Clone)]
623struct ScanComponent {
624 comp_idx: usize, dc_table: usize,
626 ac_table: usize,
627}
628
629struct JpegParser<'a> {
631 data: &'a [u8],
632 pos: usize,
633
634 entropy_start: usize,
636 entropy_len: usize,
638
639 frame_components: Vec<FrameComponent>,
640 scan_components: Vec<ScanComponent>,
641 dc_tables: [Option<HuffTable>; 4],
642 ac_tables: [Option<HuffTable>; 4],
643 restart_interval: u16,
644 image_width: u16,
645 image_height: u16,
646}
647
648impl<'a> JpegParser<'a> {
649 fn new(data: &'a [u8]) -> Result<Self, DctError> {
650 if data.len() < 2 || data[0] != 0xFF || data[1] != 0xD8 {
651 return Err(DctError::NotJpeg);
652 }
653 Ok(JpegParser {
654 data,
655 pos: 2,
656 entropy_start: 0,
657 entropy_len: 0,
658 frame_components: Vec::new(),
659 scan_components: Vec::new(),
660 dc_tables: [None, None, None, None],
661 ac_tables: [None, None, None, None],
662 restart_interval: 0,
663 image_width: 0,
664 image_height: 0,
665 })
666 }
667
668 fn read_u16(&mut self) -> Result<u16, DctError> {
670 if self.pos + 1 >= self.data.len() {
671 return Err(DctError::Truncated);
672 }
673 let v = u16::from_be_bytes([self.data[self.pos], self.data[self.pos + 1]]);
674 self.pos += 2;
675 Ok(v)
676 }
677
678 fn parse(&mut self) -> Result<(), DctError> {
681 loop {
682 if self.pos >= self.data.len() {
684 return Err(DctError::Missing("SOS marker".into()));
685 }
686 if self.data[self.pos] != 0xFF {
687 return Err(DctError::CorruptEntropy);
688 }
689 while self.pos < self.data.len() && self.data[self.pos] == 0xFF {
691 self.pos += 1;
692 }
693 if self.pos >= self.data.len() {
694 return Err(DctError::Truncated);
695 }
696 let marker = self.data[self.pos];
697 self.pos += 1;
698
699 match marker {
700 0xD8 => {} 0xD9 => return Err(DctError::Missing("SOS before EOI".into())),
702
703 0xC0 | 0xC1 => self.parse_sof()?,
705
706 0xC2 => return Err(DctError::Unsupported("progressive JPEG (SOF2)".into())),
708 0xC3 => return Err(DctError::Unsupported("lossless JPEG (SOF3)".into())),
709 0xC9 => return Err(DctError::Unsupported("arithmetic coding (SOF9)".into())),
710 0xCA => {
711 return Err(DctError::Unsupported(
712 "progressive arithmetic (SOF10)".into(),
713 ))
714 }
715 0xCB => return Err(DctError::Unsupported("lossless arithmetic (SOF11)".into())),
716
717 0xC4 => self.parse_dht()?,
718 0xDD => self.parse_dri()?,
719
720 0xDA => {
721 self.parse_sos_header()?;
723 self.entropy_start = self.pos;
724 self.entropy_len = self.find_entropy_end();
725 return Ok(());
726 }
727
728 _ => {
730 let len = self.read_u16()? as usize;
731 if len < 2 {
732 return Err(DctError::CorruptEntropy);
733 }
734 let skip = len - 2;
735 if self.pos + skip > self.data.len() {
736 return Err(DctError::Truncated);
737 }
738 self.pos += skip;
739 }
740 }
741 }
742 }
743
744 fn parse_sof(&mut self) -> Result<(), DctError> {
745 let len = self.read_u16()? as usize;
746 if len < 8 {
747 return Err(DctError::CorruptEntropy);
748 }
749 let end = self.pos + len - 2;
750 if end > self.data.len() {
751 return Err(DctError::Truncated);
752 }
753 let _precision = self.data[self.pos];
754 self.pos += 1;
755 self.image_height = u16::from_be_bytes([self.data[self.pos], self.data[self.pos + 1]]);
756 self.pos += 2;
757 self.image_width = u16::from_be_bytes([self.data[self.pos], self.data[self.pos + 1]]);
758 self.pos += 2;
759
760 if self.image_width == 0 || self.image_height == 0 {
761 return Err(DctError::Unsupported("zero image dimension".into()));
762 }
763
764 let ncomp = self.data[self.pos] as usize;
765 self.pos += 1;
766
767 if ncomp == 0 || ncomp > 4 {
768 return Err(DctError::Unsupported(format!("{} components", ncomp)));
769 }
770 if self.pos + ncomp * 3 > end {
771 return Err(DctError::Truncated);
772 }
773
774 self.frame_components.clear();
775 for _ in 0..ncomp {
776 let id = self.data[self.pos];
777 let samp = self.data[self.pos + 1];
778 let qt_id = self.data[self.pos + 2];
779 self.pos += 3;
780 let h_samp = samp >> 4;
781 let v_samp = samp & 0x0F;
782 if h_samp == 0 || v_samp == 0 {
783 return Err(DctError::CorruptEntropy);
784 }
785 self.frame_components.push(FrameComponent {
786 id,
787 h_samp,
788 v_samp,
789 qt_id,
790 });
791 }
792 self.pos = end;
793 Ok(())
794 }
795
796 fn parse_dht(&mut self) -> Result<(), DctError> {
797 let len = self.read_u16()? as usize;
798 if len < 2 {
799 return Err(DctError::CorruptEntropy);
800 }
801 let end = self.pos + len - 2;
802 if end > self.data.len() {
803 return Err(DctError::Truncated);
804 }
805
806 while self.pos < end {
807 if self.pos >= self.data.len() {
808 return Err(DctError::Truncated);
809 }
810 let tc_th = self.data[self.pos];
811 self.pos += 1;
812 let tc = (tc_th >> 4) & 0x0F; let th = (tc_th & 0x0F) as usize; if tc > 1 {
816 return Err(DctError::CorruptEntropy);
817 }
818 if th > 3 {
819 return Err(DctError::CorruptEntropy);
820 }
821
822 if self.pos + 16 > end {
823 return Err(DctError::Truncated);
824 }
825 let mut counts = [0u8; 16];
826 counts.copy_from_slice(&self.data[self.pos..self.pos + 16]);
827 self.pos += 16;
828
829 let total: usize = counts.iter().map(|&c| c as usize).sum();
830 if total > 256 {
832 return Err(DctError::CorruptEntropy);
833 }
834 if self.pos + total > end {
835 return Err(DctError::Truncated);
836 }
837 let symbols = &self.data[self.pos..self.pos + total];
838 self.pos += total;
839
840 let table = HuffTable::from_jpeg(&counts, symbols)?;
841 if tc == 0 {
842 self.dc_tables[th] = Some(table);
843 } else {
844 self.ac_tables[th] = Some(table);
845 }
846 }
847
848 self.pos = end;
849 Ok(())
850 }
851
852 fn parse_dri(&mut self) -> Result<(), DctError> {
853 let len = self.read_u16()?;
854 if len != 4 {
855 return Err(DctError::CorruptEntropy);
856 }
857 self.restart_interval = self.read_u16()?;
858 Ok(())
859 }
860
861 fn parse_sos_header(&mut self) -> Result<(), DctError> {
862 let len = self.read_u16()? as usize;
863 if len < 3 {
864 return Err(DctError::CorruptEntropy);
865 }
866 let end = self.pos + len - 2;
867 if end > self.data.len() {
868 return Err(DctError::Truncated);
869 }
870
871 let ns = self.data[self.pos] as usize;
872 self.pos += 1;
873
874 if ns == 0 || ns > self.frame_components.len() {
875 return Err(DctError::CorruptEntropy);
876 }
877 if self.pos + ns * 2 > end {
878 return Err(DctError::Truncated);
879 }
880
881 self.scan_components.clear();
882 for _ in 0..ns {
883 let comp_id = self.data[self.pos];
884 let td_ta = self.data[self.pos + 1];
885 self.pos += 2;
886
887 let dc_table = (td_ta >> 4) as usize;
888 let ac_table = (td_ta & 0x0F) as usize;
889
890 if dc_table > 3 || ac_table > 3 {
891 return Err(DctError::CorruptEntropy);
892 }
893
894 let comp_idx = self
895 .frame_components
896 .iter()
897 .position(|fc| fc.id == comp_id)
898 .ok_or_else(|| DctError::Missing(format!("component id {} in frame", comp_id)))?;
899
900 self.scan_components.push(ScanComponent {
901 comp_idx,
902 dc_table,
903 ac_table,
904 });
905 }
906
907 self.pos = end;
909 Ok(())
910 }
911
912 fn find_entropy_end(&self) -> usize {
915 let mut i = self.entropy_start;
916 while i < self.data.len() {
917 if self.data[i] == 0xFF && i + 1 < self.data.len() {
918 let next = self.data[i + 1];
919 if next == 0x00 {
920 i += 2;
922 continue;
923 }
924 if (0xD0..=0xD7).contains(&next) {
925 i += 2;
927 continue;
928 }
929 return i - self.entropy_start;
931 }
932 i += 1;
933 }
934 self.data.len() - self.entropy_start
935 }
936
937 fn max_h_samp(&self) -> u8 {
940 self.frame_components
941 .iter()
942 .map(|c| c.h_samp)
943 .max()
944 .unwrap_or(1)
945 }
946
947 fn max_v_samp(&self) -> u8 {
948 self.frame_components
949 .iter()
950 .map(|c| c.v_samp)
951 .max()
952 .unwrap_or(1)
953 }
954
955 fn mcu_cols(&self) -> usize {
956 let max_h = self.max_h_samp() as usize;
957 (self.image_width as usize + max_h * 8 - 1) / (max_h * 8)
958 }
959
960 fn mcu_rows(&self) -> usize {
961 let max_v = self.max_v_samp() as usize;
962 (self.image_height as usize + max_v * 8 - 1) / (max_v * 8)
963 }
964
965 fn mcu_count(&self) -> Result<usize, DctError> {
966 self.mcu_cols()
967 .checked_mul(self.mcu_rows())
968 .ok_or_else(|| DctError::Unsupported("image dimensions overflow usize".into()))
969 }
970
971 fn du_per_mcu(&self) -> Vec<usize> {
973 self.scan_components
974 .iter()
975 .map(|sc| {
976 let fc = &self.frame_components[sc.comp_idx];
977 (fc.h_samp as usize) * (fc.v_samp as usize)
978 })
979 .collect()
980 }
981
982 fn block_counts(&self) -> Result<Vec<usize>, DctError> {
984 let n_mcu = self.mcu_count()?;
985 let du = self.du_per_mcu();
986 let mut counts = vec![0usize; self.frame_components.len()];
987 for (sc_idx, sc) in self.scan_components.iter().enumerate() {
988 counts[sc.comp_idx] = n_mcu * du[sc_idx];
989 }
990 Ok(counts)
991 }
992
993 fn decode_coefficients(&self) -> Result<JpegCoefficients, DctError> {
996 let entropy = &self.data[self.entropy_start..self.entropy_start + self.entropy_len];
997 let n_mcu = self.mcu_count()?;
998
999 if n_mcu > MAX_MCU_COUNT {
1000 return Err(DctError::Unsupported(format!(
1001 "image too large ({} MCUs; max {})",
1002 n_mcu, MAX_MCU_COUNT
1003 )));
1004 }
1005
1006 let du = self.du_per_mcu();
1007
1008 let counts = self.block_counts()?;
1010 let mut comp_blocks: Vec<Vec<[i16; 64]>> =
1011 counts.iter().map(|&c| vec![[0i16; 64]; c]).collect();
1012 let mut comp_block_idx: Vec<usize> = vec![0; self.frame_components.len()];
1013
1014 let mut dc_pred: Vec<i16> = vec![0; self.scan_components.len()];
1015 let mut reader = BitReader::new(entropy);
1016
1017 let restart_interval = self.restart_interval as usize;
1018
1019 for mcu_idx in 0..n_mcu {
1020 if restart_interval > 0 && mcu_idx > 0 && mcu_idx % restart_interval == 0 {
1022 reader.sync_restart();
1023 for p in dc_pred.iter_mut() {
1024 *p = 0;
1025 }
1026 }
1027
1028 for (sc_idx, sc) in self.scan_components.iter().enumerate() {
1029 let dc_table = self.dc_tables[sc.dc_table]
1030 .as_ref()
1031 .ok_or_else(|| DctError::Missing(format!("DC table {}", sc.dc_table)))?;
1032 let ac_table = self.ac_tables[sc.ac_table]
1033 .as_ref()
1034 .ok_or_else(|| DctError::Missing(format!("AC table {}", sc.ac_table)))?;
1035
1036 for _du_i in 0..du[sc_idx] {
1037 let mut block = [0i16; 64];
1038
1039 let dc_cat = reader.decode_huffman(dc_table)?;
1041 let dc_cat = dc_cat.min(15);
1042 let dc_bits = reader.read_bits(dc_cat)?;
1043 let dc_diff = decode_magnitude(dc_cat, dc_bits);
1044 dc_pred[sc_idx] = dc_pred[sc_idx].saturating_add(dc_diff);
1045 block[ZIGZAG[0] as usize] = dc_pred[sc_idx];
1046
1047 let mut k = 1usize;
1049 while k < 64 {
1050 let rs = reader.decode_huffman(ac_table)?;
1051 if rs == 0x00 {
1052 break;
1054 }
1055 if rs == 0xF0 {
1056 k += 16;
1058 continue;
1059 }
1060 let run = (rs >> 4) as usize;
1061 let cat = (rs & 0x0F).min(15);
1062 k += run;
1063 if k >= 64 {
1064 break;
1065 }
1066 let bits = reader.read_bits(cat)?;
1067 let val = decode_magnitude(cat, bits);
1068 block[ZIGZAG[k] as usize] = val;
1069 k += 1;
1070 }
1071
1072 let block_idx = comp_block_idx[sc.comp_idx];
1073 if block_idx >= comp_blocks[sc.comp_idx].len() {
1074 return Err(DctError::CorruptEntropy);
1075 }
1076 comp_blocks[sc.comp_idx][block_idx] = block;
1077 comp_block_idx[sc.comp_idx] += 1;
1078 }
1079 }
1080 }
1081
1082 let components = self
1083 .frame_components
1084 .iter()
1085 .zip(comp_blocks)
1086 .map(|(fc, blocks)| ComponentCoefficients { id: fc.id, blocks })
1087 .collect();
1088
1089 Ok(JpegCoefficients { components })
1090 }
1091
1092 fn encode_coefficients(
1095 &self,
1096 original: &[u8],
1097 coeffs: &JpegCoefficients,
1098 ) -> Result<Vec<u8>, DctError> {
1099 if coeffs.components.len() != self.frame_components.len() {
1101 return Err(DctError::Incompatible(format!(
1102 "expected {} components, got {}",
1103 self.frame_components.len(),
1104 coeffs.components.len()
1105 )));
1106 }
1107 let counts = self.block_counts()?;
1108 for (i, (cc, &expected)) in coeffs.components.iter().zip(counts.iter()).enumerate() {
1109 if cc.id != self.frame_components[i].id {
1110 return Err(DctError::Incompatible(format!(
1111 "component {}: expected id {}, got {}",
1112 i, self.frame_components[i].id, cc.id
1113 )));
1114 }
1115 if cc.blocks.len() != expected {
1116 return Err(DctError::Incompatible(format!(
1117 "component {}: expected {} blocks, got {}",
1118 i,
1119 expected,
1120 cc.blocks.len()
1121 )));
1122 }
1123 }
1124
1125 let n_mcu = self.mcu_count()?;
1126 let du = self.du_per_mcu();
1127
1128 let mut writer = BitWriter::with_capacity(self.entropy_len);
1129 let mut dc_pred: Vec<i16> = vec![0; self.scan_components.len()];
1130 let mut comp_block_idx: Vec<usize> = vec![0; self.frame_components.len()];
1131 let restart_interval = self.restart_interval as usize;
1132 let mut rst_count: u8 = 0;
1133
1134 for mcu_idx in 0..n_mcu {
1135 if restart_interval > 0 && mcu_idx > 0 && mcu_idx % restart_interval == 0 {
1136 writer.write_restart_marker(rst_count);
1137 rst_count = rst_count.wrapping_add(1) & 0x07;
1138 for p in dc_pred.iter_mut() {
1139 *p = 0;
1140 }
1141 }
1142
1143 for (sc_idx, sc) in self.scan_components.iter().enumerate() {
1144 let dc_table = self.dc_tables[sc.dc_table]
1145 .as_ref()
1146 .ok_or_else(|| DctError::Missing(format!("DC table {}", sc.dc_table)))?;
1147 let ac_table = self.ac_tables[sc.ac_table]
1148 .as_ref()
1149 .ok_or_else(|| DctError::Missing(format!("AC table {}", sc.ac_table)))?;
1150
1151 for _du_i in 0..du[sc_idx] {
1152 let block = &coeffs.components[sc.comp_idx].blocks[comp_block_idx[sc.comp_idx]];
1153 comp_block_idx[sc.comp_idx] += 1;
1154
1155 let dc_val = block[ZIGZAG[0] as usize];
1157 let dc_diff = dc_val.saturating_sub(dc_pred[sc_idx]);
1158 dc_pred[sc_idx] = dc_val;
1159 let (dc_cat, dc_bits, dc_n) = encode_value(dc_diff);
1160 let (dc_code, dc_code_len) = {
1161 let e = dc_table.encode[dc_cat as usize];
1162 if e.1 == 0 {
1163 return Err(DctError::CorruptEntropy);
1164 }
1165 e
1166 };
1167 writer.write_bits(dc_code, dc_code_len);
1168 writer.write_bits(dc_bits, dc_n);
1169
1170 let last_nonzero_zz = (1..64).rev().find(|&i| block[ZIGZAG[i] as usize] != 0);
1173
1174 let mut k = 1usize;
1175 let mut zero_run = 0usize;
1176
1177 if let Some(last_pos) = last_nonzero_zz {
1178 while k <= last_pos {
1179 let val = block[ZIGZAG[k] as usize];
1180 if val == 0 {
1181 zero_run += 1;
1182 if zero_run == 16 {
1183 let (zrl_code, zrl_len) = {
1185 let e = ac_table.encode[0xF0];
1186 if e.1 == 0 {
1187 return Err(DctError::CorruptEntropy);
1188 }
1189 e
1190 };
1191 writer.write_bits(zrl_code, zrl_len);
1192 zero_run = 0;
1193 }
1194 } else {
1195 let (cat, bits, n) = encode_value(val);
1196 let rs = ((zero_run as u8) << 4) | cat;
1197 let (ac_code, ac_len) = {
1198 let e = ac_table.encode[rs as usize];
1199 if e.1 == 0 {
1200 return Err(DctError::CorruptEntropy);
1201 }
1202 e
1203 };
1204 writer.write_bits(ac_code, ac_len);
1205 writer.write_bits(bits, n);
1206 zero_run = 0;
1207 }
1208 k += 1;
1209 }
1210 }
1211 let needs_eob = last_nonzero_zz.map_or(true, |p| p < 63);
1215 if needs_eob {
1216 let (eob_code, eob_len) = {
1217 let e = ac_table.encode[0x00];
1218 if e.1 == 0 {
1219 return Err(DctError::CorruptEntropy);
1220 }
1221 e
1222 };
1223 writer.write_bits(eob_code, eob_len);
1224 }
1225 }
1226 }
1227 }
1228
1229 writer.flush();
1230
1231 let after_entropy = self.entropy_start + self.entropy_len;
1234 let mut out = Vec::with_capacity(original.len());
1235 out.extend_from_slice(&original[..self.entropy_start]);
1236 out.extend_from_slice(&writer.out);
1237 out.extend_from_slice(&original[after_entropy..]);
1238 Ok(out)
1239 }
1240}
1241
1242fn decode_magnitude(cat: u8, bits: u16) -> i16 {
1246 if cat == 0 {
1247 return 0;
1248 }
1249 if bits >= (1u16 << (cat - 1)) {
1251 bits as i16
1252 } else {
1253 bits as i16 - (1i16 << cat) + 1
1254 }
1255}
1256
1257#[cfg(test)]
1260mod tests {
1261 use super::*;
1262
1263 fn make_jpeg_gray(width: u32, height: u32) -> Vec<u8> {
1266 use image::{codecs::jpeg::JpegEncoder, GrayImage, ImageEncoder};
1267 let img = GrayImage::from_fn(width, height, |x, y| {
1268 image::Luma([(((x * 7 + y * 13) % 200) + 28) as u8])
1269 });
1270 let mut buf = Vec::new();
1271 let enc = JpegEncoder::new_with_quality(&mut buf, 90);
1272 enc.write_image(img.as_raw(), width, height, image::ExtendedColorType::L8)
1273 .unwrap();
1274 buf
1275 }
1276
1277 fn make_jpeg_rgb(width: u32, height: u32) -> Vec<u8> {
1278 use image::{codecs::jpeg::JpegEncoder, ImageEncoder, RgbImage};
1279 let img = RgbImage::from_fn(width, height, |x, y| {
1280 image::Rgb([
1281 ((x * 11 + y * 3) % 200 + 28) as u8,
1282 ((x * 5 + y * 17) % 200 + 28) as u8,
1283 ((x * 3 + y * 7) % 200 + 28) as u8,
1284 ])
1285 });
1286 let mut buf = Vec::new();
1287 let enc = JpegEncoder::new_with_quality(&mut buf, 85);
1288 enc.write_image(img.as_raw(), width, height, image::ExtendedColorType::Rgb8)
1289 .unwrap();
1290 buf
1291 }
1292
1293 #[test]
1296 fn not_jpeg_returns_error() {
1297 let result = read_coefficients(b"PNG\x00garbage");
1298 assert!(matches!(result, Err(DctError::NotJpeg)));
1299 }
1300
1301 #[test]
1302 fn empty_input_returns_error() {
1303 assert!(matches!(read_coefficients(b""), Err(DctError::NotJpeg)));
1304 }
1305
1306 #[test]
1307 fn truncated_returns_error() {
1308 assert!(matches!(
1310 read_coefficients(b"\xFF\xD8\xFF"),
1311 Err(DctError::Truncated | DctError::Missing(_))
1312 ));
1313 }
1314
1315 #[test]
1316 fn progressive_jpeg_returns_unsupported() {
1317 let mut data = vec![0xFF, 0xD8]; data.extend_from_slice(&[0xFF, 0xE0, 0x00, 0x10]);
1321 data.extend_from_slice(&[
1322 0x4A, 0x46, 0x49, 0x46, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00,
1323 ]);
1324 data.extend_from_slice(&[0xFF, 0xC2, 0x00, 0x0B]);
1326 data.extend_from_slice(&[0x08, 0x00, 0x10, 0x00, 0x10, 0x01, 0x01, 0x11, 0x00]);
1327 let result = read_coefficients(&data);
1328 assert!(matches!(result, Err(DctError::Unsupported(_))));
1329 }
1330
1331 #[test]
1332 fn incompatible_block_count_returns_error() {
1333 let jpeg = make_jpeg_gray(16, 16);
1334 let mut coeffs = read_coefficients(&jpeg).unwrap();
1335 coeffs.components[0].blocks.pop();
1337 let result = write_coefficients(&jpeg, &coeffs);
1338 assert!(matches!(result, Err(DctError::Incompatible(_))));
1339 }
1340
1341 #[test]
1344 fn roundtrip_identity_gray() {
1345 let jpeg = make_jpeg_gray(32, 32);
1346 let coeffs = read_coefficients(&jpeg).unwrap();
1347 let reencoded = write_coefficients(&jpeg, &coeffs).unwrap();
1348 assert_eq!(jpeg, reencoded, "roundtrip changed the JPEG bytes");
1350 }
1351
1352 #[test]
1353 fn roundtrip_identity_rgb() {
1354 let jpeg = make_jpeg_rgb(32, 32);
1355 let coeffs = read_coefficients(&jpeg).unwrap();
1356 let reencoded = write_coefficients(&jpeg, &coeffs).unwrap();
1357 assert_eq!(jpeg, reencoded, "roundtrip changed the JPEG bytes");
1358 }
1359
1360 #[test]
1361 fn roundtrip_identity_non_square() {
1362 let jpeg = make_jpeg_rgb(48, 16);
1363 let coeffs = read_coefficients(&jpeg).unwrap();
1364 let reencoded = write_coefficients(&jpeg, &coeffs).unwrap();
1365 assert_eq!(jpeg, reencoded);
1366 }
1367
1368 #[test]
1371 fn lsb_modification_survives_roundtrip() {
1372 let jpeg = make_jpeg_gray(32, 32);
1373 let mut coeffs = read_coefficients(&jpeg).unwrap();
1374
1375 let mut modified_count = 0usize;
1376 for block in &mut coeffs.components[0].blocks {
1377 for coeff in block[1..].iter_mut() {
1378 if coeff.abs() >= 2 {
1379 *coeff ^= 1;
1380 modified_count += 1;
1381 }
1382 }
1383 }
1384 assert!(
1385 modified_count > 0,
1386 "test image had no eligible coefficients"
1387 );
1388
1389 let modified_jpeg = write_coefficients(&jpeg, &coeffs).unwrap();
1390
1391 let coeffs2 = read_coefficients(&modified_jpeg).unwrap();
1393 assert_eq!(coeffs.components[0].blocks, coeffs2.components[0].blocks);
1394 }
1395
1396 #[test]
1399 fn block_count_gray_16x16() {
1400 let jpeg = make_jpeg_gray(16, 16);
1401 let counts = block_count(&jpeg).unwrap();
1402 assert_eq!(counts, vec![4]);
1404 }
1405
1406 #[test]
1407 fn block_count_rgb_32x32() {
1408 let jpeg = make_jpeg_rgb(32, 32);
1409 let counts = block_count(&jpeg).unwrap();
1410 assert_eq!(counts.len(), 3);
1414 let total: usize = counts.iter().sum();
1415 assert!(total > 0);
1416 }
1417
1418 #[test]
1421 fn category_values() {
1422 assert_eq!(category(0), 0);
1423 assert_eq!(category(1), 1);
1424 assert_eq!(category(-1), 1);
1425 assert_eq!(category(2), 2);
1426 assert_eq!(category(3), 2);
1427 assert_eq!(category(4), 3);
1428 assert_eq!(category(127), 7);
1429 assert_eq!(category(-128), 8);
1430 assert_eq!(category(1023), 10);
1431 assert_eq!(category(i16::MAX), 15); }
1433
1434 #[test]
1437 fn output_is_valid_jpeg() {
1438 let jpeg = make_jpeg_rgb(24, 24);
1439 let mut coeffs = read_coefficients(&jpeg).unwrap();
1440 if let Some(block) = coeffs.components[0].blocks.first_mut() {
1442 block[1] |= 1;
1443 }
1444 let out = write_coefficients(&jpeg, &coeffs).unwrap();
1445 assert_eq!(&out[..2], &[0xFF, 0xD8], "missing SOI");
1447 assert_eq!(&out[out.len() - 2..], &[0xFF, 0xD9], "missing EOI");
1448 }
1449
1450 #[test]
1453 fn inspect_gray_returns_correct_dimensions() {
1454 let jpeg = make_jpeg_gray(32, 16);
1455 let info = inspect(&jpeg).unwrap();
1456 assert_eq!(info.width, 32);
1457 assert_eq!(info.height, 16);
1458 assert_eq!(info.components.len(), 1);
1459 assert_eq!(info.components[0].block_count, 8); }
1461
1462 #[test]
1463 fn inspect_rgb_returns_three_components() {
1464 let jpeg = make_jpeg_rgb(32, 32);
1465 let info = inspect(&jpeg).unwrap();
1466 assert_eq!(info.width, 32);
1467 assert_eq!(info.height, 32);
1468 assert_eq!(info.components.len(), 3);
1469 let total: usize = info.components.iter().map(|c| c.block_count).sum();
1471 assert!(total > 0);
1472 }
1473
1474 #[test]
1475 fn inspect_matches_block_count() {
1476 let jpeg = make_jpeg_rgb(48, 32);
1477 let info = inspect(&jpeg).unwrap();
1478 let counts = block_count(&jpeg).unwrap();
1479 let info_counts: Vec<usize> = info.components.iter().map(|c| c.block_count).collect();
1480 assert_eq!(info_counts, counts);
1481 }
1482
1483 #[test]
1486 fn eligible_ac_count_is_positive() {
1487 let jpeg = make_jpeg_rgb(32, 32);
1488 let n = eligible_ac_count(&jpeg).unwrap();
1489 assert!(n > 0, "natural image should have eligible AC coefficients");
1490 }
1491
1492 #[test]
1493 fn eligible_ac_count_method_matches_free_fn() {
1494 let jpeg = make_jpeg_gray(32, 32);
1495 let coeffs = read_coefficients(&jpeg).unwrap();
1496 let via_method = coeffs.eligible_ac_count();
1497 let via_fn = eligible_ac_count(&jpeg).unwrap();
1498 assert_eq!(via_method, via_fn);
1499 }
1500
1501 #[test]
1502 fn eligible_ac_count_leq_total_ac_count() {
1503 let jpeg = make_jpeg_rgb(32, 32);
1504 let coeffs = read_coefficients(&jpeg).unwrap();
1505 let eligible = coeffs.eligible_ac_count();
1506 let total_ac: usize = coeffs
1507 .components
1508 .iter()
1509 .flat_map(|c| c.blocks.iter())
1510 .map(|_| 63) .sum();
1512 assert!(eligible <= total_ac);
1513 }
1514
1515 #[test]
1518 fn lut_decode_matches_modification_roundtrip() {
1519 let jpeg = make_jpeg_rgb(64, 64);
1522 let mut coeffs = read_coefficients(&jpeg).unwrap();
1523 let mut flipped = 0usize;
1524 for comp in &mut coeffs.components {
1525 for block in &mut comp.blocks {
1526 for coeff in block[1..].iter_mut() {
1527 if coeff.abs() >= 2 {
1528 *coeff ^= 1;
1529 flipped += 1;
1530 }
1531 }
1532 }
1533 }
1534 assert!(flipped > 0);
1535 let modified = write_coefficients(&jpeg, &coeffs).unwrap();
1536 let coeffs2 = read_coefficients(&modified).unwrap();
1537 assert_eq!(coeffs.components.len(), coeffs2.components.len());
1538 for (c1, c2) in coeffs.components.iter().zip(coeffs2.components.iter()) {
1539 assert_eq!(c1.blocks, c2.blocks);
1540 }
1541 }
1542}