rvz 0.1.3

RVZ compression library.
Documentation
// SPDX-License-Identifier: LGPL-2.1-or-later OR GPL-2.0-or-later OR MPL-2.0
// SPDX-FileCopyrightText: 2024 Gabriel Marcano <gabemarcano@yahoo.com>

use std::io;
use std::io::Cursor;
use std::io::Read;
use std::io::Seek;
use std::io::SeekFrom;

use std::cmp;

use byteorder::BigEndian;
use byteorder::ReadBytesExt;
use byteorder::WriteBytesExt;

/// Represents the state of a lagged Fibonacci generator, used to generate padding on GCN and Wii
/// games.
///
/// Uses j=32, k=521 with f = XOR
pub struct Prng {
    /// Current LFG state
    state: [u32; 521],
    /// The position within the current state window
    position: usize,
    /// The position relative to the original state window
    absolute_position: u64,
}

impl Prng {
    /// Creates a new [`Prng`] by reading 17 32-bit words (assumed big endian encoding) from the
    /// given Read-able object.
    ///
    /// # Errors
    /// Returns [`io::Error`] if some
    pub fn new<T: Read>(io: &mut T) -> io::Result<Self> {
        let mut result = Self {
            state: [0u32; 521],
            position: 0,
            absolute_position: 0,
        };

        let buffer = &mut result.state;
        for word in buffer.iter_mut().take(68 / 4) {
            *word = io.read_u32::<BigEndian>()?;
        }

        for i in 17..521 {
            buffer[i] = (buffer[i - 17] << 23) ^ (buffer[i - 16] >> 9) ^ buffer[i - 1];
        }

        result.advance_state();
        result.advance_state();
        result.advance_state();
        result.advance_state();

        Ok(result)
    }

    /// Advances the LFG state.
    fn advance_state(&mut self) {
        let buffer = &mut self.state;
        for i in 0..32 {
            buffer[i] ^= buffer[i + 521 - 32];
        }

        for i in 32..521 {
            buffer[i] ^= buffer[i - 32];
        }
        self.position = 0;
    }

    /// Reads a 32-bit word from the LFG.
    ///
    /// This automatically advances the LFG state if 521 words have been read since the last
    /// advance.
    pub fn read_word(&mut self) -> u32 {
        let result = self.read_word_();
        self.absolute_position += 4;
        self.position += 4;

        if self.position >= 521 * 4 {
            self.position -= 521 * 4;
            self.advance_state();
        }

        result
    }

    /// Reads a single byte from the LFG.
    ///
    /// This automatically advances the LFG state if 521 words have been read since the last
    /// advance.
    pub fn read_byte(&mut self) -> u8 {
        let result = self.read_word_();

        let result = ((result >> (8 * (3 - (self.position % 4)))) & 0xFF) as u8;
        self.absolute_position += 1;
        self.position += 1;

        if self.position >= 521 * 4 {
            self.position -= 521 * 4;
            self.advance_state();
        }

        result
    }

    /// Extracts a word from the LFG, as described by the Dolphin emulator documentation.
    ///
    /// To quote the documentation:
    /// ```c
    ///   *(out++) = *buffer_ptr >> 24;
    ///   *(out++) = *buffer_ptr >> 18;  // NB: 18, not 16
    ///   *(out++) = *buffer_ptr >> 8;
    ///   *(out++) = *buffer_ptr;
    /// ```
    const fn read_word_(&mut self) -> u32 {
        let buffer = &mut self.state;
        let i = self.position / 4;

        (buffer[i] & 0xFF00_FFFF) | (((buffer[i] >> 18) & 0xFF) << 16)
    }
}

impl Read for Prng {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        let size = buf.len();
        let offset = (self.absolute_position % 4) as usize;
        let first_read = cmp::min(4 - offset, size);

        let mut output = Cursor::new(buf);

        for _ in 0..first_read {
            output.write_u8(self.read_byte())?;
        }

        for _ in 0..(size - first_read) / 4 {
            output.write_u32::<BigEndian>(self.read_word())?;
        }

        for _ in 0..((size - first_read) % 4) {
            output.write_u8(self.read_byte())?;
        }

        Ok(size)
    }
}

