Skip to main content

tui/testing/
test_terminal.rs

1use std::io::{self, Write};
2
3use crossterm::style::Color;
4
5use crate::Style;
6
7/// A single cell in the terminal buffer, storing both a character and its style.
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct Cell {
10    pub ch: char,
11    pub style: Style,
12}
13
14impl Default for Cell {
15    fn default() -> Self {
16        Self {
17            ch: ' ',
18            style: Style::default(),
19        }
20    }
21}
22
23impl Cell {
24    fn new(ch: char, style: Style) -> Self {
25        Self { ch, style }
26    }
27}
28
29/// A virtual terminal buffer for testing terminal output.
30/// Captures all writes, tracks cursor position, and parses ANSI escape sequences
31/// including SGR (Select Graphic Rendition) codes for style tracking.
32///
33/// Implements delayed wrapping (DEC-style): when the cursor reaches the last
34/// column, it stays there with a pending-wrap flag. The next printable character
35/// triggers the wrap to column 0 of the next line. A `\r` clears the flag.
36#[derive(Debug, Clone)]
37pub struct TestTerminal {
38    /// 2D buffer of cells (row, column)
39    buffer: Vec<Vec<Cell>>,
40    /// Rows that have scrolled off the top of the visible buffer.
41    scrollback: Vec<Vec<Cell>>,
42    /// Current cursor position (column, row)
43    cursor: (u16, u16),
44    /// Saved cursor position (for save/restore)
45    saved_cursor: Option<(u16, u16)>,
46    /// Terminal size (columns, rows)
47    size: (u16, u16),
48    /// Buffer for incomplete escape sequences
49    escape_buffer: Vec<u8>,
50    /// Delayed wrap: cursor hit last column but hasn't wrapped yet
51    pending_wrap: bool,
52    /// Current SGR style applied to newly written characters
53    current_style: Style,
54}
55
56impl TestTerminal {
57    /// Create a new test terminal with given size
58    pub fn new(columns: u16, rows: u16) -> Self {
59        let buffer = vec![vec![Cell::default(); columns as usize]; rows as usize];
60        Self {
61            buffer,
62            scrollback: Vec::new(),
63            cursor: (0, 0),
64            saved_cursor: None,
65            size: (columns, rows),
66            escape_buffer: Vec::new(),
67            pending_wrap: false,
68            current_style: Style::default(),
69        }
70    }
71
72    /// Resize terminal without preserving prior transcript content.
73    pub fn resize(&mut self, columns: u16, rows: u16) {
74        let columns = columns.max(1);
75        let rows = rows.max(1);
76        self.buffer = vec![vec![Cell::default(); columns as usize]; rows as usize];
77        self.scrollback.clear();
78        self.size = (columns, rows);
79        self.cursor = (0, rows.saturating_sub(1));
80        self.saved_cursor = None;
81        self.pending_wrap = false;
82    }
83
84    /// Resize terminal and reflow existing transcript content to match the new width.
85    pub fn resize_preserving_transcript(&mut self, columns: u16, rows: u16) {
86        let transcript = self.get_transcript_lines();
87        let wrapped = Self::reflow_lines(&transcript, columns);
88        self.apply_reflowed_lines(columns, rows, wrapped);
89    }
90
91    fn reflow_lines(lines: &[String], columns: u16) -> Vec<String> {
92        let mut wrapped = Vec::new();
93        let width = columns.max(1) as usize;
94
95        for line in lines {
96            if line.is_empty() {
97                wrapped.push(String::new());
98                continue;
99            }
100
101            let chars: Vec<char> = line.chars().collect();
102            for chunk in chars.chunks(width) {
103                wrapped.push(chunk.iter().collect());
104            }
105        }
106
107        if wrapped.is_empty() {
108            wrapped.push(String::new());
109        }
110
111        wrapped
112    }
113
114    fn apply_reflowed_lines(&mut self, columns: u16, rows: u16, wrapped: Vec<String>) {
115        let rows_usize = rows.max(1) as usize;
116        let split_at = wrapped.len().saturating_sub(rows_usize);
117        let (scrollback, visible) = wrapped.split_at(split_at);
118
119        self.scrollback = scrollback
120            .iter()
121            .map(|line| Self::line_to_row(line, columns))
122            .collect();
123
124        self.buffer = visible
125            .iter()
126            .map(|line| Self::line_to_row(line, columns))
127            .collect();
128
129        while self.buffer.len() < rows_usize {
130            self.buffer.push(vec![Cell::default(); columns as usize]);
131        }
132
133        self.size = (columns, rows);
134        self.cursor = (0, rows.saturating_sub(1));
135        self.saved_cursor = None;
136        self.pending_wrap = false;
137    }
138
139    fn line_to_row(line: &str, columns: u16) -> Vec<Cell> {
140        let mut row: Vec<Cell> = line
141            .chars()
142            .take(columns as usize)
143            .map(|ch| Cell::new(ch, Style::default()))
144            .collect();
145        row.resize(columns as usize, Cell::default());
146        row
147    }
148
149    /// Get all lines as a vector of strings (trailing whitespace trimmed)
150    pub fn get_lines(&self) -> Vec<String> {
151        self.buffer
152            .iter()
153            .map(|cells| {
154                cells
155                    .iter()
156                    .map(|c| c.ch)
157                    .collect::<String>()
158                    .trim_end()
159                    .to_string()
160            })
161            .collect()
162    }
163
164    /// Get full terminal transcript (scrollback history + visible buffer).
165    pub fn get_transcript_lines(&self) -> Vec<String> {
166        self.scrollback
167            .iter()
168            .chain(self.buffer.iter())
169            .map(|cells| {
170                cells
171                    .iter()
172                    .map(|c| c.ch)
173                    .collect::<String>()
174                    .trim_end()
175                    .to_string()
176            })
177            .collect()
178    }
179
180    /// Get current cursor position as (column, row).
181    #[allow(dead_code)]
182    pub fn cursor_position(&self) -> (u16, u16) {
183        self.cursor
184    }
185
186    /// Get the style at a specific buffer position.
187    pub fn get_style_at(&self, row: usize, col: usize) -> Style {
188        self.buffer
189            .get(row)
190            .and_then(|r| r.get(col))
191            .map_or(Style::default(), |c| c.style)
192    }
193
194    /// Find the first occurrence of `text` on the given row and return its style.
195    ///
196    /// Returns the style of the first character of the matched text.
197    pub fn style_of_text(&self, row: usize, text: &str) -> Option<Style> {
198        let row_data = self.buffer.get(row)?;
199        let row_text: String = row_data.iter().map(|c| c.ch).collect();
200        let byte_offset = row_text.find(text)?;
201        // Convert byte offset to character index
202        let char_index = row_text[..byte_offset].chars().count();
203        Some(row_data[char_index].style)
204    }
205
206    /// Clear the entire buffer
207    pub fn clear(&mut self) {
208        for row in &mut self.buffer {
209            for cell in row {
210                *cell = Cell::default();
211            }
212        }
213    }
214
215    /// Clear the current line
216    pub fn clear_line(&mut self) {
217        if let Some(row) = self.buffer.get_mut(self.cursor.1 as usize) {
218            for cell in row {
219                *cell = Cell::default();
220            }
221        }
222    }
223
224    /// Move cursor to absolute position
225    pub fn move_to(&mut self, col: u16, row: u16) {
226        self.cursor = (
227            col.min(self.size.0.saturating_sub(1)),
228            row.min(self.size.1.saturating_sub(1)),
229        );
230        self.pending_wrap = false;
231    }
232
233    /// Move cursor to column (keep same row)
234    pub fn move_to_column(&mut self, col: u16) {
235        self.cursor.0 = col.min(self.size.0.saturating_sub(1));
236        self.pending_wrap = false;
237    }
238
239    /// Move cursor left by n positions
240    pub fn move_left(&mut self, n: u16) {
241        self.cursor.0 = self.cursor.0.saturating_sub(n);
242        self.pending_wrap = false;
243    }
244
245    /// Move cursor right by n positions
246    pub fn move_right(&mut self, n: u16) {
247        self.cursor.0 = (self.cursor.0 + n).min(self.size.0.saturating_sub(1));
248        self.pending_wrap = false;
249    }
250
251    /// Write a single character at current cursor position and advance cursor
252    fn write_char(&mut self, ch: char) {
253        match ch {
254            '\n' => {
255                self.pending_wrap = false;
256                if self.cursor.1 >= self.size.1.saturating_sub(1) {
257                    let removed = self.buffer.remove(0);
258                    self.scrollback.push(removed);
259                    self.buffer
260                        .push(vec![Cell::default(); self.size.0 as usize]);
261                } else {
262                    self.cursor.1 += 1;
263                }
264                self.cursor.0 = 0;
265            }
266            '\r' => {
267                self.cursor.0 = 0;
268                self.pending_wrap = false;
269            }
270            '\t' => {
271                for _ in 0..4 {
272                    self.write_char_at_cursor(' ');
273                }
274            }
275            _ => {
276                self.write_char_at_cursor(ch);
277            }
278        }
279    }
280
281    /// Write a character at the current cursor position (delayed wrap).
282    ///
283    /// When the cursor is at the last column with `pending_wrap` set,
284    /// the next printable character triggers the wrap first.
285    fn write_char_at_cursor(&mut self, ch: char) {
286        if self.pending_wrap {
287            self.pending_wrap = false;
288            self.cursor.0 = 0;
289            if self.cursor.1 >= self.size.1.saturating_sub(1) {
290                let removed = self.buffer.remove(0);
291                self.scrollback.push(removed);
292                self.buffer
293                    .push(vec![Cell::default(); self.size.0 as usize]);
294            } else {
295                self.cursor.1 += 1;
296            }
297        }
298
299        if let Some(row) = self.buffer.get_mut(self.cursor.1 as usize)
300            && let Some(cell) = row.get_mut(self.cursor.0 as usize)
301        {
302            *cell = Cell::new(ch, self.current_style);
303            self.cursor.0 += 1;
304            if self.cursor.0 >= self.size.0 {
305                self.cursor.0 = self.size.0 - 1;
306                self.pending_wrap = true;
307            }
308        }
309    }
310
311    /// Process a byte slice, handling ANSI escape sequences
312    fn process_bytes(&mut self, buf: &[u8]) {
313        let s = String::from_utf8_lossy(buf);
314        let mut chars = s.chars().peekable();
315
316        while let Some(ch) = chars.next() {
317            if ch == '\x1b' {
318                if chars.peek() == Some(&'[') {
319                    chars.next();
320                    self.process_csi_sequence(&mut chars);
321                } else if chars.peek() == Some(&'7') {
322                    chars.next();
323                    self.saved_cursor = Some(self.cursor);
324                } else if chars.peek() == Some(&'8') {
325                    chars.next();
326                    if let Some(saved) = self.saved_cursor {
327                        self.cursor = saved;
328                    }
329                }
330            } else {
331                self.write_char(ch);
332            }
333        }
334    }
335
336    /// Process a CSI (Control Sequence Introducer) escape sequence
337    #[allow(clippy::too_many_lines)]
338    fn process_csi_sequence(&mut self, chars: &mut std::iter::Peekable<std::str::Chars>) {
339        let private_mode = if chars.peek() == Some(&'?') {
340            chars.next();
341            true
342        } else {
343            false
344        };
345
346        let mut params = String::new();
347
348        while let Some(&ch) = chars.peek() {
349            if ch.is_ascii_digit() || ch == ';' || ch == ':' {
350                params.push(ch);
351                chars.next();
352            } else {
353                break;
354            }
355        }
356
357        if private_mode {
358            chars.next();
359            return;
360        }
361
362        if let Some(cmd) = chars.next() {
363            match cmd {
364                'H' | 'f' => {
365                    let parts: Vec<u16> =
366                        params.split(';').filter_map(|s| s.parse().ok()).collect();
367                    let row = parts.first().copied().unwrap_or(1).saturating_sub(1);
368                    let col = parts.get(1).copied().unwrap_or(1).saturating_sub(1);
369                    self.move_to(col, row);
370                }
371                'A' => {
372                    let n = params.parse().unwrap_or(1);
373                    self.cursor.1 = self.cursor.1.saturating_sub(n);
374                    self.pending_wrap = false;
375                }
376                'B' => {
377                    let n = params.parse().unwrap_or(1);
378                    self.cursor.1 = (self.cursor.1 + n).min(self.size.1.saturating_sub(1));
379                    self.pending_wrap = false;
380                }
381                'C' => {
382                    let n = params.parse().unwrap_or(1);
383                    self.move_right(n);
384                }
385                'D' => {
386                    let n = params.parse().unwrap_or(1);
387                    self.move_left(n);
388                }
389                'G' => {
390                    let col = params.parse::<u16>().unwrap_or(1).saturating_sub(1);
391                    self.move_to_column(col);
392                }
393                'J' => {
394                    let n = params.parse().unwrap_or(0);
395                    match n {
396                        0 => {
397                            for row in self.cursor.1..self.size.1 {
398                                if let Some(r) = self.buffer.get_mut(row as usize) {
399                                    let start = if row == self.cursor.1 {
400                                        self.cursor.0 as usize
401                                    } else {
402                                        0
403                                    };
404                                    for cell in r.iter_mut().skip(start) {
405                                        *cell = Cell::default();
406                                    }
407                                }
408                            }
409                        }
410                        2 => {
411                            self.clear();
412                        }
413                        _ => {}
414                    }
415                }
416                'K' => {
417                    let n = params.parse().unwrap_or(0);
418                    match n {
419                        0 => {
420                            if let Some(row) = self.buffer.get_mut(self.cursor.1 as usize) {
421                                for cell in row.iter_mut().skip(self.cursor.0 as usize) {
422                                    *cell = Cell::default();
423                                }
424                            }
425                        }
426                        2 => {
427                            self.clear_line();
428                        }
429                        _ => {}
430                    }
431                }
432                's' => {
433                    self.saved_cursor = Some(self.cursor);
434                }
435                'u' => {
436                    if let Some(saved) = self.saved_cursor {
437                        self.cursor = saved;
438                        self.pending_wrap = false;
439                    }
440                }
441                'm' => {
442                    self.apply_sgr(&params);
443                }
444                _ => {}
445            }
446        }
447    }
448
449    /// Apply SGR (Select Graphic Rendition) parameters to update `current_style`.
450    fn apply_sgr(&mut self, params: &str) {
451        if params.is_empty() {
452            self.current_style = Style::default();
453            return;
454        }
455
456        // Split on ';' for parameter groups, then take the first colon-delimited
457        // sub-parameter as the primary code (e.g. "4:1" → 4 for underline style).
458        let codes: Vec<u16> = params
459            .split(';')
460            .filter_map(|s| {
461                let primary = s.split(':').next().unwrap_or(s);
462                primary.parse().ok()
463            })
464            .collect();
465        let mut i = 0;
466        while i < codes.len() {
467            match codes[i] {
468                0 => self.current_style = Style::default(),
469                1 => self.current_style.bold = true,
470                2 => self.current_style.dim = true,
471                3 => self.current_style.italic = true,
472                4 => self.current_style.underline = true,
473                9 => self.current_style.strikethrough = true,
474                22 => {
475                    self.current_style.bold = false;
476                    self.current_style.dim = false;
477                }
478                23 => self.current_style.italic = false,
479                24 => self.current_style.underline = false,
480                29 => self.current_style.strikethrough = false,
481                30..=37 => {
482                    self.current_style.fg = Some(standard_color(codes[i] as u8 - 30));
483                }
484                38 => {
485                    i += 1;
486                    if i < codes.len() {
487                        match codes[i] {
488                            5 if i + 1 < codes.len() => {
489                                i += 1;
490                                self.current_style.fg = Some(Color::AnsiValue(codes[i] as u8));
491                            }
492                            2 if i + 3 < codes.len() => {
493                                self.current_style.fg = Some(Color::Rgb {
494                                    r: codes[i + 1] as u8,
495                                    g: codes[i + 2] as u8,
496                                    b: codes[i + 3] as u8,
497                                });
498                                i += 3;
499                            }
500                            _ => {}
501                        }
502                    }
503                }
504                39 => self.current_style.fg = None,
505                40..=47 => {
506                    self.current_style.bg = Some(standard_color(codes[i] as u8 - 40));
507                }
508                48 => {
509                    i += 1;
510                    if i < codes.len() {
511                        match codes[i] {
512                            5 if i + 1 < codes.len() => {
513                                i += 1;
514                                self.current_style.bg = Some(Color::AnsiValue(codes[i] as u8));
515                            }
516                            2 if i + 3 < codes.len() => {
517                                self.current_style.bg = Some(Color::Rgb {
518                                    r: codes[i + 1] as u8,
519                                    g: codes[i + 2] as u8,
520                                    b: codes[i + 3] as u8,
521                                });
522                                i += 3;
523                            }
524                            _ => {}
525                        }
526                    }
527                }
528                49 => self.current_style.bg = None,
529                90..=97 => {
530                    self.current_style.fg = Some(bright_color(codes[i] as u8 - 90));
531                }
532                100..=107 => {
533                    self.current_style.bg = Some(bright_color(codes[i] as u8 - 100));
534                }
535                _ => {}
536            }
537            i += 1;
538        }
539    }
540}
541
542/// Map ANSI standard color index (0-7) to crossterm Color.
543fn standard_color(index: u8) -> Color {
544    match index {
545        0 => Color::Black,
546        1 => Color::DarkRed,
547        2 => Color::DarkGreen,
548        3 => Color::DarkYellow,
549        4 => Color::DarkBlue,
550        5 => Color::DarkMagenta,
551        6 => Color::DarkCyan,
552        7 => Color::Grey,
553        _ => Color::Reset,
554    }
555}
556
557/// Map ANSI bright color index (0-7) to crossterm Color.
558fn bright_color(index: u8) -> Color {
559    match index {
560        0 => Color::DarkGrey,
561        1 => Color::Red,
562        2 => Color::Green,
563        3 => Color::Yellow,
564        4 => Color::Blue,
565        5 => Color::Magenta,
566        6 => Color::Cyan,
567        7 => Color::White,
568        _ => Color::Reset,
569    }
570}
571
572impl Write for TestTerminal {
573    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
574        self.escape_buffer.extend_from_slice(buf);
575        Ok(buf.len())
576    }
577
578    fn flush(&mut self) -> io::Result<()> {
579        if !self.escape_buffer.is_empty() {
580            let bytes = std::mem::take(&mut self.escape_buffer);
581            self.process_bytes(&bytes);
582        }
583        Ok(())
584    }
585}
586
587/// Asserts a test terminal buffer matches the expected output.
588/// Each element of the expected vector represents a row.
589/// Trailing whitespace is ignored on each line.
590pub fn assert_buffer_eq<S: AsRef<str>>(terminal: &TestTerminal, expected: &[S]) {
591    let actual_lines = terminal.get_lines();
592    let max_lines = expected.len().max(actual_lines.len());
593
594    for i in 0..max_lines {
595        let expected_line = expected.get(i).map_or("", AsRef::as_ref);
596        let actual_line = actual_lines.get(i).map_or("", String::as_str);
597
598        assert_eq!(
599            actual_line,
600            expected_line,
601            "Line {i} mismatch:\n  Expected: '{expected_line}'\n  Got:      '{actual_line}'\n\nFull buffer:\n{}",
602            actual_lines.join("\n")
603        );
604    }
605}
606
607#[cfg(test)]
608mod tests {
609    use super::*;
610
611    #[test]
612    fn test_basic_write() {
613        let mut term = TestTerminal::new(80, 24);
614        write!(term, "Hello").unwrap();
615        term.flush().unwrap();
616        let lines = term.get_lines();
617        assert_eq!(lines[0], "Hello");
618    }
619
620    #[test]
621    fn test_newline() {
622        let mut term = TestTerminal::new(80, 24);
623        write!(term, "Line 1\nLine 2").unwrap();
624        term.flush().unwrap();
625        assert_buffer_eq(&term, &["Line 1", "Line 2"]);
626    }
627
628    #[test]
629    fn test_carriage_return() {
630        let mut term = TestTerminal::new(80, 24);
631        write!(term, "Hello\rWorld").unwrap();
632        term.flush().unwrap();
633        let lines = term.get_lines();
634        assert_eq!(lines[0], "World");
635    }
636
637    #[test]
638    fn test_ansi_cursor_position() {
639        let mut term = TestTerminal::new(80, 24);
640        write!(term, "\x1b[3;5HX").unwrap();
641        term.flush().unwrap();
642        let lines = term.get_lines();
643        assert_eq!(&lines[2][4..5], "X");
644    }
645
646    #[test]
647    fn test_ansi_clear_line() {
648        let mut term = TestTerminal::new(80, 24);
649        write!(term, "Hello World").unwrap();
650        write!(term, "\x1b[1G\x1b[K").unwrap();
651        term.flush().unwrap();
652        let lines = term.get_lines();
653        assert_eq!(lines[0], "");
654    }
655
656    #[test]
657    fn test_assert_buffer_eq() {
658        let mut term = TestTerminal::new(80, 24);
659        write!(term, "Line 1\nLine 2\nLine 3").unwrap();
660        term.flush().unwrap();
661
662        assert_buffer_eq(&term, &["Line 1", "Line 2", "Line 3"]);
663    }
664
665    #[test]
666    #[should_panic(expected = "Line 0 mismatch")]
667    fn test_assert_buffer_eq_fails() {
668        let mut term = TestTerminal::new(80, 24);
669        write!(term, "Wrong").unwrap();
670        term.flush().unwrap();
671
672        assert_buffer_eq(&term, &["Expected"]);
673    }
674
675    #[test]
676    fn test_private_mode_sequences_ignored() {
677        let mut term = TestTerminal::new(80, 24);
678        write!(term, "\x1b[?2026hHello\x1b[?2026l").unwrap();
679        term.flush().unwrap();
680        let lines = term.get_lines();
681        assert_eq!(lines[0], "Hello");
682    }
683
684    #[test]
685    fn test_cursor_save_restore() {
686        let mut term = TestTerminal::new(80, 24);
687
688        write!(term, "\x1b[6;11HFirst").unwrap();
689        write!(term, "\x1b7").unwrap();
690        write!(term, "\x1b[1;1HSecond").unwrap();
691        write!(term, "\x1b8Third").unwrap();
692
693        term.flush().unwrap();
694
695        let lines = term.get_lines();
696        assert_eq!(lines[0], "Second");
697        assert_eq!(lines[5], "          FirstThird");
698    }
699
700    #[test]
701    fn test_transcript_includes_scrolled_off_lines() {
702        let mut term = TestTerminal::new(6, 2);
703        write!(term, "L1\nL2\nL3").unwrap();
704        term.flush().unwrap();
705
706        let visible = term.get_lines();
707        assert_eq!(visible[0], "L2");
708        assert_eq!(visible[1], "L3");
709
710        let transcript = term.get_transcript_lines();
711        assert_eq!(transcript, vec!["L1", "L2", "L3"]);
712    }
713
714    #[test]
715    fn test_sgr_bold() {
716        let mut term = TestTerminal::new(80, 24);
717        write!(term, "\x1b[1mbold\x1b[0m").unwrap();
718        term.flush().unwrap();
719        let lines = term.get_lines();
720        assert_eq!(lines[0], "bold");
721        assert!(term.get_style_at(0, 0).bold);
722        assert!(!term.get_style_at(0, 4).bold);
723    }
724
725    #[test]
726    fn test_sgr_fg_color() {
727        let mut term = TestTerminal::new(80, 24);
728        write!(term, "\x1b[31mred\x1b[0m").unwrap();
729        term.flush().unwrap();
730        assert_eq!(term.get_style_at(0, 0).fg, Some(Color::DarkRed));
731        assert_eq!(term.get_style_at(0, 3).fg, None);
732    }
733
734    #[test]
735    fn test_sgr_rgb_color() {
736        let mut term = TestTerminal::new(80, 24);
737        write!(term, "\x1b[38;2;255;128;0mrgb\x1b[0m").unwrap();
738        term.flush().unwrap();
739        assert_eq!(
740            term.get_style_at(0, 0).fg,
741            Some(Color::Rgb {
742                r: 255,
743                g: 128,
744                b: 0
745            })
746        );
747    }
748
749    #[test]
750    fn test_style_of_text() {
751        let mut term = TestTerminal::new(80, 24);
752        write!(term, "plain \x1b[1mbold\x1b[0m rest").unwrap();
753        term.flush().unwrap();
754        let style = term.style_of_text(0, "bold").unwrap();
755        assert!(style.bold);
756        let style = term.style_of_text(0, "plain").unwrap();
757        assert!(!style.bold);
758    }
759}