Skip to main content

rkg_utils/input_data/
mod.rs

1use crate::header::controller::Controller;
2use crate::input_data::dpad_input::{DPadButton, DPadInput};
3use crate::input_data::face_input::{FaceButton, FaceInput};
4use crate::input_data::input::Input;
5use crate::input_data::stick_input::StickInput;
6
7pub mod dpad_input;
8pub mod face_input;
9pub mod input;
10pub mod stick_input;
11
12/// Errors that can occur while parsing [`InputData`].
13#[derive(thiserror::Error, Debug)]
14pub enum InputDataError {
15    /// Input data is impossibly short.
16    #[error("Input data length is too short")]
17    InputDataLengthTooShort,
18    /// A face input entry could not be parsed.
19    #[error("Face Input Error: {0}")]
20    FaceInputError(#[from] face_input::FaceInputError),
21    /// A D-pad input entry could not be parsed.
22    #[error("DPad Input Error: {0}")]
23    DPadInputError(#[from] dpad_input::DPadInputError),
24    /// A stick input entry could not be parsed.
25    #[error("Stick Input Error: {0}")]
26    StickInputError(#[from] stick_input::StickInputError),
27}
28
29/// The controller input stream from a Mario Kart Wii RKG ghost file.
30///
31/// Stores the raw bytes (compressed or decompressed) alongside three decoded
32/// run-length encoded input streams: face buttons, analog stick, and D-pad.
33/// Adjacent identical entries across byte boundaries are merged during parsing
34/// so that each entry in the decoded vectors represents a single contiguous
35/// hold period.
36///
37/// The binary layout is documented at
38/// <https://wiki.tockdom.com/wiki/RKG_(File_Format)#Controller_Input_Data>.
39pub struct InputData {
40    /// The raw input data bytes as they appear in the file (may be Yaz1 compressed).
41    raw_data: Vec<u8>,
42    /// The number of face input entries as recorded in the stream header.
43    face_input_count: u16,
44    /// The number of stick input entries as recorded in the stream header.
45    stick_input_count: u16,
46    /// The number of D-pad input entries as recorded in the stream header.
47    dpad_input_count: u16,
48    /// The decoded and merged face button input stream.
49    face_inputs: Vec<FaceInput>,
50    /// The decoded and merged analog stick input stream.
51    stick_inputs: Vec<StickInput>,
52    /// The decoded D-pad input stream.
53    dpad_inputs: Vec<DPadInput>,
54}
55
56impl InputData {
57    /// Parses controller input data from raw RKG bytes starting at offset `0x88`.
58    ///
59    /// If the bytes beginning at offset 4 carry a `Yaz1` magic header, the
60    /// data is decompressed before parsing. Otherwise the slice is zero-padded
61    /// to `0x2774` bytes. After parsing, adjacent identical face and stick
62    /// entries are merged to represent each continuous hold as a single entry.
63    ///
64    /// # Errors
65    ///
66    /// Returns an [`InputDataError`] variant if any individual input entry
67    /// fails to parse.
68    pub fn new(input_data: &[u8]) -> Result<Self, InputDataError> {
69        if input_data.len() < 0x08 {
70            return Err(InputDataError::InputDataLengthTooShort);
71        }
72
73        let mut raw_data = Vec::from(input_data);
74
75        let input_data = if input_data[4..8] == [0x59, 0x61, 0x7A, 0x31] {
76            // YAZ1 header, decompress
77            yaz1_decompress(&input_data[4..]).unwrap()
78        } else {
79            raw_data.resize(0x2774, 0x00);
80            Vec::from(input_data)
81        };
82
83        let face_input_count = u16::from_be_bytes([input_data[0], input_data[1]]);
84        let stick_input_count = u16::from_be_bytes([input_data[2], input_data[3]]);
85        let dpad_input_count = u16::from_be_bytes([input_data[4], input_data[5]]);
86        // bytes 6-7: padding
87
88        let mut current_byte = 8;
89        let mut face_inputs: Vec<FaceInput> = Vec::with_capacity(face_input_count as usize);
90        while current_byte < 8 + face_input_count * 2 {
91            let idx = current_byte as usize;
92            let input = &input_data[idx..idx + 2];
93            face_inputs.push(FaceInput::try_from(input)?);
94            current_byte += 2;
95        }
96
97        current_byte = 8 + face_input_count * 2;
98        let mut stick_inputs: Vec<StickInput> = Vec::with_capacity(stick_input_count as usize);
99        while current_byte < 8 + (face_input_count + stick_input_count) * 2 {
100            let idx = current_byte as usize;
101            let input = &input_data[idx..idx + 2];
102            stick_inputs.push(StickInput::try_from(input)?);
103            current_byte += 2;
104        }
105
106        current_byte = 8 + (face_input_count + stick_input_count) * 2;
107        let mut dpad_inputs: Vec<DPadInput> = Vec::with_capacity(dpad_input_count as usize);
108        while current_byte < 8 + (face_input_count + stick_input_count + dpad_input_count) * 2 {
109            let idx = current_byte as usize;
110            let input = &input_data[idx..idx + 2];
111            dpad_inputs.push(DPadInput::try_from(input)?);
112            current_byte += 2;
113        }
114
115        // Combine adjacent inputs when the same button is held across multiple bytes
116        // (each input byte has a 255-frame limit, so buttons held longer need additional bytes)
117        for index in (0..face_inputs.len() - 1).rev() {
118            if face_inputs[index] == face_inputs[index + 1] {
119                let f1 = face_inputs[index].frame_duration();
120                let f2 = face_inputs[index + 1].frame_duration();
121                face_inputs[index].set_frame_duration(f1 + f2);
122                face_inputs.remove(index + 1);
123            }
124        }
125
126        for index in (0..stick_inputs.len() - 1).rev() {
127            if stick_inputs[index] == stick_inputs[index + 1] {
128                let f1 = stick_inputs[index].frame_duration();
129                let f2 = stick_inputs[index + 1].frame_duration();
130                stick_inputs[index].set_frame_duration(f1 + f2);
131                stick_inputs.remove(index + 1);
132            }
133        }
134
135        Ok(Self {
136            raw_data,
137            face_input_count,
138            stick_input_count,
139            dpad_input_count,
140            face_inputs,
141            stick_inputs,
142            dpad_inputs,
143        })
144    }
145
146    /// Returns the three input streams merged into a single frame-accurate sequence of [`Input`] values.
147    ///
148    /// The face, stick, and D-pad streams are interleaved by advancing through
149    /// all three simultaneously and emitting a new [`Input`] whenever any
150    /// stream transitions to its next entry. Each emitted [`Input`] covers
151    /// exactly the frames until the next transition across any stream.
152    pub fn inputs(&self) -> Vec<Input> {
153        let mut result = Vec::new();
154
155        // Track current position in each input stream
156        let mut face_idx = 0;
157        let mut stick_idx = 0;
158        let mut dpad_idx = 0;
159
160        // Track how many frames consumed from current input in each stream
161        let mut face_offset = 0u32;
162        let mut stick_offset = 0u32;
163        let mut dpad_offset = 0u32;
164
165        // Continue until all streams are exhausted
166        while face_idx < self.face_inputs.len()
167            || stick_idx < self.stick_inputs.len()
168            || dpad_idx < self.dpad_inputs.len()
169        {
170            // Get current input from each stream (or defaults if exhausted)
171            let face = self.face_inputs.get(face_idx);
172            let stick = self.stick_inputs.get(stick_idx);
173            let dpad = self.dpad_inputs.get(dpad_idx);
174
175            // Calculate remaining frames for current input in each stream
176            let face_remaining = face
177                .map(|f| f.frame_duration() - face_offset)
178                .unwrap_or(u32::MAX);
179            let stick_remaining = stick
180                .map(|s| s.frame_duration() - stick_offset)
181                .unwrap_or(u32::MAX);
182            let dpad_remaining = dpad
183                .map(|d| d.frame_duration() - dpad_offset)
184                .unwrap_or(u32::MAX);
185
186            // Find the minimum remaining frames (when next change occurs)
187            let duration = face_remaining.min(stick_remaining).min(dpad_remaining);
188
189            if duration == u32::MAX {
190                // if all streams exhausted
191                break;
192            }
193
194            // Create combined input for this duration
195            let combined = Input::new(
196                face.map(|f| f.buttons().clone()).unwrap_or_default(),
197                stick.map(|s| s.x()).unwrap_or(0),
198                stick.map(|s| s.y()).unwrap_or(0),
199                dpad.map(|d| d.button()).unwrap_or(DPadButton::None),
200                duration,
201            );
202            result.push(combined);
203
204            // Update offsets and advance indices where needed
205            face_offset += duration;
206            stick_offset += duration;
207            dpad_offset += duration;
208
209            if let Some(face) = face
210                && face_offset >= face.frame_duration()
211            {
212                face_idx += 1;
213                face_offset = 0;
214            }
215            if let Some(stick) = stick
216                && stick_offset >= stick.frame_duration()
217            {
218                stick_idx += 1;
219                stick_offset = 0;
220            }
221            if let Some(dpad) = dpad
222                && dpad_offset >= dpad.frame_duration()
223            {
224                dpad_idx += 1;
225                dpad_offset = 0;
226            }
227        }
228
229        result
230    }
231
232    /// Returns `true` if the face input stream contains an illegal drift or brake input.
233    ///
234    /// An input is illegal if drift is active without brake, or if brake and
235    /// accelerator are pressed simultaneously without drift when the previous
236    /// frame had accelerator but not brake (indicating a missing drift flag).
237    pub fn contains_illegal_brake_or_drift_inputs(&self) -> bool {
238        for (idx, input) in self.face_inputs().iter().enumerate() {
239            let current_buttons = input.buttons();
240            if current_buttons.contains(&FaceButton::Drift)
241                && !current_buttons.contains(&FaceButton::Brake)
242            {
243                // Illegal drift input
244                return true;
245            } else if idx > 0 {
246                let previous_buttons = self.face_inputs()[idx - 1].buttons();
247                if current_buttons.contains(&FaceButton::Brake)
248                    && current_buttons.contains(&FaceButton::Accelerator)
249                    && !current_buttons.contains(&FaceButton::Drift)
250                    && previous_buttons.contains(&FaceButton::Accelerator)
251                    && !previous_buttons.contains(&FaceButton::Brake)
252                {
253                    // Illegal brake input (drift flag isn't 1 when it should be)
254                    return true;
255                }
256            }
257        }
258        false
259    }
260
261    /// Returns `true` if the raw input data begins with a Yaz1 magic header at offset 4.
262    pub fn is_compressed(&self) -> bool {
263        self.raw_data[4..8] == [0x59, 0x61, 0x7A, 0x31]
264    }
265
266    /// Compresses the raw input data using Yaz1 encoding.
267    ///
268    /// Does nothing if the data is already compressed.
269    pub(crate) fn compress(&mut self) {
270        if !self.is_compressed() {
271            self.raw_data = yaz1_compress(&self.raw_data);
272        }
273    }
274
275    /// Decompresses the raw input data from Yaz1 encoding.
276    ///
277    /// Does nothing if the data is not compressed.
278    pub(crate) fn decompress(&mut self) {
279        if self.is_compressed() {
280            self.raw_data = yaz1_decompress(&self.raw_data[4..]).unwrap();
281        }
282    }
283
284    /// Returns the raw input data bytes as they appear in the file.
285    pub fn raw_data(&self) -> &[u8] {
286        &self.raw_data
287    }
288
289    /// Returns `true` if the stick input stream contains any illegal stick position
290    /// for the given controller type. More info on illegal input ranges here:
291    /// <https://github.com/malleoz/mkw-replay?tab=readme-ov-file#regarding-input-ranges>
292    /// <https://youtu.be/KUjS7qWWu9c?t=489>
293    ///
294    /// The Wii Wheel has a fully unrestricted input range and is never considered to
295    /// have illegal inputs.
296    pub fn contains_illegal_stick_inputs(&self, controller: Controller) -> bool {
297        // Definition of illegal stick inputs [x, y]
298        const ILLEGAL_STICK_INPUTS: [[i8; 2]; 44] = [
299            // These inputs are illegal for GCN, CCP, and Nunchuk (24 total)
300            [-7, 7],
301            [-7, 6],
302            [-7, 5],
303            [-7, -7],
304            [-7, -6],
305            [-7, -5],
306            [-6, 7],
307            [-6, 6],
308            [-6, -7],
309            [-6, -6],
310            [-5, 7],
311            [-5, -7],
312            [7, 7],
313            [7, 6],
314            [7, 5],
315            [7, -7],
316            [7, -6],
317            [7, -5],
318            [6, 7],
319            [6, 6],
320            [6, -7],
321            [6, -6],
322            [5, 7],
323            [5, -7],
324            // Illegal stick inputs for specifically GCN/CCP (additional 20)
325            [-7, 4],
326            [-6, 5],
327            [-5, 6],
328            [-4, 7],
329            [-3, 7],
330            [3, 7],
331            [4, 7],
332            [4, 6],
333            [4, -7],
334            [5, 6],
335            [5, 5],
336            [5, -6],
337            [6, 5],
338            [6, 4],
339            [6, -5],
340            [7, 4],
341            [7, 3],
342            [7, 2],
343            [7, -3],
344            [7, -4],
345        ];
346
347        let illegal_stick_inputs = match controller {
348            Controller::Nunchuk => &ILLEGAL_STICK_INPUTS[..24],
349            Controller::Classic | Controller::Gamecube => &ILLEGAL_STICK_INPUTS,
350            Controller::WiiWheel => {
351                return false;
352            }
353        };
354
355        for current_stick_input in self.stick_inputs().iter() {
356            for illegal_stick_input in illegal_stick_inputs.iter() {
357                if current_stick_input == illegal_stick_input {
358                    return true;
359                }
360            }
361        }
362
363        false
364    }
365
366    /// Returns the decoded face button input stream.
367    pub fn face_inputs(&self) -> &[FaceInput] {
368        &self.face_inputs
369    }
370
371    /// Returns the decoded analog stick input stream.
372    pub fn stick_inputs(&self) -> &[StickInput] {
373        &self.stick_inputs
374    }
375
376    /// Returns the decoded D-pad input stream.
377    pub fn dpad_inputs(&self) -> &[DPadInput] {
378        &self.dpad_inputs
379    }
380
381    /// Returns the number of face input entries as recorded in the stream header.
382    pub fn face_input_count(&self) -> u16 {
383        self.face_input_count
384    }
385
386    /// Returns the number of stick input entries as recorded in the stream header.
387    pub fn stick_input_count(&self) -> u16 {
388        self.stick_input_count
389    }
390
391    /// Returns the number of D-pad input entries as recorded in the stream header.
392    pub fn dpad_input_count(&self) -> u16 {
393        self.dpad_input_count
394    }
395}
396
397/// Decompresses a Yaz1-encoded byte slice into raw input data.
398///
399/// The slice must begin with the `Yaz1` magic followed by the uncompressed
400/// size as a big-endian `u32` and 8 bytes of padding. The result is
401/// zero-padded to `0x2774` bytes.
402///
403/// Returns `None` if the magic is missing, the data is truncated, or the
404/// decompressed output does not match the expected size.
405///
406/// Adapted from <https://github.com/AtishaRibeiro/InputDisplay/blob/master/InputDisplay/Core/Yaz1dec.cs>.
407pub(crate) fn yaz1_decompress(data: &[u8]) -> Option<Vec<u8>> {
408    // YAZ1 files start with "Yaz1" magic header
409    if data.len() < 16 || &data[0..4] != b"Yaz1" {
410        return None;
411    }
412
413    let uncompressed_size = u32::from_be_bytes([data[4], data[5], data[6], data[7]]) as usize;
414
415    let mut result = Vec::with_capacity(uncompressed_size);
416
417    let decompressed = decompress_block(
418        data,
419        16, // Start after 16-byte header
420        uncompressed_size,
421    );
422
423    if let Some(mut dec) = decompressed {
424        result.append(&mut dec);
425    }
426
427    if result.len() == uncompressed_size {
428        result.resize(0x2774, 0);
429        Some(result)
430    } else {
431        None
432    }
433}
434
435/// Decompresses a single Yaz1 block starting at `offset` within `src`.
436///
437/// Returns `None` if the source data is truncated mid-block. The output is
438/// exactly `uncompressed_size` bytes when successful.
439fn decompress_block(src: &[u8], offset: usize, uncompressed_size: usize) -> Option<Vec<u8>> {
440    let mut dst = Vec::with_capacity(uncompressed_size);
441    let mut src_pos = offset;
442
443    let mut valid_bit_count = 0; // number of valid bits left in "code" byte
444    let mut curr_code_byte = 0u8;
445
446    while dst.len() < uncompressed_size {
447        // Read new "code" byte if the current one is used up
448        if valid_bit_count == 0 {
449            if src_pos >= src.len() {
450                return None;
451            }
452            curr_code_byte = src[src_pos];
453            src_pos += 1;
454            valid_bit_count = 8;
455        }
456
457        if (curr_code_byte & 0x80) != 0 {
458            // Straight copy
459            if src_pos >= src.len() {
460                return None;
461            }
462            dst.push(src[src_pos]);
463            src_pos += 1;
464        } else {
465            // RLE part
466            if src_pos + 1 >= src.len() {
467                return None;
468            }
469
470            let byte1 = src[src_pos];
471            src_pos += 1;
472            let byte2 = src[src_pos];
473            src_pos += 1;
474
475            let dist = (((byte1 & 0xF) as usize) << 8) | (byte2 as usize);
476            let copy_source = dst.len().wrapping_sub(dist + 1);
477
478            let mut num_bytes = (byte1 >> 4) as usize;
479            if num_bytes == 0 {
480                if src_pos >= src.len() {
481                    return None;
482                }
483                num_bytes = src[src_pos] as usize + 0x12;
484                src_pos += 1;
485            } else {
486                num_bytes += 2;
487            }
488
489            // Copy run - must handle overlapping copies
490            for i in 0..num_bytes {
491                if copy_source + i >= dst.len() {
492                    return None;
493                }
494                let byte = dst[copy_source + i];
495                dst.push(byte);
496            }
497        }
498
499        // Use next bit from "code" byte
500        curr_code_byte <<= 1;
501        valid_bit_count -= 1;
502    }
503
504    Some(dst)
505}
506
507/// Compresses raw input data using Yaz1 encoding.
508///
509/// Trailing zero bytes (used to pad decompressed data to `0x2774` bytes) are
510/// stripped before compression. The output includes a full Yaz1 file header
511/// containing the compressed size, the `Yaz1` magic, the uncompressed size,
512/// and 8 bytes of padding.
513///
514/// Adapted from <https://github.com/AtishaRibeiro/TT-Rec-Tools/blob/dev/ghostmanager/Scripts/YAZ1_comp.js>.
515pub(crate) fn yaz1_compress(src: &[u8]) -> Vec<u8> {
516    // first remove padded 0s (decompressed input data is padded with 0s to 0x2774 bytes)
517    let mut trailing_bytes_to_remove = 0usize;
518    for idx in (0..src.len()).rev() {
519        if src[idx] == 0 {
520            trailing_bytes_to_remove += 1;
521        } else {
522            break;
523        }
524    }
525
526    let src = &src[0..src.len() - trailing_bytes_to_remove];
527
528    let mut dst = Vec::new();
529    let src_size = src.len();
530    let mut src_pos = 0;
531    let mut prev_flag = false;
532    let mut prev_num_bytes = 0;
533    let mut prev_match_pos = 0;
534
535    let mut code_byte = 0u8;
536    let mut valid_bit_count = 0;
537    let mut chunk = Vec::with_capacity(24); // 8 codes * 3 bytes maximum
538
539    while src_pos < src_size {
540        let (num_bytes, match_pos) = nintendo_encode(
541            src,
542            src_size,
543            src_pos,
544            &mut prev_flag,
545            &mut prev_num_bytes,
546            &mut prev_match_pos,
547        );
548
549        if num_bytes < 3 {
550            // Straight copy
551            chunk.push(src[src_pos]);
552            src_pos += 1;
553            // Set flag for straight copy
554            code_byte |= 0x80 >> valid_bit_count;
555        } else {
556            // RLE part
557            let dist = src_pos - match_pos - 1;
558
559            if num_bytes >= 0x12 {
560                // 3 byte encoding
561                let byte1 = (dist >> 8) as u8;
562                let byte2 = (dist & 0xff) as u8;
563                chunk.push(byte1);
564                chunk.push(byte2);
565
566                // Maximum runlength for 3 byte encoding
567                let num_bytes = num_bytes.min(0xff + 0x12);
568                let byte3 = (num_bytes - 0x12) as u8;
569                chunk.push(byte3);
570            } else {
571                // 2 byte encoding
572                let byte1 = (((num_bytes - 2) << 4) | (dist >> 8)) as u8;
573                let byte2 = (dist & 0xff) as u8;
574                chunk.push(byte1);
575                chunk.push(byte2);
576            }
577            src_pos += num_bytes;
578        }
579
580        valid_bit_count += 1;
581
582        // Write eight codes
583        if valid_bit_count == 8 {
584            dst.push(code_byte);
585            dst.extend_from_slice(&chunk);
586
587            code_byte = 0;
588            valid_bit_count = 0;
589            chunk.clear();
590        }
591    }
592
593    // Write remaining codes
594    if valid_bit_count > 0 {
595        dst.push(code_byte);
596        dst.extend_from_slice(&chunk);
597    }
598
599    let mut compressed_data = Vec::new();
600
601    // Write Yaz1 header
602    compressed_data.extend_from_slice(&((dst.len() + 8) as u32).to_be_bytes()); // size of compressed data
603    compressed_data.extend_from_slice(b"Yaz1");
604    compressed_data.extend_from_slice(&(src_size as u32).to_be_bytes());
605    compressed_data.extend_from_slice(&[0u8; 8]); // padding
606    compressed_data.extend_from_slice(&dst);
607    compressed_data
608}
609
610/// Determines the best encoding for the byte at `pos` using the Nintendo Yaz1 heuristic.
611///
612/// If the previous call set a lookahead flag, the cached values from that call
613/// are returned immediately. Otherwise [`simple_encode`] is run at `pos` and
614/// at `pos + 1`; if the next position's match is 2 or more bytes longer, the
615/// lookahead flag is set and a literal copy is emitted for the current position.
616fn nintendo_encode(
617    src: &[u8],
618    size: usize,
619    pos: usize,
620    prev_flag: &mut bool,
621    prev_num_bytes: &mut usize,
622    prev_match_pos: &mut usize,
623) -> (usize, usize) {
624    // If prevFlag is set, use the previously calculated values
625    if *prev_flag {
626        *prev_flag = false;
627        return (*prev_num_bytes, *prev_match_pos);
628    }
629
630    *prev_flag = false;
631    let (num_bytes, match_pos) = simple_encode(src, size, pos);
632
633    // If this position is RLE encoded, compare to copying 1 byte and next position encoding
634    if num_bytes >= 3 {
635        let (num_bytes1, match_pos1) = simple_encode(src, size, pos + 1);
636        *prev_num_bytes = num_bytes1;
637        *prev_match_pos = match_pos1;
638
639        // If the next position encoding is +2 longer, choose it
640        if num_bytes1 >= num_bytes + 2 {
641            *prev_flag = true;
642            return (1, match_pos);
643        }
644    }
645
646    (num_bytes, match_pos)
647}
648
649/// Finds the longest match for `src[pos..]` within the preceding `0x1000`-byte
650/// window using a simple linear scan.
651///
652/// Returns `(num_bytes, match_pos)` where `num_bytes` is the length of the
653/// longest match found (1 if no match of length ≥ 3 was found) and
654/// `match_pos` is the starting offset of that match in `src`.
655fn simple_encode(src: &[u8], size: usize, pos: usize) -> (usize, usize) {
656    let mut start_pos = pos as i32 - 0x1000;
657    let mut num_bytes = 1;
658    let mut match_pos = 0;
659
660    if start_pos < 0 {
661        start_pos = 0;
662    }
663    let start_pos = start_pos as usize;
664
665    for i in start_pos..pos {
666        let mut j = 0;
667        // Match the JavaScript loop condition exactly: j < size-pos
668        while j < size - pos {
669            if src[i + j] != src[j + pos] {
670                break;
671            }
672            j += 1;
673        }
674
675        if j > num_bytes {
676            num_bytes = j;
677            match_pos = i;
678        }
679    }
680
681    if num_bytes == 2 {
682        num_bytes = 1;
683    }
684
685    (num_bytes, match_pos)
686}