pwn_helper/
io.rs

1use std::{
2    io::{stdin, stdout, ErrorKind, Read, Write},
3    net::TcpStream,
4    sync::mpsc::TryRecvError,
5    thread,
6    time::Duration,
7};
8use thiserror::Error;
9
10#[derive(Debug, Error)]
11/// Errors that can be encountered during IO operations
12pub enum PwnIoError {
13    #[error("IO: {0}")]
14    Io(std::io::Error),
15    #[error("End of file")]
16    Eof,
17}
18
19impl From<std::io::Error> for PwnIoError {
20    fn from(x: std::io::Error) -> Self {
21        match x.kind() {
22            std::io::ErrorKind::UnexpectedEof => Self::Eof,
23            _ => Self::Io(x),
24        }
25    }
26}
27
28/// Trait for read IO operations
29pub trait PwnIoRead {
30    /// Attempts to read from the stream until the pattern specified in `bytes` is reached, optionally trimming the pattern from the returned bytes
31    fn receive_until(&mut self, pattern: &[u8], trim_pattern: bool) -> Result<Vec<u8>, PwnIoError>;
32
33    /// Attempts to read from the stream until it reaches a newline character, optionally trimming the newline from the returned bytes
34    fn receive_line(&mut self, trim_newline: bool) -> Result<Vec<u8>, PwnIoError>;
35
36    /// Attempts to read all bytes currently available
37    fn receive(&mut self) -> Result<Vec<u8>, PwnIoError>;
38
39    /// Attempts to read all bytes until the stream ends
40    fn receive_all(&mut self) -> Result<Vec<u8>, PwnIoError>;
41
42    /// Attempts to read `count` bytes until the stream ends
43    fn receive_count(&mut self, count: usize) -> Result<Vec<u8>, PwnIoError>;
44}
45
46impl<T> PwnIoRead for T
47where
48    T: Read,
49{
50    fn receive_until(&mut self, pattern: &[u8], trim_pattern: bool) -> Result<Vec<u8>, PwnIoError> {
51        log::debug!("Receiving until: {:?}", pattern);
52        let mut out = Vec::new();
53        let mut buffer = [0u8];
54        while !out.ends_with(pattern) {
55            self.read_exact(&mut buffer)?;
56            out.push(buffer[0]);
57        }
58
59        if trim_pattern {
60            for _ in 0..pattern.len() {
61                out.pop();
62            }
63        }
64        Ok(out)
65    }
66
67    fn receive_line(&mut self, trim_newline: bool) -> Result<Vec<u8>, PwnIoError> {
68        log::debug!("Receiving line");
69        self.receive_until(b"\n", trim_newline)
70    }
71
72    fn receive(&mut self) -> Result<Vec<u8>, PwnIoError> {
73        log::debug!("Receiving");
74        let mut out = Vec::new();
75        let mut buffer = [0u8];
76        loop {
77            if self.read(&mut buffer)? == 0 {
78                return Ok(out);
79            } else {
80                out.push(buffer[0]);
81            }
82        }
83    }
84
85    fn receive_all(&mut self) -> Result<Vec<u8>, PwnIoError> {
86        log::debug!("Receiving all");
87        let mut out = Vec::new();
88        let mut buffer = [0u8];
89        loop {
90            match self.read_exact(&mut buffer) {
91                Ok(_) => out.push(buffer[0]),
92                Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(out),
93                Err(e) => return Err(PwnIoError::Io(e)),
94            }
95        }
96    }
97
98    fn receive_count(&mut self, count: usize) -> Result<Vec<u8>, PwnIoError> {
99        log::debug!("Receiving {:x} bytes", count);
100        let mut out = Vec::with_capacity(count);
101        self.read_exact(&mut out)?;
102        Ok(out)
103    }
104}
105
106/// Trait for write IO operations
107pub trait PwnIoWrite {
108    /// Attempts to write data to the stream
109    fn send(&mut self, data: &[u8]) -> Result<(), std::io::Error>;
110    /// Attempts to write data to the stream followed by a newline character
111    fn send_line(&mut self, data: &[u8]) -> Result<(), std::io::Error>;
112}
113
114impl<T> PwnIoWrite for T
115where
116    T: Write,
117{
118    fn send_line(&mut self, data: &[u8]) -> Result<(), std::io::Error> {
119        self.write_all(data)?;
120        self.write_all(b"\n")?;
121        Ok(())
122    }
123
124    fn send(&mut self, data: &[u8]) -> Result<(), std::io::Error> {
125        self.write_all(data)?;
126        Ok(())
127    }
128}
129
130pub trait PwnIoInteractive {
131    /// Attempts to enter into interactive mode via stdin/stdout
132    fn interactive(&mut self) -> Result<(), std::io::Error>;
133}
134
135impl PwnIoInteractive for TcpStream {
136    fn interactive(&mut self) -> Result<(), std::io::Error> {
137        log::debug!("Entering interactive");
138        let mut buffer = [0u8; 2048];
139        let mut send_buffer: Option<u8> = None;
140
141        let (s, r) = std::sync::mpsc::channel();
142
143        thread::spawn(move || {
144            let mut buffer = [0; 2048];
145            let mut stdin = stdin();
146            'outer_loop: loop {
147                let read_amount = stdin.read(&mut buffer);
148
149                match read_amount {
150                    Ok(read_amount) if read_amount > 0 => {
151                        for x in buffer.iter().copied().take(read_amount) {
152                            s.send(Ok(x)).unwrap();
153                        }
154                        buffer = [0; 2048];
155                    }
156                    Ok(_) => {}
157                    Err(x) => {
158                        s.send(Err(x)).unwrap();
159                        break 'outer_loop;
160                    }
161                };
162            }
163        });
164
165        'outer_loop: loop {
166            self.set_read_timeout(Some(Duration::from_millis(200)))?;
167            self.set_write_timeout(Some(Duration::from_millis(200)))?;
168            let read_amount = match self.read(&mut buffer) {
169                Ok(x) => x,
170                Err(x) if x.kind() == ErrorKind::WouldBlock => 0,
171                Err(x) => return Err(x),
172            };
173
174            if read_amount > 0 {
175                print!("{}", String::from_utf8_lossy(&buffer[0..read_amount]));
176                stdout().flush()?;
177            }
178
179            match send_buffer {
180                Some(x) => match self.write(&[x]) {
181                    Ok(_) => {
182                        send_buffer = None;
183                    }
184                    Err(err) if err.kind() == ErrorKind::WouldBlock => {}
185                    Err(err) => return Err(err),
186                },
187                None => 'inner_loop: loop {
188                    match r.try_recv() {
189                        Ok(Ok(x)) => match self.write(&[x]) {
190                            Ok(_) => {}
191                            Err(err) if err.kind() == ErrorKind::WouldBlock => {
192                                send_buffer = Some(x);
193                                break 'inner_loop;
194                            }
195                            Err(err) => return Err(err),
196                        },
197                        Ok(Err(_)) | Err(TryRecvError::Disconnected) => {
198                            break 'outer_loop;
199                        }
200                        Err(TryRecvError::Empty) => {
201                            break 'inner_loop;
202                        }
203                    }
204                },
205            };
206        }
207
208        Ok(())
209    }
210}