1use std::{
2 io::{stdin, stdout, ErrorKind, Read, Write},
3 net::TcpStream,
4 sync::mpsc::TryRecvError,
5 thread,
6 time::Duration,
7};
8use thiserror::Error;
9
10#[derive(Debug, Error)]
11pub enum PwnIoError {
13 #[error("IO: {0}")]
14 Io(std::io::Error),
15 #[error("End of file")]
16 Eof,
17}
18
19impl From<std::io::Error> for PwnIoError {
20 fn from(x: std::io::Error) -> Self {
21 match x.kind() {
22 std::io::ErrorKind::UnexpectedEof => Self::Eof,
23 _ => Self::Io(x),
24 }
25 }
26}
27
28pub trait PwnIoRead {
30 fn receive_until(&mut self, pattern: &[u8], trim_pattern: bool) -> Result<Vec<u8>, PwnIoError>;
32
33 fn receive_line(&mut self, trim_newline: bool) -> Result<Vec<u8>, PwnIoError>;
35
36 fn receive(&mut self) -> Result<Vec<u8>, PwnIoError>;
38
39 fn receive_all(&mut self) -> Result<Vec<u8>, PwnIoError>;
41
42 fn receive_count(&mut self, count: usize) -> Result<Vec<u8>, PwnIoError>;
44}
45
46impl<T> PwnIoRead for T
47where
48 T: Read,
49{
50 fn receive_until(&mut self, pattern: &[u8], trim_pattern: bool) -> Result<Vec<u8>, PwnIoError> {
51 log::debug!("Receiving until: {:?}", pattern);
52 let mut out = Vec::new();
53 let mut buffer = [0u8];
54 while !out.ends_with(pattern) {
55 self.read_exact(&mut buffer)?;
56 out.push(buffer[0]);
57 }
58
59 if trim_pattern {
60 for _ in 0..pattern.len() {
61 out.pop();
62 }
63 }
64 Ok(out)
65 }
66
67 fn receive_line(&mut self, trim_newline: bool) -> Result<Vec<u8>, PwnIoError> {
68 log::debug!("Receiving line");
69 self.receive_until(b"\n", trim_newline)
70 }
71
72 fn receive(&mut self) -> Result<Vec<u8>, PwnIoError> {
73 log::debug!("Receiving");
74 let mut out = Vec::new();
75 let mut buffer = [0u8];
76 loop {
77 if self.read(&mut buffer)? == 0 {
78 return Ok(out);
79 } else {
80 out.push(buffer[0]);
81 }
82 }
83 }
84
85 fn receive_all(&mut self) -> Result<Vec<u8>, PwnIoError> {
86 log::debug!("Receiving all");
87 let mut out = Vec::new();
88 let mut buffer = [0u8];
89 loop {
90 match self.read_exact(&mut buffer) {
91 Ok(_) => out.push(buffer[0]),
92 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(out),
93 Err(e) => return Err(PwnIoError::Io(e)),
94 }
95 }
96 }
97
98 fn receive_count(&mut self, count: usize) -> Result<Vec<u8>, PwnIoError> {
99 log::debug!("Receiving {:x} bytes", count);
100 let mut out = Vec::with_capacity(count);
101 self.read_exact(&mut out)?;
102 Ok(out)
103 }
104}
105
106pub trait PwnIoWrite {
108 fn send(&mut self, data: &[u8]) -> Result<(), std::io::Error>;
110 fn send_line(&mut self, data: &[u8]) -> Result<(), std::io::Error>;
112}
113
114impl<T> PwnIoWrite for T
115where
116 T: Write,
117{
118 fn send_line(&mut self, data: &[u8]) -> Result<(), std::io::Error> {
119 self.write_all(data)?;
120 self.write_all(b"\n")?;
121 Ok(())
122 }
123
124 fn send(&mut self, data: &[u8]) -> Result<(), std::io::Error> {
125 self.write_all(data)?;
126 Ok(())
127 }
128}
129
130pub trait PwnIoInteractive {
131 fn interactive(&mut self) -> Result<(), std::io::Error>;
133}
134
135impl PwnIoInteractive for TcpStream {
136 fn interactive(&mut self) -> Result<(), std::io::Error> {
137 log::debug!("Entering interactive");
138 let mut buffer = [0u8; 2048];
139 let mut send_buffer: Option<u8> = None;
140
141 let (s, r) = std::sync::mpsc::channel();
142
143 thread::spawn(move || {
144 let mut buffer = [0; 2048];
145 let mut stdin = stdin();
146 'outer_loop: loop {
147 let read_amount = stdin.read(&mut buffer);
148
149 match read_amount {
150 Ok(read_amount) if read_amount > 0 => {
151 for x in buffer.iter().copied().take(read_amount) {
152 s.send(Ok(x)).unwrap();
153 }
154 buffer = [0; 2048];
155 }
156 Ok(_) => {}
157 Err(x) => {
158 s.send(Err(x)).unwrap();
159 break 'outer_loop;
160 }
161 };
162 }
163 });
164
165 'outer_loop: loop {
166 self.set_read_timeout(Some(Duration::from_millis(200)))?;
167 self.set_write_timeout(Some(Duration::from_millis(200)))?;
168 let read_amount = match self.read(&mut buffer) {
169 Ok(x) => x,
170 Err(x) if x.kind() == ErrorKind::WouldBlock => 0,
171 Err(x) => return Err(x),
172 };
173
174 if read_amount > 0 {
175 print!("{}", String::from_utf8_lossy(&buffer[0..read_amount]));
176 stdout().flush()?;
177 }
178
179 match send_buffer {
180 Some(x) => match self.write(&[x]) {
181 Ok(_) => {
182 send_buffer = None;
183 }
184 Err(err) if err.kind() == ErrorKind::WouldBlock => {}
185 Err(err) => return Err(err),
186 },
187 None => 'inner_loop: loop {
188 match r.try_recv() {
189 Ok(Ok(x)) => match self.write(&[x]) {
190 Ok(_) => {}
191 Err(err) if err.kind() == ErrorKind::WouldBlock => {
192 send_buffer = Some(x);
193 break 'inner_loop;
194 }
195 Err(err) => return Err(err),
196 },
197 Ok(Err(_)) | Err(TryRecvError::Disconnected) => {
198 break 'outer_loop;
199 }
200 Err(TryRecvError::Empty) => {
201 break 'inner_loop;
202 }
203 }
204 },
205 };
206 }
207
208 Ok(())
209 }
210}