Skip to main content

iori_ssa/
lib.rs

1mod constant;
2mod error;
3pub use error::*;
4
5use aes::cipher::{BlockDecryptMut, KeyIvInit};
6use memchr::memmem;
7use mpeg2ts::es::StreamType;
8use mpeg2ts::pes::PesHeader;
9use mpeg2ts::ts::{
10    payload::{Bytes, Pes},
11    ContinuityCounter, ReadTsPacket, TransportScramblingControl, TsHeader, TsPacket,
12    TsPacketReader, TsPacketWriter, TsPayload, WriteTsPacket,
13};
14use std::collections::HashMap;
15use std::io::{BufRead, BufReader, BufWriter, Read, Write};
16
17pub struct NALUnit {
18    data: Vec<u8>,
19    pub r#type: u8,
20    length: usize,
21    start_code_length: u8,
22}
23
24impl NALUnit {
25    pub fn get_next(input: &[u8]) -> Result<(Self, &[u8])> {
26        let start_code_length = if input.len() > 4 && &input[0..4] == b"\x00\x00\x00\x01" {
27            4
28        } else if input.len() > 3 && &input[0..3] == b"\x00\x00\x01" {
29            3
30        } else {
31            return Err(Error::InvalidStartCode);
32        };
33
34        let next = &input[start_code_length..];
35        let next_pos = if let Some(pos) = memmem::find(next, b"\x00\x00\x01") {
36            // check pos - 1 for 0x00
37            if pos > 0 && next[pos - 1] == 0x00 {
38                start_code_length + pos - 1
39            } else {
40                start_code_length + pos
41            }
42        } else {
43            input.len()
44        };
45        let next = &input[next_pos..];
46
47        let data = input[start_code_length..next_pos].to_vec();
48        Ok((
49            Self {
50                r#type: data[0] & 0x1f,
51                data,
52                length: next_pos - start_code_length,
53                start_code_length: start_code_length as u8,
54            },
55            next,
56        ))
57    }
58
59    fn remove_scep_3_bytes(&mut self) {
60        let mut i = 0;
61        let mut j = 0;
62
63        while i < self.length {
64            if self.length - i > 3 && self.data[i..i + 3] == [0x00, 0x00, 0x03] {
65                self.data[j] = 0x00;
66                self.data[j + 1] = 0x00;
67                i += 3;
68                j += 2;
69            } else {
70                self.data[j] = self.data[i];
71                i += 1;
72                j += 1;
73            }
74        }
75
76        self.data.truncate(j);
77        self.length = j;
78    }
79
80    /// Encrypted_nal_unit () {
81    ///     nal_unit_type_byte                // 1 byte
82    ///     unencrypted_leader                // 31 bytes
83    ///     while (bytes_remaining() > 0) {
84    ///         if (bytes_remaining() > 16) {
85    ///             encrypted_block           // 16 bytes
86    ///         }
87    ///         unencrypted_block           // MIN(144, bytes_remaining()) bytes
88    ///     }
89    /// }
90    pub fn decrypt(&mut self, key: &[u8; 16], iv: &[u8; 16]) {
91        if self.data.len() <= 48 {
92            return;
93        }
94
95        self.remove_scep_3_bytes();
96
97        let mut decryptor = cbc::Decryptor::<aes::Aes128>::new(key.into(), iv.into());
98
99        if self.data.len() < 32 {
100            return;
101        }
102
103        let mut pos = &mut self.data.as_mut_slice()[32..];
104
105        while !pos.is_empty() {
106            if pos.len() > 16 {
107                let block = &mut pos[..16];
108                decryptor.decrypt_block_mut(block.into());
109                pos = &mut pos[16..];
110            }
111
112            let remaining_len = pos.len();
113            pos = &mut pos[144.min(remaining_len)..];
114        }
115    }
116
117    pub fn write<W: Write>(&self, output: &mut W) -> Result<()> {
118        if self.start_code_length == 4 {
119            output.write_all(&[0x00, 0x00, 0x00, 0x01])?;
120        } else {
121            output.write_all(&[0x00, 0x00, 0x01])?;
122        }
123        output.write_all(&self.data)?;
124
125        Ok(())
126    }
127}
128
129struct AdtsHeader {
130    length: usize,
131    crc: bool,
132}
133
134impl AdtsHeader {
135    fn new(data: &[u8]) -> Self {
136        Self {
137            length: Self::read_adts_frame_length(data),
138            // Protection absence, set to 1 if there is no CRC and 0 if there is CRC.
139            crc: data[1] & 0x01 == 0,
140        }
141    }
142
143    fn data<'a>(&self, input: &'a mut [u8]) -> &'a mut [u8] {
144        &mut input[if self.crc { 9 } else { 7 }..self.length]
145    }
146
147    fn read_adts_frame_length(header: &[u8]) -> usize {
148        // 2
149        let byte3 = header[3] as u16;
150        // 8
151        let byte4 = header[4] as u16;
152        // 3
153        let byte5 = header[5] as u16;
154
155        // Extract and combine bits
156        let length = ((byte3 & 0b11) << 11) | (byte4 << 3) | (byte5 >> 5);
157        length as usize
158    }
159}
160
161struct Ac3Header {
162    length: usize,
163}
164
165impl Ac3Header {
166    fn new(data: &[u8]) -> Self {
167        Self {
168            length: Self::read_ac3_frame_length(data),
169        }
170    }
171
172    fn data<'a>(&self, input: &'a mut [u8]) -> &'a mut [u8] {
173        &mut input[..self.length]
174    }
175
176    fn read_ac3_frame_length(header: &[u8]) -> usize {
177        let fscod = (header[4] >> 6) as usize;
178        let frmsizcod = (header[4] & 0b111111) as usize;
179        // the number of (2-byte) words before the next syncword
180        let frame_size = constant::AC3_FRAME_SIZE_CODE_TABLE[frmsizcod][fscod];
181        frame_size * 2
182    }
183}
184
185struct Eac3Header {
186    length: usize,
187}
188
189impl Eac3Header {
190    fn new(data: &[u8]) -> Self {
191        Self {
192            length: Self::read_eac3_frame_length(data),
193        }
194    }
195
196    fn data<'a>(&self, input: &'a mut [u8]) -> &'a mut [u8] {
197        &mut input[..self.length]
198    }
199
200    fn read_eac3_frame_length(header: &[u8]) -> usize {
201        let frame_size =
202            1 + ((((header[2] as usize) << 8) | header[3] as usize) & 0b0000011111111111);
203        frame_size * 2
204    }
205}
206
207/// Encrypted_AAC_Frame () {
208///     ADTS_Header                        // 7 or 9 bytes
209///     unencrypted_leader                 // 16 bytes
210///     while (bytes_remaining() >= 16) {
211///         encrypted_block                // 16 bytes
212///     }
213///     unencrypted_trailer                // 0-15 bytes
214/// }
215fn decrypt_aac_frame(input: &mut [u8], key: [u8; 16], iv: [u8; 16]) -> usize {
216    let adts = AdtsHeader::new(input);
217    let data = adts.data(input);
218
219    decrypt_raw_sample(data, key, iv);
220    adts.length
221}
222
223/// Encrypted_AC3_Frame () {
224///     unencrypted_leader                 // 16 bytes
225///     while (bytes_remaining() >= 16) {
226///         encrypted_block                // 16 bytes
227///     }
228///     unencrypted_trailer                // 0-15 bytes
229/// }
230fn decrypt_ac3_frame(input: &mut [u8], key: [u8; 16], iv: [u8; 16]) -> usize {
231    let ac3 = Ac3Header::new(input);
232    let data = ac3.data(input);
233
234    decrypt_raw_sample(data, key, iv);
235    ac3.length
236}
237
238/// Encrypted_Enhanced_AC3_syncframe () {
239///     unencrypted_leader                 // 16 bytes
240///     while (bytes_remaining() >= 16) {
241///         encrypted_block                // 16 bytes
242///     }
243///     unencrypted_trailer                // 0-15 bytes
244/// }
245fn decrypt_eac3_frame(input: &mut [u8], key: [u8; 16], iv: [u8; 16]) -> usize {
246    let eac3 = Eac3Header::new(input);
247    let data = eac3.data(input);
248
249    decrypt_raw_sample(data, key, iv);
250    eac3.length
251}
252
253fn decrypt_raw_sample(input: &mut [u8], key: [u8; 16], iv: [u8; 16]) {
254    let mut decryptor = cbc::Decryptor::<aes::Aes128>::new(&key.into(), &iv.into());
255
256    let mut is_first = true;
257    let chunks = input.chunks_mut(16);
258    for chunk in chunks {
259        if chunk.len() < 16 || is_first {
260            is_first = false;
261            continue;
262        }
263        decryptor.decrypt_block_mut(chunk.into());
264    }
265}
266
267struct PESSegment {
268    stream_type: StreamType,
269
270    pes_ts_header: TsHeader,
271    pes_header: PesHeader,
272    pes_packet_len: u16,
273    initial_size: usize,
274
275    data: Vec<u8>,
276    data_packet_num: usize,
277}
278
279impl PESSegment {
280    fn decrypt_and_write<W: Write>(
281        mut self,
282        key: [u8; 16],
283        iv: [u8; 16],
284        writer: &mut IoriTsPacketWriter<W>,
285    ) -> Result<()> {
286        // do decrypt first
287        match self.stream_type {
288            // avc
289            StreamType::H264 | StreamType::H264WithAes128Cbc => self.decrypt_video(key, iv)?,
290            // adts
291            StreamType::AdtsAac
292            | StreamType::AdtsAacWithAes128Cbc
293            // ac3
294            | StreamType::DolbyDigitalUpToSixChannelAudio
295            | StreamType::DolbyDigitalUpToSixChannelAudioWithAes128Cbc
296            // eac3
297            | StreamType::DolbyDigitalPlusUpTo16ChannelAudio
298            | StreamType::DolbyDigitalPlusUpToSixChannelAudioWithAes128Cbc => {
299                self.decrypt_audio(key, iv)
300            }
301            _ => unreachable!("Unsupported stream type: {:?}", self.stream_type),
302        }
303
304        let pid = self.pes_ts_header.pid;
305
306        // split data into PES packets and write
307        // max TS packet size is 188
308        let mut input = self.data.as_slice();
309        let initial_size = input.len().min(self.initial_size);
310        writer.write_packet(&mut TsPacket {
311            header: self.pes_ts_header,
312            adaptation_field: None,
313            payload: Some(TsPayload::Pes(Pes {
314                header: self.pes_header,
315                pes_packet_len: self.pes_packet_len,
316                data: Bytes::new(&self.data[..initial_size])?,
317            })),
318        })?;
319
320        input = &input[initial_size..];
321        let mut remaining_packets = self.data_packet_num;
322
323        while !input.is_empty() {
324            // We need to make sure the total count of packets not change after decryption
325            let size = input.len() / remaining_packets;
326            let data = &input[..size];
327            input = &input[size..];
328
329            let mut packet = TsPacket {
330                header: TsHeader {
331                    pid,
332                    transport_scrambling_control: TransportScramblingControl::NotScrambled,
333                    transport_error_indicator: false,
334                    transport_priority: false,
335                    continuity_counter: ContinuityCounter::new(), // will be set by writer
336                },
337                adaptation_field: None,
338                // SAFETY: unwrap here is safe because we know the data length <= Bytes::MAX_SIZE
339                payload: Some(TsPayload::Raw(Bytes::new(data).unwrap())),
340            };
341            writer.write_packet(&mut packet)?;
342
343            remaining_packets -= 1;
344        }
345
346        Ok(())
347    }
348
349    fn decrypt_video(&mut self, key: [u8; 16], iv: [u8; 16]) -> Result<()> {
350        let mut input = self.data.as_slice();
351        let output = Vec::with_capacity(self.data.len() * 2);
352        let mut output = BufWriter::new(output);
353
354        loop {
355            let (mut nal_unit, data_new) = NALUnit::get_next(input)?;
356            input = data_new;
357
358            if nal_unit.r#type == 5 || nal_unit.r#type == 1 {
359                nal_unit.decrypt(&key, &iv);
360            }
361
362            nal_unit.write(&mut output)?;
363
364            if input.is_empty() {
365                break;
366            }
367        }
368
369        self.data = output.into_inner().map_err(|e| e.into_error())?;
370
371        Ok(())
372    }
373
374    fn decrypt_audio(&mut self, key: [u8; 16], iv: [u8; 16]) {
375        let mut input = self.data.as_mut_slice();
376        while !input.is_empty() {
377            match self.stream_type {
378                // adts
379                StreamType::AdtsAac | StreamType::AdtsAacWithAes128Cbc => {
380                    let size = decrypt_aac_frame(input, key, iv);
381                    input = &mut input[size..];
382                }
383                // ac3
384                StreamType::DolbyDigitalUpToSixChannelAudio
385                | StreamType::DolbyDigitalUpToSixChannelAudioWithAes128Cbc => {
386                    let size = decrypt_ac3_frame(input, key, iv);
387                    input = &mut input[size..];
388                }
389                // eac3
390                StreamType::DolbyDigitalPlusUpTo16ChannelAudio
391                | StreamType::DolbyDigitalPlusUpToSixChannelAudioWithAes128Cbc => {
392                    let size = decrypt_eac3_frame(input, key, iv);
393                    input = &mut input[size..];
394                }
395                _ => unimplemented!("Unsupported stream type: {:?}", self.stream_type),
396            }
397        }
398    }
399}
400
401struct IoriTsPacketWriter<W> {
402    inner: TsPacketWriter<W>,
403    counters: HashMap<u16, ContinuityCounter>,
404}
405
406impl<W: Write> IoriTsPacketWriter<W> {
407    fn new(inner: W) -> Self {
408        Self {
409            inner: TsPacketWriter::new(inner),
410            counters: HashMap::new(),
411        }
412    }
413
414    fn get_counter(
415        &mut self,
416        pid: u16,
417        default_counter: ContinuityCounter,
418    ) -> &mut ContinuityCounter {
419        self.counters.entry(pid).or_insert(default_counter)
420    }
421
422    fn write_packet(&mut self, packet: &mut TsPacket) -> mpeg2ts::Result<()> {
423        let counter =
424            self.get_counter(packet.header.pid.as_u16(), packet.header.continuity_counter);
425        packet.header.continuity_counter = *counter;
426
427        if !matches!(packet.payload, None | Some(TsPayload::Null(_))) {
428            counter.increment();
429        }
430
431        self.inner.write_ts_packet(packet)
432    }
433}
434
435fn should_decrypt_stream(id_map: &HashMap<u16, StreamType>, pid: u16) -> bool {
436    let stream_type = id_map.get(&pid);
437
438    match stream_type {
439        Some(
440            // avc
441            StreamType::H264WithAes128Cbc
442            | StreamType::H264
443            // adts
444            | StreamType::AdtsAacWithAes128Cbc
445            | StreamType::AdtsAac
446            // ac3
447            | StreamType::DolbyDigitalUpToSixChannelAudioWithAes128Cbc
448            | StreamType::DolbyDigitalUpToSixChannelAudio
449            // eac3
450            | StreamType::DolbyDigitalPlusUpToSixChannelAudioWithAes128Cbc
451            | StreamType::DolbyDigitalPlusUpTo16ChannelAudio,
452        ) => true,
453        _ => false,
454    }
455}
456
457pub fn decrypt_mpegts<R, W>(input: R, output: W, key: [u8; 16], iv: [u8; 16]) -> Result<()>
458where
459    R: Read,
460    W: Write,
461{
462    let mut reader = TsPacketReader::new(input);
463    let mut writer = IoriTsPacketWriter::new(output);
464
465    let mut streams = HashMap::new();
466    let mut pid_map = HashMap::new();
467
468    while let Ok(Some(TsPacket {
469        header,
470        adaptation_field,
471        payload,
472    })) = reader.read_ts_packet()
473    {
474        if let Some(payload) = payload {
475            // do not flush after receiving the following payloads
476            let flush = if matches!(
477                payload,
478                // PES is the start of a new stream
479                TsPayload::Pes(_) |
480                // RAW is part of the current stream
481                TsPayload::Raw(_) |
482                // NULL is just placeholder, no need to flush
483                TsPayload::Null(_)
484            ) {
485                None
486            } else {
487                Some(header.pid)
488            };
489
490            match payload {
491                TsPayload::Pmt(mut pmt) => {
492                    // modify from encrypted to clear stream
493                    for es in pmt.es_info.iter_mut() {
494                        // save stream type before modify
495                        pid_map.insert(es.elementary_pid.as_u16(), es.stream_type);
496
497                        // map stream types to its unencrypted version
498                        es.stream_type = match es.stream_type {
499                            StreamType::H264WithAes128Cbc => StreamType::H264,
500                            StreamType::AdtsAacWithAes128Cbc => StreamType::AdtsAac,
501                            StreamType::DolbyDigitalUpToSixChannelAudioWithAes128Cbc => {
502                                StreamType::DolbyDigitalUpToSixChannelAudio
503                            }
504                            StreamType::DolbyDigitalPlusUpToSixChannelAudioWithAes128Cbc => {
505                                StreamType::DolbyDigitalPlusUpTo16ChannelAudio
506                            }
507                            _ => es.stream_type,
508                        };
509                    }
510                    writer.write_packet(&mut TsPacket {
511                        header,
512                        adaptation_field,
513                        payload: Some(TsPayload::Pmt(pmt)),
514                    })?;
515                }
516                // only decrypt stream that should be decrypted
517                TsPayload::Pes(pes) if should_decrypt_stream(&pid_map, header.pid.as_u16()) => {
518                    let stream_type = pid_map.get(&header.pid.as_u16());
519
520                    let prev_pes = streams.insert(
521                        header.pid,
522                        PESSegment {
523                            // SAFETY: we know the stream type is valid
524                            stream_type: *stream_type.unwrap(),
525
526                            pes_ts_header: header,
527                            pes_header: pes.header,
528                            pes_packet_len: pes.pes_packet_len,
529                            initial_size: pes.data.len(),
530                            data: pes.data.to_vec(),
531                            data_packet_num: 0,
532                        },
533                    );
534
535                    if let Some(pes) = prev_pes {
536                        pes.decrypt_and_write(key, iv, &mut writer)?;
537                    }
538                }
539                TsPayload::Raw(bytes) if streams.contains_key(&header.pid) => {
540                    // SAFETY: We've validated the stream exist in streams
541                    let pes = streams.get_mut(&header.pid).unwrap();
542                    pes.data_packet_num += 1;
543                    pes.data.extend_from_slice(&bytes);
544                }
545                // for other payload, just write it without modification
546                _ => writer.write_packet(&mut TsPacket {
547                    header,
548                    adaptation_field,
549                    payload: Some(payload),
550                })?,
551            }
552
553            if let Some(flush) = flush {
554                if let Some(pes) = streams.remove(&flush) {
555                    pes.decrypt_and_write(key, iv, &mut writer)?;
556                };
557            }
558        }
559    }
560
561    // handle remaining streams
562    for pes in streams.into_values() {
563        pes.decrypt_and_write(key, iv, &mut writer)?;
564    }
565
566    Ok(())
567}
568
569enum AudioSetupType {
570    /// AAC-LC
571    AacLc,
572    /// AAC-HEv1
573    AacHeV1,
574    /// AAC-HEv2
575    AacHeV2,
576    /// AC-3
577    Ac3,
578    /// Enhanced AC-3
579    EnhancedAc3,
580}
581
582pub fn decrypt<R, W>(input: R, mut output: W, key: [u8; 16], iv: [u8; 16]) -> Result<()>
583where
584    R: Read,
585    W: Write,
586{
587    let mut input = BufReader::new(input);
588    let magic = input.fill_buf()?;
589
590    if magic.is_empty() {
591        return Ok(());
592    }
593
594    // MPEG-TS
595    if magic[0] == 0x47 {
596        return decrypt_mpegts(input, output, key, iv);
597    }
598
599    let mut audio_format = None;
600    let mut is_id3 = &magic[0..3] == b"ID3";
601    while is_id3 {
602        #[allow(deprecated)]
603        let tag = id3::Tag::read_from(&mut input)?;
604        tag.write_to(&mut output, tag.version())?;
605
606        // In elementary streams the audio setup information is carried inside an ID3 Private Frame, as defined in ID3 tag version 2.4.0.
607        // The owner identifier is com.apple.streaming.audioDescription.
608        let format = tag.frames().find(|f| f.id() == "PRIV").and_then(|p| {
609            if let id3::Content::Private(p) = p.content() {
610                if p.owner_identifier == "com.apple.streaming.audioDescription" {
611                    // audio_setup_information() {
612                    //     audio_type               // 4 bytes
613                    //     priming                  // 2 bytes
614                    //     version                  // 1 byte
615                    //     setup_data_length        // 1 byte
616                    //     setup_data               // setup_data_length
617                    // }
618                    let data = &p.private_data;
619                    if data.len() >= 4 {
620                        let format = &data[0..4];
621                        return match format {
622                            b"zaac" => Some(AudioSetupType::AacLc),
623                            b"zach" => Some(AudioSetupType::AacHeV1),
624                            b"zacp" => Some(AudioSetupType::AacHeV2),
625                            b"zac3" => Some(AudioSetupType::Ac3),
626                            b"zec3" => Some(AudioSetupType::EnhancedAc3),
627                            _ => None,
628                        };
629                    }
630                }
631            }
632
633            None
634        });
635
636        if let Some(format) = format {
637            audio_format = Some(format);
638        }
639
640        let magic = input.fill_buf()?;
641        is_id3 = magic.len() >= 3 && &magic[0..3] == b"ID3";
642    }
643
644    let Some(audio_format) = audio_format else {
645        return Ok(());
646    };
647
648    let mut buf = Vec::new();
649    input.read_to_end(&mut buf)?;
650
651    let mut data = &mut buf[..];
652    loop {
653        if data.is_empty() {
654            break;
655        }
656
657        let size = match audio_format {
658            AudioSetupType::AacLc | AudioSetupType::AacHeV1 | AudioSetupType::AacHeV2 => {
659                decrypt_aac_frame(data, key, iv)
660            }
661            AudioSetupType::Ac3 => decrypt_ac3_frame(data, key, iv),
662            AudioSetupType::EnhancedAc3 => decrypt_eac3_frame(data, key, iv),
663        };
664
665        let decrypted = &data[..size];
666        output.write_all(decrypted)?;
667
668        data = &mut data[size..];
669    }
670
671    Ok(())
672}