getch_rs/
lib.rs

1//! # getch-rs
2//!
3//! `getch` is a C language function designed to capture a single character input from the keyboard without requiring the user to press the Enter key. This function suspends program execution until the user provides input. Typically employed in console-based programs, it proves useful for scenarios where menu selection or awaiting key input is required.
4//!
5//! ## Example
6//!
7//! ```no_run
8//! use getch_rs::{Getch, Key};
9//!
10//! fn main() {
11//!     let g = Getch::new();
12//!
13//!     println!("press `q` to exit");
14//!
15//!     loop {
16//!         match g.getch() {
17//!             Ok(Key::Char('q')) => break,
18//!             Ok(key) => println!("{:?}", key),
19//!             Err(e) => println!("{}", e),
20//!         }
21//!     }
22//! }
23//! ```
24
25#[cfg(windows)]
26use winapi::{
27    shared::minwindef::DWORD,
28    um::consoleapi::{GetConsoleMode, SetConsoleMode},
29    um::handleapi::INVALID_HANDLE_VALUE,
30    um::processenv::GetStdHandle,
31    um::winbase::STD_INPUT_HANDLE,
32    um::wincon::{ENABLE_ECHO_INPUT, ENABLE_VIRTUAL_TERMINAL_INPUT},
33};
34
35#[cfg(not(windows))]
36use nix::sys::termios;
37
38use std::cell::RefCell;
39use std::io::Read;
40
41#[cfg(windows)]
42pub struct Getch {
43    orig_term: DWORD,
44    leftover: RefCell<Option<u8>>,
45}
46
47#[cfg(not(windows))]
48pub struct Getch {
49    orig_term: termios::Termios,
50    leftover: RefCell<Option<u8>>,
51}
52
53/// Keys
54#[derive(Debug, Clone, PartialEq, Eq, Hash)]
55pub enum Key {
56    /// Null byte.
57    EOF,
58    /// Backspace.
59    Backspace,
60    /// Delete key.
61    Delete,
62    /// Esc key.
63    Esc,
64    /// Up arrow.
65    Up,
66    /// Down arrow.
67    Down,
68    /// Right arrow.
69    Right,
70    /// Left arrow.
71    Left,
72    /// End key.
73    End,
74    /// Home key.
75    Home,
76    /// Backward Tab key.
77    BackTab,
78    /// Insert key.
79    Insert,
80    /// Page Up key.
81    PageUp,
82    /// Page Down key.
83    PageDown,
84    /// Function keys.
85    ///
86    /// Only function keys 1 through 12 are supported.
87    F(u8),
88    /// Normal character.
89    Char(char),
90    /// Alt modified character.
91    Alt(char),
92    /// Ctrl modified character.
93    ///
94    /// Note that certain keys may not be modifiable with `ctrl`, due to limitations of terminals.
95    Ctrl(char),
96    /// Other key.
97    Other(Vec<u8>),
98}
99
100impl Getch {
101    #[cfg(windows)]
102    #[allow(clippy::new_without_default)]
103    pub fn new() -> Self {
104        let mut console_mode: DWORD = 0;
105
106        unsafe {
107            let input_handle = GetStdHandle(STD_INPUT_HANDLE);
108            if GetConsoleMode(input_handle, &mut console_mode) != 0 {
109                SetConsoleMode(input_handle, ENABLE_VIRTUAL_TERMINAL_INPUT);
110            }
111        }
112
113        Self {
114            orig_term: console_mode,
115            leftover: RefCell::new(None),
116        }
117    }
118    #[cfg(not(windows))]
119    #[allow(clippy::new_without_default)]
120    pub fn new() -> Self {
121        // Quering original as a separate, since `Termios` does not implement copy
122        let orig_term       = termios::tcgetattr(0).unwrap();
123        let mut raw_termios = termios::tcgetattr(0).unwrap();
124
125        // Unset canonical mode, so we get characters immediately
126        raw_termios.local_flags.remove(termios::LocalFlags::ICANON);
127        // Don't generate signals on Ctrl-C and friends
128        raw_termios.local_flags.remove(termios::LocalFlags::ISIG);
129        // Disable local echo
130        raw_termios.local_flags.remove(termios::LocalFlags::ECHO);
131
132        termios::tcsetattr(0, termios::SetArg::TCSADRAIN, &raw_termios).unwrap();
133
134        Self {
135            orig_term,
136            leftover: RefCell::new(None),
137        }
138    }
139
140    #[allow(clippy::unused_io_amount)]
141    pub fn getch(&self) -> Result<Key, std::io::Error> {
142        let source = &mut std::io::stdin();
143        let mut buf: [u8; 2] = [0; 2];
144
145        if self.leftover.borrow().is_some() {
146            // we have a leftover byte, use it
147            let c = self.leftover.borrow().unwrap();
148            self.leftover.replace(None);
149            return parse_key(c, &mut source.bytes());
150        }
151
152        match source.read(&mut buf) {
153            Ok(0) => Ok(Key::Ctrl('z')),
154            Ok(1) => match buf[0] {
155                b'\x1B' => Ok(Key::Esc),
156                c => parse_key(c, &mut source.bytes()),
157            },
158            Ok(2) => {
159                let option_iter = &mut Some(buf[1]).into_iter();
160                let result = {
161                    let mut iter = option_iter.map(Ok).chain(source.bytes());
162                    parse_key(buf[0], &mut iter)
163                };
164                // If the option_iter wasn't consumed, keep the byte for later.
165                self.leftover.replace(option_iter.next());
166                result
167            }
168            Ok(_) => unreachable!(),
169            Err(e) => Err(e),
170        }
171    }
172}
173
174/// Enable local echo
175pub fn enable_echo_input() {
176    #[cfg(windows)]
177    unsafe {
178        let input_handle = GetStdHandle(STD_INPUT_HANDLE);
179        let mut console_mode: DWORD = 0;
180
181        if input_handle == INVALID_HANDLE_VALUE {
182            return;
183        }
184
185        if GetConsoleMode(input_handle, &mut console_mode) != 0 {
186            SetConsoleMode(input_handle, console_mode | ENABLE_ECHO_INPUT);
187        }
188    }
189
190    #[cfg(not(windows))]
191    {
192        let mut raw_termios = termios::tcgetattr(0).unwrap();
193        raw_termios.local_flags.insert(termios::LocalFlags::ECHO);
194        termios::tcsetattr(0, termios::SetArg::TCSADRAIN, &raw_termios).unwrap();
195    }
196}
197
198/// Disable local echo
199pub fn disable_echo_input() {
200    #[cfg(windows)]
201    unsafe {
202        let input_handle = GetStdHandle(STD_INPUT_HANDLE);
203        let mut console_mode: DWORD = 0;
204
205        if input_handle == INVALID_HANDLE_VALUE {
206            return;
207        }
208
209        if GetConsoleMode(input_handle, &mut console_mode) != 0 {
210            SetConsoleMode(input_handle, console_mode & !ENABLE_ECHO_INPUT);
211        }
212    }
213
214    #[cfg(not(windows))]
215    {
216        let mut raw_termios = termios::tcgetattr(0).unwrap();
217        raw_termios.local_flags.remove(termios::LocalFlags::ECHO);
218        termios::tcsetattr(0, termios::SetArg::TCSADRAIN, &raw_termios).unwrap();
219    }
220}
221
222/// Parse an Event from `item` and possibly subsequent bytes through `iter`.
223fn parse_key<I>(item: u8, iter: &mut I) -> Result<Key, std::io::Error>
224where
225    I: Iterator<Item = Result<u8, std::io::Error>>,
226{
227    match item {
228        b'\x1B' => {
229            Ok(match iter.next() {
230                Some(Ok(b'[')) => parse_csi(iter)?,
231                Some(Ok(b'O')) => {
232                    match iter.next() {
233                        // F1-F4
234                        Some(Ok(val @ b'P'..=b'S')) => Key::F(1 + val - b'P'),
235                        Some(Ok(val)) => Key::Other(vec![b'\x1B', b'O', val]),
236                        _ => Key::Other(vec![b'\x1B', b'O']),
237                    }
238                }
239                Some(Ok(c)) => match parse_utf8_char(c, iter)? {
240                    Ok(ch)   => Key::Alt(ch),
241                    Err(vec) => Key::Other(vec),
242                },
243                Some(Err(e)) => return Err(e),
244                None => Key::Esc,
245            })
246        }
247        b'\n' | b'\r'         => Ok(Key::Char('\r')),
248        b'\t'                 => Ok(Key::Char('\t')),
249        b'\x08'               => Ok(Key::Backspace),
250        b'\x7F'               => Ok(Key::Delete),
251        c @ b'\x01'..=b'\x1A' => Ok(Key::Ctrl((c - 0x1 + b'a') as char)),
252        c @ b'\x1C'..=b'\x1F' => Ok(Key::Ctrl((c - 0x1C + b'4') as char)),
253        b'\0'                 => Ok(Key::EOF),
254        c => Ok(match parse_utf8_char(c, iter)? {
255            Ok(ch)   => Key::Char(ch),
256            Err(vec) => Key::Other(vec),
257        }),
258    }
259}
260
261/// Parses a CSI sequence, just after reading ^[
262///
263/// Returns None if an unrecognized sequence is found.
264fn parse_csi<I>(iter: &mut I) -> Result<Key, std::io::Error>
265where
266    I: Iterator<Item = Result<u8, std::io::Error>>,
267{
268    Ok(match iter.next() {
269        Some(Ok(b'[')) => match iter.next() {
270            Some(Ok(val @ b'A'..=b'E')) => Key::F(1 + val - b'A'),
271            Some(Ok(val)) => Key::Other(vec![b'\x1B', b'[', b'[', val]),
272            _ => Key::Other(vec![b'\x1B', b'[', b'[']),
273        },
274        Some(Ok(b'A')) => Key::Up,
275        Some(Ok(b'B')) => Key::Down,
276        Some(Ok(b'C')) => Key::Right,
277        Some(Ok(b'D')) => Key::Left,
278        Some(Ok(b'F')) => Key::End,
279        Some(Ok(b'H')) => Key::Home,
280        Some(Ok(b'Z')) => Key::BackTab,
281        Some(Ok(c @ b'0'..=b'9')) => {
282            // Numbered escape code.
283            let mut buf = vec![c];
284            let mut c = iter.next().unwrap().unwrap();
285            // The final byte of a CSI sequence can be in the range 64-126, so
286            // let's keep reading anything else.
287            while !(64..=126).contains(&c) {  // c < 64 || 126 < c
288                buf.push(c);
289                c = iter.next().unwrap().unwrap();
290            }
291            match c {
292                // Special key code.
293                b'~' => {
294                    let str_buf = std::str::from_utf8(&buf).unwrap();
295
296                    // This CSI sequence can be a list of semicolon-separated
297                    // numbers.
298                    let nums: Vec<u8> = str_buf.split(';').map(|n| n.parse().unwrap()).collect();
299
300                    if nums.is_empty() || nums.len() > 1 {
301                        let mut keys = vec![b'\x1B', b'['];
302                        keys.append(&mut buf);
303                        return Ok(Key::Other(keys));
304                    }
305
306                    match nums[0] {
307                        1 | 7 => Key::Home,
308                        2     => Key::Insert,
309                        3     => Key::Delete,
310                        4 | 8 => Key::End,
311                        5     => Key::PageUp,
312                        6     => Key::PageDown,
313                        v @ 11..=15 => Key::F(v - 10),
314                        v @ 17..=21 => Key::F(v - 11),
315                        v @ 23..=24 => Key::F(v - 12),
316                        _ => {
317                            let mut keys = vec![b'\x1B', b'['];
318                            keys.append(&mut buf);
319                            keys.push(nums[0]);
320                            return Ok(Key::Other(keys));
321                        }
322                    }
323                }
324                _ => {
325                    let mut keys = vec![b'\x1B', b'['];
326                    keys.append(&mut buf);
327                    keys.push(c);
328                    return Ok(Key::Other(keys));
329                }
330            }
331        }
332        Some(Ok(c)) => Key::Other(vec![b'\x1B', b'[', c]),
333        _ => Key::Other(vec![b'\x1B', b'[']),
334    })
335}
336
337/// Parse `c` as either a single byte ASCII char or a variable size UTF-8 char.
338fn parse_utf8_char<I>(c: u8, iter: &mut I) -> Result<Result<char, Vec<u8>>, std::io::Error>
339where
340    I: Iterator<Item = Result<u8, std::io::Error>>,
341{
342    if c.is_ascii() {
343        Ok(Ok(c as char))
344    } else {
345        let bytes = &mut Vec::new();
346        bytes.push(c);
347
348        loop {
349            match iter.next() {
350                Some(Ok(next)) => {
351                    bytes.push(next);
352                    if let Ok(st) = std::str::from_utf8(bytes) {
353                        return Ok(Ok(st.chars().next().unwrap()));
354                    }
355                    if bytes.len() >= 4 {
356                        return Ok(Err(bytes.to_vec()));
357                    }
358                }
359                _ => return Ok(Err(bytes.to_vec())),
360            }
361        }
362    }
363}
364
365impl Drop for Getch {
366    #[cfg(windows)]
367    fn drop(&mut self) {
368        unsafe {
369            let input_handle = GetStdHandle(STD_INPUT_HANDLE);
370            SetConsoleMode(input_handle, self.orig_term);
371        }
372    }
373
374    #[cfg(not(windows))]
375    fn drop(&mut self) {
376        termios::tcsetattr(0, termios::SetArg::TCSADRAIN, &self.orig_term).unwrap();
377    }
378}