symphonia-wem 0.1.0

Symphonia demuxer for Wwise Encoded Media files
Documentation
//! Read packets, of which the data chunk consists.

use bitvec::{array::BitArray, order::Lsb0, vec::BitVec, view::BitView as _};
use symphonia_core::{
    errors::Result,
    io::{MediaSourceStream, ReadBytes as _},
    packet::{Packet, PacketBuilder},
    units::{Duration, Timestamp},
};

use crate::bits::{read, read_write};

/// Read the packets of a data chunk.
pub(crate) struct PacketReader {
    /// Data required for reading a packet.
    packet_info: PacketInfo,
    /// Window flag of the previous packet.
    previous_window_flag: bool,
    /// Bytes of the next packet.
    ///
    /// Needs to be read in advance for the current packet.
    next_packet_bytes: Option<Box<[u8]>>,
    /// Scratch buffer for output packets.
    output: BitVec<u8, Lsb0>,
}

impl PacketReader {
    /// Construct a new packet reader.
    pub(crate) fn new(packet_info: PacketInfo) -> Self {
        Self {
            packet_info,
            previous_window_flag: false,
            output: BitVec::new(),
            next_packet_bytes: None,
        }
    }

    /// Read the next packet information.
    ///
    /// Outside of the `Self::read` function this only needs to be called once.
    pub(crate) fn read_first(&mut self, mss: &mut MediaSourceStream) -> Result<()> {
        if self.packet_info.last_offset - mss.pos() <= 2 {
            // No more packets in the buffer
            self.next_packet_bytes = None;
        } else {
            // Read the packet
            self.next_packet_bytes = Self::read_packet_bytes(mss)?;
        }

        Ok(())
    }

    /// Read a single packet, moving the buffer.
    pub(crate) fn read(
        &mut self,
        mss: &mut MediaSourceStream,
        track_id: u32,
    ) -> Result<Option<Packet>> {
        let Some(current_packet_bytes) = self.next_packet_bytes.take() else {
            // No more packets in the buffer
            return Ok(None);
        };

        // Read the next packet
        self.read_first(mss)?;

        // Convert to bits
        let current_packet_bits = current_packet_bytes.view_bits();

        // Start new packet
        self.output.clear();
        // Packet type is audio
        self.output.push(false);

        // Get mode number from first byte
        let (current_packet_bits, mode_number): (_, usize) = read_write(
            current_packet_bits,
            &mut self.output,
            self.packet_info.mode_bits,
        );

        let current_mode_block_flag = self.packet_info.modes_block_flags[mode_number];
        if current_mode_block_flag {
            // Long window, look at next packet
            let next_window_flag = if let Some(next_packet) = &self.next_packet_bytes {
                // Only read the mode bits from the next packet
                let (_, next_mode_number): (_, usize) =
                    read(next_packet.view_bits(), self.packet_info.mode_bits);

                self.packet_info.modes_block_flags[next_mode_number]
            } else {
                // If there's no bytes do nothing
                false
            };

            // Previouws window type bit
            self.output.push(self.previous_window_flag);

            // Next window type bit
            self.output.push(next_window_flag);
        }

        // Keep track of the mode flag for the next packet
        self.previous_window_flag = current_mode_block_flag;

        // Write the rest of the input bits
        let (_, _remainder): (_, u8) = read_write(
            current_packet_bits,
            &mut self.output,
            8 - self.packet_info.mode_bits,
        );

        // Copy the rest of the buffer
        self.output.extend(&current_packet_bytes[1..]);

        // Get the duration of the packet
        let (dur, _) = self.parse_duration(mode_number);

        // Create the Symphonia packet
        let packet = PacketBuilder::new()
            .track_id(track_id)
            .dur(dur)
            .pts(Timestamp::ZERO)
            .data(self.output.as_raw_slice())
            .build();

        Ok(Some(packet))
    }

    /// Extract the durations from a packet.
    fn parse_duration(&self, mode_number: usize) -> (Duration, Duration) {
        // Determine the current block size
        let current_block_size_exponent = if mode_number < self.packet_info.num_modes as usize {
            if self.packet_info.modes_block_flags[mode_number] {
                self.packet_info.block_size_1
            } else {
                self.packet_info.block_size_0
            }
        } else {
            return (Duration::ZERO, Duration::ZERO);
        };

        // Calculate the duration and number of frames to discard
        let current_block_size = 1_u64 << current_block_size_exponent;

        // TODO
        // let (dur, discard) = if let Some(prev_bs_exp) = prev_bs_exp {
        //     let prev_block_n = 1 << prev_bs_exp;
        //     // Have previous block, do not discard any frames
        //     ((prev_block_n >> 2) + (cur_block_n >> 2), 0)
        // } else {
        //     // Do not have previous block, all lapped frames will be disarded
        //     (cur_block_n >> 1, cur_block_n >> 1)
        // };

        // prev_bs_exp = Some(cur_bs_exp);

        let (dur, discard) = (current_block_size >> 1, current_block_size >> 1);

        (Duration::new(dur), Duration::new(discard))
    }

    /// Read the bytes for a single packet from the stream.
    fn read_packet_bytes(mss: &mut MediaSourceStream) -> Result<Option<Box<[u8]>>> {
        // Read the size
        let size = mss.read_u16()?;

        // Don't read empty packets
        if size == 0 {
            return Ok(None);
        }

        // Read the packet
        Ok(Some(mss.read_boxed_slice_exact(size as usize)?))
    }
}

/// Information about the packets.
#[derive(Default)]
pub(crate) struct PacketInfo {
    /// Last byte offset in the reader before the end-of-file.
    pub(crate) last_offset: u64,
    /// Block size if block flag is not set.
    pub(crate) block_size_0: u8,
    /// Block size if block flag is set.
    pub(crate) block_size_1: u8,
    /// Log function of the amount of modes.
    pub(crate) mode_bits: usize,
    /// How many bits in `modes_block_flags` are set.
    pub(crate) num_modes: u8,
    /// Each mode as a flag.
    pub(crate) modes_block_flags: BitArray<u64, Lsb0>,
}