1use futures::future::poll_fn;
2use tokio::{
3 io::{self, AsyncReadExt, AsyncWriteExt, Interest},
4 net::{TcpListener, TcpStream},
5};
6
7use crate::{config::Config, Request};
8
9pub struct Server {
11 pub config: Config,
12}
13
14impl Server {
15 pub fn new(config: Config) -> Self {
16 Self { config }
17 }
18 pub async fn start(self) -> Result<(), Box<dyn std::error::Error>> {
33 let address = format!("{}:{}", self.config.host, self.config.port);
34 let listener = TcpListener::bind(address.clone()).await?;
35 println!("Serving on {}", address);
36 loop {
37 let mut client = listener.accept().await?.0;
38 let config = self.config.clone();
39 tokio::task::spawn(async move { handle_client(&mut client, config).await });
40 }
41 }
42}
43
44async fn handle_client(
46 client: &mut TcpStream,
47 config: Config,
48) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
49 let request_buffer = read_stream(client).await?;
50 println!(
51 "******************* Request Received *****************\n{}\n",
52 String::from_utf8_lossy(&request_buffer).trim()
53 );
54 let request = Request::from(request_buffer);
55 connect_and_handle_client_request(client, request, &config).await?;
57 Ok(())
58}
59
60pub async fn read_stream(
62 stream: &mut TcpStream,
63) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
64 let mut buffer: Vec<u8> = Vec::new();
65 let ready = stream.ready(Interest::READABLE).await?;
66 if ready.is_readable() {
69 let buffer_size: usize = 1024;
70 loop {
71 let mut fixed_buffer = vec![0; buffer_size];
72 match stream.read(&mut fixed_buffer).await {
73 Ok(n) if n == 0 => break,
74 Ok(n) if n < buffer_size => {
75 buffer.append(&mut fixed_buffer[..n].to_vec());
76 break;
77 }
78 Ok(_) => {
79 buffer.append(&mut fixed_buffer);
80 }
81 Err(e) => {
82 println!("Error in reading stram data: {}", e);
83 break;
84 }
85 }
86 }
87 }
88 Ok(buffer)
89}
90
91async fn connect_and_handle_client_request(
93 client: &mut TcpStream,
94 request: Request,
95 config: &Config,
96) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
97 println!("Handling client request....");
98 let address = if config.enable_proxy {
99 format!("{}:{}", config.proxy_host, config.proxy_port)
100 } else {
101 format!("{}:{}", request.host, request.port)
102 };
103 println!("Connecting to the remote host ({})", address);
104 let mut remote = TcpStream::connect(address.clone()).await?;
105 println!("Connected to the remote host ({})", address);
106 match request.method.as_str() {
107 "CONNECT" => handle_connect(client, request, &mut remote).await?,
108 _ => handle_default(client, request, &mut remote).await?,
109 }
110
111 println!("******** Complete Response sent to the client ********\n");
116
117 Ok(())
118}
119
120async fn handle_default(
122 client: &mut TcpStream,
123 request: Request,
124 remote: &mut TcpStream,
125) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
126 println!("Handling non-HTTPS request....");
127 match remote.write(&request.raw_data).await {
128 Ok(n) => println!(
129 "Wrote {} bytes and data to remote: {:?}",
130 n,
131 String::from_utf8_lossy(&request.raw_data)
132 ),
133 Err(e) => println!("Write error in remote: {}", e),
134 }
135 match read_stream(remote).await {
136 Ok(response) => {
137 println!(
138 "Read {} bytes and data from server: {:?}",
139 response.len(),
140 String::from_utf8_lossy(&response)
141 );
142 match client.write(&response).await {
143 Ok(n) => println!(
144 "Wrote {} bytes and data to client: {:?}",
145 n,
146 String::from_utf8_lossy(&response)
147 ),
148 Err(e) => println!("Write error in client: {}", e),
149 }
150 }
151 Err(e) => println!("Write error in client: {}", e),
152 }
153 Ok(())
154}
155
156async fn handle_connect(
160 client: &mut TcpStream,
161 request: Request,
162 remote: &mut TcpStream,
163) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
164 println!("Handling HTTPS request....");
165 let empty_response = format!("{} 200 Connection Established\r\n\r\n", request.version);
166 println!(
167 "********** Sending Response to client **********\n{}",
168 empty_response.trim()
169 );
170 client.write_all(empty_response.as_bytes()).await?;
171
172 let (mut cr, mut cw) = client.split();
183 let (mut rr, mut rw) = remote.split();
184
185 let client_to_remote = async {
186 let mut buffer = vec![0; 8096];
187 let mut read_half = tokio::io::ReadBuf::new(&mut buffer);
188 let _peeked_data_len = poll_fn(|cx| cr.poll_peek(cx, &mut read_half)).await?;
189 io::copy(&mut cr, &mut rw).await?;
194 rw.shutdown().await
195 };
196
197 let remote_to_client = async {
198 let mut buffer = vec![0; 8096];
199 let mut read_half = tokio::io::ReadBuf::new(&mut buffer);
200 let _peeked_data_len = poll_fn(|cx| rr.poll_peek(cx, &mut read_half)).await?;
201 io::copy(&mut rr, &mut cw).await?;
206 cw.shutdown().await
207 };
208
209 tokio::try_join!(client_to_remote, remote_to_client)?;
210 Ok(())
214}