Skip to main content

tui/testing/
test_terminal.rs

1use crossterm::style::{Color, force_color_output};
2use std::io::{self, Write};
3
4use crate::Style;
5use crate::terminal_codes::BELL;
6
7/// A single cell in the terminal buffer, storing both a character and its style.
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct Cell {
10    pub ch: char,
11    pub style: Style,
12}
13
14impl Default for Cell {
15    fn default() -> Self {
16        Self { ch: ' ', style: Style::default() }
17    }
18}
19
20impl Cell {
21    fn new(ch: char, style: Style) -> Self {
22        Self { ch, style }
23    }
24}
25
26/// A virtual terminal buffer for testing terminal output.
27/// Captures all writes, tracks cursor position, and parses ANSI escape sequences
28/// including SGR (Select Graphic Rendition) codes for style tracking.
29///
30/// Implements delayed wrapping (DEC-style): when the cursor reaches the last
31/// column, it stays there with a pending-wrap flag. The next printable character
32/// triggers the wrap to column 0 of the next line. A `\r` clears the flag.
33#[derive(Debug, Clone)]
34pub struct TestTerminal {
35    /// 2D buffer of cells (row, column)
36    buffer: Vec<Vec<Cell>>,
37    /// Rows that have scrolled off the top of the visible buffer.
38    scrollback: Vec<Vec<Cell>>,
39    /// Current cursor position (column, row)
40    cursor: (u16, u16),
41    /// Saved cursor position (for save/restore)
42    saved_cursor: Option<(u16, u16)>,
43    /// Terminal size (columns, rows)
44    size: (u16, u16),
45    /// Buffer for incomplete escape sequences
46    escape_buffer: Vec<u8>,
47    /// Delayed wrap: cursor hit last column but hasn't wrapped yet
48    pending_wrap: bool,
49    /// Current SGR style applied to newly written characters
50    current_style: Style,
51    bell_count: usize,
52}
53
54impl TestTerminal {
55    /// Create a new test terminal with given size
56    pub fn new(columns: u16, rows: u16) -> Self {
57        force_color_output(true);
58        let buffer = vec![vec![Cell::default(); columns as usize]; rows as usize];
59        Self {
60            buffer,
61            scrollback: Vec::new(),
62            cursor: (0, 0),
63            saved_cursor: None,
64            size: (columns, rows),
65            escape_buffer: Vec::new(),
66            pending_wrap: false,
67            current_style: Style::default(),
68            bell_count: 0,
69        }
70    }
71
72    pub fn bell_count(&self) -> usize {
73        self.bell_count
74    }
75
76    /// Resize terminal without preserving prior transcript content.
77    pub fn resize(&mut self, columns: u16, rows: u16) {
78        let columns = columns.max(1);
79        let rows = rows.max(1);
80        self.buffer = vec![vec![Cell::default(); columns as usize]; rows as usize];
81        self.scrollback.clear();
82        self.size = (columns, rows);
83        self.cursor = (0, rows.saturating_sub(1));
84        self.saved_cursor = None;
85        self.pending_wrap = false;
86    }
87
88    /// Resize terminal and reflow existing transcript content to match the new width.
89    pub fn resize_preserving_transcript(&mut self, columns: u16, rows: u16) {
90        let transcript = self.get_transcript_lines();
91        let wrapped = Self::reflow_lines(&transcript, columns);
92        self.apply_reflowed_lines(columns, rows, &wrapped);
93    }
94
95    fn reflow_lines(lines: &[String], columns: u16) -> Vec<String> {
96        let mut wrapped = Vec::new();
97        let width = columns.max(1) as usize;
98
99        for line in lines {
100            if line.is_empty() {
101                wrapped.push(String::new());
102                continue;
103            }
104
105            let chars: Vec<char> = line.chars().collect();
106            for chunk in chars.chunks(width) {
107                wrapped.push(chunk.iter().collect());
108            }
109        }
110
111        if wrapped.is_empty() {
112            wrapped.push(String::new());
113        }
114
115        wrapped
116    }
117
118    fn apply_reflowed_lines(&mut self, columns: u16, rows: u16, wrapped: &[String]) {
119        let rows_usize = rows.max(1) as usize;
120        let split_at = wrapped.len().saturating_sub(rows_usize);
121        let (scrollback, visible) = wrapped.split_at(split_at);
122
123        self.scrollback = scrollback.iter().map(|line| Self::line_to_row(line, columns)).collect();
124
125        self.buffer = visible.iter().map(|line| Self::line_to_row(line, columns)).collect();
126
127        while self.buffer.len() < rows_usize {
128            self.buffer.push(vec![Cell::default(); columns as usize]);
129        }
130
131        self.size = (columns, rows);
132        self.cursor = (0, rows.saturating_sub(1));
133        self.saved_cursor = None;
134        self.pending_wrap = false;
135    }
136
137    fn line_to_row(line: &str, columns: u16) -> Vec<Cell> {
138        let mut row: Vec<Cell> =
139            line.chars().take(columns as usize).map(|ch| Cell::new(ch, Style::default())).collect();
140        row.resize(columns as usize, Cell::default());
141        row
142    }
143
144    /// Get all lines as a vector of strings (trailing whitespace trimmed)
145    pub fn get_lines(&self) -> Vec<String> {
146        self.buffer.iter().map(|cells| cells.iter().map(|c| c.ch).collect::<String>().trim_end().to_string()).collect()
147    }
148
149    /// Get full terminal transcript (scrollback history + visible buffer).
150    pub fn get_transcript_lines(&self) -> Vec<String> {
151        self.scrollback
152            .iter()
153            .chain(self.buffer.iter())
154            .map(|cells| cells.iter().map(|c| c.ch).collect::<String>().trim_end().to_string())
155            .collect()
156    }
157
158    /// Get current cursor position as (column, row).
159    #[allow(dead_code)]
160    pub fn cursor_position(&self) -> (u16, u16) {
161        self.cursor
162    }
163
164    /// Get the style at a specific buffer position.
165    pub fn get_style_at(&self, row: usize, col: usize) -> Style {
166        self.buffer.get(row).and_then(|r| r.get(col)).map_or(Style::default(), |c| c.style)
167    }
168
169    /// Find the first occurrence of `text` on the given row and return its style.
170    ///
171    /// Returns the style of the first character of the matched text.
172    pub fn style_of_text(&self, row: usize, text: &str) -> Option<Style> {
173        let row_data = self.buffer.get(row)?;
174        let row_text: String = row_data.iter().map(|c| c.ch).collect();
175        let byte_offset = row_text.find(text)?;
176        // Convert byte offset to character index
177        let char_index = row_text[..byte_offset].chars().count();
178        Some(row_data[char_index].style)
179    }
180
181    /// Clear the entire buffer
182    pub fn clear(&mut self) {
183        for row in &mut self.buffer {
184            for cell in row {
185                *cell = Cell::default();
186            }
187        }
188    }
189
190    /// Clear the current line
191    pub fn clear_line(&mut self) {
192        if let Some(row) = self.buffer.get_mut(self.cursor.1 as usize) {
193            for cell in row {
194                *cell = Cell::default();
195            }
196        }
197    }
198
199    /// Move cursor to absolute position
200    pub fn move_to(&mut self, col: u16, row: u16) {
201        self.cursor = (col.min(self.size.0.saturating_sub(1)), row.min(self.size.1.saturating_sub(1)));
202        self.pending_wrap = false;
203    }
204
205    /// Move cursor to column (keep same row)
206    pub fn move_to_column(&mut self, col: u16) {
207        self.cursor.0 = col.min(self.size.0.saturating_sub(1));
208        self.pending_wrap = false;
209    }
210
211    /// Move cursor left by n positions
212    pub fn move_left(&mut self, n: u16) {
213        self.cursor.0 = self.cursor.0.saturating_sub(n);
214        self.pending_wrap = false;
215    }
216
217    /// Move cursor right by n positions
218    pub fn move_right(&mut self, n: u16) {
219        self.cursor.0 = (self.cursor.0 + n).min(self.size.0.saturating_sub(1));
220        self.pending_wrap = false;
221    }
222
223    /// Write a single character at current cursor position and advance cursor
224    fn write_char(&mut self, ch: char) {
225        match ch {
226            '\n' => {
227                self.pending_wrap = false;
228                if self.cursor.1 >= self.size.1.saturating_sub(1) {
229                    let removed = self.buffer.remove(0);
230                    self.scrollback.push(removed);
231                    self.buffer.push(vec![Cell::default(); self.size.0 as usize]);
232                } else {
233                    self.cursor.1 += 1;
234                }
235                self.cursor.0 = 0;
236            }
237            '\r' => {
238                self.cursor.0 = 0;
239                self.pending_wrap = false;
240            }
241            '\t' => {
242                for _ in 0..4 {
243                    self.write_char_at_cursor(' ');
244                }
245            }
246            _ => {
247                self.write_char_at_cursor(ch);
248            }
249        }
250    }
251
252    /// Write a character at the current cursor position (delayed wrap).
253    ///
254    /// When the cursor is at the last column with `pending_wrap` set,
255    /// the next printable character triggers the wrap first.
256    fn write_char_at_cursor(&mut self, ch: char) {
257        if self.pending_wrap {
258            self.pending_wrap = false;
259            self.cursor.0 = 0;
260            if self.cursor.1 >= self.size.1.saturating_sub(1) {
261                let removed = self.buffer.remove(0);
262                self.scrollback.push(removed);
263                self.buffer.push(vec![Cell::default(); self.size.0 as usize]);
264            } else {
265                self.cursor.1 += 1;
266            }
267        }
268
269        if let Some(row) = self.buffer.get_mut(self.cursor.1 as usize)
270            && let Some(cell) = row.get_mut(self.cursor.0 as usize)
271        {
272            *cell = Cell::new(ch, self.current_style);
273            self.cursor.0 += 1;
274            if self.cursor.0 >= self.size.0 {
275                self.cursor.0 = self.size.0 - 1;
276                self.pending_wrap = true;
277            }
278        }
279    }
280
281    /// Process a byte slice, handling ANSI escape sequences
282    fn process_bytes(&mut self, buf: &[u8]) {
283        let s = String::from_utf8_lossy(buf);
284        let mut chars = s.chars().peekable();
285
286        while let Some(ch) = chars.next() {
287            if ch == '\x1b' {
288                if chars.peek() == Some(&'[') {
289                    chars.next();
290                    self.process_csi_sequence(&mut chars);
291                } else if chars.peek() == Some(&'7') {
292                    chars.next();
293                    self.saved_cursor = Some(self.cursor);
294                } else if chars.peek() == Some(&'8') {
295                    chars.next();
296                    if let Some(saved) = self.saved_cursor {
297                        self.cursor = saved;
298                    }
299                }
300            } else if ch == char::from(BELL) {
301                self.bell_count += 1;
302            } else {
303                self.write_char(ch);
304            }
305        }
306    }
307
308    /// Process a CSI (Control Sequence Introducer) escape sequence
309    #[allow(clippy::too_many_lines)]
310    fn process_csi_sequence(&mut self, chars: &mut std::iter::Peekable<std::str::Chars>) {
311        let private_mode = if chars.peek() == Some(&'?') {
312            chars.next();
313            true
314        } else {
315            false
316        };
317
318        let mut params = String::new();
319
320        while let Some(&ch) = chars.peek() {
321            if ch.is_ascii_digit() || ch == ';' || ch == ':' {
322                params.push(ch);
323                chars.next();
324            } else {
325                break;
326            }
327        }
328
329        if private_mode {
330            chars.next();
331            return;
332        }
333
334        if let Some(cmd) = chars.next() {
335            match cmd {
336                'H' | 'f' => {
337                    let parts: Vec<u16> = params.split(';').filter_map(|s| s.parse().ok()).collect();
338                    let row = parts.first().copied().unwrap_or(1).saturating_sub(1);
339                    let col = parts.get(1).copied().unwrap_or(1).saturating_sub(1);
340                    self.move_to(col, row);
341                }
342                'A' => {
343                    let n = params.parse().unwrap_or(1);
344                    self.cursor.1 = self.cursor.1.saturating_sub(n);
345                    self.pending_wrap = false;
346                }
347                'B' => {
348                    let n = params.parse().unwrap_or(1);
349                    self.cursor.1 = (self.cursor.1 + n).min(self.size.1.saturating_sub(1));
350                    self.pending_wrap = false;
351                }
352                'C' => {
353                    let n = params.parse().unwrap_or(1);
354                    self.move_right(n);
355                }
356                'D' => {
357                    let n = params.parse().unwrap_or(1);
358                    self.move_left(n);
359                }
360                'G' => {
361                    let col = params.parse::<u16>().unwrap_or(1).saturating_sub(1);
362                    self.move_to_column(col);
363                }
364                'J' => {
365                    let n = params.parse().unwrap_or(0);
366                    match n {
367                        0 => {
368                            for row in self.cursor.1..self.size.1 {
369                                if let Some(r) = self.buffer.get_mut(row as usize) {
370                                    let start = if row == self.cursor.1 { self.cursor.0 as usize } else { 0 };
371                                    for cell in r.iter_mut().skip(start) {
372                                        *cell = Cell::default();
373                                    }
374                                }
375                            }
376                        }
377                        2 => {
378                            self.clear();
379                        }
380                        3 => {
381                            self.scrollback.clear();
382                        }
383                        _ => {}
384                    }
385                }
386                'K' => {
387                    let n = params.parse().unwrap_or(0);
388                    match n {
389                        0 => {
390                            if let Some(row) = self.buffer.get_mut(self.cursor.1 as usize) {
391                                for cell in row.iter_mut().skip(self.cursor.0 as usize) {
392                                    *cell = Cell::default();
393                                }
394                            }
395                        }
396                        2 => {
397                            self.clear_line();
398                        }
399                        _ => {}
400                    }
401                }
402                's' => {
403                    self.saved_cursor = Some(self.cursor);
404                }
405                'u' => {
406                    if let Some(saved) = self.saved_cursor {
407                        self.cursor = saved;
408                        self.pending_wrap = false;
409                    }
410                }
411                'm' => {
412                    self.apply_sgr(&params);
413                }
414                _ => {}
415            }
416        }
417    }
418
419    /// Apply SGR (Select Graphic Rendition) parameters to update `current_style`.
420    #[allow(clippy::cast_possible_truncation)]
421    fn apply_sgr(&mut self, params: &str) {
422        if params.is_empty() {
423            self.current_style = Style::default();
424            return;
425        }
426
427        // Split on ';' for parameter groups, then take the first colon-delimited
428        // sub-parameter as the primary code (e.g. "4:1" → 4 for underline style).
429        let codes: Vec<u16> = params
430            .split(';')
431            .filter_map(|s| {
432                let primary = s.split(':').next().unwrap_or(s);
433                primary.parse().ok()
434            })
435            .collect();
436        let mut i = 0;
437        while i < codes.len() {
438            match codes[i] {
439                0 => self.current_style = Style::default(),
440                1 => self.current_style.bold = true,
441                2 => self.current_style.dim = true,
442                3 => self.current_style.italic = true,
443                4 => self.current_style.underline = true,
444                9 => self.current_style.strikethrough = true,
445                22 => {
446                    self.current_style.bold = false;
447                    self.current_style.dim = false;
448                }
449                23 => self.current_style.italic = false,
450                24 => self.current_style.underline = false,
451                29 => self.current_style.strikethrough = false,
452                30..=37 => {
453                    self.current_style.fg = Some(standard_color(codes[i] as u8 - 30));
454                }
455                38 => {
456                    i += 1;
457                    if i < codes.len() {
458                        match codes[i] {
459                            5 if i + 1 < codes.len() => {
460                                i += 1;
461                                self.current_style.fg = Some(Color::AnsiValue(codes[i] as u8));
462                            }
463                            2 if i + 3 < codes.len() => {
464                                self.current_style.fg = Some(Color::Rgb {
465                                    r: codes[i + 1] as u8,
466                                    g: codes[i + 2] as u8,
467                                    b: codes[i + 3] as u8,
468                                });
469                                i += 3;
470                            }
471                            _ => {}
472                        }
473                    }
474                }
475                39 => self.current_style.fg = None,
476                40..=47 => {
477                    self.current_style.bg = Some(standard_color(codes[i] as u8 - 40));
478                }
479                48 => {
480                    i += 1;
481                    if i < codes.len() {
482                        match codes[i] {
483                            5 if i + 1 < codes.len() => {
484                                i += 1;
485                                self.current_style.bg = Some(Color::AnsiValue(codes[i] as u8));
486                            }
487                            2 if i + 3 < codes.len() => {
488                                self.current_style.bg = Some(Color::Rgb {
489                                    r: codes[i + 1] as u8,
490                                    g: codes[i + 2] as u8,
491                                    b: codes[i + 3] as u8,
492                                });
493                                i += 3;
494                            }
495                            _ => {}
496                        }
497                    }
498                }
499                49 => self.current_style.bg = None,
500                90..=97 => {
501                    self.current_style.fg = Some(bright_color(codes[i] as u8 - 90));
502                }
503                100..=107 => {
504                    self.current_style.bg = Some(bright_color(codes[i] as u8 - 100));
505                }
506                _ => {}
507            }
508            i += 1;
509        }
510    }
511}
512
513/// Map ANSI standard color index (0-7) to crossterm Color.
514fn standard_color(index: u8) -> Color {
515    match index {
516        0 => Color::Black,
517        1 => Color::DarkRed,
518        2 => Color::DarkGreen,
519        3 => Color::DarkYellow,
520        4 => Color::DarkBlue,
521        5 => Color::DarkMagenta,
522        6 => Color::DarkCyan,
523        7 => Color::Grey,
524        _ => Color::Reset,
525    }
526}
527
528/// Map ANSI bright color index (0-7) to crossterm Color.
529fn bright_color(index: u8) -> Color {
530    match index {
531        0 => Color::DarkGrey,
532        1 => Color::Red,
533        2 => Color::Green,
534        3 => Color::Yellow,
535        4 => Color::Blue,
536        5 => Color::Magenta,
537        6 => Color::Cyan,
538        7 => Color::White,
539        _ => Color::Reset,
540    }
541}
542
543impl Write for TestTerminal {
544    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
545        self.escape_buffer.extend_from_slice(buf);
546        Ok(buf.len())
547    }
548
549    fn flush(&mut self) -> io::Result<()> {
550        if !self.escape_buffer.is_empty() {
551            let bytes = std::mem::take(&mut self.escape_buffer);
552            self.process_bytes(&bytes);
553        }
554        Ok(())
555    }
556}
557
558/// Asserts a test terminal buffer matches the expected output.
559/// Each element of the expected vector represents a row.
560/// Trailing whitespace is ignored on each line.
561pub fn assert_buffer_eq<S: AsRef<str>>(terminal: &TestTerminal, expected: &[S]) {
562    let actual_lines = terminal.get_lines();
563    let max_lines = expected.len().max(actual_lines.len());
564
565    for i in 0..max_lines {
566        let expected_line = expected.get(i).map_or("", AsRef::as_ref);
567        let actual_line = actual_lines.get(i).map_or("", String::as_str);
568
569        assert_eq!(
570            actual_line,
571            expected_line,
572            "Line {i} mismatch:\n  Expected: '{expected_line}'\n  Got:      '{actual_line}'\n\nFull buffer:\n{}",
573            actual_lines.join("\n")
574        );
575    }
576}
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581
582    #[test]
583    fn test_basic_write() {
584        let mut term = TestTerminal::new(80, 24);
585        write!(term, "Hello").unwrap();
586        term.flush().unwrap();
587        let lines = term.get_lines();
588        assert_eq!(lines[0], "Hello");
589    }
590
591    #[test]
592    fn test_newline() {
593        let mut term = TestTerminal::new(80, 24);
594        write!(term, "Line 1\nLine 2").unwrap();
595        term.flush().unwrap();
596        assert_buffer_eq(&term, &["Line 1", "Line 2"]);
597    }
598
599    #[test]
600    fn test_carriage_return() {
601        let mut term = TestTerminal::new(80, 24);
602        write!(term, "Hello\rWorld").unwrap();
603        term.flush().unwrap();
604        let lines = term.get_lines();
605        assert_eq!(lines[0], "World");
606    }
607
608    #[test]
609    fn test_ansi_cursor_position() {
610        let mut term = TestTerminal::new(80, 24);
611        write!(term, "\x1b[3;5HX").unwrap();
612        term.flush().unwrap();
613        let lines = term.get_lines();
614        assert_eq!(&lines[2][4..5], "X");
615    }
616
617    #[test]
618    fn test_ansi_clear_line() {
619        let mut term = TestTerminal::new(80, 24);
620        write!(term, "Hello World").unwrap();
621        write!(term, "\x1b[1G\x1b[K").unwrap();
622        term.flush().unwrap();
623        let lines = term.get_lines();
624        assert_eq!(lines[0], "");
625    }
626
627    #[test]
628    fn test_assert_buffer_eq() {
629        let mut term = TestTerminal::new(80, 24);
630        write!(term, "Line 1\nLine 2\nLine 3").unwrap();
631        term.flush().unwrap();
632
633        assert_buffer_eq(&term, &["Line 1", "Line 2", "Line 3"]);
634    }
635
636    #[test]
637    #[should_panic(expected = "Line 0 mismatch")]
638    fn test_assert_buffer_eq_fails() {
639        let mut term = TestTerminal::new(80, 24);
640        write!(term, "Wrong").unwrap();
641        term.flush().unwrap();
642
643        assert_buffer_eq(&term, &["Expected"]);
644    }
645
646    #[test]
647    fn test_private_mode_sequences_ignored() {
648        let mut term = TestTerminal::new(80, 24);
649        write!(term, "\x1b[?2026hHello\x1b[?2026l").unwrap();
650        term.flush().unwrap();
651        let lines = term.get_lines();
652        assert_eq!(lines[0], "Hello");
653    }
654
655    #[test]
656    fn test_cursor_save_restore() {
657        let mut term = TestTerminal::new(80, 24);
658
659        write!(term, "\x1b[6;11HFirst").unwrap();
660        write!(term, "\x1b7").unwrap();
661        write!(term, "\x1b[1;1HSecond").unwrap();
662        write!(term, "\x1b8Third").unwrap();
663
664        term.flush().unwrap();
665
666        let lines = term.get_lines();
667        assert_eq!(lines[0], "Second");
668        assert_eq!(lines[5], "          FirstThird");
669    }
670
671    #[test]
672    fn test_transcript_includes_scrolled_off_lines() {
673        let mut term = TestTerminal::new(6, 2);
674        write!(term, "L1\nL2\nL3").unwrap();
675        term.flush().unwrap();
676
677        let visible = term.get_lines();
678        assert_eq!(visible[0], "L2");
679        assert_eq!(visible[1], "L3");
680
681        let transcript = term.get_transcript_lines();
682        assert_eq!(transcript, vec!["L1", "L2", "L3"]);
683    }
684
685    #[test]
686    fn test_sgr_bold() {
687        let mut term = TestTerminal::new(80, 24);
688        write!(term, "\x1b[1mbold\x1b[0m").unwrap();
689        term.flush().unwrap();
690        let lines = term.get_lines();
691        assert_eq!(lines[0], "bold");
692        assert!(term.get_style_at(0, 0).bold);
693        assert!(!term.get_style_at(0, 4).bold);
694    }
695
696    #[test]
697    fn test_sgr_fg_color() {
698        let mut term = TestTerminal::new(80, 24);
699        write!(term, "\x1b[31mred\x1b[0m").unwrap();
700        term.flush().unwrap();
701        assert_eq!(term.get_style_at(0, 0).fg, Some(Color::DarkRed));
702        assert_eq!(term.get_style_at(0, 3).fg, None);
703    }
704
705    #[test]
706    fn test_sgr_rgb_color() {
707        let mut term = TestTerminal::new(80, 24);
708        write!(term, "\x1b[38;2;255;128;0mrgb\x1b[0m").unwrap();
709        term.flush().unwrap();
710        assert_eq!(term.get_style_at(0, 0).fg, Some(Color::Rgb { r: 255, g: 128, b: 0 }));
711    }
712
713    #[test]
714    fn test_style_of_text() {
715        let mut term = TestTerminal::new(80, 24);
716        write!(term, "plain \x1b[1mbold\x1b[0m rest").unwrap();
717        term.flush().unwrap();
718        let style = term.style_of_text(0, "bold").unwrap();
719        assert!(style.bold);
720        let style = term.style_of_text(0, "plain").unwrap();
721        assert!(!style.bold);
722    }
723}