impl Seek for Prng {
    fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
        match pos {
            SeekFrom::End(_) => return Err(io::Error::from(io::ErrorKind::Unsupported)),
            SeekFrom::Start(offset) => {
                if self.absolute_position > offset {
                    return Err(io::Error::from(io::ErrorKind::InvalidInput));
                }

                let position = offset + self.absolute_position;
                let advance = position / (521 * 4);
                for _ in 0..advance {
                    self.advance_state();
                }
                self.absolute_position = position;
                self.position = usize::try_from(position % (521 * 4)).unwrap();
            }
            SeekFrom::Current(offset) => {
                if offset < 0 {
                    return Err(io::Error::from(io::ErrorKind::InvalidInput));
                }
                let position = u64::try_from(self.position).unwrap();
                let adjust = u64::try_from(offset).unwrap() + position;
                let advance = adjust / (521 * 4);
                for _ in 0..advance {
                    self.advance_state();
                }
                self.absolute_position += u64::try_from(offset).unwrap();
                self.position = usize::try_from(adjust % (521 * 4)).unwrap();
            }
        }
        Ok(self.absolute_position)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    const SEED: [u8; 68] = [
        0x0e, 0x41, 0xc8, 0xa2, 0x1d, 0x10, 0x21, 0xdc, 0xed, 0x01, 0x55, 0xeb, 0x6e, 0x8c, 0xa5,
        0x79, 0x3e, 0x07, 0xb0, 0xf5, 0x54, 0xd4, 0x67, 0x60, 0x8c, 0xe7, 0x41, 0x11, 0x7f, 0xc0,
        0x71, 0x97, 0x8a, 0x1e, 0xe6, 0xdb, 0x9a, 0x64, 0xaa, 0x9e, 0x8c, 0x98, 0xce, 0x1b, 0x01,
        0x83, 0x90, 0x9b, 0xbe, 0xc1, 0xca, 0x7c, 0x15, 0x65, 0x16, 0x83, 0x9a, 0xe8, 0xc6, 0xe2,
        0xb7, 0x2e, 0xb3, 0x7f, 0x10, 0x33, 0xf1, 0x0f,
    ];

    #[test]
    fn test_prng() -> io::Result<()> {
        let mut prng = Prng::new(&mut Cursor::new(&SEED))?;
        assert_eq!(prng.read_byte(), 0xFCu8);
        assert_eq!(prng.read_byte(), 0x1Du8);
        assert_eq!(prng.read_byte(), 0x18u8);
        assert_eq!(prng.read_byte(), 0xCAu8);
        Ok(())
    }

    // This is an actual seed from a real RVZ file
    const SEED2: [u32; 17] = [
        3372614438, 158434837, 3234245604, 4115409617, 995691059, 2259600081, 2567596901,
        3224889622, 2252069048, 2066444621, 1716761720, 4279803570, 4002705656, 3252961509,
        1672553374, 521454735, 808455626,
    ];

    #[test]
    fn test_find() -> io::Result<()> {
        let mut tmp = vec![];
        for elem in SEED2 {
            tmp.extend(elem.to_be_bytes());
        }

        let mut prng = Prng::new(&mut Cursor::new(&tmp))?;
        for counter in 0..(0x8000 / 4) {
            if prng.read_word() == 0xd29eef8cu32 {
                assert_eq!(counter * 4, 0x27fc);
                break;
            }
        }

        Ok(())
    }

    const SEED3: [u32; 17] = [
        2533838971, 2111335816, 3192687302, 3877980478, 3258345385, 3525162821, 990536516,
        3587986763, 1332061773, 524688494, 3109810193, 41400884, 3342497737, 3070796892,
        3033303145, 3632200354, 895704912,
    ];

    #[test]
    fn test_find2() -> io::Result<()> {
        let mut tmp = vec![];
        for elem in SEED3 {
            tmp.extend(elem.to_be_bytes());
        }

        let mut prng = Prng::new(&mut Cursor::new(&tmp))?;
        for counter in 0..(0x8000 / 4) {
            if prng.read_word() == 0xeeb3ecd5 {
                assert_eq!(counter * 4, 0x4C00);
                break;
            }
        }

        Ok(())
    }
}