Skip to main content

rkg_utils/input_data/
mod.rs

1use crate::header::controller::Controller;
2use crate::input_data::dpad_input::{DPadButton, DPadInput};
3use crate::input_data::face_input::{FaceButton, FaceInput};
4use crate::input_data::input::Input;
5use crate::input_data::stick_input::StickInput;
6
7pub mod dpad_input;
8pub mod face_input;
9pub mod input;
10pub mod stick_input;
11
12/// Errors that can occur while parsing [`InputData`].
13#[derive(thiserror::Error, Debug)]
14pub enum InputDataError {
15    /// A face input entry could not be parsed.
16    #[error("Face Input Error: {0}")]
17    FaceInputError(#[from] face_input::FaceInputError),
18    /// A D-pad input entry could not be parsed.
19    #[error("DPad Input Error: {0}")]
20    DPadInputError(#[from] dpad_input::DPadInputError),
21    /// A stick input entry could not be parsed.
22    #[error("Stick Input Error: {0}")]
23    StickInputError(#[from] stick_input::StickInputError),
24}
25
26/// The controller input stream from a Mario Kart Wii RKG ghost file.
27///
28/// Stores the raw bytes (compressed or decompressed) alongside three decoded
29/// run-length encoded input streams: face buttons, analog stick, and D-pad.
30/// Adjacent identical entries across byte boundaries are merged during parsing
31/// so that each entry in the decoded vectors represents a single contiguous
32/// hold period.
33///
34/// The binary layout is documented at
35/// <https://wiki.tockdom.com/wiki/RKG_(File_Format)#Controller_Input_Data>.
36pub struct InputData {
37    /// The raw input data bytes as they appear in the file (may be Yaz1 compressed).
38    raw_data: Vec<u8>,
39    /// The number of face input entries as recorded in the stream header.
40    face_input_count: u16,
41    /// The number of stick input entries as recorded in the stream header.
42    stick_input_count: u16,
43    /// The number of D-pad input entries as recorded in the stream header.
44    dpad_input_count: u16,
45    /// The decoded and merged face button input stream.
46    face_inputs: Vec<FaceInput>,
47    /// The decoded and merged analog stick input stream.
48    stick_inputs: Vec<StickInput>,
49    /// The decoded D-pad input stream.
50    dpad_inputs: Vec<DPadInput>,
51}
52
53impl InputData {
54    /// Parses controller input data from raw RKG bytes starting at offset `0x88`.
55    ///
56    /// If the bytes beginning at offset 4 carry a `Yaz1` magic header, the
57    /// data is decompressed before parsing. Otherwise the slice is zero-padded
58    /// to `0x2774` bytes. After parsing, adjacent identical face and stick
59    /// entries are merged to represent each continuous hold as a single entry.
60    ///
61    /// # Errors
62    ///
63    /// Returns an [`InputDataError`] variant if any individual input entry
64    /// fails to parse.
65    pub fn new(input_data: &[u8]) -> Result<Self, InputDataError> {
66        let mut raw_data = Vec::from(input_data);
67
68        let input_data = if input_data[4..8] == [0x59, 0x61, 0x7A, 0x31] {
69            // YAZ1 header, decompress
70            yaz1_decompress(&input_data[4..]).unwrap()
71        } else {
72            raw_data.resize(0x2774, 0x00);
73            Vec::from(input_data)
74        };
75
76        let face_input_count = u16::from_be_bytes([input_data[0], input_data[1]]);
77        let stick_input_count = u16::from_be_bytes([input_data[2], input_data[3]]);
78        let dpad_input_count = u16::from_be_bytes([input_data[4], input_data[5]]);
79        // bytes 6-7: padding
80
81        let mut current_byte = 8;
82        let mut face_inputs: Vec<FaceInput> = Vec::with_capacity(face_input_count as usize);
83        while current_byte < 8 + face_input_count * 2 {
84            let idx = current_byte as usize;
85            let input = &input_data[idx..idx + 2];
86            face_inputs.push(FaceInput::try_from(input)?);
87            current_byte += 2;
88        }
89
90        current_byte = 8 + face_input_count * 2;
91        let mut stick_inputs: Vec<StickInput> = Vec::with_capacity(stick_input_count as usize);
92        while current_byte < 8 + (face_input_count + stick_input_count) * 2 {
93            let idx = current_byte as usize;
94            let input = &input_data[idx..idx + 2];
95            stick_inputs.push(StickInput::try_from(input)?);
96            current_byte += 2;
97        }
98
99        current_byte = 8 + (face_input_count + stick_input_count) * 2;
100        let mut dpad_inputs: Vec<DPadInput> = Vec::with_capacity(dpad_input_count as usize);
101        while current_byte < 8 + (face_input_count + stick_input_count + dpad_input_count) * 2 {
102            let idx = current_byte as usize;
103            let input = &input_data[idx..idx + 2];
104            dpad_inputs.push(DPadInput::try_from(input)?);
105            current_byte += 2;
106        }
107
108        // Combine adjacent inputs when the same button is held across multiple bytes
109        // (each input byte has a 255-frame limit, so buttons held longer need additional bytes)
110        for index in (0..face_inputs.len() - 1).rev() {
111            if face_inputs[index] == face_inputs[index + 1] {
112                let f1 = face_inputs[index].frame_duration();
113                let f2 = face_inputs[index + 1].frame_duration();
114                face_inputs[index].set_frame_duration(f1 + f2);
115                face_inputs.remove(index + 1);
116            }
117        }
118
119        for index in (0..stick_inputs.len() - 1).rev() {
120            if stick_inputs[index] == stick_inputs[index + 1] {
121                let f1 = stick_inputs[index].frame_duration();
122                let f2 = stick_inputs[index + 1].frame_duration();
123                stick_inputs[index].set_frame_duration(f1 + f2);
124                stick_inputs.remove(index + 1);
125            }
126        }
127
128        Ok(Self {
129            raw_data,
130            face_input_count,
131            stick_input_count,
132            dpad_input_count,
133            face_inputs,
134            stick_inputs,
135            dpad_inputs,
136        })
137    }
138
139    /// Returns the three input streams merged into a single frame-accurate sequence of [`Input`] values.
140    ///
141    /// The face, stick, and D-pad streams are interleaved by advancing through
142    /// all three simultaneously and emitting a new [`Input`] whenever any
143    /// stream transitions to its next entry. Each emitted [`Input`] covers
144    /// exactly the frames until the next transition across any stream.
145    pub fn inputs(&self) -> Vec<Input> {
146        let mut result = Vec::new();
147
148        // Track current position in each input stream
149        let mut face_idx = 0;
150        let mut stick_idx = 0;
151        let mut dpad_idx = 0;
152
153        // Track how many frames consumed from current input in each stream
154        let mut face_offset = 0u32;
155        let mut stick_offset = 0u32;
156        let mut dpad_offset = 0u32;
157
158        // Continue until all streams are exhausted
159        while face_idx < self.face_inputs.len()
160            || stick_idx < self.stick_inputs.len()
161            || dpad_idx < self.dpad_inputs.len()
162        {
163            // Get current input from each stream (or defaults if exhausted)
164            let face = self.face_inputs.get(face_idx);
165            let stick = self.stick_inputs.get(stick_idx);
166            let dpad = self.dpad_inputs.get(dpad_idx);
167
168            // Calculate remaining frames for current input in each stream
169            let face_remaining = face
170                .map(|f| f.frame_duration() - face_offset)
171                .unwrap_or(u32::MAX);
172            let stick_remaining = stick
173                .map(|s| s.frame_duration() - stick_offset)
174                .unwrap_or(u32::MAX);
175            let dpad_remaining = dpad
176                .map(|d| d.frame_duration() - dpad_offset)
177                .unwrap_or(u32::MAX);
178
179            // Find the minimum remaining frames (when next change occurs)
180            let duration = face_remaining.min(stick_remaining).min(dpad_remaining);
181
182            if duration == u32::MAX {
183                // if all streams exhausted
184                break;
185            }
186
187            // Create combined input for this duration
188            let combined = Input::new(
189                face.map(|f| f.buttons().clone()).unwrap_or_default(),
190                stick.map(|s| s.x()).unwrap_or(0),
191                stick.map(|s| s.y()).unwrap_or(0),
192                dpad.map(|d| d.button()).unwrap_or(DPadButton::None),
193                duration,
194            );
195            result.push(combined);
196
197            // Update offsets and advance indices where needed
198            face_offset += duration;
199            stick_offset += duration;
200            dpad_offset += duration;
201
202            if let Some(face) = face
203                && face_offset >= face.frame_duration()
204            {
205                face_idx += 1;
206                face_offset = 0;
207            }
208            if let Some(stick) = stick
209                && stick_offset >= stick.frame_duration()
210            {
211                stick_idx += 1;
212                stick_offset = 0;
213            }
214            if let Some(dpad) = dpad
215                && dpad_offset >= dpad.frame_duration()
216            {
217                dpad_idx += 1;
218                dpad_offset = 0;
219            }
220        }
221
222        result
223    }
224
225    /// Returns `true` if the face input stream contains an illegal drift or brake input.
226    ///
227    /// An input is illegal if drift is active without brake, or if brake and
228    /// accelerator are pressed simultaneously without drift when the previous
229    /// frame had accelerator but not brake (indicating a missing drift flag).
230    pub fn contains_illegal_brake_or_drift_inputs(&self) -> bool {
231        for (idx, input) in self.face_inputs().iter().enumerate() {
232            let current_buttons = input.buttons();
233            if current_buttons.contains(&FaceButton::Drift)
234                && !current_buttons.contains(&FaceButton::Brake)
235            {
236                // Illegal drift input
237                return true;
238            } else if idx > 0 {
239                let previous_buttons = self.face_inputs()[idx - 1].buttons();
240                if current_buttons.contains(&FaceButton::Brake)
241                    && current_buttons.contains(&FaceButton::Accelerator)
242                    && !current_buttons.contains(&FaceButton::Drift)
243                    && previous_buttons.contains(&FaceButton::Accelerator)
244                    && !previous_buttons.contains(&FaceButton::Brake)
245                {
246                    // Illegal brake input (drift flag isn't 1 when it should be)
247                    return true;
248                }
249            }
250        }
251        false
252    }
253
254    /// Returns `true` if the raw input data begins with a Yaz1 magic header at offset 4.
255    pub fn is_compressed(&self) -> bool {
256        self.raw_data[4..8] == [0x59, 0x61, 0x7A, 0x31]
257    }
258
259    /// Compresses the raw input data using Yaz1 encoding.
260    ///
261    /// Does nothing if the data is already compressed.
262    pub(crate) fn compress(&mut self) {
263        if !self.is_compressed() {
264            self.raw_data = yaz1_compress(&self.raw_data);
265        }
266    }
267
268    /// Decompresses the raw input data from Yaz1 encoding.
269    ///
270    /// Does nothing if the data is not compressed.
271    pub(crate) fn decompress(&mut self) {
272        if self.is_compressed() {
273            self.raw_data = yaz1_decompress(&self.raw_data[4..]).unwrap();
274        }
275    }
276
277    /// Returns the raw input data bytes as they appear in the file.
278    pub fn raw_data(&self) -> &[u8] {
279        &self.raw_data
280    }
281
282    /// Returns `true` if the stick input stream contains any illegal stick position
283    /// for the given controller type. More info on illegal input ranges here:
284    /// <https://github.com/malleoz/mkw-replay?tab=readme-ov-file#regarding-input-ranges>
285    /// <https://youtu.be/KUjS7qWWu9c?t=489>
286    ///
287    /// The Wii Wheel has a fully unrestricted input range and is never considered to
288    /// have illegal inputs.
289    pub fn contains_illegal_stick_inputs(&self, controller: Controller) -> bool {
290        // Definition of illegal stick inputs [x, y]
291        const ILLEGAL_STICK_INPUTS: [[i8; 2]; 44] = [
292            // These inputs are illegal for GCN, CCP, and Nunchuk (24 total)
293            [-7, 7],
294            [-7, 6],
295            [-7, 5],
296            [-7, -7],
297            [-7, -6],
298            [-7, -5],
299            [-6, 7],
300            [-6, 6],
301            [-6, -7],
302            [-6, -6],
303            [-5, 7],
304            [-5, -7],
305            [7, 7],
306            [7, 6],
307            [7, 5],
308            [7, -7],
309            [7, -6],
310            [7, -5],
311            [6, 7],
312            [6, 6],
313            [6, -7],
314            [6, -6],
315            [5, 7],
316            [5, -7],
317            // Illegal stick inputs for specifically GCN/CCP (additional 20)
318            [-7, 4],
319            [-6, 5],
320            [-5, 6],
321            [-4, 7],
322            [-3, 7],
323            [3, 7],
324            [4, 7],
325            [4, 6],
326            [4, -7],
327            [5, 6],
328            [5, 5],
329            [5, -6],
330            [6, 5],
331            [6, 4],
332            [6, -5],
333            [7, 4],
334            [7, 3],
335            [7, 2],
336            [7, -3],
337            [7, -4],
338        ];
339
340        let illegal_stick_inputs = match controller {
341            Controller::Nunchuk => &ILLEGAL_STICK_INPUTS[..24],
342            Controller::Classic | Controller::Gamecube => &ILLEGAL_STICK_INPUTS,
343            Controller::WiiWheel => {
344                return false;
345            }
346        };
347
348        for current_stick_input in self.stick_inputs().iter() {
349            for illegal_stick_input in illegal_stick_inputs.iter() {
350                if current_stick_input == illegal_stick_input {
351                    return true;
352                }
353            }
354        }
355
356        false
357    }
358
359    /// Returns the decoded face button input stream.
360    pub fn face_inputs(&self) -> &[FaceInput] {
361        &self.face_inputs
362    }
363
364    /// Returns the decoded analog stick input stream.
365    pub fn stick_inputs(&self) -> &[StickInput] {
366        &self.stick_inputs
367    }
368
369    /// Returns the decoded D-pad input stream.
370    pub fn dpad_inputs(&self) -> &[DPadInput] {
371        &self.dpad_inputs
372    }
373
374    /// Returns the number of face input entries as recorded in the stream header.
375    pub fn face_input_count(&self) -> u16 {
376        self.face_input_count
377    }
378
379    /// Returns the number of stick input entries as recorded in the stream header.
380    pub fn stick_input_count(&self) -> u16 {
381        self.stick_input_count
382    }
383
384    /// Returns the number of D-pad input entries as recorded in the stream header.
385    pub fn dpad_input_count(&self) -> u16 {
386        self.dpad_input_count
387    }
388}
389
390/// Decompresses a Yaz1-encoded byte slice into raw input data.
391///
392/// The slice must begin with the `Yaz1` magic followed by the uncompressed
393/// size as a big-endian `u32` and 8 bytes of padding. The result is
394/// zero-padded to `0x2774` bytes.
395///
396/// Returns `None` if the magic is missing, the data is truncated, or the
397/// decompressed output does not match the expected size.
398///
399/// Adapted from <https://github.com/AtishaRibeiro/InputDisplay/blob/master/InputDisplay/Core/Yaz1dec.cs>.
400pub(crate) fn yaz1_decompress(data: &[u8]) -> Option<Vec<u8>> {
401    // YAZ1 files start with "Yaz1" magic header
402    if data.len() < 16 || &data[0..4] != b"Yaz1" {
403        return None;
404    }
405
406    let uncompressed_size = u32::from_be_bytes([data[4], data[5], data[6], data[7]]) as usize;
407
408    let mut result = Vec::with_capacity(uncompressed_size);
409
410    let decompressed = decompress_block(
411        data,
412        16, // Start after 16-byte header
413        uncompressed_size,
414    );
415
416    if let Some(mut dec) = decompressed {
417        result.append(&mut dec);
418    }
419
420    if result.len() == uncompressed_size {
421        result.resize(0x2774, 0);
422        Some(result)
423    } else {
424        None
425    }
426}
427
428/// Decompresses a single Yaz1 block starting at `offset` within `src`.
429///
430/// Returns `None` if the source data is truncated mid-block. The output is
431/// exactly `uncompressed_size` bytes when successful.
432fn decompress_block(src: &[u8], offset: usize, uncompressed_size: usize) -> Option<Vec<u8>> {
433    let mut dst = Vec::with_capacity(uncompressed_size);
434    let mut src_pos = offset;
435
436    let mut valid_bit_count = 0; // number of valid bits left in "code" byte
437    let mut curr_code_byte = 0u8;
438
439    while dst.len() < uncompressed_size {
440        // Read new "code" byte if the current one is used up
441        if valid_bit_count == 0 {
442            if src_pos >= src.len() {
443                return None;
444            }
445            curr_code_byte = src[src_pos];
446            src_pos += 1;
447            valid_bit_count = 8;
448        }
449
450        if (curr_code_byte & 0x80) != 0 {
451            // Straight copy
452            if src_pos >= src.len() {
453                return None;
454            }
455            dst.push(src[src_pos]);
456            src_pos += 1;
457        } else {
458            // RLE part
459            if src_pos + 1 >= src.len() {
460                return None;
461            }
462
463            let byte1 = src[src_pos];
464            src_pos += 1;
465            let byte2 = src[src_pos];
466            src_pos += 1;
467
468            let dist = (((byte1 & 0xF) as usize) << 8) | (byte2 as usize);
469            let copy_source = dst.len().wrapping_sub(dist + 1);
470
471            let mut num_bytes = (byte1 >> 4) as usize;
472            if num_bytes == 0 {
473                if src_pos >= src.len() {
474                    return None;
475                }
476                num_bytes = src[src_pos] as usize + 0x12;
477                src_pos += 1;
478            } else {
479                num_bytes += 2;
480            }
481
482            // Copy run - must handle overlapping copies
483            for i in 0..num_bytes {
484                if copy_source + i >= dst.len() {
485                    return None;
486                }
487                let byte = dst[copy_source + i];
488                dst.push(byte);
489            }
490        }
491
492        // Use next bit from "code" byte
493        curr_code_byte <<= 1;
494        valid_bit_count -= 1;
495    }
496
497    Some(dst)
498}
499
500/// Compresses raw input data using Yaz1 encoding.
501///
502/// Trailing zero bytes (used to pad decompressed data to `0x2774` bytes) are
503/// stripped before compression. The output includes a full Yaz1 file header
504/// containing the compressed size, the `Yaz1` magic, the uncompressed size,
505/// and 8 bytes of padding.
506///
507/// Adapted from <https://github.com/AtishaRibeiro/TT-Rec-Tools/blob/dev/ghostmanager/Scripts/YAZ1_comp.js>.
508pub(crate) fn yaz1_compress(src: &[u8]) -> Vec<u8> {
509    // first remove padded 0s (decompressed input data is padded with 0s to 0x2774 bytes)
510    let mut trailing_bytes_to_remove = 0usize;
511    for idx in (0..src.len()).rev() {
512        if src[idx] == 0 {
513            trailing_bytes_to_remove += 1;
514        } else {
515            break;
516        }
517    }
518
519    let src = &src[0..src.len() - trailing_bytes_to_remove];
520
521    let mut dst = Vec::new();
522    let src_size = src.len();
523    let mut src_pos = 0;
524    let mut prev_flag = false;
525    let mut prev_num_bytes = 0;
526    let mut prev_match_pos = 0;
527
528    let mut code_byte = 0u8;
529    let mut valid_bit_count = 0;
530    let mut chunk = Vec::with_capacity(24); // 8 codes * 3 bytes maximum
531
532    while src_pos < src_size {
533        let (num_bytes, match_pos) = nintendo_encode(
534            src,
535            src_size,
536            src_pos,
537            &mut prev_flag,
538            &mut prev_num_bytes,
539            &mut prev_match_pos,
540        );
541
542        if num_bytes < 3 {
543            // Straight copy
544            chunk.push(src[src_pos]);
545            src_pos += 1;
546            // Set flag for straight copy
547            code_byte |= 0x80 >> valid_bit_count;
548        } else {
549            // RLE part
550            let dist = src_pos - match_pos - 1;
551
552            if num_bytes >= 0x12 {
553                // 3 byte encoding
554                let byte1 = (dist >> 8) as u8;
555                let byte2 = (dist & 0xff) as u8;
556                chunk.push(byte1);
557                chunk.push(byte2);
558
559                // Maximum runlength for 3 byte encoding
560                let num_bytes = num_bytes.min(0xff + 0x12);
561                let byte3 = (num_bytes - 0x12) as u8;
562                chunk.push(byte3);
563            } else {
564                // 2 byte encoding
565                let byte1 = (((num_bytes - 2) << 4) | (dist >> 8)) as u8;
566                let byte2 = (dist & 0xff) as u8;
567                chunk.push(byte1);
568                chunk.push(byte2);
569            }
570            src_pos += num_bytes;
571        }
572
573        valid_bit_count += 1;
574
575        // Write eight codes
576        if valid_bit_count == 8 {
577            dst.push(code_byte);
578            dst.extend_from_slice(&chunk);
579
580            code_byte = 0;
581            valid_bit_count = 0;
582            chunk.clear();
583        }
584    }
585
586    // Write remaining codes
587    if valid_bit_count > 0 {
588        dst.push(code_byte);
589        dst.extend_from_slice(&chunk);
590    }
591
592    let mut compressed_data = Vec::new();
593
594    // Write Yaz1 header
595    compressed_data.extend_from_slice(&((dst.len() + 8) as u32).to_be_bytes()); // size of compressed data
596    compressed_data.extend_from_slice(b"Yaz1");
597    compressed_data.extend_from_slice(&(src_size as u32).to_be_bytes());
598    compressed_data.extend_from_slice(&[0u8; 8]); // padding
599    compressed_data.extend_from_slice(&dst);
600    compressed_data
601}
602
603/// Determines the best encoding for the byte at `pos` using the Nintendo Yaz1 heuristic.
604///
605/// If the previous call set a lookahead flag, the cached values from that call
606/// are returned immediately. Otherwise [`simple_encode`] is run at `pos` and
607/// at `pos + 1`; if the next position's match is 2 or more bytes longer, the
608/// lookahead flag is set and a literal copy is emitted for the current position.
609fn nintendo_encode(
610    src: &[u8],
611    size: usize,
612    pos: usize,
613    prev_flag: &mut bool,
614    prev_num_bytes: &mut usize,
615    prev_match_pos: &mut usize,
616) -> (usize, usize) {
617    // If prevFlag is set, use the previously calculated values
618    if *prev_flag {
619        *prev_flag = false;
620        return (*prev_num_bytes, *prev_match_pos);
621    }
622
623    *prev_flag = false;
624    let (num_bytes, match_pos) = simple_encode(src, size, pos);
625
626    // If this position is RLE encoded, compare to copying 1 byte and next position encoding
627    if num_bytes >= 3 {
628        let (num_bytes1, match_pos1) = simple_encode(src, size, pos + 1);
629        *prev_num_bytes = num_bytes1;
630        *prev_match_pos = match_pos1;
631
632        // If the next position encoding is +2 longer, choose it
633        if num_bytes1 >= num_bytes + 2 {
634            *prev_flag = true;
635            return (1, match_pos);
636        }
637    }
638
639    (num_bytes, match_pos)
640}
641
642/// Finds the longest match for `src[pos..]` within the preceding `0x1000`-byte
643/// window using a simple linear scan.
644///
645/// Returns `(num_bytes, match_pos)` where `num_bytes` is the length of the
646/// longest match found (1 if no match of length ≥ 3 was found) and
647/// `match_pos` is the starting offset of that match in `src`.
648fn simple_encode(src: &[u8], size: usize, pos: usize) -> (usize, usize) {
649    let mut start_pos = pos as i32 - 0x1000;
650    let mut num_bytes = 1;
651    let mut match_pos = 0;
652
653    if start_pos < 0 {
654        start_pos = 0;
655    }
656    let start_pos = start_pos as usize;
657
658    for i in start_pos..pos {
659        let mut j = 0;
660        // Match the JavaScript loop condition exactly: j < size-pos
661        while j < size - pos {
662            if src[i + j] != src[j + pos] {
663                break;
664            }
665            j += 1;
666        }
667
668        if j > num_bytes {
669            num_bytes = j;
670            match_pos = i;
671        }
672    }
673
674    if num_bytes == 2 {
675        num_bytes = 1;
676    }
677
678    (num_bytes, match_pos)
679}