ctf_pwn/io/stdio/
shell_bridge.rs

1use crate::io::stdio::{is_stop_terminal, is_terminate_process, TerminalBridge, TerminalResult};
2
3use std::io::stdout;
4use std::io::ErrorKind::TimedOut;
5use std::io::Write;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10
11use crossterm::cursor::{DisableBlinking, EnableBlinking, MoveTo};
12use crossterm::event::{
13    DisableBracketedPaste, EnableBracketedPaste, KeyCode, KeyEvent, KeyEventKind,
14};
15use crossterm::style::Print;
16use crossterm::terminal::{Clear, ClearType, EnterAlternateScreen, LeaveAlternateScreen};
17use crossterm::*;
18use tokio::join;
19use tokio::sync::mpsc::error::TryRecvError;
20use tokio::sync::mpsc::{channel, Receiver, Sender};
21use crate::io::{AsyncReadTimeoutExt, TerminalError};
22
23pub struct ShellTerminalBridge {}
24
25struct StdoutState<'a, W: AsyncWrite + Unpin> {
26    text: String,
27    start_position: (u16, u16),
28    cursor_position: (u16, u16),
29    //TODO: Handle resize correctly
30    current_dimensions: (u16, u16),
31    writer: &'a mut W,
32    stop_signal: Arc<AtomicBool>,
33}
34
35impl<'a, W: AsyncWrite + Unpin> StdoutState<'a, W> {
36    pub fn new(writer: &mut W, stop_signal: Arc<AtomicBool>) -> TerminalResult<StdoutState<W>> {
37        let cursor_position = cursor::position()?;
38        let current_dimensions = terminal::size()?;
39        Ok(StdoutState {
40            text: String::new(),
41            start_position: cursor_position,
42            cursor_position,
43            current_dimensions,
44            writer,
45            stop_signal,
46        })
47    }
48
49    pub async fn insert(&mut self, key_event: KeyEvent) -> TerminalResult<()> {
50        if is_terminate_process(key_event) {
51            return Err(TerminalError::Terminate);
52        }
53
54        if is_stop_terminal(key_event) {
55            self.stop_signal.store(true, Ordering::SeqCst);
56            return Ok(());
57        }
58
59        if key_event.kind != KeyEventKind::Press && key_event.kind != KeyEventKind::Repeat {
60            return Ok(());
61        }
62
63        match key_event.code {
64            KeyCode::Char(c) => self.insert_char(c)?,
65            KeyCode::Left => self.decrement_cursor()?,
66            KeyCode::Right => self.increment_cursor()?,
67            KeyCode::Backspace => self.backspace()?,
68            KeyCode::Delete => self.del()?,
69            KeyCode::Enter => self.send_data().await?,
70            KeyCode::Home => self.home()?,
71            KeyCode::End => self.end()?,
72            _ => {}
73        };
74
75        Ok(())
76    }
77
78    fn get_cursor_relative_index(&self) -> usize {
79        let (start_x, start_y) = self.start_position;
80        let (end_x, end_y) = self.cursor_position;
81        let (w, _h) = self.current_dimensions;
82
83        let full_lines = if end_y > start_y {
84            end_y - start_y - 1
85        } else {
86            0
87        };
88        let last_line_chars = if end_y > start_y {
89            end_x
90        } else {
91            end_x - start_x
92        };
93
94        full_lines as usize * w as usize + last_line_chars as usize
95    }
96
97    fn set_cursor_relative_index(&mut self, index: usize) -> TerminalResult<()> {
98        if index > self.text.len() {
99            return Ok(());
100        }
101
102        let (start_x, start_y) = self.start_position;
103        let (w, _) = self.current_dimensions;
104
105        let lines_down = index / w as usize;
106        let new_y = start_y + lines_down as u16;
107
108        let new_x = (start_x as usize + index % w as usize) as u16;
109
110        let (final_x, final_y) = if new_x >= w {
111            (new_x - w, new_y + 1)
112        } else {
113            (new_x, new_y)
114        };
115
116        self.set_cursor_position(final_x, final_y)
117    }
118
119    fn increment_cursor(&mut self) -> TerminalResult<()> {
120        let index = self.get_cursor_relative_index();
121        if index >= self.text.len() {
122            return Ok(());
123        }
124        self.set_cursor_relative_index(index + 1)
125    }
126
127    fn decrement_cursor(&mut self) -> TerminalResult<()> {
128        let index = self.get_cursor_relative_index();
129        if index <= 0 {
130            return Ok(());
131        }
132        self.set_cursor_relative_index(index - 1)
133    }
134
135    pub fn set_cursor_position(&mut self, x: u16, y: u16) -> TerminalResult<()> {
136        self.cursor_position = (x, y);
137        execute!(stdout(), MoveTo(x, y))?;
138        Ok(())
139    }
140
141    pub fn print(&mut self, data: &[u8]) -> TerminalResult<()> {
142        self.clear()?;
143        terminal::disable_raw_mode().unwrap();
144        stdout().write_all(data)?;
145        stdout().flush()?;
146        terminal::enable_raw_mode().unwrap();
147        self.start_position = cursor::position()?;
148        self.cursor_position = self.start_position;
149        self.redraw()?;
150        Ok(())
151    }
152
153    pub fn insert_str(&mut self, text: &str) -> TerminalResult<()> {
154        let index = self.get_cursor_relative_index();
155        self.text.insert_str(index, text);
156        self.redraw()?;
157        self.set_cursor_relative_index(index + text.len())?;
158        Ok(())
159    }
160
161    pub fn insert_char(&mut self, c: char) -> TerminalResult<()> {
162        let index = self.get_cursor_relative_index();
163        self.text.insert(index, c);
164        self.redraw()?;
165        self.set_cursor_relative_index(index + 1)?;
166        Ok(())
167    }
168
169    fn del(&mut self) -> TerminalResult<()> {
170        let index = self.get_cursor_relative_index();
171        if index >= self.text.len() {
172            return Ok(());
173        }
174        let _ = self.text.remove(index);
175        self.redraw()?;
176        self.set_cursor_relative_index(index)?;
177        Ok(())
178    }
179
180    fn backspace(&mut self) -> TerminalResult<()> {
181        let index = self.get_cursor_relative_index();
182        if index <= 0 {
183            return Ok(());
184        }
185        let _ = self.text.remove(index - 1);
186        self.redraw()?;
187        self.set_cursor_relative_index(index - 1)?;
188        Ok(())
189    }
190
191    pub fn home(&mut self) -> TerminalResult<()> {
192        self.set_cursor_relative_index(0)
193    }
194
195    pub fn end(&mut self) -> TerminalResult<()> {
196        self.set_cursor_relative_index(self.text.len())
197    }
198
199    pub async fn send_data(&mut self) -> TerminalResult<()> {
200        self.end()?;
201        self.redraw()?;
202        let (_, y) = cursor::position()?;
203        println!();
204        self.set_cursor_position(0, y + 1)?;
205        self.start_position = self.cursor_position;
206        let mut text = self.text.clone();
207        self.text.clear();
208        text.push('\n');
209
210        self.writer.write_all(text.as_bytes()).await?;
211        self.writer.flush().await?;
212        Ok(())
213    }
214
215    fn clear(&mut self) -> TerminalResult<()> {
216        let (start_x, start_y) = self.start_position;
217        self.set_cursor_position(start_x, start_y)?;
218
219        execute!(
220            stdout(),
221            MoveTo(start_x, start_y),
222            Clear(ClearType::FromCursorDown)
223        )?;
224        Ok(())
225    }
226
227    pub fn redraw(&mut self) -> TerminalResult<()> {
228        self.clear()?;
229        let (start_x, start_y) = self.start_position;
230
231        execute!(stdout(), MoveTo(start_x, start_y), Print(&self.text))?;
232        Ok(())
233    }
234}
235
236async fn read_task<R>(
237    reader: &mut R,
238    stop_signal: Arc<AtomicBool>,
239    sender: Sender<Vec<u8>>,
240) -> TerminalResult<()>
241where
242    R: AsyncRead + Send + Unpin,
243{
244    let mut buffer = [0; 1024];
245    loop {
246        if stop_signal.load(Ordering::SeqCst) {
247            break;
248        }
249
250        let n = match reader.read_timeout(&mut buffer, Duration::from_secs(1)).await {
251            Ok(n) if n == 0 => break, //EOF
252            Ok(n) => n,
253            Err(e) if e.kind() == TimedOut =>{
254                continue;
255            }
256            Err(e) => return Err(e.into()),
257        };
258
259        sender.send(buffer[..n].to_vec()).await?
260    }
261    Ok(())
262}
263
264async fn write_task<W>(
265    writer: &mut W,
266    stop_signal: Arc<AtomicBool>,
267    mut receiver: Receiver<Vec<u8>>,
268) -> TerminalResult<()>
269where
270    W: AsyncWrite + Send + Unpin,
271{
272    let mut stdout = StdoutState::new(writer, stop_signal.clone())?;
273
274    loop {
275        if stop_signal.load(Ordering::SeqCst) {
276            return Ok(());
277        }
278
279        match receiver.try_recv() {
280            Ok(data) => {
281                stdout.print(&data)?;
282            }
283            Err(TryRecvError::Empty) => {}
284            Err(TryRecvError::Disconnected) => return Err(TryRecvError::Disconnected.into()),
285        }
286
287        if let Ok(true) = event::poll(Duration::from_millis(0)) {
288            match event::read() {
289                Ok(event::Event::Key(key_event)) => {
290                    stdout.insert(key_event).await?;
291                }
292                Ok(event::Event::Resize(_width, _height)) => {
293                    //TODO: recalculate cursor after resize
294                }
295                Ok(event::Event::Paste(text)) => {
296                    stdout.insert_str(&text)?;
297                }
298                _ => {}
299            }
300        }
301    }
302}
303
304impl TerminalBridge for ShellTerminalBridge {
305    async fn bridge<R: AsyncRead + Send + Unpin, W: AsyncWrite + Send + Unpin>(
306        reader: &mut R,
307        writer: &mut W,
308    ) {
309        let reader_ptr = reader as *mut R as usize;
310        let writer_ptr = writer as *mut W as usize;
311
312        let (rx, tx) = channel(100);
313
314        let _ = execute!(stdout(), EnterAlternateScreen);
315        terminal::enable_raw_mode().unwrap();
316        let _ = execute!(stdout(), EnableBlinking, EnableBracketedPaste);
317
318        let stop_signal = Arc::new(AtomicBool::new(false));
319        let read_stop_signal = stop_signal.clone();
320
321        let reader_task = tokio::spawn(async move {
322            let reader_ptr = reader_ptr as *mut R;
323            let reader = unsafe { &mut *reader_ptr };
324            let res = read_task(reader, read_stop_signal.clone(), rx).await;
325            read_stop_signal.store(true, Ordering::SeqCst);
326            res
327        });
328
329        let writer_task = tokio::spawn(async move {
330            let writer_ptr = writer_ptr as *mut W;
331            let writer = unsafe { &mut *writer_ptr };
332            let res = write_task(writer, stop_signal.clone(), tx).await;
333            stop_signal.store(true, Ordering::SeqCst);
334            res
335        });
336
337        let (read_res, write_res) = join!(reader_task, writer_task);
338
339        let _ = execute!(stdout(), DisableBlinking, DisableBracketedPaste);
340
341        terminal::disable_raw_mode().unwrap();
342
343        let _ = execute!(stdout(), LeaveAlternateScreen);
344
345        if let Ok(Err(TerminalError::Terminate)) = read_res
346        {
347            std::process::exit(130); //SIGINT
348        }
349        else if  let Ok(Err(TerminalError::Terminate)) = write_res
350        {
351            std::process::exit(130); //SIGINT
352        }
353    }
354}