Skip to main content

rkg_utils/input_data/
mod.rs

1use crate::header::controller::Controller;
2use crate::input_data::dpad_input::{DPadButton, DPadInput};
3use crate::input_data::face_input::{FaceButton, FaceInput};
4use crate::input_data::input::Input;
5use crate::input_data::stick_input::StickInput;
6
7pub mod dpad_input;
8pub mod face_input;
9pub mod input;
10pub mod stick_input;
11
12/// Errors that can occur while parsing [`InputData`].
13#[derive(thiserror::Error, Debug)]
14pub enum InputDataError {
15    /// Input data is impossibly short.
16    #[error("Input data length is too short")]
17    InputDataLengthTooShort,
18    /// Input data is malformed.
19    #[error("Input data is malformed")]
20    InputDataMalformed,
21    /// A face input entry could not be parsed.
22    #[error("Face Input Error: {0}")]
23    FaceInputError(#[from] face_input::FaceInputError),
24    /// A D-pad input entry could not be parsed.
25    #[error("DPad Input Error: {0}")]
26    DPadInputError(#[from] dpad_input::DPadInputError),
27    /// A stick input entry could not be parsed.
28    #[error("Stick Input Error: {0}")]
29    StickInputError(#[from] stick_input::StickInputError),
30}
31
32/// The controller input stream from a Mario Kart Wii RKG ghost file.
33///
34/// Stores the raw bytes (compressed or decompressed) alongside three decoded
35/// run-length encoded input streams: face buttons, analog stick, and D-pad.
36/// Adjacent identical entries across byte boundaries are merged during parsing
37/// so that each entry in the decoded vectors represents a single contiguous
38/// hold period.
39///
40/// The binary layout is documented at
41/// <https://wiki.tockdom.com/wiki/RKG_(File_Format)#Controller_Input_Data>.
42pub struct InputData {
43    /// The raw input data bytes as they appear in the file (may be Yaz1 compressed).
44    raw_data: Vec<u8>,
45    /// The number of face input entries as recorded in the stream header.
46    face_input_count: u16,
47    /// The number of stick input entries as recorded in the stream header.
48    stick_input_count: u16,
49    /// The number of D-pad input entries as recorded in the stream header.
50    dpad_input_count: u16,
51    /// The decoded and merged face button input stream.
52    face_inputs: Vec<FaceInput>,
53    /// The decoded and merged analog stick input stream.
54    stick_inputs: Vec<StickInput>,
55    /// The decoded D-pad input stream.
56    dpad_inputs: Vec<DPadInput>,
57}
58
59impl InputData {
60    /// Parses controller input data from raw RKG bytes starting at offset `0x88`.
61    ///
62    /// If the bytes beginning at offset 4 carry a `Yaz1` magic header, the
63    /// data is decompressed before parsing. Otherwise the slice is zero-padded
64    /// to `0x2774` bytes. After parsing, adjacent identical face and stick
65    /// entries are merged to represent each continuous hold as a single entry.
66    ///
67    /// # Errors
68    ///
69    /// Returns an [`InputDataError`] variant if any individual input entry
70    /// fails to parse.
71    pub fn new(input_data: &[u8]) -> Result<Self, InputDataError> {
72        if input_data.len() < 0x08 {
73            return Err(InputDataError::InputDataLengthTooShort);
74        }
75
76        let mut raw_data = Vec::from(input_data);
77
78        let input_data = if input_data[4..8] == [0x59, 0x61, 0x7A, 0x31] {
79            // YAZ1 header, decompress
80            yaz1_decompress(&input_data[4..]).unwrap()
81        } else {
82            raw_data.resize(0x2774, 0x00);
83            Vec::from(input_data)
84        };
85
86        let face_input_count = u16::from_be_bytes([input_data[0], input_data[1]]);
87        let stick_input_count = u16::from_be_bytes([input_data[2], input_data[3]]);
88        let dpad_input_count = u16::from_be_bytes([input_data[4], input_data[5]]);
89        // bytes 6-7: padding
90
91        if (input_data.len() as u64)
92            < ((face_input_count as u64 + stick_input_count as u64 + dpad_input_count as u64) * 2)
93                + 8
94        {
95            return Err(InputDataError::InputDataMalformed);
96        }
97
98        let mut current_byte = 8;
99        let mut face_inputs: Vec<FaceInput> = Vec::with_capacity(face_input_count as usize);
100        while current_byte < 8 + face_input_count * 2 {
101            let idx = current_byte as usize;
102            let input = &input_data[idx..idx + 2];
103            face_inputs.push(FaceInput::try_from(input)?);
104            current_byte += 2;
105        }
106
107        current_byte = 8 + face_input_count * 2;
108        let mut stick_inputs: Vec<StickInput> = Vec::with_capacity(stick_input_count as usize);
109        while current_byte < 8 + (face_input_count + stick_input_count) * 2 {
110            let idx = current_byte as usize;
111            let input = &input_data[idx..idx + 2];
112            stick_inputs.push(StickInput::try_from(input)?);
113            current_byte += 2;
114        }
115
116        current_byte = 8 + (face_input_count + stick_input_count) * 2;
117        let mut dpad_inputs: Vec<DPadInput> = Vec::with_capacity(dpad_input_count as usize);
118        while current_byte < 8 + (face_input_count + stick_input_count + dpad_input_count) * 2 {
119            let idx = current_byte as usize;
120            let input = &input_data[idx..idx + 2];
121            dpad_inputs.push(DPadInput::try_from(input)?);
122            current_byte += 2;
123        }
124
125        // Combine adjacent inputs when the same button is held across multiple bytes
126        // (each input byte has a 255-frame limit, so buttons held longer need additional bytes)
127        for index in (0..face_inputs.len() - 1).rev() {
128            if face_inputs[index] == face_inputs[index + 1] {
129                let f1 = face_inputs[index].frame_duration();
130                let f2 = face_inputs[index + 1].frame_duration();
131                face_inputs[index].set_frame_duration(f1 + f2);
132                face_inputs.remove(index + 1);
133            }
134        }
135
136        for index in (0..stick_inputs.len() - 1).rev() {
137            if stick_inputs[index] == stick_inputs[index + 1] {
138                let f1 = stick_inputs[index].frame_duration();
139                let f2 = stick_inputs[index + 1].frame_duration();
140                stick_inputs[index].set_frame_duration(f1 + f2);
141                stick_inputs.remove(index + 1);
142            }
143        }
144
145        Ok(Self {
146            raw_data,
147            face_input_count,
148            stick_input_count,
149            dpad_input_count,
150            face_inputs,
151            stick_inputs,
152            dpad_inputs,
153        })
154    }
155
156    /// Returns the three input streams merged into a single frame-accurate sequence of [`Input`] values.
157    ///
158    /// The face, stick, and D-pad streams are interleaved by advancing through
159    /// all three simultaneously and emitting a new [`Input`] whenever any
160    /// stream transitions to its next entry. Each emitted [`Input`] covers
161    /// exactly the frames until the next transition across any stream.
162    pub fn inputs(&self) -> Vec<Input> {
163        let mut result = Vec::new();
164
165        // Track current position in each input stream
166        let mut face_idx = 0;
167        let mut stick_idx = 0;
168        let mut dpad_idx = 0;
169
170        // Track how many frames consumed from current input in each stream
171        let mut face_offset = 0u32;
172        let mut stick_offset = 0u32;
173        let mut dpad_offset = 0u32;
174
175        // Continue until all streams are exhausted
176        while face_idx < self.face_inputs.len()
177            || stick_idx < self.stick_inputs.len()
178            || dpad_idx < self.dpad_inputs.len()
179        {
180            // Get current input from each stream (or defaults if exhausted)
181            let face = self.face_inputs.get(face_idx);
182            let stick = self.stick_inputs.get(stick_idx);
183            let dpad = self.dpad_inputs.get(dpad_idx);
184
185            // Calculate remaining frames for current input in each stream
186            let face_remaining = face
187                .map(|f| f.frame_duration() - face_offset)
188                .unwrap_or(u32::MAX);
189            let stick_remaining = stick
190                .map(|s| s.frame_duration() - stick_offset)
191                .unwrap_or(u32::MAX);
192            let dpad_remaining = dpad
193                .map(|d| d.frame_duration() - dpad_offset)
194                .unwrap_or(u32::MAX);
195
196            // Find the minimum remaining frames (when next change occurs)
197            let duration = face_remaining.min(stick_remaining).min(dpad_remaining);
198
199            if duration == u32::MAX {
200                // if all streams exhausted
201                break;
202            }
203
204            // Create combined input for this duration
205            let combined = Input::new(
206                face.map(|f| f.buttons().clone()).unwrap_or_default(),
207                stick.map(|s| s.x()).unwrap_or(0),
208                stick.map(|s| s.y()).unwrap_or(0),
209                dpad.map(|d| d.button()).unwrap_or(DPadButton::None),
210                duration,
211            );
212            result.push(combined);
213
214            // Update offsets and advance indices where needed
215            face_offset += duration;
216            stick_offset += duration;
217            dpad_offset += duration;
218
219            if let Some(face) = face
220                && face_offset >= face.frame_duration()
221            {
222                face_idx += 1;
223                face_offset = 0;
224            }
225            if let Some(stick) = stick
226                && stick_offset >= stick.frame_duration()
227            {
228                stick_idx += 1;
229                stick_offset = 0;
230            }
231            if let Some(dpad) = dpad
232                && dpad_offset >= dpad.frame_duration()
233            {
234                dpad_idx += 1;
235                dpad_offset = 0;
236            }
237        }
238
239        result
240    }
241
242    /// Returns `true` if the face input stream contains an illegal drift or brake input.
243    ///
244    /// An input is illegal if drift is active without brake, or if brake and
245    /// accelerator are pressed simultaneously without drift when the previous
246    /// frame had accelerator but not brake (indicating a missing drift flag).
247    pub fn contains_illegal_brake_or_drift_inputs(&self) -> bool {
248        for (idx, input) in self.face_inputs().iter().enumerate() {
249            let current_buttons = input.buttons();
250            if current_buttons.contains(&FaceButton::Drift)
251                && !current_buttons.contains(&FaceButton::Brake)
252            {
253                // Illegal drift input
254                return true;
255            } else if idx > 0 {
256                let previous_buttons = self.face_inputs()[idx - 1].buttons();
257                if current_buttons.contains(&FaceButton::Brake)
258                    && current_buttons.contains(&FaceButton::Accelerator)
259                    && !current_buttons.contains(&FaceButton::Drift)
260                    && previous_buttons.contains(&FaceButton::Accelerator)
261                    && !previous_buttons.contains(&FaceButton::Brake)
262                {
263                    // Illegal brake input (drift flag isn't 1 when it should be)
264                    return true;
265                }
266            }
267        }
268        false
269    }
270
271    /// Returns `true` if the raw input data begins with a Yaz1 magic header at offset 4.
272    pub fn is_compressed(&self) -> bool {
273        self.raw_data[4..8] == [0x59, 0x61, 0x7A, 0x31]
274    }
275
276    /// Compresses the raw input data using Yaz1 encoding.
277    ///
278    /// Does nothing if the data is already compressed.
279    pub(crate) fn compress(&mut self) {
280        if !self.is_compressed() {
281            self.raw_data = yaz1_compress(&self.raw_data);
282        }
283    }
284
285    /// Decompresses the raw input data from Yaz1 encoding.
286    ///
287    /// Does nothing if the data is not compressed.
288    pub(crate) fn decompress(&mut self) {
289        if self.is_compressed() {
290            self.raw_data = yaz1_decompress(&self.raw_data[4..]).unwrap();
291        }
292    }
293
294    /// Returns the raw input data bytes as they appear in the file.
295    pub fn raw_data(&self) -> &[u8] {
296        &self.raw_data
297    }
298
299    /// Returns `true` if the stick input stream contains any illegal stick position
300    /// for the given controller type. More info on illegal input ranges here:
301    /// <https://github.com/malleoz/mkw-replay?tab=readme-ov-file#regarding-input-ranges>
302    /// <https://youtu.be/KUjS7qWWu9c?t=489>
303    ///
304    /// The Wii Wheel has a fully unrestricted input range and is never considered to
305    /// have illegal inputs.
306    pub fn contains_illegal_stick_inputs(&self, controller: Controller) -> bool {
307        // Definition of illegal stick inputs [x, y]
308        const ILLEGAL_STICK_INPUTS: [[i8; 2]; 44] = [
309            // These inputs are illegal for GCN, CCP, and Nunchuk (24 total)
310            [-7, 7],
311            [-7, 6],
312            [-7, 5],
313            [-7, -7],
314            [-7, -6],
315            [-7, -5],
316            [-6, 7],
317            [-6, 6],
318            [-6, -7],
319            [-6, -6],
320            [-5, 7],
321            [-5, -7],
322            [7, 7],
323            [7, 6],
324            [7, 5],
325            [7, -7],
326            [7, -6],
327            [7, -5],
328            [6, 7],
329            [6, 6],
330            [6, -7],
331            [6, -6],
332            [5, 7],
333            [5, -7],
334            // Illegal stick inputs for specifically GCN/CCP (additional 20)
335            [-7, 4],
336            [-6, 5],
337            [-5, 6],
338            [-4, 7],
339            [-3, 7],
340            [3, 7],
341            [4, 7],
342            [4, 6],
343            [4, -7],
344            [5, 6],
345            [5, 5],
346            [5, -6],
347            [6, 5],
348            [6, 4],
349            [6, -5],
350            [7, 4],
351            [7, 3],
352            [7, 2],
353            [7, -3],
354            [7, -4],
355        ];
356
357        let illegal_stick_inputs = match controller {
358            Controller::Nunchuk => &ILLEGAL_STICK_INPUTS[..24],
359            Controller::Classic | Controller::Gamecube => &ILLEGAL_STICK_INPUTS,
360            Controller::WiiWheel => {
361                return false;
362            }
363        };
364
365        for current_stick_input in self.stick_inputs().iter() {
366            for illegal_stick_input in illegal_stick_inputs.iter() {
367                if current_stick_input == illegal_stick_input {
368                    return true;
369                }
370            }
371        }
372
373        false
374    }
375
376    /// Returns the decoded face button input stream.
377    pub fn face_inputs(&self) -> &[FaceInput] {
378        &self.face_inputs
379    }
380
381    /// Returns the decoded analog stick input stream.
382    pub fn stick_inputs(&self) -> &[StickInput] {
383        &self.stick_inputs
384    }
385
386    /// Returns the decoded D-pad input stream.
387    pub fn dpad_inputs(&self) -> &[DPadInput] {
388        &self.dpad_inputs
389    }
390
391    /// Returns the number of face input entries as recorded in the stream header.
392    pub fn face_input_count(&self) -> u16 {
393        self.face_input_count
394    }
395
396    /// Returns the number of stick input entries as recorded in the stream header.
397    pub fn stick_input_count(&self) -> u16 {
398        self.stick_input_count
399    }
400
401    /// Returns the number of D-pad input entries as recorded in the stream header.
402    pub fn dpad_input_count(&self) -> u16 {
403        self.dpad_input_count
404    }
405}
406
407/// Decompresses a Yaz1-encoded byte slice into raw input data.
408///
409/// The slice must begin with the `Yaz1` magic followed by the uncompressed
410/// size as a big-endian `u32` and 8 bytes of padding. The result is
411/// zero-padded to `0x2774` bytes.
412///
413/// Returns `None` if the magic is missing, the data is truncated, or the
414/// decompressed output does not match the expected size.
415///
416/// Adapted from <https://github.com/AtishaRibeiro/InputDisplay/blob/master/InputDisplay/Core/Yaz1dec.cs>.
417pub(crate) fn yaz1_decompress(data: &[u8]) -> Option<Vec<u8>> {
418    // YAZ1 files start with "Yaz1" magic header
419    if data.len() < 16 || &data[0..4] != b"Yaz1" {
420        return None;
421    }
422
423    let uncompressed_size = u32::from_be_bytes([data[4], data[5], data[6], data[7]]) as usize;
424
425    let mut result = Vec::with_capacity(uncompressed_size);
426
427    let decompressed = decompress_block(
428        data,
429        16, // Start after 16-byte header
430        uncompressed_size,
431    );
432
433    if let Some(mut dec) = decompressed {
434        result.append(&mut dec);
435    }
436
437    if result.len() == uncompressed_size {
438        result.resize(0x2774, 0);
439        Some(result)
440    } else {
441        None
442    }
443}
444
445/// Decompresses a single Yaz1 block starting at `offset` within `src`.
446///
447/// Returns `None` if the source data is truncated mid-block. The output is
448/// exactly `uncompressed_size` bytes when successful.
449fn decompress_block(src: &[u8], offset: usize, uncompressed_size: usize) -> Option<Vec<u8>> {
450    let mut dst = Vec::with_capacity(uncompressed_size);
451    let mut src_pos = offset;
452
453    let mut valid_bit_count = 0; // number of valid bits left in "code" byte
454    let mut curr_code_byte = 0u8;
455
456    while dst.len() < uncompressed_size {
457        // Read new "code" byte if the current one is used up
458        if valid_bit_count == 0 {
459            if src_pos >= src.len() {
460                return None;
461            }
462            curr_code_byte = src[src_pos];
463            src_pos += 1;
464            valid_bit_count = 8;
465        }
466
467        if (curr_code_byte & 0x80) != 0 {
468            // Straight copy
469            if src_pos >= src.len() {
470                return None;
471            }
472            dst.push(src[src_pos]);
473            src_pos += 1;
474        } else {
475            // RLE part
476            if src_pos + 1 >= src.len() {
477                return None;
478            }
479
480            let byte1 = src[src_pos];
481            src_pos += 1;
482            let byte2 = src[src_pos];
483            src_pos += 1;
484
485            let dist = (((byte1 & 0xF) as usize) << 8) | (byte2 as usize);
486            let copy_source = dst.len().wrapping_sub(dist + 1);
487
488            let mut num_bytes = (byte1 >> 4) as usize;
489            if num_bytes == 0 {
490                if src_pos >= src.len() {
491                    return None;
492                }
493                num_bytes = src[src_pos] as usize + 0x12;
494                src_pos += 1;
495            } else {
496                num_bytes += 2;
497            }
498
499            // Copy run - must handle overlapping copies
500            for i in 0..num_bytes {
501                if copy_source + i >= dst.len() {
502                    return None;
503                }
504                let byte = dst[copy_source + i];
505                dst.push(byte);
506            }
507        }
508
509        // Use next bit from "code" byte
510        curr_code_byte <<= 1;
511        valid_bit_count -= 1;
512    }
513
514    Some(dst)
515}
516
517/// Compresses raw input data using Yaz1 encoding.
518///
519/// Trailing zero bytes (used to pad decompressed data to `0x2774` bytes) are
520/// stripped before compression. The output includes a full Yaz1 file header
521/// containing the compressed size, the `Yaz1` magic, the uncompressed size,
522/// and 8 bytes of padding.
523///
524/// Adapted from <https://github.com/AtishaRibeiro/TT-Rec-Tools/blob/dev/ghostmanager/Scripts/YAZ1_comp.js>.
525pub(crate) fn yaz1_compress(src: &[u8]) -> Vec<u8> {
526    // first remove padded 0s (decompressed input data is padded with 0s to 0x2774 bytes)
527    let mut trailing_bytes_to_remove = 0usize;
528    for idx in (0..src.len()).rev() {
529        if src[idx] == 0 {
530            trailing_bytes_to_remove += 1;
531        } else {
532            break;
533        }
534    }
535
536    let src = &src[0..src.len() - trailing_bytes_to_remove];
537
538    let mut dst = Vec::new();
539    let src_size = src.len();
540    let mut src_pos = 0;
541    let mut prev_flag = false;
542    let mut prev_num_bytes = 0;
543    let mut prev_match_pos = 0;
544
545    let mut code_byte = 0u8;
546    let mut valid_bit_count = 0;
547    let mut chunk = Vec::with_capacity(24); // 8 codes * 3 bytes maximum
548
549    while src_pos < src_size {
550        let (num_bytes, match_pos) = nintendo_encode(
551            src,
552            src_size,
553            src_pos,
554            &mut prev_flag,
555            &mut prev_num_bytes,
556            &mut prev_match_pos,
557        );
558
559        if num_bytes < 3 {
560            // Straight copy
561            chunk.push(src[src_pos]);
562            src_pos += 1;
563            // Set flag for straight copy
564            code_byte |= 0x80 >> valid_bit_count;
565        } else {
566            // RLE part
567            let dist = src_pos - match_pos - 1;
568
569            if num_bytes >= 0x12 {
570                // 3 byte encoding
571                let byte1 = (dist >> 8) as u8;
572                let byte2 = (dist & 0xff) as u8;
573                chunk.push(byte1);
574                chunk.push(byte2);
575
576                // Maximum runlength for 3 byte encoding
577                let num_bytes = num_bytes.min(0xff + 0x12);
578                let byte3 = (num_bytes - 0x12) as u8;
579                chunk.push(byte3);
580            } else {
581                // 2 byte encoding
582                let byte1 = (((num_bytes - 2) << 4) | (dist >> 8)) as u8;
583                let byte2 = (dist & 0xff) as u8;
584                chunk.push(byte1);
585                chunk.push(byte2);
586            }
587            src_pos += num_bytes;
588        }
589
590        valid_bit_count += 1;
591
592        // Write eight codes
593        if valid_bit_count == 8 {
594            dst.push(code_byte);
595            dst.extend_from_slice(&chunk);
596
597            code_byte = 0;
598            valid_bit_count = 0;
599            chunk.clear();
600        }
601    }
602
603    // Write remaining codes
604    if valid_bit_count > 0 {
605        dst.push(code_byte);
606        dst.extend_from_slice(&chunk);
607    }
608
609    let mut compressed_data = Vec::new();
610
611    // Write Yaz1 header
612    compressed_data.extend_from_slice(&((dst.len() + 8) as u32).to_be_bytes()); // size of compressed data
613    compressed_data.extend_from_slice(b"Yaz1");
614    compressed_data.extend_from_slice(&(src_size as u32).to_be_bytes());
615    compressed_data.extend_from_slice(&[0u8; 8]); // padding
616    compressed_data.extend_from_slice(&dst);
617    compressed_data
618}
619
620/// Determines the best encoding for the byte at `pos` using the Nintendo Yaz1 heuristic.
621///
622/// If the previous call set a lookahead flag, the cached values from that call
623/// are returned immediately. Otherwise [`simple_encode`] is run at `pos` and
624/// at `pos + 1`; if the next position's match is 2 or more bytes longer, the
625/// lookahead flag is set and a literal copy is emitted for the current position.
626fn nintendo_encode(
627    src: &[u8],
628    size: usize,
629    pos: usize,
630    prev_flag: &mut bool,
631    prev_num_bytes: &mut usize,
632    prev_match_pos: &mut usize,
633) -> (usize, usize) {
634    // If prevFlag is set, use the previously calculated values
635    if *prev_flag {
636        *prev_flag = false;
637        return (*prev_num_bytes, *prev_match_pos);
638    }
639
640    *prev_flag = false;
641    let (num_bytes, match_pos) = simple_encode(src, size, pos);
642
643    // If this position is RLE encoded, compare to copying 1 byte and next position encoding
644    if num_bytes >= 3 {
645        let (num_bytes1, match_pos1) = simple_encode(src, size, pos + 1);
646        *prev_num_bytes = num_bytes1;
647        *prev_match_pos = match_pos1;
648
649        // If the next position encoding is +2 longer, choose it
650        if num_bytes1 >= num_bytes + 2 {
651            *prev_flag = true;
652            return (1, match_pos);
653        }
654    }
655
656    (num_bytes, match_pos)
657}
658
659/// Finds the longest match for `src[pos..]` within the preceding `0x1000`-byte
660/// window using a simple linear scan.
661///
662/// Returns `(num_bytes, match_pos)` where `num_bytes` is the length of the
663/// longest match found (1 if no match of length ≥ 3 was found) and
664/// `match_pos` is the starting offset of that match in `src`.
665fn simple_encode(src: &[u8], size: usize, pos: usize) -> (usize, usize) {
666    let mut start_pos = pos as i32 - 0x1000;
667    let mut num_bytes = 1;
668    let mut match_pos = 0;
669
670    if start_pos < 0 {
671        start_pos = 0;
672    }
673    let start_pos = start_pos as usize;
674
675    for i in start_pos..pos {
676        let mut j = 0;
677        // Match the JavaScript loop condition exactly: j < size-pos
678        while j < size - pos {
679            if src[i + j] != src[j + pos] {
680                break;
681            }
682            j += 1;
683        }
684
685        if j > num_bytes {
686            num_bytes = j;
687            match_pos = i;
688        }
689    }
690
691    if num_bytes == 2 {
692        num_bytes = 1;
693    }
694
695    (num_bytes, match_pos)
696}