makeup_console/
lib.rs

1use std::os::fd::{BorrowedFd, RawFd};
2use std::os::unix::prelude::AsRawFd;
3use std::time::Duration;
4
5use async_recursion::async_recursion;
6use eyre::{eyre, Result};
7use nix::poll::{poll, PollFd, PollFlags};
8use nix::sys::select::FdSet;
9use nix::sys::signal::Signal;
10use nix::sys::signalfd::SigSet;
11use nix::sys::termios;
12use nix::sys::termios::InputFlags;
13use nix::sys::time::TimeSpec;
14
15#[derive(Debug, Clone)] // TODO: Are clone bounds safe here?
16pub struct ConsoleState<'a>(#[doc(hidden)] BorrowedFd<'a>);
17
18pub async fn init(fd: Option<RawFd>) -> Result<ConsoleState<'static>> {
19    // Safety: It's impossible for these to not be valid fds
20    Ok(ConsoleState(unsafe {
21        BorrowedFd::borrow_raw(if let Some(fd) = fd {
22            fd
23        } else {
24            std::io::stderr().as_raw_fd()
25        })
26    }))
27}
28
29/// - Check if stdin is a terminal (libc::isatty == 1)
30///   - If not, open /dev/tty
31/// - Put the terminal in raw input mode
32/// - Enable TCSADRAIN
33/// - Read a byte
34///   - If \x1b, csi, so read next byte
35///     - If next byte is [, start reading control sequence
36///       - Match next byte
37///         - A => up
38///         - B => down
39///         - C => right
40///         - D => left
41///         - H => home
42///         - F => end
43///         - Z => shift-tab
44///         - _ =>
45///           - Match next byte
46///             - ~ =>
47///               - Match next byte
48///                 - 1 => home
49///                 - 2 => insert
50///                 - 3 => delete
51///                 - 4 => end
52///                 - 5 => page up
53///                 - 6 => page down
54///                 - 7 => home
55///                 - 8 => end
56///                 - Else, the escape sequence was unknown
57///             - Else, the escape sequence was unknown
58///     - Else, if next byte is not [, bail out on unknown control sequence
59///     - Else, if there was no next byte, input was <ESC>
60///   - Else, if byte & 224u8 == 192u8, Unicode 2-byte
61///   - Else, if byte & 240u8 == 224u8, Unicode 3-byte
62///   - Else, if byte & 248u8 == 240u8, Unicode 4-byte
63///   - Else:
64///     - If byte == \r || byte == \n, <RETURN>
65///     - If byte == \t, <TAB>
66///     - If byte == \x7f, <BACKSPACE>
67///     - If byte == \x1b, <ESC>
68///     - If byte == \x01, <HOME>
69///     - If byte == \x05, <END>
70///     - If byte == \x08, <BACKSPACE>
71///     - Else, char = byte
72///   - Else, if no byte to read:
73///     - If stdin is a terminal, return None
74/// - Disable TCSADRAIN
75pub async fn next_keypress(state: &ConsoleState<'static>) -> Result<Option<Keypress>> {
76    let original_termios = termios::tcgetattr(state.0)?;
77    let mut termios = original_termios.clone();
78
79    // Note: This is ONLY what termios::cfmakeraw does to input
80    termios.input_flags &= !(InputFlags::IGNBRK
81        | InputFlags::BRKINT
82        | InputFlags::PARMRK
83        | InputFlags::ISTRIP
84        | InputFlags::INLCR
85        | InputFlags::IGNCR
86        | InputFlags::ICRNL
87        | InputFlags::IXON);
88    termios.local_flags &= !(termios::LocalFlags::ECHO
89        | termios::LocalFlags::ECHONL
90        | termios::LocalFlags::ICANON
91        | termios::LocalFlags::ISIG
92        | termios::LocalFlags::IEXTEN);
93    termios::tcsetattr(state.0, termios::SetArg::TCSADRAIN, &termios)?;
94
95    let out = read_next_key(&state.0).await;
96
97    termios::tcsetattr(state.0, termios::SetArg::TCSADRAIN, &original_termios)?;
98
99    out
100}
101
102#[async_recursion]
103async fn read_next_key(fd: &BorrowedFd<'_>) -> Result<Option<Keypress>> {
104    match read_char(fd)? {
105        Some('\x1b') => match read_char(fd)? {
106            Some('[') => match read_char(fd)? {
107                Some('A') => Ok(Some(Keypress::Up)),
108                Some('B') => Ok(Some(Keypress::Down)),
109                Some('C') => Ok(Some(Keypress::Right)),
110                Some('D') => Ok(Some(Keypress::Left)),
111                Some('H') => Ok(Some(Keypress::Home)),
112                Some('F') => Ok(Some(Keypress::End)),
113                Some('Z') => Ok(Some(Keypress::ShiftTab)),
114                Some(byte3) => match read_char(fd)? {
115                    Some('~') => match read_char(fd)? {
116                        Some('1') => Ok(Some(Keypress::Home)),
117                        Some('2') => Ok(Some(Keypress::Insert)),
118                        Some('3') => Ok(Some(Keypress::Delete)),
119                        Some('4') => Ok(Some(Keypress::End)),
120                        Some('5') => Ok(Some(Keypress::PageUp)),
121                        Some('6') => Ok(Some(Keypress::PageDown)),
122                        Some('7') => Ok(Some(Keypress::Home)),
123                        Some('8') => Ok(Some(Keypress::End)),
124                        Some(byte5) => Ok(Some(Keypress::UnknownSequence(vec![
125                            '\x1b', '[', byte3, '~', byte5,
126                        ]))),
127                        None => Ok(Some(Keypress::UnknownSequence(vec![
128                            '\x1b', '[', byte3, '~',
129                        ]))),
130                    },
131                    Some(byte4) => Ok(Some(Keypress::UnknownSequence(vec![
132                        '\x1b', '[', byte3, byte4,
133                    ]))),
134                    None => Ok(Some(Keypress::UnknownSequence(vec!['\x1b', '[', byte3]))),
135                },
136                None => Ok(Some(Keypress::Escape)),
137            },
138            Some(byte) => Ok(Some(Keypress::UnknownSequence(vec!['\x1b', byte]))),
139            None => Ok(Some(Keypress::Escape)),
140        },
141        Some('\r') | Some('\n') => Ok(Some(Keypress::Return)),
142        Some('\t') => Ok(Some(Keypress::Tab)),
143        Some('\x7f') => Ok(Some(Keypress::Backspace)),
144        Some('\x01') => Ok(Some(Keypress::Home)),
145        // ^C
146        Some('\x03') => Err(ConsoleError::Interrupted.into()),
147        Some('\x05') => Ok(Some(Keypress::End)),
148        Some('\x08') => Ok(Some(Keypress::Backspace)),
149        Some(byte) => {
150            if (byte as u8) & 224u8 == 192u8 {
151                let bytes = vec![byte as u8, read_byte(fd)?.unwrap()];
152                Ok(Some(Keypress::Char(char_from_utf8(&bytes)?)))
153            } else if (byte as u8) & 240u8 == 224u8 {
154                let bytes: Vec<u8> =
155                    vec![byte as u8, read_byte(fd)?.unwrap(), read_byte(fd)?.unwrap()];
156                Ok(Some(Keypress::Char(char_from_utf8(&bytes)?)))
157            } else if (byte as u8) & 248u8 == 240u8 {
158                let bytes: Vec<u8> = vec![
159                    byte as u8,
160                    read_byte(fd)?.unwrap(),
161                    read_byte(fd)?.unwrap(),
162                    read_byte(fd)?.unwrap(),
163                ];
164                Ok(Some(Keypress::Char(char_from_utf8(&bytes)?)))
165            } else {
166                Ok(Some(Keypress::Char(byte)))
167            }
168        }
169        None => {
170            // there is no subsequent byte ready to be read, block and wait for input
171            let pollfd = PollFd::new(&fd, PollFlags::POLLIN);
172            let ret = poll(&mut [pollfd], 0)?;
173
174            if ret < 0 {
175                let last_error = std::io::Error::last_os_error();
176                if last_error.kind() == std::io::ErrorKind::Interrupted {
177                    // User probably hit ^C, oops
178                    return Err(ConsoleError::Interrupted.into());
179                } else {
180                    return Err(ConsoleError::Io(last_error).into());
181                }
182            }
183
184            Ok(None)
185        }
186    }
187}
188
189fn read_byte(fd: &BorrowedFd<'_>) -> Result<Option<u8>> {
190    let mut buf = [0u8; 1];
191    let mut read_fds = FdSet::new();
192    read_fds.insert(fd);
193
194    let mut signals = SigSet::empty();
195    signals.add(Signal::SIGINT);
196    signals.add(Signal::SIGTERM);
197    signals.add(Signal::SIGKILL);
198
199    match nix::sys::select::pselect(
200        fd.as_raw_fd() + 1,
201        Some(&mut read_fds),
202        Some(&mut FdSet::new()),
203        Some(&mut FdSet::new()),
204        Some(&TimeSpec::new(
205            0,
206            Duration::from_millis(50).as_nanos() as i64,
207        )),
208        Some(&signals),
209    ) {
210        Ok(0) => Ok(None),
211        Ok(_) => match nix::unistd::read(fd.as_raw_fd(), &mut buf) {
212            Ok(0) => Ok(None),
213            Ok(_) => Ok(Some(buf[0])),
214            Err(err) => Err(err.into()),
215        },
216        Err(err) => Err(err.into()),
217    }
218}
219
220fn read_char(fd: &BorrowedFd<'_>) -> Result<Option<char>> {
221    read_byte(fd).map(|byte| byte.map(|byte| byte as char))
222}
223
224fn char_from_utf8(buf: &[u8]) -> Result<char> {
225    let str = std::str::from_utf8(buf)?;
226    let ch = str.chars().next();
227    match ch {
228        Some(c) => Ok(c),
229        None => Err(eyre!("invalid utf8 sequence: {:?}", buf)),
230    }
231}
232
233#[derive(Debug, Clone, PartialEq, Eq)]
234pub enum Keypress {
235    Up,
236    Down,
237    Right,
238    Left,
239    Home,
240    End,
241    ShiftTab,
242    Insert,
243    Delete,
244    PageUp,
245    PageDown,
246    Return,
247    Tab,
248    Backspace,
249    Escape,
250    Char(char),
251    UnknownSequence(Vec<char>),
252}
253
254#[derive(thiserror::Error, Debug)]
255pub enum ConsoleError {
256    #[error("Interrupted!")]
257    Interrupted,
258    #[error("IO error: {0}")]
259    Io(#[from] std::io::Error),
260}