1use std::os::fd::{BorrowedFd, RawFd};
2use std::os::unix::prelude::AsRawFd;
3use std::time::Duration;
4
5use async_recursion::async_recursion;
6use eyre::{eyre, Result};
7use nix::poll::{poll, PollFd, PollFlags};
8use nix::sys::select::FdSet;
9use nix::sys::signal::Signal;
10use nix::sys::signalfd::SigSet;
11use nix::sys::termios;
12use nix::sys::termios::InputFlags;
13use nix::sys::time::TimeSpec;
14
15#[derive(Debug, Clone)] pub struct ConsoleState<'a>(#[doc(hidden)] BorrowedFd<'a>);
17
18pub async fn init(fd: Option<RawFd>) -> Result<ConsoleState<'static>> {
19 Ok(ConsoleState(unsafe {
21 BorrowedFd::borrow_raw(if let Some(fd) = fd {
22 fd
23 } else {
24 std::io::stderr().as_raw_fd()
25 })
26 }))
27}
28
29pub async fn next_keypress(state: &ConsoleState<'static>) -> Result<Option<Keypress>> {
76 let original_termios = termios::tcgetattr(state.0)?;
77 let mut termios = original_termios.clone();
78
79 termios.input_flags &= !(InputFlags::IGNBRK
81 | InputFlags::BRKINT
82 | InputFlags::PARMRK
83 | InputFlags::ISTRIP
84 | InputFlags::INLCR
85 | InputFlags::IGNCR
86 | InputFlags::ICRNL
87 | InputFlags::IXON);
88 termios.local_flags &= !(termios::LocalFlags::ECHO
89 | termios::LocalFlags::ECHONL
90 | termios::LocalFlags::ICANON
91 | termios::LocalFlags::ISIG
92 | termios::LocalFlags::IEXTEN);
93 termios::tcsetattr(state.0, termios::SetArg::TCSADRAIN, &termios)?;
94
95 let out = read_next_key(&state.0).await;
96
97 termios::tcsetattr(state.0, termios::SetArg::TCSADRAIN, &original_termios)?;
98
99 out
100}
101
102#[async_recursion]
103async fn read_next_key(fd: &BorrowedFd<'_>) -> Result<Option<Keypress>> {
104 match read_char(fd)? {
105 Some('\x1b') => match read_char(fd)? {
106 Some('[') => match read_char(fd)? {
107 Some('A') => Ok(Some(Keypress::Up)),
108 Some('B') => Ok(Some(Keypress::Down)),
109 Some('C') => Ok(Some(Keypress::Right)),
110 Some('D') => Ok(Some(Keypress::Left)),
111 Some('H') => Ok(Some(Keypress::Home)),
112 Some('F') => Ok(Some(Keypress::End)),
113 Some('Z') => Ok(Some(Keypress::ShiftTab)),
114 Some(byte3) => match read_char(fd)? {
115 Some('~') => match read_char(fd)? {
116 Some('1') => Ok(Some(Keypress::Home)),
117 Some('2') => Ok(Some(Keypress::Insert)),
118 Some('3') => Ok(Some(Keypress::Delete)),
119 Some('4') => Ok(Some(Keypress::End)),
120 Some('5') => Ok(Some(Keypress::PageUp)),
121 Some('6') => Ok(Some(Keypress::PageDown)),
122 Some('7') => Ok(Some(Keypress::Home)),
123 Some('8') => Ok(Some(Keypress::End)),
124 Some(byte5) => Ok(Some(Keypress::UnknownSequence(vec![
125 '\x1b', '[', byte3, '~', byte5,
126 ]))),
127 None => Ok(Some(Keypress::UnknownSequence(vec![
128 '\x1b', '[', byte3, '~',
129 ]))),
130 },
131 Some(byte4) => Ok(Some(Keypress::UnknownSequence(vec![
132 '\x1b', '[', byte3, byte4,
133 ]))),
134 None => Ok(Some(Keypress::UnknownSequence(vec!['\x1b', '[', byte3]))),
135 },
136 None => Ok(Some(Keypress::Escape)),
137 },
138 Some(byte) => Ok(Some(Keypress::UnknownSequence(vec!['\x1b', byte]))),
139 None => Ok(Some(Keypress::Escape)),
140 },
141 Some('\r') | Some('\n') => Ok(Some(Keypress::Return)),
142 Some('\t') => Ok(Some(Keypress::Tab)),
143 Some('\x7f') => Ok(Some(Keypress::Backspace)),
144 Some('\x01') => Ok(Some(Keypress::Home)),
145 Some('\x03') => Err(ConsoleError::Interrupted.into()),
147 Some('\x05') => Ok(Some(Keypress::End)),
148 Some('\x08') => Ok(Some(Keypress::Backspace)),
149 Some(byte) => {
150 if (byte as u8) & 224u8 == 192u8 {
151 let bytes = vec![byte as u8, read_byte(fd)?.unwrap()];
152 Ok(Some(Keypress::Char(char_from_utf8(&bytes)?)))
153 } else if (byte as u8) & 240u8 == 224u8 {
154 let bytes: Vec<u8> =
155 vec![byte as u8, read_byte(fd)?.unwrap(), read_byte(fd)?.unwrap()];
156 Ok(Some(Keypress::Char(char_from_utf8(&bytes)?)))
157 } else if (byte as u8) & 248u8 == 240u8 {
158 let bytes: Vec<u8> = vec![
159 byte as u8,
160 read_byte(fd)?.unwrap(),
161 read_byte(fd)?.unwrap(),
162 read_byte(fd)?.unwrap(),
163 ];
164 Ok(Some(Keypress::Char(char_from_utf8(&bytes)?)))
165 } else {
166 Ok(Some(Keypress::Char(byte)))
167 }
168 }
169 None => {
170 let pollfd = PollFd::new(&fd, PollFlags::POLLIN);
172 let ret = poll(&mut [pollfd], 0)?;
173
174 if ret < 0 {
175 let last_error = std::io::Error::last_os_error();
176 if last_error.kind() == std::io::ErrorKind::Interrupted {
177 return Err(ConsoleError::Interrupted.into());
179 } else {
180 return Err(ConsoleError::Io(last_error).into());
181 }
182 }
183
184 Ok(None)
185 }
186 }
187}
188
189fn read_byte(fd: &BorrowedFd<'_>) -> Result<Option<u8>> {
190 let mut buf = [0u8; 1];
191 let mut read_fds = FdSet::new();
192 read_fds.insert(fd);
193
194 let mut signals = SigSet::empty();
195 signals.add(Signal::SIGINT);
196 signals.add(Signal::SIGTERM);
197 signals.add(Signal::SIGKILL);
198
199 match nix::sys::select::pselect(
200 fd.as_raw_fd() + 1,
201 Some(&mut read_fds),
202 Some(&mut FdSet::new()),
203 Some(&mut FdSet::new()),
204 Some(&TimeSpec::new(
205 0,
206 Duration::from_millis(50).as_nanos() as i64,
207 )),
208 Some(&signals),
209 ) {
210 Ok(0) => Ok(None),
211 Ok(_) => match nix::unistd::read(fd.as_raw_fd(), &mut buf) {
212 Ok(0) => Ok(None),
213 Ok(_) => Ok(Some(buf[0])),
214 Err(err) => Err(err.into()),
215 },
216 Err(err) => Err(err.into()),
217 }
218}
219
220fn read_char(fd: &BorrowedFd<'_>) -> Result<Option<char>> {
221 read_byte(fd).map(|byte| byte.map(|byte| byte as char))
222}
223
224fn char_from_utf8(buf: &[u8]) -> Result<char> {
225 let str = std::str::from_utf8(buf)?;
226 let ch = str.chars().next();
227 match ch {
228 Some(c) => Ok(c),
229 None => Err(eyre!("invalid utf8 sequence: {:?}", buf)),
230 }
231}
232
233#[derive(Debug, Clone, PartialEq, Eq)]
234pub enum Keypress {
235 Up,
236 Down,
237 Right,
238 Left,
239 Home,
240 End,
241 ShiftTab,
242 Insert,
243 Delete,
244 PageUp,
245 PageDown,
246 Return,
247 Tab,
248 Backspace,
249 Escape,
250 Char(char),
251 UnknownSequence(Vec<char>),
252}
253
254#[derive(thiserror::Error, Debug)]
255pub enum ConsoleError {
256 #[error("Interrupted!")]
257 Interrupted,
258 #[error("IO error: {0}")]
259 Io(#[from] std::io::Error),
260}