1mod constant;
2mod error;
3pub use error::*;
4
5use aes::cipher::{BlockDecryptMut, KeyIvInit};
6use memchr::memmem;
7use mpeg2ts::es::StreamType;
8use mpeg2ts::pes::PesHeader;
9use mpeg2ts::ts::{
10 payload::{Bytes, Pes},
11 ContinuityCounter, ReadTsPacket, TransportScramblingControl, TsHeader, TsPacket,
12 TsPacketReader, TsPacketWriter, TsPayload, WriteTsPacket,
13};
14use std::collections::HashMap;
15use std::io::{BufRead, BufReader, BufWriter, Read, Write};
16
17pub struct NALUnit {
18 data: Vec<u8>,
19 pub r#type: u8,
20 length: usize,
21 start_code_length: u8,
22}
23
24impl NALUnit {
25 pub fn get_next(input: &[u8]) -> Result<(Self, &[u8])> {
26 let start_code_length = if input.len() > 4 && &input[0..4] == b"\x00\x00\x00\x01" {
27 4
28 } else if input.len() > 3 && &input[0..3] == b"\x00\x00\x01" {
29 3
30 } else {
31 return Err(Error::InvalidStartCode);
32 };
33
34 let next = &input[start_code_length..];
35 let next_pos = if let Some(pos) = memmem::find(next, b"\x00\x00\x01") {
36 if pos > 0 && next[pos - 1] == 0x00 {
38 start_code_length + pos - 1
39 } else {
40 start_code_length + pos
41 }
42 } else {
43 input.len()
44 };
45 let next = &input[next_pos..];
46
47 let data = input[start_code_length..next_pos].to_vec();
48 Ok((
49 Self {
50 r#type: data[0] & 0x1f,
51 data,
52 length: next_pos - start_code_length,
53 start_code_length: start_code_length as u8,
54 },
55 next,
56 ))
57 }
58
59 fn remove_scep_3_bytes(&mut self) {
60 let mut i = 0;
61 let mut j = 0;
62
63 while i < self.length {
64 if self.length - i > 3 && self.data[i..i + 3] == [0x00, 0x00, 0x03] {
65 self.data[j] = 0x00;
66 self.data[j + 1] = 0x00;
67 i += 3;
68 j += 2;
69 } else {
70 self.data[j] = self.data[i];
71 i += 1;
72 j += 1;
73 }
74 }
75
76 self.data.truncate(j);
77 self.length = j;
78 }
79
80 pub fn decrypt(&mut self, key: &[u8; 16], iv: &[u8; 16]) {
91 if self.data.len() <= 48 {
92 return;
93 }
94
95 self.remove_scep_3_bytes();
96
97 let mut decryptor = cbc::Decryptor::<aes::Aes128>::new(key.into(), iv.into());
98
99 if self.data.len() < 32 {
100 return;
101 }
102
103 let mut pos = &mut self.data.as_mut_slice()[32..];
104
105 while !pos.is_empty() {
106 if pos.len() > 16 {
107 let block = &mut pos[..16];
108 decryptor.decrypt_block_mut(block.into());
109 pos = &mut pos[16..];
110 }
111
112 let remaining_len = pos.len();
113 pos = &mut pos[144.min(remaining_len)..];
114 }
115 }
116
117 pub fn write<W: Write>(&self, output: &mut W) -> Result<()> {
118 if self.start_code_length == 4 {
119 output.write_all(&[0x00, 0x00, 0x00, 0x01])?;
120 } else {
121 output.write_all(&[0x00, 0x00, 0x01])?;
122 }
123 output.write_all(&self.data)?;
124
125 Ok(())
126 }
127}
128
129struct AdtsHeader {
130 length: usize,
131 crc: bool,
132}
133
134impl AdtsHeader {
135 fn new(data: &[u8]) -> Self {
136 Self {
137 length: Self::read_adts_frame_length(data),
138 crc: data[1] & 0x01 == 0,
140 }
141 }
142
143 fn data<'a>(&self, input: &'a mut [u8]) -> &'a mut [u8] {
144 &mut input[if self.crc { 9 } else { 7 }..self.length]
145 }
146
147 fn read_adts_frame_length(header: &[u8]) -> usize {
148 let byte3 = header[3] as u16;
150 let byte4 = header[4] as u16;
152 let byte5 = header[5] as u16;
154
155 let length = ((byte3 & 0b11) << 11) | (byte4 << 3) | (byte5 >> 5);
157 length as usize
158 }
159}
160
161struct Ac3Header {
162 length: usize,
163}
164
165impl Ac3Header {
166 fn new(data: &[u8]) -> Self {
167 Self {
168 length: Self::read_ac3_frame_length(data),
169 }
170 }
171
172 fn data<'a>(&self, input: &'a mut [u8]) -> &'a mut [u8] {
173 &mut input[..self.length]
174 }
175
176 fn read_ac3_frame_length(header: &[u8]) -> usize {
177 let fscod = (header[4] >> 6) as usize;
178 let frmsizcod = (header[4] & 0b111111) as usize;
179 let frame_size = constant::AC3_FRAME_SIZE_CODE_TABLE[frmsizcod][fscod];
181 frame_size * 2
182 }
183}
184
185struct Eac3Header {
186 length: usize,
187}
188
189impl Eac3Header {
190 fn new(data: &[u8]) -> Self {
191 Self {
192 length: Self::read_eac3_frame_length(data),
193 }
194 }
195
196 fn data<'a>(&self, input: &'a mut [u8]) -> &'a mut [u8] {
197 &mut input[..self.length]
198 }
199
200 fn read_eac3_frame_length(header: &[u8]) -> usize {
201 let frame_size =
202 1 + ((((header[2] as usize) << 8) | header[3] as usize) & 0b0000011111111111);
203 frame_size * 2
204 }
205}
206
207fn decrypt_aac_frame(input: &mut [u8], key: [u8; 16], iv: [u8; 16]) -> usize {
216 let adts = AdtsHeader::new(input);
217 let data = adts.data(input);
218
219 decrypt_raw_sample(data, key, iv);
220 adts.length
221}
222
223fn decrypt_ac3_frame(input: &mut [u8], key: [u8; 16], iv: [u8; 16]) -> usize {
231 let ac3 = Ac3Header::new(input);
232 let data = ac3.data(input);
233
234 decrypt_raw_sample(data, key, iv);
235 ac3.length
236}
237
238fn decrypt_eac3_frame(input: &mut [u8], key: [u8; 16], iv: [u8; 16]) -> usize {
246 let eac3 = Eac3Header::new(input);
247 let data = eac3.data(input);
248
249 decrypt_raw_sample(data, key, iv);
250 eac3.length
251}
252
253fn decrypt_raw_sample(input: &mut [u8], key: [u8; 16], iv: [u8; 16]) {
254 let mut decryptor = cbc::Decryptor::<aes::Aes128>::new(&key.into(), &iv.into());
255
256 let mut is_first = true;
257 let chunks = input.chunks_mut(16);
258 for chunk in chunks {
259 if chunk.len() < 16 || is_first {
260 is_first = false;
261 continue;
262 }
263 decryptor.decrypt_block_mut(chunk.into());
264 }
265}
266
267struct PESSegment {
268 stream_type: StreamType,
269
270 pes_ts_header: TsHeader,
271 pes_header: PesHeader,
272 pes_packet_len: u16,
273 initial_size: usize,
274
275 data: Vec<u8>,
276 data_packet_num: usize,
277}
278
279impl PESSegment {
280 fn decrypt_and_write<W: Write>(
281 mut self,
282 key: [u8; 16],
283 iv: [u8; 16],
284 writer: &mut IoriTsPacketWriter<W>,
285 ) -> Result<()> {
286 match self.stream_type {
288 StreamType::H264 | StreamType::H264WithAes128Cbc => self.decrypt_video(key, iv)?,
290 StreamType::AdtsAac
292 | StreamType::AdtsAacWithAes128Cbc
293 | StreamType::DolbyDigitalUpToSixChannelAudio
295 | StreamType::DolbyDigitalUpToSixChannelAudioWithAes128Cbc
296 | StreamType::DolbyDigitalPlusUpTo16ChannelAudio
298 | StreamType::DolbyDigitalPlusUpToSixChannelAudioWithAes128Cbc => {
299 self.decrypt_audio(key, iv)
300 }
301 _ => unreachable!("Unsupported stream type: {:?}", self.stream_type),
302 }
303
304 let pid = self.pes_ts_header.pid;
305
306 let mut input = self.data.as_slice();
309 let initial_size = input.len().min(self.initial_size);
310 writer.write_packet(&mut TsPacket {
311 header: self.pes_ts_header,
312 adaptation_field: None,
313 payload: Some(TsPayload::Pes(Pes {
314 header: self.pes_header,
315 pes_packet_len: self.pes_packet_len,
316 data: Bytes::new(&self.data[..initial_size])?,
317 })),
318 })?;
319
320 input = &input[initial_size..];
321 let mut remaining_packets = self.data_packet_num;
322
323 while !input.is_empty() {
324 let size = input.len() / remaining_packets;
326 let data = &input[..size];
327 input = &input[size..];
328
329 let mut packet = TsPacket {
330 header: TsHeader {
331 pid,
332 transport_scrambling_control: TransportScramblingControl::NotScrambled,
333 transport_error_indicator: false,
334 transport_priority: false,
335 continuity_counter: ContinuityCounter::new(), },
337 adaptation_field: None,
338 payload: Some(TsPayload::Raw(Bytes::new(data).unwrap())),
340 };
341 writer.write_packet(&mut packet)?;
342
343 remaining_packets -= 1;
344 }
345
346 Ok(())
347 }
348
349 fn decrypt_video(&mut self, key: [u8; 16], iv: [u8; 16]) -> Result<()> {
350 let mut input = self.data.as_slice();
351 let output = Vec::with_capacity(self.data.len() * 2);
352 let mut output = BufWriter::new(output);
353
354 loop {
355 let (mut nal_unit, data_new) = NALUnit::get_next(input)?;
356 input = data_new;
357
358 if nal_unit.r#type == 5 || nal_unit.r#type == 1 {
359 nal_unit.decrypt(&key, &iv);
360 }
361
362 nal_unit.write(&mut output)?;
363
364 if input.is_empty() {
365 break;
366 }
367 }
368
369 self.data = output.into_inner().map_err(|e| e.into_error())?;
370
371 Ok(())
372 }
373
374 fn decrypt_audio(&mut self, key: [u8; 16], iv: [u8; 16]) {
375 let mut input = self.data.as_mut_slice();
376 while !input.is_empty() {
377 match self.stream_type {
378 StreamType::AdtsAac | StreamType::AdtsAacWithAes128Cbc => {
380 let size = decrypt_aac_frame(input, key, iv);
381 input = &mut input[size..];
382 }
383 StreamType::DolbyDigitalUpToSixChannelAudio
385 | StreamType::DolbyDigitalUpToSixChannelAudioWithAes128Cbc => {
386 let size = decrypt_ac3_frame(input, key, iv);
387 input = &mut input[size..];
388 }
389 StreamType::DolbyDigitalPlusUpTo16ChannelAudio
391 | StreamType::DolbyDigitalPlusUpToSixChannelAudioWithAes128Cbc => {
392 let size = decrypt_eac3_frame(input, key, iv);
393 input = &mut input[size..];
394 }
395 _ => unimplemented!("Unsupported stream type: {:?}", self.stream_type),
396 }
397 }
398 }
399}
400
401struct IoriTsPacketWriter<W> {
402 inner: TsPacketWriter<W>,
403 counters: HashMap<u16, ContinuityCounter>,
404}
405
406impl<W: Write> IoriTsPacketWriter<W> {
407 fn new(inner: W) -> Self {
408 Self {
409 inner: TsPacketWriter::new(inner),
410 counters: HashMap::new(),
411 }
412 }
413
414 fn get_counter(
415 &mut self,
416 pid: u16,
417 default_counter: ContinuityCounter,
418 ) -> &mut ContinuityCounter {
419 self.counters.entry(pid).or_insert(default_counter)
420 }
421
422 fn write_packet(&mut self, packet: &mut TsPacket) -> mpeg2ts::Result<()> {
423 let counter =
424 self.get_counter(packet.header.pid.as_u16(), packet.header.continuity_counter);
425 packet.header.continuity_counter = *counter;
426
427 if !matches!(packet.payload, None | Some(TsPayload::Null(_))) {
428 counter.increment();
429 }
430
431 self.inner.write_ts_packet(packet)
432 }
433}
434
435fn should_decrypt_stream(id_map: &HashMap<u16, StreamType>, pid: u16) -> bool {
436 let stream_type = id_map.get(&pid);
437
438 match stream_type {
439 Some(
440 StreamType::H264WithAes128Cbc
442 | StreamType::H264
443 | StreamType::AdtsAacWithAes128Cbc
445 | StreamType::AdtsAac
446 | StreamType::DolbyDigitalUpToSixChannelAudioWithAes128Cbc
448 | StreamType::DolbyDigitalUpToSixChannelAudio
449 | StreamType::DolbyDigitalPlusUpToSixChannelAudioWithAes128Cbc
451 | StreamType::DolbyDigitalPlusUpTo16ChannelAudio,
452 ) => true,
453 _ => false,
454 }
455}
456
457pub fn decrypt_mpegts<R, W>(input: R, output: W, key: [u8; 16], iv: [u8; 16]) -> Result<()>
458where
459 R: Read,
460 W: Write,
461{
462 let mut reader = TsPacketReader::new(input);
463 let mut writer = IoriTsPacketWriter::new(output);
464
465 let mut streams = HashMap::new();
466 let mut pid_map = HashMap::new();
467
468 while let Ok(Some(TsPacket {
469 header,
470 adaptation_field,
471 payload,
472 })) = reader.read_ts_packet()
473 {
474 if let Some(payload) = payload {
475 let flush = if matches!(
477 payload,
478 TsPayload::Pes(_) |
480 TsPayload::Raw(_) |
482 TsPayload::Null(_)
484 ) {
485 None
486 } else {
487 Some(header.pid)
488 };
489
490 match payload {
491 TsPayload::Pmt(mut pmt) => {
492 for es in pmt.es_info.iter_mut() {
494 pid_map.insert(es.elementary_pid.as_u16(), es.stream_type);
496
497 es.stream_type = match es.stream_type {
499 StreamType::H264WithAes128Cbc => StreamType::H264,
500 StreamType::AdtsAacWithAes128Cbc => StreamType::AdtsAac,
501 StreamType::DolbyDigitalUpToSixChannelAudioWithAes128Cbc => {
502 StreamType::DolbyDigitalUpToSixChannelAudio
503 }
504 StreamType::DolbyDigitalPlusUpToSixChannelAudioWithAes128Cbc => {
505 StreamType::DolbyDigitalPlusUpTo16ChannelAudio
506 }
507 _ => es.stream_type,
508 };
509 }
510 writer.write_packet(&mut TsPacket {
511 header,
512 adaptation_field,
513 payload: Some(TsPayload::Pmt(pmt)),
514 })?;
515 }
516 TsPayload::Pes(pes) if should_decrypt_stream(&pid_map, header.pid.as_u16()) => {
518 let stream_type = pid_map.get(&header.pid.as_u16());
519
520 let prev_pes = streams.insert(
521 header.pid,
522 PESSegment {
523 stream_type: *stream_type.unwrap(),
525
526 pes_ts_header: header,
527 pes_header: pes.header,
528 pes_packet_len: pes.pes_packet_len,
529 initial_size: pes.data.len(),
530 data: pes.data.to_vec(),
531 data_packet_num: 0,
532 },
533 );
534
535 if let Some(pes) = prev_pes {
536 pes.decrypt_and_write(key, iv, &mut writer)?;
537 }
538 }
539 TsPayload::Raw(bytes) if streams.contains_key(&header.pid) => {
540 let pes = streams.get_mut(&header.pid).unwrap();
542 pes.data_packet_num += 1;
543 pes.data.extend_from_slice(&bytes);
544 }
545 _ => writer.write_packet(&mut TsPacket {
547 header,
548 adaptation_field,
549 payload: Some(payload),
550 })?,
551 }
552
553 if let Some(flush) = flush {
554 if let Some(pes) = streams.remove(&flush) {
555 pes.decrypt_and_write(key, iv, &mut writer)?;
556 };
557 }
558 }
559 }
560
561 for pes in streams.into_values() {
563 pes.decrypt_and_write(key, iv, &mut writer)?;
564 }
565
566 Ok(())
567}
568
569enum AudioSetupType {
570 AacLc,
572 AacHeV1,
574 AacHeV2,
576 Ac3,
578 EnhancedAc3,
580}
581
582pub fn decrypt<R, W>(input: R, mut output: W, key: [u8; 16], iv: [u8; 16]) -> Result<()>
583where
584 R: Read,
585 W: Write,
586{
587 let mut input = BufReader::new(input);
588 let magic = input.fill_buf()?;
589
590 if magic.is_empty() {
591 return Ok(());
592 }
593
594 if magic[0] == 0x47 {
596 return decrypt_mpegts(input, output, key, iv);
597 }
598
599 let mut audio_format = None;
600 let mut is_id3 = &magic[0..3] == b"ID3";
601 while is_id3 {
602 #[allow(deprecated)]
603 let tag = id3::Tag::read_from(&mut input)?;
604 tag.write_to(&mut output, tag.version())?;
605
606 let format = tag.frames().find(|f| f.id() == "PRIV").and_then(|p| {
609 if let id3::Content::Private(p) = p.content() {
610 if p.owner_identifier == "com.apple.streaming.audioDescription" {
611 let data = &p.private_data;
619 if data.len() >= 4 {
620 let format = &data[0..4];
621 return match format {
622 b"zaac" => Some(AudioSetupType::AacLc),
623 b"zach" => Some(AudioSetupType::AacHeV1),
624 b"zacp" => Some(AudioSetupType::AacHeV2),
625 b"zac3" => Some(AudioSetupType::Ac3),
626 b"zec3" => Some(AudioSetupType::EnhancedAc3),
627 _ => None,
628 };
629 }
630 }
631 }
632
633 None
634 });
635
636 if let Some(format) = format {
637 audio_format = Some(format);
638 }
639
640 let magic = input.fill_buf()?;
641 is_id3 = magic.len() >= 3 && &magic[0..3] == b"ID3";
642 }
643
644 let Some(audio_format) = audio_format else {
645 return Ok(());
646 };
647
648 let mut buf = Vec::new();
649 input.read_to_end(&mut buf)?;
650
651 let mut data = &mut buf[..];
652 loop {
653 if data.is_empty() {
654 break;
655 }
656
657 let size = match audio_format {
658 AudioSetupType::AacLc | AudioSetupType::AacHeV1 | AudioSetupType::AacHeV2 => {
659 decrypt_aac_frame(data, key, iv)
660 }
661 AudioSetupType::Ac3 => decrypt_ac3_frame(data, key, iv),
662 AudioSetupType::EnhancedAc3 => decrypt_eac3_frame(data, key, iv),
663 };
664
665 let decrypted = &data[..size];
666 output.write_all(decrypted)?;
667
668 data = &mut data[size..];
669 }
670
671 Ok(())
672}