compeg/
lib.rs

1//! WebGPU compute shader JPEG decoder.
2//!
3//! Usage:
4//!
5//! - Create a [`Gpu`] context (either automatically via [`Gpu::open`] or from an existing [`wgpu`] context via [`Gpu::from_wgpu`]).
6//! - Create a [`Decoder`] (or multiple) via [`Decoder::new`].
7//! - For each JPEG image you want to decode, create an [`ImageData`] object and pass it to [`Decoder::start_decode`].
8//!   - The [`Decoder`] will automatically resize buffers and textures when they are too small for the passed [`ImageData`].
9//! - Access the output [`Texture`] via [`DecodeOp::texture`].
10//!   - [`wgpu`] will automatically ensure that the proper barriers are in place when this
11//!     [`Texture`] is used in a GPU operation.
12
13mod bits;
14mod dynamic;
15mod error;
16mod file;
17mod huffman;
18mod metadata;
19mod scan;
20
21#[cfg(test)]
22mod tests;
23
24use std::{
25    borrow::Cow,
26    mem,
27    sync::Arc,
28    time::{Duration, Instant},
29};
30
31use bytemuck::Zeroable;
32use dynamic::DynamicBindGroup;
33use error::{Error, Result};
34use file::{JpegParser, SegmentKind};
35use wgpu::*;
36
37use crate::{
38    dynamic::{DynamicBuffer, DynamicTexture},
39    file::SofMarker,
40    huffman::{HuffmanTables, TableData},
41    metadata::QTable,
42};
43
44/// **Not** part of the public API. Used for benchmarks only.
45#[doc(hidden)]
46pub use scan::ScanBuffer;
47
48const OUTPUT_FORMAT: TextureFormat = TextureFormat::Rgba8Unorm;
49
50const HUFFMAN_WORKGROUP_SIZE: u32 = 64;
51
52const DCT_WORKGROUP_SIZE: u32 = 256;
53const THREADS_PER_DCT: u32 = 8;
54const DCTS_PER_WORKGROUP: u32 = DCT_WORKGROUP_SIZE / THREADS_PER_DCT;
55
56const FINALIZE_WORKGROUP_SIZE: u32 = 256;
57const MCU_HEIGHT: u32 = 8;
58const THREADS_PER_MCU: u32 = MCU_HEIGHT;
59const MCUS_PER_WORKGROUP: u32 = FINALIZE_WORKGROUP_SIZE / THREADS_PER_MCU;
60
61/// An open handle to a GPU.
62///
63/// This stores all static data (shaders, pipelines) needed for JPEG decoding.
64pub struct Gpu {
65    device: Arc<Device>,
66    queue: Arc<Queue>,
67    metadata_bgl: Arc<BindGroupLayout>,
68    huffman_bgl: Arc<BindGroupLayout>,
69    coefficients_bgl: Arc<BindGroupLayout>,
70    output_bgl: Arc<BindGroupLayout>,
71    huffman_decode_pipeline: ComputePipeline,
72    dct_pipeline: ComputePipeline,
73    finalize_pipeline: ComputePipeline,
74}
75
76impl Gpu {
77    /// Opens a suitable default GPU.
78    pub async fn open() -> Result<Self> {
79        let instance = Instance::new(InstanceDescriptor {
80            // The OpenGL backend panics spuriously, so don't enable it.
81            backends: Backends::PRIMARY,
82            ..Default::default()
83        });
84        let adapter = instance
85            .request_adapter(&RequestAdapterOptions::default())
86            .await
87            .ok_or_else(|| Error::from("no supported graphics adapter found"))?;
88        let (device, queue) = adapter
89            .request_device(&Default::default(), None)
90            .await
91            .map_err(|_| Error::from("no supported graphics device found"))?;
92
93        let info = adapter.get_info();
94        log::info!(
95            "opened {:?} adapter {} ({})",
96            info.backend,
97            info.name,
98            info.driver
99        );
100
101        Self::from_wgpu(device.into(), queue.into())
102    }
103
104    /// Creates a [`Gpu`] handle from an existing [`wgpu`] [`Device`] and [`Queue`].
105    pub fn from_wgpu(device: Arc<Device>, queue: Arc<Queue>) -> Result<Self> {
106        let shared = include_str!("shared.wgsl");
107        let huffman = include_str!("huffman.wgsl");
108        let dct = include_str!("dct.wgsl");
109        let huffman = format!("{shared}\n\n{huffman}");
110        let dct = format!("{shared}\n\n{dct}");
111        let huffman = device.create_shader_module(ShaderModuleDescriptor {
112            label: Some("huffman"),
113            source: wgpu::ShaderSource::Wgsl(huffman.into()),
114        });
115        let dct = device.create_shader_module(ShaderModuleDescriptor {
116            label: Some("dct"),
117            source: wgpu::ShaderSource::Wgsl(dct.into()),
118        });
119
120        let metadata_bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
121            label: Some("metadata_bgl"),
122            entries: &[
123                // `metadata`
124                BindGroupLayoutEntry {
125                    binding: 0,
126                    visibility: ShaderStages::COMPUTE,
127                    ty: BindingType::Buffer {
128                        ty: BufferBindingType::Storage { read_only: true },
129                        has_dynamic_offset: false,
130                        min_binding_size: None,
131                    },
132                    count: None,
133                },
134            ],
135        });
136        let huffman_bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
137            label: Some("huffman_bgl"),
138            entries: &[
139                // `huffman_l1`
140                BindGroupLayoutEntry {
141                    binding: 0,
142                    visibility: ShaderStages::COMPUTE,
143                    ty: BindingType::Buffer {
144                        ty: BufferBindingType::Storage { read_only: true },
145                        has_dynamic_offset: false,
146                        min_binding_size: None,
147                    },
148                    count: None,
149                },
150                // `huffman_l2`
151                BindGroupLayoutEntry {
152                    binding: 1,
153                    visibility: ShaderStages::COMPUTE,
154                    ty: BindingType::Buffer {
155                        ty: BufferBindingType::Storage { read_only: true },
156                        has_dynamic_offset: false,
157                        min_binding_size: None,
158                    },
159                    count: None,
160                },
161                // `scan_data`
162                BindGroupLayoutEntry {
163                    binding: 2,
164                    visibility: ShaderStages::COMPUTE,
165                    ty: BindingType::Buffer {
166                        ty: BufferBindingType::Storage { read_only: true },
167                        has_dynamic_offset: false,
168                        min_binding_size: None,
169                    },
170                    count: None,
171                },
172                // `scan_positions`
173                BindGroupLayoutEntry {
174                    binding: 3,
175                    visibility: ShaderStages::COMPUTE,
176                    ty: BindingType::Buffer {
177                        ty: BufferBindingType::Storage { read_only: true },
178                        has_dynamic_offset: false,
179                        min_binding_size: None,
180                    },
181                    count: None,
182                },
183            ],
184        });
185        let coefficients_bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
186            label: Some("coefficients_bgl"),
187            entries: &[
188                // `coefficients`
189                BindGroupLayoutEntry {
190                    binding: 0,
191                    visibility: ShaderStages::COMPUTE,
192                    ty: BindingType::Buffer {
193                        ty: BufferBindingType::Storage { read_only: false },
194                        has_dynamic_offset: false,
195                        min_binding_size: None,
196                    },
197                    count: None,
198                },
199            ],
200        });
201        let output_bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
202            label: Some("output_bgl"),
203            entries: &[
204                // `out`
205                BindGroupLayoutEntry {
206                    binding: 0,
207                    visibility: ShaderStages::COMPUTE,
208                    ty: BindingType::StorageTexture {
209                        access: StorageTextureAccess::WriteOnly,
210                        format: OUTPUT_FORMAT,
211                        view_dimension: TextureViewDimension::D2,
212                    },
213                    count: None,
214                },
215            ],
216        });
217        let opts = PipelineCompilationOptions {
218            // wgpu's zero init code is pretty suboptimal, so turn it off. We don't need it anyways.
219            zero_initialize_workgroup_memory: false,
220            ..Default::default()
221        };
222        let huffman_decode_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
223            label: Some("huffman_decode_pipeline"),
224            layout: Some(&device.create_pipeline_layout(&PipelineLayoutDescriptor {
225                label: None,
226                bind_group_layouts: &[&metadata_bgl, &huffman_bgl, &coefficients_bgl],
227                push_constant_ranges: &[],
228            })),
229            module: &huffman,
230            entry_point: "huffman",
231            compilation_options: opts.clone(),
232            cache: None,
233        });
234        let dct_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
235            label: Some("dct_pipeline"),
236            layout: Some(&device.create_pipeline_layout(&PipelineLayoutDescriptor {
237                label: None,
238                bind_group_layouts: &[&metadata_bgl, &coefficients_bgl, &output_bgl],
239                push_constant_ranges: &[],
240            })),
241            module: &dct,
242            entry_point: "dct",
243            compilation_options: opts.clone(),
244            cache: None,
245        });
246        let finalize_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
247            label: Some("finalize_pipeline"),
248            layout: Some(&device.create_pipeline_layout(&PipelineLayoutDescriptor {
249                label: None,
250                bind_group_layouts: &[&metadata_bgl, &coefficients_bgl, &output_bgl],
251                push_constant_ranges: &[],
252            })),
253            module: &dct,
254            entry_point: "finalize",
255            compilation_options: opts.clone(),
256            cache: None,
257        });
258
259        Ok(Self {
260            device,
261            queue,
262            metadata_bgl: Arc::new(metadata_bgl),
263            huffman_bgl: Arc::new(huffman_bgl),
264            coefficients_bgl: Arc::new(coefficients_bgl),
265            output_bgl: Arc::new(output_bgl),
266            huffman_decode_pipeline,
267            dct_pipeline,
268            finalize_pipeline,
269        })
270    }
271}
272
273/// A GPU JPEG decode context.
274///
275/// Holds all on-GPU buffers and textures needed for JPEG decoding.
276pub struct Decoder {
277    gpu: Arc<Gpu>,
278    metadata: Buffer,
279    huffman_l1: Buffer,
280    huffman_l2: DynamicBuffer,
281    /// Holds all the scan data of the JPEG (including all embedded RST markers). This constitutes
282    /// the main input data to the shader pipeline.
283    scan_data: DynamicBuffer,
284    start_positions_buffer: DynamicBuffer,
285    coefficients: DynamicBuffer,
286    output: DynamicTexture,
287    metadata_bg: DynamicBindGroup,
288    huffman_bg: DynamicBindGroup,
289    coefficients_bg: DynamicBindGroup,
290    output_bg: DynamicBindGroup,
291    scan_buffer: ScanBuffer,
292}
293
294impl Decoder {
295    /// [`wgpu`] only guarantees that it is able to dispatch 65535 workgroups at once, so this is
296    /// the maximum number of shader invocations we can run (and thus the max. number of restart
297    /// intervals we can process).
298    const MAX_RESTART_INTERVALS: u32 = HUFFMAN_WORKGROUP_SIZE * 65535;
299    // FIXME: fix this to use MCUs/DUs as the limiting factor?
300
301    /// Creates a new JPEG decoding context on the given [`Gpu`].
302    pub fn new(gpu: Arc<Gpu>) -> Self {
303        Self::with_texture_usages(gpu, TextureUsages::empty())
304    }
305
306    pub fn with_texture_usages(gpu: Arc<Gpu>, usage: TextureUsages) -> Self {
307        let metadata = gpu.device.create_buffer(&BufferDescriptor {
308            label: Some("metadata"),
309            size: mem::size_of::<metadata::Metadata>() as u64,
310            usage: BufferUsages::COPY_DST | BufferUsages::STORAGE,
311            mapped_at_creation: false,
312        });
313        let huffman_l1 = gpu.device.create_buffer(&BufferDescriptor {
314            label: Some("huffman_l1"),
315            size: HuffmanTables::TOTAL_L1_SIZE as u64,
316            usage: BufferUsages::COPY_DST | BufferUsages::STORAGE,
317            mapped_at_creation: false,
318        });
319        let huffman_l2 = DynamicBuffer::new(
320            gpu.clone(),
321            "huffman_l2",
322            BufferUsages::COPY_DST | BufferUsages::STORAGE,
323        );
324
325        let scan_data = DynamicBuffer::new(
326            gpu.clone(),
327            "scan_data",
328            BufferUsages::COPY_DST | BufferUsages::STORAGE,
329        );
330        let start_positions_buffer = DynamicBuffer::new(
331            gpu.clone(),
332            "start_positions",
333            BufferUsages::COPY_DST | BufferUsages::STORAGE,
334        );
335        let coefficients = DynamicBuffer::new(gpu.clone(), "coefficients", BufferUsages::STORAGE);
336
337        let output = DynamicTexture::new(
338            gpu.clone(),
339            "output",
340            TextureUsages::STORAGE_BINDING
341                | TextureUsages::TEXTURE_BINDING
342                | TextureUsages::COPY_SRC
343                | TextureUsages::COPY_DST
344                | usage,
345            OUTPUT_FORMAT,
346        );
347
348        let metadata_bg =
349            DynamicBindGroup::new(gpu.clone(), gpu.metadata_bgl.clone(), "metadata_bg");
350        let huffman_bg = DynamicBindGroup::new(gpu.clone(), gpu.huffman_bgl.clone(), "huffman_bg");
351        let coefficients_bg =
352            DynamicBindGroup::new(gpu.clone(), gpu.coefficients_bgl.clone(), "coefficients_bg");
353        let output_bg = DynamicBindGroup::new(gpu.clone(), gpu.output_bgl.clone(), "output_bg");
354
355        Self {
356            gpu,
357            metadata,
358            huffman_l1,
359            huffman_l2,
360            scan_data,
361            start_positions_buffer,
362            coefficients,
363            output,
364            metadata_bg,
365            huffman_bg,
366            coefficients_bg,
367            output_bg,
368            scan_buffer: ScanBuffer::new(),
369        }
370    }
371
372    pub fn texture(&self) -> &Texture {
373        self.output.texture()
374    }
375
376    /// Consumes this [`Decoder`] and returns its output [`Texture`].
377    ///
378    /// The [`Texture`] will only contain data if a decode operation has been previously started on
379    /// this [`Decoder`]. Depending on what data has been decoded previously, the [`Texture`] might
380    /// have larger dimensions than necessary to fit the *last* frame.
381    pub fn into_texture(self) -> Texture {
382        self.output.into_texture()
383    }
384
385    pub fn enqueue(&mut self, data: &ImageData<'_>, enc: &mut CommandEncoder) -> bool {
386        let texture_changed = self.output.reserve(data.width(), data.height());
387
388        let total_restart_intervals = data.metadata.total_restart_intervals;
389        let total_mcus = total_restart_intervals * data.metadata.restart_interval;
390        let total_dus = total_mcus * data.metadata.dus_per_mcu;
391        let t_preprocess = time(|| {
392            self.scan_buffer
393                .process(data.scan_data(), total_restart_intervals)
394        });
395
396        let t_enqueue_writes = time(|| {
397            self.gpu
398                .queue
399                .write_buffer(&self.metadata, 0, bytemuck::bytes_of(&data.metadata));
400            self.scan_data.write(self.scan_buffer.processed_scan_data());
401            self.start_positions_buffer
402                .write(self.scan_buffer.start_positions());
403
404            self.gpu
405                .queue
406                .write_buffer(&self.huffman_l1, 0, data.huffman_tables.l1_data());
407            self.huffman_l2.write(data.huffman_tables.l2_data());
408
409            // Reserve space for the decoded coefficients. There are 64 32-bit values per data unit.
410            self.coefficients
411                .reserve(4 * u64::from(data.metadata.retained_coefficients) * u64::from(total_dus));
412        });
413
414        let metadata_bg = self
415            .metadata_bg
416            .bind_group(&[self.metadata.as_entire_binding().into()]);
417        let huffman_bg = self.huffman_bg.bind_group(&[
418            self.huffman_l1.as_entire_binding().into(),
419            self.huffman_l2.as_resource(),
420            self.scan_data.as_resource(),
421            self.start_positions_buffer.as_resource(),
422        ]);
423        let coefficients_bg = self
424            .coefficients_bg
425            .bind_group(&[self.coefficients.as_resource()]);
426        let output_bg = self.output_bg.bind_group(&[self.output.as_resource()]);
427
428        enc.clear_buffer(self.coefficients.buffer(), 0, None);
429
430        let mut compute = enc.begin_compute_pass(&ComputePassDescriptor::default());
431        let huffman_workgroups =
432            (total_restart_intervals + HUFFMAN_WORKGROUP_SIZE - 1) / HUFFMAN_WORKGROUP_SIZE;
433        let dct_workgroups = (total_dus + DCTS_PER_WORKGROUP - 1) / DCTS_PER_WORKGROUP;
434        let finalize_workgroups = (total_mcus + MCUS_PER_WORKGROUP - 1) / MCUS_PER_WORKGROUP;
435
436        compute.set_bind_group(0, metadata_bg, &[]);
437        compute.set_bind_group(1, huffman_bg, &[]);
438        compute.set_bind_group(2, coefficients_bg, &[]);
439        compute.set_pipeline(&self.gpu.huffman_decode_pipeline);
440        compute.dispatch_workgroups(huffman_workgroups, 1, 1);
441
442        compute.set_bind_group(0, metadata_bg, &[]);
443        compute.set_bind_group(1, coefficients_bg, &[]);
444        compute.set_bind_group(2, output_bg, &[]);
445        compute.set_pipeline(&self.gpu.dct_pipeline);
446        compute.dispatch_workgroups(dct_workgroups, 1, 1);
447        compute.set_pipeline(&self.gpu.finalize_pipeline);
448        compute.dispatch_workgroups(finalize_workgroups, 1, 1);
449
450        drop(compute);
451
452        log::trace!(
453            "dispatching {} workgroups for huffman decoding ({} shader invocations; {} restart intervals)",
454            huffman_workgroups,
455            huffman_workgroups * HUFFMAN_WORKGROUP_SIZE,
456            total_restart_intervals,
457        );
458        log::trace!(
459            "dispatching {} workgroups for IDCT ({} shader invocations; {} MCUs; {} DUs)",
460            dct_workgroups,
461            dct_workgroups * DCT_WORKGROUP_SIZE,
462            total_mcus,
463            total_dus,
464        );
465        log::trace!(
466            "dispatching {} workgroups for compositing ({} shader invocations; {} MCUs)",
467            finalize_workgroups,
468            finalize_workgroups * FINALIZE_WORKGROUP_SIZE,
469            total_mcus,
470        );
471
472        log::trace!(
473            "t_preprocess={t_preprocess:?}, \
474            t_enqueue_writes={t_enqueue_writes:?}"
475        );
476
477        texture_changed
478    }
479
480    /// Preprocesses and uploads a JPEG image, and dispatches the decoding operation on the GPU.
481    ///
482    /// Returns a [`DecodeOp`] with information about the decode operation.
483    pub fn start_decode(&mut self, data: &ImageData<'_>) -> DecodeOp<'_> {
484        let mut enc = self
485            .gpu
486            .device
487            .create_command_encoder(&CommandEncoderDescriptor::default());
488
489        let texture_changed = self.enqueue(data, &mut enc);
490
491        let buffer = enc.finish();
492        let submission = self.gpu.queue.submit([buffer]);
493
494        DecodeOp {
495            submission,
496            texture: self.output.texture(),
497            texture_changed,
498        }
499    }
500
501    /// Performs a blocking decode operation.
502    ///
503    /// This method works identically to [`Decoder::start_decode`], but will wait until the
504    /// operation on the GPU is finished.
505    ///
506    /// Note that it is not typically necessary to use this method, since [`wgpu`] will
507    /// automatically insert barriers before the target texture is accessed.
508    pub fn decode_blocking(&mut self, data: &ImageData<'_>) -> DecodeOp<'_> {
509        // FIXME: destructuring and recreation is annoyingly needed because the `DecodeOp` will
510        // borrow `self` *mutably*, even though an immutable borrow would suffice.
511        let DecodeOp {
512            submission,
513            texture: _,
514            texture_changed,
515        } = self.start_decode(data);
516        let t_poll = time(|| {
517            self.gpu
518                .device
519                .poll(MaintainBase::WaitForSubmissionIndex(submission.clone()))
520        });
521
522        log::trace!("t_poll={:?}", t_poll);
523
524        DecodeOp {
525            submission,
526            texture: self.output.texture(),
527            texture_changed,
528        }
529    }
530}
531
532fn time<R>(f: impl FnOnce() -> R) -> Duration {
533    let start = Instant::now();
534    f();
535    start.elapsed()
536}
537
538/// Information about an ongoing JPEG decode operation.
539///
540/// Returned by [`Decoder::start_decode`].
541pub struct DecodeOp<'a> {
542    submission: SubmissionIndex,
543    texture: &'a Texture,
544    texture_changed: bool,
545}
546
547impl<'a> DecodeOp<'a> {
548    /// Returns the [`SubmissionIndex`] associated with the compute shader dispatch.
549    #[inline]
550    pub fn submission(&self) -> &SubmissionIndex {
551        &self.submission
552    }
553
554    /// Returns a reference to the target [`Texture`] that the JPEG decode operation is writing to.
555    ///
556    /// Note that, when using the [`Decoder`] with JPEG images of varying sizes, not the entire
557    /// target texture will be written to. The caller has to ensure to only use the area of the
558    /// [`Texture`] indicated by [`ImageData::width`] and [`ImageData::height`].
559    #[inline]
560    pub fn texture(&self) -> &Texture {
561        self.texture
562    }
563
564    /// Returns a [`bool`] indicating whether the target [`Texture`] has been reallocated since the
565    /// last decode operation on the same [`Decoder`] was started.
566    ///
567    /// If this is the first decode operation, this method will return `true`. The return value of
568    /// this method can be used to determine whether any bind groups referencing the target
569    /// [`Texture`] need to be recreated.
570    #[inline]
571    pub fn texture_changed(&self) -> bool {
572        self.texture_changed
573    }
574}
575
576/// A parsed JPEG image, containing all data needed for on-GPU decoding.
577pub struct ImageData<'a> {
578    metadata: metadata::Metadata,
579    width: u16,
580    height: u16,
581    huffman_tables: HuffmanTables,
582    jpeg: Cow<'a, [u8]>,
583    scan_data_offset: usize,
584    scan_data_len: usize,
585}
586
587impl<'a> ImageData<'a> {
588    /// Reads [`ImageData`] from an in-memory JPEG file.
589    ///
590    /// If this returns an error, it either means that the JPEG file is malformed, or that it uses
591    /// features this library does not support. In either case, the application should fall back to
592    /// a more fully-featured software decoder.
593    pub fn new(jpeg: impl Into<Cow<'a, [u8]>>) -> Result<Self> {
594        Self::new_impl(jpeg.into())
595    }
596
597    fn new_impl(jpeg: Cow<'a, [u8]>) -> Result<Self> {
598        macro_rules! bail {
599            ($($args:tt)*) => {
600                return Err(Error::from(format!(
601                    $($args)*
602                )))
603            };
604        }
605
606        let mut size = None;
607        let mut ri = None;
608        let mut huffman_tables = [
609            TableData::default_luminance_dc(),
610            TableData::default_luminance_ac(),
611            TableData::default_chrominance_dc(),
612            TableData::default_chrominance_ac(),
613        ];
614        let mut qtables = [QTable::zeroed(); 4];
615        let mut scan_data = None;
616        let mut components = None;
617        let mut component_indices = None;
618        let mut component_dchuff = [0; 3];
619        let mut component_achuff = [0; 3];
620
621        let mut parser = JpegParser::new(&jpeg)?;
622        while let Some(segment) = parser.next_segment()? {
623            let Some(kind) = segment.as_segment_kind() else {
624                continue;
625            };
626            match kind {
627                SegmentKind::SOF(sof) => {
628                    if sof.sof() != SofMarker::SOF0 {
629                        bail!("not a baseline JPEG (SOF={:?})", sof.sof());
630                    }
631
632                    if sof.P() != 8 {
633                        bail!("sample precision of {} bits is not supported", sof.P());
634                    }
635
636                    if component_indices.is_some() {
637                        bail!("encountered multiple SOF markers");
638                    }
639
640                    match sof.components() {
641                        [y, u, v] => {
642                            log::trace!("frame components:");
643                            log::trace!("- {:?}", y);
644                            log::trace!("- {:?}", u);
645                            log::trace!("- {:?}", v);
646
647                            if y.Tqi() > 3 || u.Tqi() > 3 || v.Tqi() > 3 {
648                                bail!("invalid quantization table selection [{},{},{}] (only tables 0-3 are valid)", y.Tqi(), u.Tqi(), v.Tqi());
649                            }
650                            if y.Hi() != 2 || y.Vi() != 1 {
651                                bail!(
652                                    "invalid sampling factors {}x{} for Y component (expected 2x1)",
653                                    y.Hi(),
654                                    y.Vi(),
655                                );
656                            }
657                            if u.Hi() != v.Hi() || u.Vi() != v.Vi() || u.Hi() != 1 || u.Vi() != 1 {
658                                bail!(
659                                    "invalid U/V sampling factors {}x{} and {}x{} (expected 1x1)",
660                                    u.Hi(),
661                                    u.Vi(),
662                                    v.Hi(),
663                                    v.Vi(),
664                                );
665                            }
666
667                            component_indices = Some([y.Ci(), u.Ci(), v.Ci()]);
668
669                            components = Some([y, u, v]);
670                        }
671                        _ => {
672                            bail!("frame with {} components not supported (only 3 components are supported)", sof.components().len());
673                        }
674                    }
675
676                    size = Some((sof.X(), sof.Y()));
677                }
678                SegmentKind::DQT(dqt) => {
679                    for table in dqt.tables() {
680                        if table.Pq() != 0 {
681                            bail!(
682                                "invalid quantization table precision Pq={} (only 0 is allowed)",
683                                table.Pq()
684                            );
685                        }
686                        if table.Tq() > 3 {
687                            bail!(
688                                "invalid quantization table destination Tq={} (0-3 are allowed)",
689                                table.Tq()
690                            );
691                        }
692
693                        for (dest, src) in qtables[usize::from(table.Tq())]
694                            .values
695                            .iter_mut()
696                            .zip(table.Qk())
697                        {
698                            *dest = u32::from(*src);
699                        }
700                    }
701                }
702                SegmentKind::DHT(dht) => {
703                    for table in dht.tables() {
704                        let index = table.Th();
705                        if index > 1 {
706                            bail!(
707                                "DHT Th={}, only 0 and 1 are allowed for baseline JPEGs",
708                                table.Th()
709                            );
710                        }
711
712                        let class = match table.Tc() {
713                            class @ (0 | 1) => class,
714                            err => bail!("invalid table class Tc={err} (only 0 and 1 are valid)"),
715                        };
716
717                        let index = (index << 1) | class;
718                        let data = TableData::build(table.Li(), table.Vij());
719                        huffman_tables[usize::from(index)] = data;
720                    }
721                }
722                SegmentKind::DRI(dri) => {
723                    // FIXME: add some checks here, we probably should have a maximum Ri value?
724                    ri = Some(dri.Ri() as u32);
725                }
726                SegmentKind::SOS(sos) => {
727                    if sos.Ss() != 0 || sos.Se() != 63 || sos.Ah() != 0 || sos.Al() != 0 {
728                        bail!("non-baseline scan header");
729                    }
730
731                    let Some(component_indices) = component_indices else {
732                        bail!("SOS not preceded by SOF header");
733                    };
734
735                    match sos.components() {
736                        [y,u,v] => {
737                            log::trace!("scan components:");
738                            log::trace!("- {:?}", y);
739                            log::trace!("- {:?}", u);
740                            log::trace!("- {:?}", v);
741
742                            let scan_indices = [y.Csj(), u.Csj(), v.Csj()];
743                            if component_indices != scan_indices {
744                                bail!("scan component index mismatch (expected component order {:?}, got {:?})", component_indices, scan_indices);
745                            }
746
747                            component_dchuff = [y.Tdj(), u.Tdj(), v.Tdj()];
748                            component_achuff = [y.Taj(), u.Taj(), v.Taj()];
749                        }
750                        _ => bail!("scan with {} components not supported (only 3 components are supported)", sos.components().len()),
751                    }
752
753                    scan_data = Some((sos.data_offset(), sos.data().len()));
754                }
755                _ => {}
756            }
757        }
758
759        #[rustfmt::skip]
760        let (
761            Some((width, height)),
762            Some(components),
763            Some((scan_data_offset, scan_data_len)),
764        ) = (size, components, scan_data) else {
765            bail!("missing SOS/SOI marker");
766        };
767
768        let dus_per_mcu = components
769            .iter()
770            .map(|c| c.Hi() * c.Vi())
771            .sum::<u8>()
772            .into();
773
774        let max_hsample = components.iter().map(|c| c.Hi()).max().unwrap().into();
775        let max_vsample = components.iter().map(|c| c.Vi()).max().unwrap().into();
776        let width_dus = u32::from((width + 7) / 8);
777        let height_dus = u32::from((height + 7) / 8);
778        let width_mcus = (width_dus + max_hsample - 1) / max_hsample; // (round up)
779        let height_mcus = (height_dus + max_vsample - 1) / max_vsample; // (round up)
780
781        log::trace!("max Hi={} Vi={}", max_hsample, max_vsample);
782        log::trace!("width={width} height={height} width_dus={width_dus} height_dus={height_dus} width_mcus={width_mcus} height_mcus={height_mcus}");
783
784        let ri = ri.unwrap_or(height_mcus * width_mcus);
785        let total_restart_intervals = height_mcus * width_mcus / ri;
786
787        if total_restart_intervals > Decoder::MAX_RESTART_INTERVALS {
788            bail!(
789                "number of restart intervals exceeds limit ({} > {})",
790                total_restart_intervals,
791                Decoder::MAX_RESTART_INTERVALS,
792            );
793        }
794
795        let metadata = metadata::Metadata {
796            restart_interval: ri,
797            qtables,
798            components: [0, 1, 2].map(|i| metadata::Component {
799                hsample: components[i].Hi().into(),
800                vsample: components[i].Vi().into(),
801                qtable: components[i].Tqi().into(),
802                dchuff: u32::from(component_dchuff[i] << 1),
803                achuff: u32::from((component_achuff[i] << 1) | 1),
804            }),
805            total_restart_intervals,
806            width_mcus,
807            max_hsample,
808            max_vsample,
809            dus_per_mcu,
810            retained_coefficients: metadata::DEFAULT_RETAINED_COEFFICIENTS,
811        };
812
813        let huffman_tables = HuffmanTables::new(huffman_tables);
814
815        Ok(Self {
816            metadata,
817            width,
818            height,
819            huffman_tables,
820            jpeg,
821            scan_data_offset,
822            scan_data_len,
823        })
824    }
825
826    /// Returns the width of the image in pixels.
827    #[inline]
828    pub fn width(&self) -> u32 {
829        self.width.into()
830    }
831
832    /// Returns the height of the image in pixels.
833    #[inline]
834    pub fn height(&self) -> u32 {
835        self.height.into()
836    }
837
838    /// Returns the total parallelism this JPEG permits.
839    ///
840    /// This number indicates how many parts of the image can be processed in parallel. It is
841    /// crucial for performance that this number is as high as possible. If it is below 10000, it is
842    /// likely faster to use a CPU-based decoder instead.
843    #[inline]
844    pub fn parallelism(&self) -> u32 {
845        self.metadata.total_restart_intervals
846    }
847
848    fn scan_data(&self) -> &[u8] {
849        &self.jpeg[self.scan_data_offset..][..self.scan_data_len]
850    }
851}