Skip to main content

tui/testing/
test_terminal.rs

1use std::io::{self, Write};
2
3use crossterm::style::Color;
4
5use crate::Style;
6
7/// A single cell in the terminal buffer, storing both a character and its style.
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct Cell {
10    pub ch: char,
11    pub style: Style,
12}
13
14impl Default for Cell {
15    fn default() -> Self {
16        Self { ch: ' ', style: Style::default() }
17    }
18}
19
20impl Cell {
21    fn new(ch: char, style: Style) -> Self {
22        Self { ch, style }
23    }
24}
25
26/// A virtual terminal buffer for testing terminal output.
27/// Captures all writes, tracks cursor position, and parses ANSI escape sequences
28/// including SGR (Select Graphic Rendition) codes for style tracking.
29///
30/// Implements delayed wrapping (DEC-style): when the cursor reaches the last
31/// column, it stays there with a pending-wrap flag. The next printable character
32/// triggers the wrap to column 0 of the next line. A `\r` clears the flag.
33#[derive(Debug, Clone)]
34pub struct TestTerminal {
35    /// 2D buffer of cells (row, column)
36    buffer: Vec<Vec<Cell>>,
37    /// Rows that have scrolled off the top of the visible buffer.
38    scrollback: Vec<Vec<Cell>>,
39    /// Current cursor position (column, row)
40    cursor: (u16, u16),
41    /// Saved cursor position (for save/restore)
42    saved_cursor: Option<(u16, u16)>,
43    /// Terminal size (columns, rows)
44    size: (u16, u16),
45    /// Buffer for incomplete escape sequences
46    escape_buffer: Vec<u8>,
47    /// Delayed wrap: cursor hit last column but hasn't wrapped yet
48    pending_wrap: bool,
49    /// Current SGR style applied to newly written characters
50    current_style: Style,
51}
52
53impl TestTerminal {
54    /// Create a new test terminal with given size
55    pub fn new(columns: u16, rows: u16) -> Self {
56        let buffer = vec![vec![Cell::default(); columns as usize]; rows as usize];
57        Self {
58            buffer,
59            scrollback: Vec::new(),
60            cursor: (0, 0),
61            saved_cursor: None,
62            size: (columns, rows),
63            escape_buffer: Vec::new(),
64            pending_wrap: false,
65            current_style: Style::default(),
66        }
67    }
68
69    /// Resize terminal without preserving prior transcript content.
70    pub fn resize(&mut self, columns: u16, rows: u16) {
71        let columns = columns.max(1);
72        let rows = rows.max(1);
73        self.buffer = vec![vec![Cell::default(); columns as usize]; rows as usize];
74        self.scrollback.clear();
75        self.size = (columns, rows);
76        self.cursor = (0, rows.saturating_sub(1));
77        self.saved_cursor = None;
78        self.pending_wrap = false;
79    }
80
81    /// Resize terminal and reflow existing transcript content to match the new width.
82    pub fn resize_preserving_transcript(&mut self, columns: u16, rows: u16) {
83        let transcript = self.get_transcript_lines();
84        let wrapped = Self::reflow_lines(&transcript, columns);
85        self.apply_reflowed_lines(columns, rows, &wrapped);
86    }
87
88    fn reflow_lines(lines: &[String], columns: u16) -> Vec<String> {
89        let mut wrapped = Vec::new();
90        let width = columns.max(1) as usize;
91
92        for line in lines {
93            if line.is_empty() {
94                wrapped.push(String::new());
95                continue;
96            }
97
98            let chars: Vec<char> = line.chars().collect();
99            for chunk in chars.chunks(width) {
100                wrapped.push(chunk.iter().collect());
101            }
102        }
103
104        if wrapped.is_empty() {
105            wrapped.push(String::new());
106        }
107
108        wrapped
109    }
110
111    fn apply_reflowed_lines(&mut self, columns: u16, rows: u16, wrapped: &[String]) {
112        let rows_usize = rows.max(1) as usize;
113        let split_at = wrapped.len().saturating_sub(rows_usize);
114        let (scrollback, visible) = wrapped.split_at(split_at);
115
116        self.scrollback = scrollback.iter().map(|line| Self::line_to_row(line, columns)).collect();
117
118        self.buffer = visible.iter().map(|line| Self::line_to_row(line, columns)).collect();
119
120        while self.buffer.len() < rows_usize {
121            self.buffer.push(vec![Cell::default(); columns as usize]);
122        }
123
124        self.size = (columns, rows);
125        self.cursor = (0, rows.saturating_sub(1));
126        self.saved_cursor = None;
127        self.pending_wrap = false;
128    }
129
130    fn line_to_row(line: &str, columns: u16) -> Vec<Cell> {
131        let mut row: Vec<Cell> =
132            line.chars().take(columns as usize).map(|ch| Cell::new(ch, Style::default())).collect();
133        row.resize(columns as usize, Cell::default());
134        row
135    }
136
137    /// Get all lines as a vector of strings (trailing whitespace trimmed)
138    pub fn get_lines(&self) -> Vec<String> {
139        self.buffer.iter().map(|cells| cells.iter().map(|c| c.ch).collect::<String>().trim_end().to_string()).collect()
140    }
141
142    /// Get full terminal transcript (scrollback history + visible buffer).
143    pub fn get_transcript_lines(&self) -> Vec<String> {
144        self.scrollback
145            .iter()
146            .chain(self.buffer.iter())
147            .map(|cells| cells.iter().map(|c| c.ch).collect::<String>().trim_end().to_string())
148            .collect()
149    }
150
151    /// Get current cursor position as (column, row).
152    #[allow(dead_code)]
153    pub fn cursor_position(&self) -> (u16, u16) {
154        self.cursor
155    }
156
157    /// Get the style at a specific buffer position.
158    pub fn get_style_at(&self, row: usize, col: usize) -> Style {
159        self.buffer.get(row).and_then(|r| r.get(col)).map_or(Style::default(), |c| c.style)
160    }
161
162    /// Find the first occurrence of `text` on the given row and return its style.
163    ///
164    /// Returns the style of the first character of the matched text.
165    pub fn style_of_text(&self, row: usize, text: &str) -> Option<Style> {
166        let row_data = self.buffer.get(row)?;
167        let row_text: String = row_data.iter().map(|c| c.ch).collect();
168        let byte_offset = row_text.find(text)?;
169        // Convert byte offset to character index
170        let char_index = row_text[..byte_offset].chars().count();
171        Some(row_data[char_index].style)
172    }
173
174    /// Clear the entire buffer
175    pub fn clear(&mut self) {
176        for row in &mut self.buffer {
177            for cell in row {
178                *cell = Cell::default();
179            }
180        }
181    }
182
183    /// Clear the current line
184    pub fn clear_line(&mut self) {
185        if let Some(row) = self.buffer.get_mut(self.cursor.1 as usize) {
186            for cell in row {
187                *cell = Cell::default();
188            }
189        }
190    }
191
192    /// Move cursor to absolute position
193    pub fn move_to(&mut self, col: u16, row: u16) {
194        self.cursor = (col.min(self.size.0.saturating_sub(1)), row.min(self.size.1.saturating_sub(1)));
195        self.pending_wrap = false;
196    }
197
198    /// Move cursor to column (keep same row)
199    pub fn move_to_column(&mut self, col: u16) {
200        self.cursor.0 = col.min(self.size.0.saturating_sub(1));
201        self.pending_wrap = false;
202    }
203
204    /// Move cursor left by n positions
205    pub fn move_left(&mut self, n: u16) {
206        self.cursor.0 = self.cursor.0.saturating_sub(n);
207        self.pending_wrap = false;
208    }
209
210    /// Move cursor right by n positions
211    pub fn move_right(&mut self, n: u16) {
212        self.cursor.0 = (self.cursor.0 + n).min(self.size.0.saturating_sub(1));
213        self.pending_wrap = false;
214    }
215
216    /// Write a single character at current cursor position and advance cursor
217    fn write_char(&mut self, ch: char) {
218        match ch {
219            '\n' => {
220                self.pending_wrap = false;
221                if self.cursor.1 >= self.size.1.saturating_sub(1) {
222                    let removed = self.buffer.remove(0);
223                    self.scrollback.push(removed);
224                    self.buffer.push(vec![Cell::default(); self.size.0 as usize]);
225                } else {
226                    self.cursor.1 += 1;
227                }
228                self.cursor.0 = 0;
229            }
230            '\r' => {
231                self.cursor.0 = 0;
232                self.pending_wrap = false;
233            }
234            '\t' => {
235                for _ in 0..4 {
236                    self.write_char_at_cursor(' ');
237                }
238            }
239            _ => {
240                self.write_char_at_cursor(ch);
241            }
242        }
243    }
244
245    /// Write a character at the current cursor position (delayed wrap).
246    ///
247    /// When the cursor is at the last column with `pending_wrap` set,
248    /// the next printable character triggers the wrap first.
249    fn write_char_at_cursor(&mut self, ch: char) {
250        if self.pending_wrap {
251            self.pending_wrap = false;
252            self.cursor.0 = 0;
253            if self.cursor.1 >= self.size.1.saturating_sub(1) {
254                let removed = self.buffer.remove(0);
255                self.scrollback.push(removed);
256                self.buffer.push(vec![Cell::default(); self.size.0 as usize]);
257            } else {
258                self.cursor.1 += 1;
259            }
260        }
261
262        if let Some(row) = self.buffer.get_mut(self.cursor.1 as usize)
263            && let Some(cell) = row.get_mut(self.cursor.0 as usize)
264        {
265            *cell = Cell::new(ch, self.current_style);
266            self.cursor.0 += 1;
267            if self.cursor.0 >= self.size.0 {
268                self.cursor.0 = self.size.0 - 1;
269                self.pending_wrap = true;
270            }
271        }
272    }
273
274    /// Process a byte slice, handling ANSI escape sequences
275    fn process_bytes(&mut self, buf: &[u8]) {
276        let s = String::from_utf8_lossy(buf);
277        let mut chars = s.chars().peekable();
278
279        while let Some(ch) = chars.next() {
280            if ch == '\x1b' {
281                if chars.peek() == Some(&'[') {
282                    chars.next();
283                    self.process_csi_sequence(&mut chars);
284                } else if chars.peek() == Some(&'7') {
285                    chars.next();
286                    self.saved_cursor = Some(self.cursor);
287                } else if chars.peek() == Some(&'8') {
288                    chars.next();
289                    if let Some(saved) = self.saved_cursor {
290                        self.cursor = saved;
291                    }
292                }
293            } else {
294                self.write_char(ch);
295            }
296        }
297    }
298
299    /// Process a CSI (Control Sequence Introducer) escape sequence
300    #[allow(clippy::too_many_lines)]
301    fn process_csi_sequence(&mut self, chars: &mut std::iter::Peekable<std::str::Chars>) {
302        let private_mode = if chars.peek() == Some(&'?') {
303            chars.next();
304            true
305        } else {
306            false
307        };
308
309        let mut params = String::new();
310
311        while let Some(&ch) = chars.peek() {
312            if ch.is_ascii_digit() || ch == ';' || ch == ':' {
313                params.push(ch);
314                chars.next();
315            } else {
316                break;
317            }
318        }
319
320        if private_mode {
321            chars.next();
322            return;
323        }
324
325        if let Some(cmd) = chars.next() {
326            match cmd {
327                'H' | 'f' => {
328                    let parts: Vec<u16> = params.split(';').filter_map(|s| s.parse().ok()).collect();
329                    let row = parts.first().copied().unwrap_or(1).saturating_sub(1);
330                    let col = parts.get(1).copied().unwrap_or(1).saturating_sub(1);
331                    self.move_to(col, row);
332                }
333                'A' => {
334                    let n = params.parse().unwrap_or(1);
335                    self.cursor.1 = self.cursor.1.saturating_sub(n);
336                    self.pending_wrap = false;
337                }
338                'B' => {
339                    let n = params.parse().unwrap_or(1);
340                    self.cursor.1 = (self.cursor.1 + n).min(self.size.1.saturating_sub(1));
341                    self.pending_wrap = false;
342                }
343                'C' => {
344                    let n = params.parse().unwrap_or(1);
345                    self.move_right(n);
346                }
347                'D' => {
348                    let n = params.parse().unwrap_or(1);
349                    self.move_left(n);
350                }
351                'G' => {
352                    let col = params.parse::<u16>().unwrap_or(1).saturating_sub(1);
353                    self.move_to_column(col);
354                }
355                'J' => {
356                    let n = params.parse().unwrap_or(0);
357                    match n {
358                        0 => {
359                            for row in self.cursor.1..self.size.1 {
360                                if let Some(r) = self.buffer.get_mut(row as usize) {
361                                    let start = if row == self.cursor.1 { self.cursor.0 as usize } else { 0 };
362                                    for cell in r.iter_mut().skip(start) {
363                                        *cell = Cell::default();
364                                    }
365                                }
366                            }
367                        }
368                        2 => {
369                            self.clear();
370                        }
371                        _ => {}
372                    }
373                }
374                'K' => {
375                    let n = params.parse().unwrap_or(0);
376                    match n {
377                        0 => {
378                            if let Some(row) = self.buffer.get_mut(self.cursor.1 as usize) {
379                                for cell in row.iter_mut().skip(self.cursor.0 as usize) {
380                                    *cell = Cell::default();
381                                }
382                            }
383                        }
384                        2 => {
385                            self.clear_line();
386                        }
387                        _ => {}
388                    }
389                }
390                's' => {
391                    self.saved_cursor = Some(self.cursor);
392                }
393                'u' => {
394                    if let Some(saved) = self.saved_cursor {
395                        self.cursor = saved;
396                        self.pending_wrap = false;
397                    }
398                }
399                'm' => {
400                    self.apply_sgr(&params);
401                }
402                _ => {}
403            }
404        }
405    }
406
407    /// Apply SGR (Select Graphic Rendition) parameters to update `current_style`.
408    #[allow(clippy::cast_possible_truncation)]
409    fn apply_sgr(&mut self, params: &str) {
410        if params.is_empty() {
411            self.current_style = Style::default();
412            return;
413        }
414
415        // Split on ';' for parameter groups, then take the first colon-delimited
416        // sub-parameter as the primary code (e.g. "4:1" → 4 for underline style).
417        let codes: Vec<u16> = params
418            .split(';')
419            .filter_map(|s| {
420                let primary = s.split(':').next().unwrap_or(s);
421                primary.parse().ok()
422            })
423            .collect();
424        let mut i = 0;
425        while i < codes.len() {
426            match codes[i] {
427                0 => self.current_style = Style::default(),
428                1 => self.current_style.bold = true,
429                2 => self.current_style.dim = true,
430                3 => self.current_style.italic = true,
431                4 => self.current_style.underline = true,
432                9 => self.current_style.strikethrough = true,
433                22 => {
434                    self.current_style.bold = false;
435                    self.current_style.dim = false;
436                }
437                23 => self.current_style.italic = false,
438                24 => self.current_style.underline = false,
439                29 => self.current_style.strikethrough = false,
440                30..=37 => {
441                    self.current_style.fg = Some(standard_color(codes[i] as u8 - 30));
442                }
443                38 => {
444                    i += 1;
445                    if i < codes.len() {
446                        match codes[i] {
447                            5 if i + 1 < codes.len() => {
448                                i += 1;
449                                self.current_style.fg = Some(Color::AnsiValue(codes[i] as u8));
450                            }
451                            2 if i + 3 < codes.len() => {
452                                self.current_style.fg = Some(Color::Rgb {
453                                    r: codes[i + 1] as u8,
454                                    g: codes[i + 2] as u8,
455                                    b: codes[i + 3] as u8,
456                                });
457                                i += 3;
458                            }
459                            _ => {}
460                        }
461                    }
462                }
463                39 => self.current_style.fg = None,
464                40..=47 => {
465                    self.current_style.bg = Some(standard_color(codes[i] as u8 - 40));
466                }
467                48 => {
468                    i += 1;
469                    if i < codes.len() {
470                        match codes[i] {
471                            5 if i + 1 < codes.len() => {
472                                i += 1;
473                                self.current_style.bg = Some(Color::AnsiValue(codes[i] as u8));
474                            }
475                            2 if i + 3 < codes.len() => {
476                                self.current_style.bg = Some(Color::Rgb {
477                                    r: codes[i + 1] as u8,
478                                    g: codes[i + 2] as u8,
479                                    b: codes[i + 3] as u8,
480                                });
481                                i += 3;
482                            }
483                            _ => {}
484                        }
485                    }
486                }
487                49 => self.current_style.bg = None,
488                90..=97 => {
489                    self.current_style.fg = Some(bright_color(codes[i] as u8 - 90));
490                }
491                100..=107 => {
492                    self.current_style.bg = Some(bright_color(codes[i] as u8 - 100));
493                }
494                _ => {}
495            }
496            i += 1;
497        }
498    }
499}
500
501/// Map ANSI standard color index (0-7) to crossterm Color.
502fn standard_color(index: u8) -> Color {
503    match index {
504        0 => Color::Black,
505        1 => Color::DarkRed,
506        2 => Color::DarkGreen,
507        3 => Color::DarkYellow,
508        4 => Color::DarkBlue,
509        5 => Color::DarkMagenta,
510        6 => Color::DarkCyan,
511        7 => Color::Grey,
512        _ => Color::Reset,
513    }
514}
515
516/// Map ANSI bright color index (0-7) to crossterm Color.
517fn bright_color(index: u8) -> Color {
518    match index {
519        0 => Color::DarkGrey,
520        1 => Color::Red,
521        2 => Color::Green,
522        3 => Color::Yellow,
523        4 => Color::Blue,
524        5 => Color::Magenta,
525        6 => Color::Cyan,
526        7 => Color::White,
527        _ => Color::Reset,
528    }
529}
530
531impl Write for TestTerminal {
532    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
533        self.escape_buffer.extend_from_slice(buf);
534        Ok(buf.len())
535    }
536
537    fn flush(&mut self) -> io::Result<()> {
538        if !self.escape_buffer.is_empty() {
539            let bytes = std::mem::take(&mut self.escape_buffer);
540            self.process_bytes(&bytes);
541        }
542        Ok(())
543    }
544}
545
546/// Asserts a test terminal buffer matches the expected output.
547/// Each element of the expected vector represents a row.
548/// Trailing whitespace is ignored on each line.
549pub fn assert_buffer_eq<S: AsRef<str>>(terminal: &TestTerminal, expected: &[S]) {
550    let actual_lines = terminal.get_lines();
551    let max_lines = expected.len().max(actual_lines.len());
552
553    for i in 0..max_lines {
554        let expected_line = expected.get(i).map_or("", AsRef::as_ref);
555        let actual_line = actual_lines.get(i).map_or("", String::as_str);
556
557        assert_eq!(
558            actual_line,
559            expected_line,
560            "Line {i} mismatch:\n  Expected: '{expected_line}'\n  Got:      '{actual_line}'\n\nFull buffer:\n{}",
561            actual_lines.join("\n")
562        );
563    }
564}
565
566#[cfg(test)]
567mod tests {
568    use super::*;
569
570    #[test]
571    fn test_basic_write() {
572        let mut term = TestTerminal::new(80, 24);
573        write!(term, "Hello").unwrap();
574        term.flush().unwrap();
575        let lines = term.get_lines();
576        assert_eq!(lines[0], "Hello");
577    }
578
579    #[test]
580    fn test_newline() {
581        let mut term = TestTerminal::new(80, 24);
582        write!(term, "Line 1\nLine 2").unwrap();
583        term.flush().unwrap();
584        assert_buffer_eq(&term, &["Line 1", "Line 2"]);
585    }
586
587    #[test]
588    fn test_carriage_return() {
589        let mut term = TestTerminal::new(80, 24);
590        write!(term, "Hello\rWorld").unwrap();
591        term.flush().unwrap();
592        let lines = term.get_lines();
593        assert_eq!(lines[0], "World");
594    }
595
596    #[test]
597    fn test_ansi_cursor_position() {
598        let mut term = TestTerminal::new(80, 24);
599        write!(term, "\x1b[3;5HX").unwrap();
600        term.flush().unwrap();
601        let lines = term.get_lines();
602        assert_eq!(&lines[2][4..5], "X");
603    }
604
605    #[test]
606    fn test_ansi_clear_line() {
607        let mut term = TestTerminal::new(80, 24);
608        write!(term, "Hello World").unwrap();
609        write!(term, "\x1b[1G\x1b[K").unwrap();
610        term.flush().unwrap();
611        let lines = term.get_lines();
612        assert_eq!(lines[0], "");
613    }
614
615    #[test]
616    fn test_assert_buffer_eq() {
617        let mut term = TestTerminal::new(80, 24);
618        write!(term, "Line 1\nLine 2\nLine 3").unwrap();
619        term.flush().unwrap();
620
621        assert_buffer_eq(&term, &["Line 1", "Line 2", "Line 3"]);
622    }
623
624    #[test]
625    #[should_panic(expected = "Line 0 mismatch")]
626    fn test_assert_buffer_eq_fails() {
627        let mut term = TestTerminal::new(80, 24);
628        write!(term, "Wrong").unwrap();
629        term.flush().unwrap();
630
631        assert_buffer_eq(&term, &["Expected"]);
632    }
633
634    #[test]
635    fn test_private_mode_sequences_ignored() {
636        let mut term = TestTerminal::new(80, 24);
637        write!(term, "\x1b[?2026hHello\x1b[?2026l").unwrap();
638        term.flush().unwrap();
639        let lines = term.get_lines();
640        assert_eq!(lines[0], "Hello");
641    }
642
643    #[test]
644    fn test_cursor_save_restore() {
645        let mut term = TestTerminal::new(80, 24);
646
647        write!(term, "\x1b[6;11HFirst").unwrap();
648        write!(term, "\x1b7").unwrap();
649        write!(term, "\x1b[1;1HSecond").unwrap();
650        write!(term, "\x1b8Third").unwrap();
651
652        term.flush().unwrap();
653
654        let lines = term.get_lines();
655        assert_eq!(lines[0], "Second");
656        assert_eq!(lines[5], "          FirstThird");
657    }
658
659    #[test]
660    fn test_transcript_includes_scrolled_off_lines() {
661        let mut term = TestTerminal::new(6, 2);
662        write!(term, "L1\nL2\nL3").unwrap();
663        term.flush().unwrap();
664
665        let visible = term.get_lines();
666        assert_eq!(visible[0], "L2");
667        assert_eq!(visible[1], "L3");
668
669        let transcript = term.get_transcript_lines();
670        assert_eq!(transcript, vec!["L1", "L2", "L3"]);
671    }
672
673    #[test]
674    fn test_sgr_bold() {
675        let mut term = TestTerminal::new(80, 24);
676        write!(term, "\x1b[1mbold\x1b[0m").unwrap();
677        term.flush().unwrap();
678        let lines = term.get_lines();
679        assert_eq!(lines[0], "bold");
680        assert!(term.get_style_at(0, 0).bold);
681        assert!(!term.get_style_at(0, 4).bold);
682    }
683
684    #[test]
685    fn test_sgr_fg_color() {
686        let mut term = TestTerminal::new(80, 24);
687        write!(term, "\x1b[31mred\x1b[0m").unwrap();
688        term.flush().unwrap();
689        assert_eq!(term.get_style_at(0, 0).fg, Some(Color::DarkRed));
690        assert_eq!(term.get_style_at(0, 3).fg, None);
691    }
692
693    #[test]
694    fn test_sgr_rgb_color() {
695        let mut term = TestTerminal::new(80, 24);
696        write!(term, "\x1b[38;2;255;128;0mrgb\x1b[0m").unwrap();
697        term.flush().unwrap();
698        assert_eq!(term.get_style_at(0, 0).fg, Some(Color::Rgb { r: 255, g: 128, b: 0 }));
699    }
700
701    #[test]
702    fn test_style_of_text() {
703        let mut term = TestTerminal::new(80, 24);
704        write!(term, "plain \x1b[1mbold\x1b[0m rest").unwrap();
705        term.flush().unwrap();
706        let style = term.style_of_text(0, "bold").unwrap();
707        assert!(style.bold);
708        let style = term.style_of_text(0, "plain").unwrap();
709        assert!(!style.bold);
710    }
711}