Skip to main content

media_codec_video_toolbox/
decoder.rs

1use std::{
2    collections::VecDeque,
3    sync::{Arc, LazyLock, Mutex},
4};
5
6use core_media::{
7    block_buffer::CMBlockBuffer,
8    format_description::{
9        kCMVideoCodecType_AV1, kCMVideoCodecType_H264, kCMVideoCodecType_HEVC, kCMVideoCodecType_VP9, CMVideoFormatDescription, TCMFormatDescription,
10    },
11    sample_buffer::{CMSampleBuffer, CMSampleTimingInfo},
12    time::CMTime,
13};
14use core_video::pixel_buffer::CVPixelBuffer;
15use media_codec_types::{
16    codec::{Codec, CodecID},
17    decoder::{Decoder, DecoderBuilder, ExtraData, VideoDecoder, VideoDecoderParameters},
18    packet::Packet,
19    CodecBuilder, CodecInformation, CodecParameters,
20};
21use media_core::{
22    error::Error, failed_error, frame::SharedFrame, frame_pool::FramePool, not_found_error, rational::Rational64, unsupported_error,
23    variant::Variant, video::VideoFrame, Result,
24};
25use os_ver::if_greater_than;
26use video_toolbox::{decompression_session::VTDecompressionSession, errors::VTDecodeFrameFlags};
27
28const CODEC_NAME: &str = "video-toolbox";
29
30pub struct VTDecoder {
31    id: CodecID,
32    session: VTDecompressionSession,
33    format_desc: CMVideoFormatDescription,
34    output_queue: Arc<Mutex<VecDeque<SharedFrame<VideoFrame<'static>>>>>,
35}
36
37unsafe impl Send for VTDecoder {}
38unsafe impl Sync for VTDecoder {}
39
40impl CodecInformation for VTDecoder {
41    fn id(&self) -> CodecID {
42        self.id
43    }
44
45    fn name(&self) -> &'static str {
46        CODEC_NAME
47    }
48}
49
50impl Codec<VideoDecoder> for VTDecoder {
51    fn configure(&mut self, _params: Option<&CodecParameters>, _options: Option<&Variant>) -> Result<()> {
52        Ok(())
53    }
54
55    fn set_option(&mut self, _key: &str, _value: &Variant) -> Result<()> {
56        Ok(())
57    }
58}
59
60impl Decoder<VideoDecoder> for VTDecoder {
61    fn send_packet(&mut self, _config: &VideoDecoder, _pool: Option<&Arc<FramePool<VideoFrame<'static>>>>, packet: &Packet) -> Result<()> {
62        let data = packet.data();
63        if data.is_empty() {
64            return Ok(());
65        }
66
67        let block_buffer = unsafe {
68            CMBlockBuffer::new_with_memory_block(None, data.len(), None, 0, data.len(), 0)
69                .map_err(|_| Error::CreationFailed("CMBlockBuffer".into()))?
70        };
71
72        block_buffer.replace_data_bytes(data, 0).map_err(|err| failed_error!(err))?;
73
74        let pts = packet.pts.unwrap_or(0);
75        let dts = packet.dts.unwrap_or(pts);
76        let duration = packet.duration.unwrap_or(0);
77        let time_base = packet.time_base.unwrap_or(Rational64::new(1, 1_000_000));
78
79        let timing_info = CMSampleTimingInfo {
80            duration: CMTime::make(duration * time_base.numer(), *time_base.denom() as i32),
81            presentationTimeStamp: CMTime::make(pts * time_base.numer(), *time_base.denom() as i32),
82            decodeTimeStamp: CMTime::make(dts * time_base.numer(), *time_base.denom() as i32),
83        };
84
85        let format_desc = &self.format_desc.as_format_description();
86        let sample_buffer = unsafe {
87            CMSampleBuffer::new(Some(&block_buffer), true, None, None, Some(format_desc), 1, Some(&[timing_info]), Some(&[data.len()]))
88                .map_err(|_| Error::CreationFailed("CMSampleBuffer".into()))?
89        };
90
91        let queue = Arc::clone(&self.output_queue);
92        self.session
93            .decode_frame_with_closure(sample_buffer, VTDecodeFrameFlags::empty(), move |status, _flags, image_buffer, pts, duration| {
94                if status == 0 {
95                    if let Some(pixel_buffer) = image_buffer.downcast::<CVPixelBuffer>() {
96                        if let Ok(mut video_frame) = VideoFrame::from_pixel_buffer(&pixel_buffer) {
97                            video_frame.pts = Some(pts.value);
98                            video_frame.dts = None;
99                            video_frame.duration = Some(duration.value);
100
101                            let shared_frame: SharedFrame<VideoFrame<'static>> = SharedFrame::<VideoFrame<'static>>::new(video_frame);
102                            if let Ok(mut queue) = queue.lock() {
103                                queue.push_back(shared_frame);
104                            }
105                        }
106                    }
107                }
108            })
109            .map_err(|err| failed_error!("decode frame", err))?;
110
111        Ok(())
112    }
113
114    fn receive_frame(
115        &mut self,
116        _config: &VideoDecoder,
117        _pool: Option<&Arc<FramePool<VideoFrame<'static>>>>,
118    ) -> Result<SharedFrame<VideoFrame<'static>>> {
119        let mut queue = self.output_queue.lock().map_err(|err| failed_error!(err))?;
120
121        if let Some(frame) = queue.pop_front() {
122            Ok(frame)
123        } else {
124            Err(Error::Again("no frame available".into()))
125        }
126    }
127
128    fn flush(&mut self, _config: &VideoDecoder) -> Result<()> {
129        // Wait for all async frames
130        self.session.wait_for_asynchronous_frames().map_err(|err| failed_error!("wait for frames", err))?;
131
132        // Clear output queue
133        if let Ok(mut queue) = self.output_queue.lock() {
134            queue.clear();
135        }
136
137        Ok(())
138    }
139}
140
141impl VTDecoder {
142    pub fn new(id: CodecID, params: &VideoDecoderParameters, _options: Option<&Variant>) -> Result<Self> {
143        let format_desc = Self::create_format_description(id, params)?;
144        let output_queue = Arc::new(Mutex::new(VecDeque::new()));
145
146        let session =
147            VTDecompressionSession::new(format_desc.clone(), None, None).map_err(|_| Error::CreationFailed("VTDecompressionSession".into()))?;
148
149        Ok(Self {
150            id,
151            session,
152            format_desc,
153            output_queue,
154        })
155    }
156
157    fn codec_id_to_type(id: CodecID) -> Result<u32> {
158        match id {
159            CodecID::H264 => Ok(kCMVideoCodecType_H264),
160            CodecID::HEVC => Ok(kCMVideoCodecType_HEVC),
161            CodecID::VP9 => Ok(kCMVideoCodecType_VP9),
162            CodecID::AV1 => Ok(kCMVideoCodecType_AV1),
163            _ => Err(unsupported_error!("codec", id)),
164        }
165    }
166
167    fn get_nalu_length_size(extra_data: &Option<ExtraData>) -> u8 {
168        match extra_data {
169            Some(ExtraData::AVC {
170                nalu_length_size, ..
171            }) => *nalu_length_size,
172            Some(ExtraData::HEVC {
173                nalu_length_size, ..
174            }) => *nalu_length_size,
175            _ => 0,
176        }
177    }
178
179    fn create_format_description(id: CodecID, params: &VideoDecoderParameters) -> Result<CMVideoFormatDescription> {
180        let codec_type = Self::codec_id_to_type(id)?;
181        let width = params.video.width.ok_or_else(|| not_found_error!("width"))?;
182        let height = params.video.height.ok_or_else(|| not_found_error!("height"))?;
183
184        match &params.decoder.extra_data {
185            Some(ExtraData::AVC {
186                sps,
187                pps,
188                ..
189            }) => {
190                // Get first SPS and PPS
191                if sps.is_empty() || pps.is_empty() {
192                    return Err(not_found_error!("SPS or PPS"));
193                }
194                let sps_slice: &[u8] = &sps[0];
195                let pps_slice: &[u8] = &pps[0];
196                CMVideoFormatDescription::from_h264_parameter_sets(
197                    &[sps_slice, pps_slice],
198                    Self::get_nalu_length_size(&params.decoder.extra_data) as i32,
199                )
200                .map_err(|_| Error::CreationFailed("CMVideoFormatDescription".into()))
201            }
202            Some(ExtraData::HEVC {
203                vps,
204                sps,
205                pps,
206                ..
207            }) => {
208                // Get first SPS and PPS
209                if sps.is_empty() || pps.is_empty() {
210                    return Err(not_found_error!("SPS or PPS"));
211                }
212                let mut parameter_sets: Vec<&[u8]> = Vec::new();
213                if let Some(vps_vec) = vps {
214                    if !vps_vec.is_empty() {
215                        parameter_sets.push(&vps_vec[0]);
216                    }
217                }
218                parameter_sets.push(&sps[0]);
219                parameter_sets.push(&pps[0]);
220
221                CMVideoFormatDescription::from_hevc_parameter_sets(
222                    &parameter_sets,
223                    Self::get_nalu_length_size(&params.decoder.extra_data) as i32,
224                    None,
225                )
226                .map_err(|_| Error::CreationFailed("CMVideoFormatDescription".into()))
227            }
228            _ => CMVideoFormatDescription::new(codec_type, width.get() as i32, height.get() as i32, None)
229                .map_err(|_| Error::CreationFailed("CMVideoFormatDescription".into())),
230        }
231    }
232}
233
234pub struct VTDecoderBuilder;
235
236impl DecoderBuilder<VideoDecoder> for VTDecoderBuilder {
237    fn new_decoder(&self, codec_id: CodecID, params: &CodecParameters, options: Option<&Variant>) -> Result<Box<dyn Decoder<VideoDecoder>>> {
238        Ok(Box::new(VTDecoder::new(codec_id, &params.try_into()?, options)?))
239    }
240}
241
242impl CodecBuilder<VideoDecoder> for VTDecoderBuilder {
243    fn ids(&self) -> &'static [CodecID] {
244        &SUPPORTED_CODEC_IDS
245    }
246
247    fn name(&self) -> &'static str {
248        CODEC_NAME
249    }
250}
251
252static SUPPORTED_CODEC_IDS: LazyLock<Vec<CodecID>> = LazyLock::new(|| {
253    let mut ids = Vec::new();
254    if VTDecompressionSession::is_hardware_decode_supported(kCMVideoCodecType_H264) {
255        ids.push(CodecID::H264);
256    }
257    if_greater_than! {(10, 13) => {
258        if VTDecompressionSession::is_hardware_decode_supported(kCMVideoCodecType_HEVC) {
259            ids.push(CodecID::HEVC);
260        }
261    }}
262    if_greater_than! {(14, 0) => {
263        if VTDecompressionSession::is_hardware_decode_supported(kCMVideoCodecType_VP9) {
264            ids.push(CodecID::VP9);
265        }
266        if VTDecompressionSession::is_hardware_decode_supported(kCMVideoCodecType_AV1) {
267            ids.push(CodecID::AV1);
268        }
269    }}
270    ids
271});