vt_push_parser/
capture.rs

1//! Raw-input-capturing push parser.
2
3use crate::{VT_PARSER_INTEREST_DEFAULT, VTEvent, VTPushParser};
4
5pub trait VTInputCaptureCallback {
6    fn event(&mut self, event: VTCaptureEvent<'_>) -> VTInputCapture;
7}
8
9impl<F: FnMut(VTCaptureEvent<'_>) -> VTInputCapture> VTInputCaptureCallback for F {
10    #[inline(always)]
11    fn event(&mut self, event: VTCaptureEvent<'_>) -> VTInputCapture {
12        self(event)
13    }
14}
15
16/// The type of capture mode to use after this event has been emitted.
17///
18/// The data will be emitted as a [`VTInputEvent::Captured`] event.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum VTInputCapture {
21    /// No capture mode. This must also be returned from any
22    /// [`VTInputEvent::Captured`] event.
23    None,
24    /// Capture a fixed number of bytes.
25    Count(usize),
26    /// Capture a fixed number of UTF-8 chars.
27    CountUtf8(usize),
28    /// Capture bytes until a terminator is found.
29    Terminator(&'static [u8]),
30}
31
32#[cfg_attr(feature = "serde", derive(serde::Serialize))]
33#[derive(Debug)]
34pub enum VTCaptureEvent<'a> {
35    VTEvent(VTEvent<'a>),
36    Capture(&'a [u8]),
37    CaptureEnd,
38}
39
40/// The internal state of the capture parser.
41///
42/// This is not part of the public API and may change at any time.
43#[doc(hidden)]
44#[derive(Debug, Default)]
45pub enum VTCaptureInternal {
46    #[default]
47    None,
48    Count(usize),
49    CountUtf8(usize),
50    Terminator(&'static [u8], usize),
51}
52
53impl VTCaptureInternal {
54    pub fn feed<'a>(&mut self, input: &mut &'a [u8]) -> Option<&'a [u8]> {
55        match self {
56            VTCaptureInternal::None => None,
57            VTCaptureInternal::Count(count) => {
58                if input.len() >= *count {
59                    let (capture, rest) = input.split_at(*count);
60                    *input = rest;
61                    *self = VTCaptureInternal::None;
62                    Some(capture)
63                } else {
64                    None
65                }
66            }
67            VTCaptureInternal::CountUtf8(count) => {
68                // Count UTF-8 characters, not bytes
69                let mut chars_found = 0;
70                let mut bytes_consumed = 0;
71
72                for (i, &byte) in input.iter().enumerate() {
73                    // Check if this is the start of a new UTF-8 character
74                    if byte & 0xC0 != 0x80 {
75                        // Not a continuation byte
76                        chars_found += 1;
77                        if chars_found == *count {
78                            // We found the nth character, now we need to find where it ends
79                            // by consuming all its continuation bytes
80                            let mut j = i + 1;
81                            while j < input.len() && input[j] & 0xC0 == 0x80 {
82                                j += 1;
83                            }
84                            bytes_consumed = j;
85                            break;
86                        }
87                    }
88                }
89
90                if chars_found == *count {
91                    let (capture, rest) = input.split_at(bytes_consumed);
92                    *input = rest;
93                    *self = VTCaptureInternal::None;
94                    Some(capture)
95                } else {
96                    None
97                }
98            }
99            VTCaptureInternal::Terminator(terminator, found) => {
100                // Ground state
101                if *found == 0 {
102                    if let Some(position) = input.iter().position(|&b| b == terminator[0]) {
103                        // Advance to first match position
104                        *found = 1;
105                        let unmatched = &input[..position];
106                        *input = &input[position + 1..];
107                        return Some(unmatched);
108                    } else {
109                        let unmatched = *input;
110                        *input = &[];
111                        return Some(unmatched);
112                    }
113                }
114
115                // We've already found part of the terminator, so we can continue
116                while *found < terminator.len() {
117                    if input.is_empty() {
118                        return None;
119                    }
120
121                    if input[0] == terminator[*found] {
122                        *found += 1;
123                        *input = &input[1..];
124                    } else {
125                        // Failed a match, so return the part of the terminator we already matched
126                        let old_found = std::mem::take(found);
127                        return Some(&terminator[..old_found]);
128                    }
129                }
130
131                // We've matched the entire terminator
132                *self = VTCaptureInternal::None;
133                None
134            }
135        }
136    }
137}
138
139/// A parser that allows for "capturing" of input data, ie: temporarily
140/// transferring control of the parser to unparsed data events.
141///
142/// This functions in the same way as [`VTPushParser`], but emits
143/// [`VTCaptureEvent`]s instead of [`VTEvent`]s.
144pub struct VTCapturePushParser<const INTEREST: u8 = VT_PARSER_INTEREST_DEFAULT> {
145    parser: VTPushParser<INTEREST>,
146    capture: VTCaptureInternal,
147}
148
149impl Default for VTCapturePushParser {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155impl VTCapturePushParser {
156    pub const fn new() -> VTCapturePushParser {
157        VTCapturePushParser::new_with_interest::<VT_PARSER_INTEREST_DEFAULT>()
158    }
159
160    pub const fn new_with_interest<const INTEREST: u8>() -> VTCapturePushParser<INTEREST> {
161        VTCapturePushParser::new_with()
162    }
163}
164
165impl<const INTEREST: u8> VTCapturePushParser<INTEREST> {
166    const fn new_with() -> Self {
167        Self {
168            parser: VTPushParser::new_with(),
169            capture: VTCaptureInternal::None,
170        }
171    }
172
173    pub fn is_ground(&self) -> bool {
174        self.parser.is_ground()
175    }
176
177    pub fn idle(&mut self) -> Option<VTCaptureEvent<'static>> {
178        self.parser.idle().map(VTCaptureEvent::VTEvent)
179    }
180
181    pub fn feed_with<F: VTInputCaptureCallback>(&mut self, mut input: &[u8], mut cb: F) {
182        while !input.is_empty() {
183            match &mut self.capture {
184                VTCaptureInternal::None => {
185                    // Normal parsing mode - feed to the underlying parser
186                    let count = self
187                        .parser
188                        .feed_with_abortable(input, &mut |event: VTEvent| {
189                            let capture_mode = cb.event(VTCaptureEvent::VTEvent(event));
190                            match capture_mode {
191                                VTInputCapture::None => {
192                                    // Stay in normal mode
193                                }
194                                VTInputCapture::Count(count) => {
195                                    self.capture = VTCaptureInternal::Count(count);
196                                }
197                                VTInputCapture::CountUtf8(count) => {
198                                    self.capture = VTCaptureInternal::CountUtf8(count);
199                                }
200                                VTInputCapture::Terminator(terminator) => {
201                                    self.capture = VTCaptureInternal::Terminator(terminator, 0);
202                                }
203                            }
204                            false // Don't abort parsing
205                        });
206
207                    input = &input[count..];
208                }
209                capture => {
210                    // Capture mode - collect data until capture is complete
211                    if let Some(captured_data) = capture.feed(&mut input) {
212                        cb.event(VTCaptureEvent::Capture(captured_data));
213                    }
214
215                    // Check if capture is complete
216                    if matches!(self.capture, VTCaptureInternal::None) {
217                        cb.event(VTCaptureEvent::CaptureEnd);
218                    }
219                }
220            }
221        }
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[test]
230    fn test_capture_paste() {
231        let mut output = String::new();
232        let mut parser = VTCapturePushParser::new();
233        parser.feed_with(
234            b"raw\x1b[200~paste\x1b[201~raw",
235            &mut |event: VTCaptureEvent| {
236                output.push_str(&format!("{event:?}\n"));
237                match event {
238                    VTCaptureEvent::VTEvent(VTEvent::Csi(csi)) => {
239                        if csi.params.try_parse::<usize>(0).unwrap_or(0) == 200 {
240                            VTInputCapture::Terminator(b"\x1b[201~")
241                        } else {
242                            VTInputCapture::None
243                        }
244                    }
245                    _ => VTInputCapture::None,
246                }
247            },
248        );
249        assert_eq!(
250            output.trim(),
251            r#"
252VTEvent(Raw('raw'))
253VTEvent(Csi('200', '', '~'))
254Capture([112, 97, 115, 116, 101])
255CaptureEnd
256VTEvent(Raw('raw'))
257"#
258            .trim()
259        );
260    }
261
262    #[test]
263    fn test_capture_count() {
264        let mut output = String::new();
265        let mut parser = VTCapturePushParser::new();
266        parser.feed_with(b"raw\x1b[Xpaste\x1b[Yraw", &mut |event: VTCaptureEvent| {
267            output.push_str(&format!("{event:?}\n"));
268            match event {
269                VTCaptureEvent::VTEvent(VTEvent::Csi(csi)) => {
270                    if csi.final_byte == b'X' {
271                        VTInputCapture::Count(5)
272                    } else {
273                        VTInputCapture::None
274                    }
275                }
276                _ => VTInputCapture::None,
277            }
278        });
279        assert_eq!(
280            output.trim(),
281            r#"
282VTEvent(Raw('raw'))
283VTEvent(Csi('', 'X'))
284Capture([112, 97, 115, 116, 101])
285CaptureEnd
286VTEvent(Csi('', 'Y'))
287VTEvent(Raw('raw'))
288"#
289            .trim()
290        );
291    }
292
293    #[test]
294    fn test_capture_count_utf8_but_ascii() {
295        let mut output = String::new();
296        let mut parser = VTCapturePushParser::new();
297        parser.feed_with(b"raw\x1b[Xpaste\x1b[Yraw", &mut |event: VTCaptureEvent| {
298            output.push_str(&format!("{event:?}\n"));
299            match event {
300                VTCaptureEvent::VTEvent(VTEvent::Csi(csi)) => {
301                    if csi.final_byte == b'X' {
302                        VTInputCapture::CountUtf8(5)
303                    } else {
304                        VTInputCapture::None
305                    }
306                }
307                _ => VTInputCapture::None,
308            }
309        });
310        assert_eq!(
311            output.trim(),
312            r#"
313VTEvent(Raw('raw'))
314VTEvent(Csi('', 'X'))
315Capture([112, 97, 115, 116, 101])
316CaptureEnd
317VTEvent(Csi('', 'Y'))
318VTEvent(Raw('raw'))
319"#
320            .trim()
321        );
322    }
323
324    #[test]
325    fn test_capture_count_utf8() {
326        let mut output = String::new();
327        let mut parser = VTCapturePushParser::new();
328        let input = "raw\u{001b}[X🤖🦕✅😀🕓\u{001b}[Yraw".as_bytes();
329        parser.feed_with(input, &mut |event: VTCaptureEvent| {
330            output.push_str(&format!("{event:?}\n"));
331            match event {
332                VTCaptureEvent::VTEvent(VTEvent::Csi(csi)) => {
333                    if csi.final_byte == b'X' {
334                        VTInputCapture::CountUtf8(5)
335                    } else {
336                        VTInputCapture::None
337                    }
338                }
339                _ => VTInputCapture::None,
340            }
341        });
342        assert_eq!(output.trim(), r#"
343VTEvent(Raw('raw'))
344VTEvent(Csi('', 'X'))
345Capture([240, 159, 164, 150, 240, 159, 166, 149, 226, 156, 133, 240, 159, 152, 128, 240, 159, 149, 147])
346CaptureEnd
347VTEvent(Csi('', 'Y'))
348VTEvent(Raw('raw'))
349"#.trim());
350    }
351
352    #[test]
353    fn test_capture_terminator_partial_match() {
354        let mut output = String::new();
355        let mut parser = VTCapturePushParser::new();
356
357        parser.feed_with(
358            b"start\x1b[200~part\x1b[201ial\x1b[201~end",
359            &mut |event: VTCaptureEvent| {
360                output.push_str(&format!("{event:?}\n"));
361                match event {
362                    VTCaptureEvent::VTEvent(VTEvent::Csi(csi)) => {
363                        if csi.final_byte == b'~'
364                            && csi.params.try_parse::<usize>(0).unwrap_or(0) == 200
365                        {
366                            VTInputCapture::Terminator(b"\x1b[201~")
367                        } else {
368                            VTInputCapture::None
369                        }
370                    }
371                    _ => VTInputCapture::None,
372                }
373            },
374        );
375
376        assert_eq!(
377            output.trim(),
378            r#"VTEvent(Raw('start'))
379VTEvent(Csi('200', '', '~'))
380Capture([112, 97, 114, 116])
381Capture([27, 91, 50, 48, 49])
382Capture([105, 97, 108])
383CaptureEnd
384VTEvent(Raw('end'))"#
385        );
386    }
387
388    #[test]
389    fn test_capture_terminator_partial_match_single_byte() {
390        let input = b"start\x1b[200~part\x1b[201ial\x1b[201~end";
391
392        for chunk_size in 1..5 {
393            let (captured, output) = capture_chunk_size(input, chunk_size);
394            assert_eq!(captured, b"part\x1b[201ial", "{output}",);
395        }
396    }
397
398    fn capture_chunk_size(input: &'static [u8; 32], chunk_size: usize) -> (Vec<u8>, String) {
399        let mut output = String::new();
400        let mut parser = VTCapturePushParser::new();
401        let mut captured = Vec::new();
402        for chunk in input.chunks(chunk_size) {
403            parser.feed_with(chunk, &mut |event: VTCaptureEvent| {
404                output.push_str(&format!("{event:?}\n"));
405                match event {
406                    VTCaptureEvent::Capture(data) => {
407                        captured.extend_from_slice(data);
408                        VTInputCapture::None
409                    }
410                    VTCaptureEvent::VTEvent(VTEvent::Csi(csi)) => {
411                        if csi.final_byte == b'~'
412                            && csi.params.try_parse::<usize>(0).unwrap_or(0) == 200
413                        {
414                            VTInputCapture::Terminator(b"\x1b[201~")
415                        } else {
416                            VTInputCapture::None
417                        }
418                    }
419                    _ => VTInputCapture::None,
420                }
421            });
422        }
423        (captured, output)
424    }
425}