masquerade_proxy/
client.rs

1use quiche;
2use quiche::h3::{NameValue, Header};
3use ring::rand::*;
4
5use std::future::Future;
6use std::net::{ToSocketAddrs, SocketAddr};
7use std::collections::HashMap;
8use std::error::{Error, self};
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11
12use tokio::io::{AsyncWriteExt, AsyncReadExt};
13use tokio::net::{UdpSocket, TcpStream, TcpListener};
14use tokio::sync::mpsc::{self, UnboundedSender, UnboundedReceiver};
15use tokio::time;
16
17use log::*;
18
19use crate::common::*;
20
21#[derive(Debug)]
22enum Content {
23    Request {
24        headers: Vec<quiche::h3::Header>,
25        stream_id_sender: mpsc::Sender<u64>,
26    },
27    Headers {
28        headers: Vec<quiche::h3::Header>,
29    },
30    Data {
31        data: Vec<u8>,
32    },
33    Datagram {
34        payload: Vec<u8>,
35    },
36    Finished,
37}
38
39#[derive(Debug)]
40struct ToSend {
41    stream_id: u64, // or flow_id for DATAGRAM
42    content: Content,
43    finished: bool,
44}
45
46#[derive(Debug, Clone)]
47struct RunBeforeBindError;
48
49impl std::fmt::Display for RunBeforeBindError {
50    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
51        write!(f, "bind(listen_addr) has to be called before run()")
52    }
53}
54impl Error for RunBeforeBindError {}
55
56struct Client {
57    listener: Option<TcpListener>,
58}
59
60
61impl Client {
62    pub fn new() -> Client {
63        Client { listener: None }
64    }
65
66    /**
67     * returns None if client is not bound to a socket yet
68     */
69    pub fn listen_addr(&self) -> Option<SocketAddr> {
70        return self.listener.as_ref().map(|listener| listener.local_addr().unwrap())
71    }
72
73    /**
74     * Bind the server to listen to an address
75     */
76    pub async fn bind<T: tokio::net::ToSocketAddrs>(&mut self, bind_addr: T) -> Result<(), Box<dyn Error>> {
77        debug!("creating TCP listener");
78
79        let mut listener = TcpListener::bind(bind_addr).await?;
80        debug!("listening on {}", listener.local_addr().unwrap());
81        
82        self.listener = Some(listener);
83        Ok(())
84    }
85    
86    /**
87     * Run client to receive TCP connections at the binded address, and handle 
88     * incoming streams with stream_handler (e.g. handshake, negotiation, proxying traffic)
89     * 
90     * This enables any protocol that accepts TCP connection to start with, such as HTTP1.1
91     * CONNECT and SOCKS5 as implemented below. Similarly, UDP listening can be easily 
92     * added if necessary.
93     */
94    pub async fn run<F, Fut>(&mut self, server_addr: &String, mut stream_handler: F) -> Result<(), Box<dyn Error>> 
95    where
96        F: FnMut(TcpStream, UnboundedSender<ToSend>, Arc<Mutex<HashMap<u64, UnboundedSender<Content>>>>, Arc<Mutex<HashMap<u64, UnboundedSender<Content>>>>) -> Fut,
97        Fut: Future<Output = ()> + Send + 'static,
98    {
99        if self.listener.is_none() {
100            return Err(Box::new(RunBeforeBindError));
101        }
102        let listener = self.listener.as_mut().unwrap();
103
104        let server_name = format!("https://{}", server_addr); // TODO: avoid duplicate https://
105    
106        // Resolve server address.
107        let url = url::Url::parse(&server_name).unwrap();
108        let peer_addr = url.to_socket_addrs().unwrap().next().unwrap();
109        
110        debug!("creating socket");
111        let socket = UdpSocket::bind("0.0.0.0:0".parse::<SocketAddr>().unwrap()).await?;
112        socket.connect(peer_addr.clone()).await?;
113        let socket = Arc::new(socket);
114        debug!("connecting to {} at {}", server_name, peer_addr);
115        
116    
117        let mut buf = [0; 65535];
118        let mut out = [0; MAX_DATAGRAM_SIZE];
119    
120        let mut config = quiche::Config::new(quiche::PROTOCOL_VERSION).unwrap();
121        // TODO: *CAUTION*: this should not be set to `false` in production!!!
122        config.verify_peer(false);
123    
124        config.set_application_protos(quiche::h3::APPLICATION_PROTOCOL).unwrap();
125        
126        config.set_max_idle_timeout(1000);
127        config.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE);
128        config.set_max_send_udp_payload_size(MAX_DATAGRAM_SIZE);
129        config.set_initial_max_data(10_000_000);
130        config.set_initial_max_stream_data_bidi_local(1_000_000);
131        config.set_initial_max_stream_data_bidi_remote(1_000_000);
132        config.set_initial_max_stream_data_uni(1_000_000);
133        config.set_initial_max_streams_bidi(100);
134        config.set_initial_max_streams_uni(100);
135        config.set_disable_active_migration(true);
136        config.enable_dgram(true, 1000, 1000); 
137
138    
139        let mut scid = [0; quiche::MAX_CONN_ID_LEN];
140        let rng = SystemRandom::new();
141        rng.fill(&mut scid[..]).unwrap();
142        let scid = quiche::ConnectionId::from_ref(&scid);
143        
144        // Client connection.
145        let local_addr = socket.local_addr().unwrap();
146        let mut conn = quiche::connect(url.domain(), &scid, local_addr, peer_addr, &mut config).expect("quic connection failed");
147        info!(
148            "connecting to {:} from {:} with scid {}",
149            peer_addr,
150            socket.local_addr().unwrap(),
151            hex_dump(&scid)
152        );
153    
154        let (write, send_info) = conn.send(&mut out).expect("initial send failed"); 
155        while let Err(e) = socket.send_to(&out[..write], send_info.to).await {
156            if e.kind() == std::io::ErrorKind::WouldBlock {
157                debug!("send_to() would block");
158                continue;
159            }
160            panic!("UDP socket send_to() failed: {:?}", e);
161        }
162        debug!("written {}", write);
163    
164        let mut http3_conn: Option<quiche::h3::Connection> = None;
165        let (http3_sender, mut http3_receiver) = mpsc::unbounded_channel::<ToSend>();
166        let connect_streams: Arc<Mutex<HashMap<u64, UnboundedSender<Content>>>> = Arc::new(Mutex::new(HashMap::new()));
167        let connect_sockets: Arc<Mutex<HashMap<u64, UnboundedSender<Content>>>> = Arc::new(Mutex::new(HashMap::new()));
168        let mut http3_retry_send: Option<ToSend> = None;
169        let mut interval = time::interval(Duration::from_millis(20));
170        interval.set_missed_tick_behavior(time::MissedTickBehavior::Delay);
171        loop {
172            if conn.is_closed() {
173                info!("connection closed, {:?}", conn.stats());
174                break;
175            }
176    
177            tokio::select! {
178                // handle QUIC received data
179                recvd = socket.recv_from(&mut buf) => {
180                    let (read, from) = match recvd {
181                        Ok(v) => v,
182                        Err(e) => {
183                            error!("error when reading from UDP socket");
184                            continue
185                        },
186                    };
187                    debug!("received {} bytes", read);
188                    let recv_info = quiche::RecvInfo {
189                        to: local_addr,
190                        from,
191                    };
192    
193                    // Process potentially coalesced packets.
194                    let read = match conn.recv(&mut buf[..read], recv_info) {
195                        Ok(v) => v,
196    
197                        Err(e) => {
198                            error!("QUIC recv failed: {:?}", e);
199                            continue
200                        },
201                    };
202                    debug!("processed {} bytes", read);
203    
204                    if let Some(http3_conn) = &mut http3_conn {
205                        // Process HTTP/3 events.
206                        loop {
207                            match http3_conn.poll(&mut conn) {
208                                Ok((stream_id, quiche::h3::Event::Headers { list, .. })) => {
209                                    info!("got response headers {:?} on stream id {}", hdrs_to_strings(&list), stream_id);
210                                    let connect_streams = connect_streams.lock().unwrap();
211                                    if let Some(sender) = connect_streams.get(&stream_id) {
212                                        sender.send(Content::Headers { headers: list });
213                                    }
214                                },
215            
216                                Ok((stream_id, quiche::h3::Event::Data)) => {
217                                    let connect_streams = connect_streams.lock().unwrap();
218                                    if let Some(sender) = connect_streams.get(&stream_id) {
219                                        while let Ok(read) = http3_conn.recv_body(&mut conn, stream_id, &mut buf) {
220                                            debug!("got {} bytes of response data on stream {}", read, stream_id);
221                                            trace!("{}", unsafe {std::str::from_utf8_unchecked(&buf[..read])});
222                                            sender.send(Content::Data { data: buf[..read].to_vec() });
223                                        }
224                                    }
225                                },
226            
227                                Ok((stream_id, quiche::h3::Event::Finished)) => {
228                                    info!("finished received, stream id: {} closing", stream_id);
229                                    let connect_streams = connect_streams.lock().unwrap();
230                                    if let Some(sender) = connect_streams.get(&stream_id) {
231                                        sender.send(Content::Finished {});
232                                    }
233                                },
234            
235                                Ok((stream_id, quiche::h3::Event::Reset(e))) => {
236                                    error!("request was reset by peer with {}, stream id: {} closed", e, stream_id);
237                                    let connect_streams = connect_streams.lock().unwrap();
238                                    if let Some(sender) = connect_streams.get(&stream_id) {
239                                        sender.send(Content::Finished {});
240                                    }
241                                },
242            
243                                Ok((flow_id, quiche::h3::Event::Datagram)) => {
244                                    debug!("got {} bytes of datagram on flow {}", read, flow_id);
245                                    let connect_sockets = connect_sockets.lock().unwrap();
246                                    if let Some(sender) = connect_sockets.get(&flow_id) {
247                                        match http3_conn.recv_dgram(&mut conn, &mut buf) {
248                                            Ok((read, recvd_flow_id, flow_id_len)) => {
249                                                debug!("got {} bytes of datagram on flow {}", read, flow_id);
250                                                assert_eq!(flow_id, recvd_flow_id, "flow id by recv_dgram does not match");
251                                                trace!("{}", unsafe {std::str::from_utf8_unchecked(&buf[flow_id_len..read])});
252                                                sender.send(Content::Datagram { payload: buf[flow_id_len..read].to_vec() });
253                                            },
254                                            Err(e) => {
255                                                error!("error recv_dgram(): {}", e);
256                                                break;
257                                            }
258                                        }
259                                    }
260                                },
261            
262                                Ok((_, quiche::h3::Event::PriorityUpdate)) => unreachable!(),
263            
264                                Ok((goaway_id, quiche::h3::Event::GoAway)) => {
265                                    info!("GOAWAY id={}", goaway_id);
266                                },
267            
268                                Err(quiche::h3::Error::Done) => {
269                                    break;
270                                },
271            
272                                Err(e) => {
273                                    error!("HTTP/3 processing failed: {:?}", e);
274            
275                                    break;
276                                },
277                            }
278                        }
279                    }
280                },
281                // Send pending HTTP3 data in channel to HTTP3 connection on QUIC
282                http3_to_send = http3_receiver.recv(), if http3_conn.is_some() && http3_retry_send.is_none() => {
283                    if http3_to_send.is_none() {
284                        unreachable!()
285                    }
286                    let mut to_send = http3_to_send.unwrap();
287                    let http3_conn = http3_conn.as_mut().unwrap();
288                    loop {
289                        let result = match &to_send.content {
290                            Content::Headers { .. } => unreachable!(),
291                            Content::Request { headers, stream_id_sender } => {
292                                debug!("sending http3 request {:?}", hdrs_to_strings(&headers));
293                                match http3_conn.send_request(&mut conn, headers, to_send.finished) {
294                                    Ok(stream_id) => {
295                                        stream_id_sender.send(stream_id).await;
296                                        Ok(())
297                                    },
298                                    Err(e) => {
299                                        error!("http3 request send failed");
300                                        Err(e)
301                                    },
302                                }
303                            },
304                            Content::Data { data } => {
305                                debug!("sending http3 data of {} bytes", data.len());
306                                let mut written = 0;
307                                loop {
308                                    if written >= data.len() {
309                                        break Ok(())
310                                    }
311                                    match http3_conn.send_body(&mut conn, to_send.stream_id, &data[written..], to_send.finished) {
312                                        Ok(v) => written += v,
313                                        Err(e) => {
314                                            to_send = ToSend { stream_id: to_send.stream_id, content: Content::Data { data: data[written..].to_vec() }, finished: to_send.finished };
315                                            break Err(e)
316                                        },
317                                    }
318                                    debug!("written http3 data {} of {} bytes", written, data.len());
319                                }
320                            },
321                            Content::Datagram { payload } => {
322                                debug!("sending http3 datagram of {} bytes", payload.len());
323                                http3_conn.send_dgram(&mut conn, to_send.stream_id, &payload)
324                            },
325                            Content::Finished => todo!(),
326                        };
327                        match result {
328                            Ok(_) => {},
329                            Err(quiche::h3::Error::StreamBlocked | quiche::h3::Error::Done) => {
330                                debug!("Connection {} stream {} stream blocked, retry later", conn.trace_id(), to_send.stream_id);
331                                http3_retry_send = Some(to_send);
332                                break; 
333                            },
334                            Err(e) => {
335                                error!("Connection {} stream {} send failed {:?}", conn.trace_id(), to_send.stream_id, e);
336                                conn.stream_shutdown(to_send.stream_id, quiche::Shutdown::Write, 0);
337                                {
338                                    let mut connect_streams = connect_streams.lock().unwrap();
339                                    connect_streams.remove(&to_send.stream_id);
340                                }
341                            }
342                        };
343                        to_send = match http3_receiver.try_recv() {
344                            Ok(v) => v,
345                            Err(e) => break,
346                        };
347                    }
348                },
349
350                // Accept a new TCP connection
351                tcp_accepted = listener.accept() => {
352                    match tcp_accepted {
353                        Ok((tcp_socket, addr)) => {
354                            debug!("accepted connection from {}", addr);
355                            tokio::spawn(stream_handler(tcp_socket, http3_sender.clone(), connect_streams.clone(), connect_sockets.clone()));
356                        },
357                        Err(_) => todo!(),
358                    };
359                },
360
361                // Retry sending in case of stream blocking
362                _ = interval.tick(), if http3_conn.is_some() && http3_retry_send.is_some() => {
363                    let mut to_send = http3_retry_send.unwrap();
364                    let http3_conn = http3_conn.as_mut().unwrap();
365                    let result = match &to_send.content {
366                        Content::Headers { .. } => unreachable!(),
367                        Content::Request { headers, stream_id_sender } => {
368                            debug!("retry sending http3 request {:?}", hdrs_to_strings(&headers));
369                            match http3_conn.send_request(&mut conn, headers, to_send.finished) {
370                                Ok(stream_id) => {
371                                    stream_id_sender.send(stream_id).await;
372                                    Ok(())
373                                },
374                                Err(e) => {
375                                    error!("http3 request send failed");
376                                    Err(e)
377                                },
378                            }
379                        },
380                        Content::Data { data } => {
381                            debug!("retry sending http3 data of {} bytes", data.len());
382                            let mut written = 0;
383                            loop {
384                                if written >= data.len() {
385                                    break Ok(())
386                                }
387                                match http3_conn.send_body(&mut conn, to_send.stream_id, &data[written..], to_send.finished) {
388                                    Ok(v) => written += v,
389                                    Err(e) => {
390                                        to_send = ToSend { stream_id: to_send.stream_id, content: Content::Data { data: data[written..].to_vec() }, finished: to_send.finished };
391                                        break Err(e)
392                                    },
393                                }
394                                debug!("written http3 data {} of {} bytes", written, data.len());
395                            }
396                        },
397                        Content::Datagram { payload } => {
398                            debug!("retry sending http3 datagram of {} bytes", payload.len());
399                            http3_conn.send_dgram(&mut conn, to_send.stream_id, &payload)
400                        },
401                        Content::Finished => todo!(),
402                    };
403                    match result {
404                        Ok(_) => {
405                            http3_retry_send = None;
406                        },
407                        Err(quiche::h3::Error::StreamBlocked | quiche::h3::Error::Done) => {
408                            debug!("Connection {} stream {} stream blocked, retry later", conn.trace_id(), to_send.stream_id);
409                            http3_retry_send = Some(to_send);
410                        },
411                        Err(e) => {
412                            error!("Connection {} stream {} send failed {:?}", conn.trace_id(), to_send.stream_id, e);
413                            conn.stream_shutdown(to_send.stream_id, quiche::Shutdown::Write, 0);
414                            {
415                                let mut connect_streams = connect_streams.lock().unwrap();
416                                connect_streams.remove(&to_send.stream_id);
417                            }
418                            http3_retry_send = None;
419                        }
420                    };
421                },
422    
423                else => break,
424            }
425            
426            // Create a new HTTP/3 connection once the QUIC connection is established.
427            if conn.is_established() && http3_conn.is_none() {
428                let h3_config = quiche::h3::Config::new().unwrap();
429                http3_conn = Some(
430                    quiche::h3::Connection::with_transport(&mut conn, &h3_config)
431                    .expect("Unable to create HTTP/3 connection, check the server's uni stream limit and window size"),
432                );
433            }
434        // Send pending QUIC packets
435            loop {
436                let (write, send_info) = match conn.send(&mut out) {
437                    Ok(v) => v,
438    
439                    Err(quiche::Error::Done) => {
440                        debug!("QUIC connection {} done writing", conn.trace_id());
441                        break;
442                    },
443    
444                    Err(e) => {
445                        error!("QUIC connection {} send failed: {:?}", conn.trace_id(), e);
446    
447                        conn.close(false, 0x1, b"fail").ok();
448                        break;
449                    },
450                };
451    
452                match socket.send_to(&out[..write], send_info.to).await {
453                    Ok(written) => debug!("{} written {} bytes out of {}", conn.trace_id(), written, write),
454                    Err(e) => panic!("UDP socket send_to() failed: {:?}", e),
455                }
456            }
457    
458        }
459    
460        Ok(())
461    }
462}
463
464async fn handle_http1_stream(mut stream: TcpStream, http3_sender: UnboundedSender<ToSend>, connect_streams: Arc<Mutex<HashMap<u64, UnboundedSender<Content>>>>, _connect_sockets: Arc<Mutex<HashMap<u64, UnboundedSender<Content>>>>) {
465    let mut buf = [0; 65535];
466    let mut pos = match stream.read(&mut buf).await {
467        Ok(v) => v,
468        Err(e) => {
469            error!("Error reading from TCP stream: {}", e);
470            return
471        },
472    };
473    loop {
474        match stream.try_read(&mut buf[pos..]) {
475            Ok(read) => pos += read,
476            Err(ref e) if would_block(e) => break,
477            Err(ref e) if interrupted(e) => continue,
478            Err(e) => {
479                error!("Error reading from TCP stream: {}", e);
480                return
481            }
482        };
483    }
484    let peer_addr = stream.peer_addr().unwrap();
485
486    let mut headers = [httparse::EMPTY_HEADER; 16];
487    let mut req = httparse::Request::new(&mut headers);
488    let res = req.parse(&buf[..pos]).unwrap();
489    if let Some(method) = req.method {
490        if let Some(path) = req.path {
491            if method.eq_ignore_ascii_case("CONNECT") {
492                // TODO: Check Host?
493                let headers = vec![
494                    quiche::h3::Header::new(b":method", b"CONNECT"),
495                    quiche::h3::Header::new(b":authority", path.as_bytes()),
496                    quiche::h3::Header::new(b":authorization", b"dummy-authorization"),    
497                ];
498                info!("sending HTTP3 request {:?}", headers);
499                let (stream_id_sender, mut stream_id_receiver) = mpsc::channel(1);
500                let (response_sender, mut response_receiver) = mpsc::unbounded_channel::<Content>();
501                http3_sender.send(ToSend { content: Content::Request { headers, stream_id_sender }, finished: false, stream_id: 0});
502                let stream_id = stream_id_receiver.recv().await.expect("stream_id receiver error");
503                {
504                    let mut connect_streams = connect_streams.lock().unwrap();
505                    connect_streams.insert(stream_id, response_sender); 
506                    // TODO: potential race condition: the response could be received before connect_streams is even inserted and get dropped
507                }
508
509                let response = response_receiver.recv().await.expect("http3 response receiver error");
510                if let Content::Headers { headers } = response {
511                    info!("Got response {:?}", hdrs_to_strings(&headers));
512                    let mut status = None;
513                    for hdr in headers {
514                        match hdr.name() {
515                            b":status" => status = Some(hdr.value().to_owned()),
516                            _ => (),
517                        }
518                    }
519                    if let Some(status) = status {
520                        if let Ok(status_str) = std::str::from_utf8(&status) {
521                            if let Ok(status_code) = status_str.parse::<i32>() {
522                                if status_code >= 200 && status_code < 300 {
523                                    info!("connection established, sending 200 OK");
524                                    stream.write(&b"HTTP/1.1 200 OK\r\n\r\n".to_vec()).await;
525                                }
526                            }
527                        }
528                    }
529                } else {
530                    error!("received others when expecting headers for connect");
531                }
532
533                let (mut read_half, mut write_half) = stream.into_split();
534                let http3_sender_clone = http3_sender.clone();
535                let read_task = tokio::spawn(async move {
536                    let mut buf = [0; 65535];
537                    loop {
538                        let read = match read_half.read(&mut buf).await {
539                            Ok(v) => v,
540                            Err(e) => {
541                                error!("Error reading from TCP {}: {}", peer_addr, e);
542                                break
543                            },
544                        };
545                        if read == 0 {
546                            debug!("TCP connection closed from {}", peer_addr);
547                            break
548                        }
549                        debug!("read {} bytes from TCP from {} for stream {}", read, peer_addr, stream_id);
550                        http3_sender_clone.send(ToSend { stream_id: stream_id, content: Content::Data { data: buf[..read].to_vec() }, finished: false });
551                    }
552                });
553                let write_task = tokio::spawn(async move {
554                    loop {
555                        let data = match response_receiver.recv().await {
556                            Some(v) => v,
557                            None => {
558                                debug!("TCP receiver channel closed for stream {}", stream_id);
559                                break
560                            },
561                        };
562                        match data {
563                            Content::Request { .. } => unreachable!(),
564                            Content::Headers { .. } => unreachable!(),
565                            Content::Data { data } => {
566                                let mut pos = 0;
567                                while pos < data.len() {
568                                    let bytes_written = match write_half.write(&data[pos..]).await {
569                                        Ok(v) => v,
570                                        Err(e) => {
571                                            error!("Error writing to TCP {} on stream id {}: {}", peer_addr, stream_id, e);
572                                            return
573                                        },
574                                    };
575                                    pos += bytes_written;
576                                }
577                                debug!("written {} bytes from TCP to {} for stream {}", data.len(), peer_addr, stream_id);
578                            },
579                            Content::Datagram { .. } => unreachable!(),
580                            Content::Finished => todo!(),
581                        };
582                        
583                    }
584                });
585                tokio::join!(read_task, write_task);
586                
587                {
588                    let mut connect_streams = connect_streams.lock().unwrap();
589                    connect_streams.remove(&stream_id);
590                }
591                return
592            }
593        }
594    }
595    stream.write(&b"HTTP/1.1 400 Bad Request\r\n\r\n".to_vec()).await;
596}
597
598pub struct Http1Client {
599    client: Client,
600}
601
602impl Http1Client {
603    pub fn new() -> Http1Client {
604        Http1Client { client: Client::new() }
605    }
606
607    pub fn listen_addr(&self) -> Option<SocketAddr> {
608        return self.client.listen_addr()
609    }
610
611    pub async fn bind<T: tokio::net::ToSocketAddrs>(&mut self, bind_addr: T) -> Result<(), Box<dyn Error>> {
612        self.client.bind(bind_addr).await
613    }
614
615    pub async fn run(&mut self, server_addr: &String) -> Result<(), Box<dyn Error>> {
616        self.client.run(server_addr, handle_http1_stream).await
617    }
618}
619
620async fn handle_socks5_stream(mut stream: TcpStream, http3_sender: UnboundedSender<ToSend>, connect_streams: Arc<Mutex<HashMap<u64, UnboundedSender<Content>>>>, connect_sockets: Arc<Mutex<HashMap<u64, UnboundedSender<Content>>>>) {
621    let peer_addr = stream.peer_addr().unwrap();
622    let hs_req = match socks5_proto::HandshakeRequest::read_from(&mut stream).await {
623        Ok(v) => v,
624        Err(e) => {
625            error!("socks5 handshake request read failed: {}", e);
626            return
627        }
628    };
629
630    if hs_req.methods.contains(&socks5_proto::HandshakeMethod::None) {
631        let hs_resp = socks5_proto::HandshakeResponse::new(socks5_proto::HandshakeMethod::None);
632        match hs_resp.write_to(&mut stream).await {
633            Ok(_) => {},
634            Err(e) => {
635                error!("socks5 handshake write response failed: {}", e);
636                return
637            }
638        };
639    } else {
640        error!("No available handshake method provided by client, currently only support no auth");
641        let hs_resp = socks5_proto::HandshakeResponse::new(socks5_proto::HandshakeMethod::Unacceptable);
642        match hs_resp.write_to(&mut stream).await {
643            Ok(_) => {},
644            Err(e) => {
645                error!("socks5 handshake write response failed: {}", e);
646                return
647            }
648        };
649        let _ = stream.shutdown().await;
650        return
651    }
652
653    let req = match socks5_proto::Request::read_from(&mut stream).await {
654        Ok(v) => v,
655        Err(e) => {
656            error!("socks5 request parse failed: {}", e);
657            let resp = socks5_proto::Response::new(socks5_proto::Reply::GeneralFailure, socks5_proto::Address::unspecified());
658            match resp.write_to(&mut stream).await {
659                Ok(_) => {},
660                Err(e) => {
661                    error!("socks5 write response failed: {}", e);
662                    return
663                }
664            };
665            let _ = stream.shutdown().await;
666            return
667        }
668    };
669
670    match req.command {
671        socks5_proto::Command::Connect => {
672            let path = socks5_addr_to_string(&req.address);
673            let headers = vec![
674                quiche::h3::Header::new(b":method", b"CONNECT"),
675                quiche::h3::Header::new(b":authority", path.as_bytes()),
676                quiche::h3::Header::new(b":authorization", b"dummy-authorization"),    
677            ];
678            info!("sending HTTP3 request {:?}", headers);
679            let (stream_id_sender, mut stream_id_receiver) = mpsc::channel(1);
680            let (response_sender, mut response_receiver) = mpsc::unbounded_channel::<Content>();
681            http3_sender.send(ToSend { content: Content::Request { headers, stream_id_sender }, finished: false, stream_id: 0});
682            let stream_id = stream_id_receiver.recv().await.expect("stream_id receiver error");
683            {
684                let mut connect_streams = connect_streams.lock().unwrap();
685                connect_streams.insert(stream_id, response_sender); 
686                // TODO: potential race condition: the response could be received before connect_streams is even inserted and get dropped
687            }
688
689            let response = response_receiver.recv().await.expect("http3 response receiver error");
690            let mut succeeded = false;
691            if let Content::Headers { headers } = response {
692                info!("Got response {:?}", hdrs_to_strings(&headers));
693                let mut status = None;
694                for hdr in headers {
695                    match hdr.name() {
696                        b":status" => status = Some(hdr.value().to_owned()),
697                        _ => (),
698                    }
699                }
700                if let Some(status) = status {
701                    if let Ok(status_str) = std::str::from_utf8(&status) {
702                        if let Ok(status_code) = status_str.parse::<i32>() {
703                            if status_code >= 200 && status_code < 300 {
704                                info!("connection established, sending OK socks response");
705                                let response = socks5_proto::Response::new(socks5_proto::Reply::Succeeded, socks5_proto::Address::unspecified());
706                                succeeded = true;
707                                match response.write_to(&mut stream).await {
708                                    Ok(_) => {},
709                                    Err(e) => {
710                                        error!("socks5 response write error: {}", e);
711                                        let _ = stream.shutdown().await;
712                                        return
713                                    }
714                                }
715                            }
716                        }
717                    }
718                }
719            } else {
720                error!("received others when expecting headers for connect");
721            }
722            if !succeeded {
723                error!("http3 CONNECT failed");
724                let response = socks5_proto::Response::new(socks5_proto::Reply::GeneralFailure, socks5_proto::Address::unspecified());
725                let _ = response.write_to(&mut stream).await;
726                let _ = stream.shutdown().await;
727                return
728            }
729
730            let (mut read_half, mut write_half) = stream.into_split();
731            let http3_sender_clone = http3_sender.clone();
732            let read_task = tokio::spawn(async move {
733                let mut buf = [0; 65535];
734                loop {
735                    let read = match read_half.read(&mut buf).await {
736                        Ok(v) => v,
737                        Err(e) => {
738                            error!("Error reading from TCP {}: {}", peer_addr, e);
739                            break
740                        },
741                    };
742                    if read == 0 {
743                        debug!("TCP connection closed from {}", peer_addr);
744                        break
745                    }
746                    debug!("read {} bytes from TCP from {} for stream {}", read, peer_addr, stream_id);
747                    http3_sender_clone.send(ToSend { stream_id: stream_id, content: Content::Data { data: buf[..read].to_vec() }, finished: false });
748                }
749            });
750            let write_task = tokio::spawn(async move {
751                loop {
752                    let data = match response_receiver.recv().await {
753                        Some(v) => v,
754                        None => {
755                            debug!("TCP receiver channel closed for stream {}", stream_id);
756                            break
757                        },
758                    };
759                    match data {
760                        Content::Request { .. } => unreachable!(),
761                        Content::Headers { .. } => unreachable!(),
762                        Content::Data { data } => {
763                            let mut pos = 0;
764                            while pos < data.len() {
765                                let bytes_written = match write_half.write(&data[pos..]).await {
766                                    Ok(v) => v,
767                                    Err(e) => {
768                                        error!("Error writing to TCP {} on stream id {}: {}", peer_addr, stream_id, e);
769                                        return
770                                    },
771                                };
772                                pos += bytes_written;
773                            }
774                            debug!("written {} bytes from TCP to {} for stream {}", data.len(), peer_addr, stream_id);
775                        },
776                        Content::Datagram { .. } => unreachable!(),
777                        Content::Finished => todo!(),
778                    };
779                    
780                }
781            });
782            tokio::join!(read_task, write_task);
783            
784            {
785                let mut connect_streams = connect_streams.lock().unwrap();
786                connect_streams.remove(&stream_id);
787            }
788        },
789        socks5_proto::Command::Associate => {
790            // NOTE: Currently do not support fragmentation
791            let mut local_addr = stream.local_addr().unwrap(); // bind on the same ip address of the tcp connection
792            local_addr.set_port(0); // let the OS assign a port
793            if let Ok(bind_socket) = UdpSocket::bind(local_addr).await { 
794                if let Ok(local_addr) = bind_socket.local_addr() {
795                    let response = socks5_proto::Response::new(socks5_proto::Reply::Succeeded, socks5_proto::Address::SocketAddress(local_addr));
796                    match response.write_to(&mut stream).await {
797                        Ok(_) => {},
798                        Err(e) => {
799                            error!("socks5 response write error: {}", e);
800                            let _ = stream.shutdown().await;
801                            return
802                        }
803                    }
804                    let bind_socket = Arc::new(bind_socket);
805                    let http3_sender_clone = http3_sender.clone();
806                    let listen_task = tokio::spawn(async move {
807                        let mut buf = [0; 65535];
808                        let mut dest_to_flow: HashMap<socks5_proto::Address, u64> = HashMap::new();
809                        loop {
810                            match bind_socket.recv_from(&mut buf).await {
811                                Ok((read, recv_addr)) => {
812                                    debug!("read {} bytes from UDP from {}", read, recv_addr);
813                                    let socks5_udp_header = match socks5_proto::UdpHeader::read_from(&mut &buf[..read]).await {
814                                        Ok(v) => v,
815                                        Err(e) => {
816                                            error!("udp socks5 socket received packet cannot be parsed: {}", e);
817                                            continue
818                                        },
819                                    };
820                                    let payload = &buf[socks5_udp_header.serialized_len()..read];
821                                    let flow_id = match dest_to_flow.get(&socks5_udp_header.address) {
822                                        Some(flow_id) => {
823                                            *flow_id
824                                        },
825                                        None => {
826                                            // New destination address to proxy, set up connect-udp flow
827                                            let path = socks5_addr_to_connect_udp_path(&socks5_udp_header.address);
828                                            let headers = vec![
829                                                quiche::h3::Header::new(b":method", b"CONNECT"),
830                                                quiche::h3::Header::new(b":path", path.as_bytes()),
831                                                quiche::h3::Header::new(b":protocol", b"connect-udp"),
832                                                quiche::h3::Header::new(b":scheme", b"dummy-scheme"),
833                                                quiche::h3::Header::new(b":authority", b"dummy-authority"),
834                                                quiche::h3::Header::new(b":authorization", b"dummy-authorization"),
835                                            ];
836                                            debug!("sending HTTP3 request {:?}", headers);
837                                            let (stream_id_sender, mut stream_id_receiver) = mpsc::channel(1);
838                                            let (stream_response_sender, mut stream_response_receiver) = mpsc::unbounded_channel::<Content>();
839                                            let (flow_response_sender, mut flow_response_receiver) = mpsc::unbounded_channel::<Content>();
840                                            http3_sender.send(ToSend { content: Content::Request { headers, stream_id_sender }, finished: false, stream_id: 0});
841                                            let stream_id = stream_id_receiver.recv().await.expect("stream_id receiver error");
842                                            let flow_id = stream_id / 4;
843                                            {
844                                                let mut connect_streams = connect_streams.lock().unwrap();
845                                                connect_streams.insert(stream_id, stream_response_sender); 
846                                                // TODO: potential race condition: the response could be received before connect_streams is even inserted and get dropped
847                                            }
848                                            {
849                                                let mut connect_sockets = connect_sockets.lock().unwrap();
850                                                connect_sockets.insert(flow_id, flow_response_sender); 
851                                            }
852                                            let mut succeeded = false;
853                                            let response = stream_response_receiver.recv().await.expect("http3 response receiver error");
854                                            if let Content::Headers { headers } = response {
855                                                debug!("Got response {:?}", hdrs_to_strings(&headers));
856                                                let mut status = None;
857                                                for hdr in headers {
858                                                    match hdr.name() {
859                                                        b":status" => status = Some(hdr.value().to_owned()),
860                                                        _ => (),
861                                                    }
862                                                }
863                                                if let Some(status) = status {
864                                                    if let Ok(status_str) = std::str::from_utf8(&status) {
865                                                        if let Ok(status_code) = status_str.parse::<i32>() {
866                                                            if status_code >= 200 && status_code < 300 {
867                                                                succeeded = true;
868                                                                debug!("UDP CONNECT connection established for flow {}", flow_id);
869                                                                dest_to_flow.insert(socks5_udp_header.address, flow_id);
870                                                            }
871                                                        }
872                                                    }
873                                                }
874                                            } else {
875                                                error!("received others when expecting headers for connect");
876                                            }
877                                            if !succeeded {
878                                                error!("http3 CONNECT UDP failed");
879                                                continue
880                                            }
881                                            let bind_socket_clone = bind_socket.clone();
882                                            let _write_task = tokio::spawn(async move {
883                                                loop {
884                                                    let data = match flow_response_receiver.recv().await {
885                                                        Some(v) => v,
886                                                        None => {
887                                                            debug!("receiver channel closed for flow {}", flow_id);
888                                                            break
889                                                        },
890                                                    };
891                                                    match data {
892                                                        Content::Request { .. } => unreachable!(),
893                                                        Content::Headers { .. } => unreachable!(),
894                                                        Content::Data { .. } => unreachable!(),
895                                                        Content::Datagram { payload } => {
896                                                            trace!("raw UDP datagram is {} bytes long", payload.len());
897                                                            let (context_id, payload) = decode_var_int(&payload);
898                                                            trace!("UDP datagram payload without context id is {} bytes long", payload.len());
899                                                            assert_eq!(context_id, 0, "received UDP Proxying Datagram with non-zero Context ID");
900                                
901                                                            let udp_header = socks5_proto::UdpHeader::new(0, socks5_proto::Address::SocketAddress("0.0.0.0:0".parse().unwrap()));
902                                                            trace!("appending SOCKS5 UDP request header of length {}", udp_header.serialized_len());
903                                                            let mut serialized_udp_header = Vec::new();
904                                                            udp_header.write_to_buf(&mut serialized_udp_header);
905                                                            trace!("SOCKS5 UDP request header: {:02x?}", serialized_udp_header);
906                                                            let payload = [&serialized_udp_header, payload].concat();
907                                                            trace!("start sending on UDP");
908                                                            let bytes_written = match bind_socket_clone.send_to(&payload, recv_addr).await {
909                                                                Ok(v) => v,
910                                                                Err(e) => {
911                                                                    error!("Error writing to UDP {} on flow id {}: {}", recv_addr, flow_id, e);
912                                                                    continue
913                                                                },
914                                                            };
915                                                            if bytes_written < payload.len() {
916                                                                debug!("Partially sent {} bytes of UDP packet of length {}", bytes_written, payload.len());
917                                                            }
918                                                            debug!("written {} bytes from UDP to {} for flow {}", payload.len(), recv_addr, flow_id);
919                                                        },
920                                                        Content::Finished => todo!(),
921                                                    };
922                                                    
923                                                }
924                                            });
925                                            flow_id
926                                        },
927                                    };
928                                    debug!("sending {} bytes of data to flow {}", payload.len(), flow_id);
929                                    let data = wrap_udp_connect_payload(0, payload);
930                                    http3_sender_clone.send(ToSend { stream_id: flow_id, content: Content::Datagram { payload: data }, finished: false });
931                                },
932                                Err(e) => {
933                                    error!("udp socks5 socket recv failed: {}", e);
934                                    break
935                                },
936                            }
937                        }
938                    });
939                    tokio::join!(listen_task);
940                }
941            }
942            // TODO: handle termination of UDP assoiciate correctly
943            
944            // {
945            //     let mut connect_sockets = connect_sockets.lock().unwrap();
946            //     connect_sockets.remove(&flow_id);
947            // }
948            // {
949            //     let mut connect_streams = connect_streams.lock().unwrap();
950            //     connect_streams.remove(&stream_id);
951            // }
952        },
953        socks5_proto::Command::Bind => unimplemented!(),
954    }
955
956
957    
958    
959    
960}
961
962pub struct Socks5Client {
963    client: Client,
964}
965
966impl Socks5Client {
967    pub fn new() -> Socks5Client {
968        Socks5Client { client: Client::new() }
969    }
970
971    pub fn listen_addr(&self) -> Option<SocketAddr> {
972        return self.client.listen_addr()
973    }
974
975    pub async fn bind<T: tokio::net::ToSocketAddrs>(&mut self, bind_addr: T) -> Result<(), Box<dyn Error>> {
976        self.client.bind(bind_addr).await
977    }
978
979    pub async fn run(&mut self, server_addr: &String) -> Result<(), Box<dyn Error>> {
980        self.client.run(server_addr, handle_socks5_stream).await
981    }
982}
983
984fn socks5_addr_to_string(addr: &socks5_proto::Address) -> String {
985    match addr {
986        socks5_proto::Address::SocketAddress(socket_addr) => socket_addr.to_string(),
987        socks5_proto::Address::DomainAddress(domain, port) => format!("{}:{}", domain, port),
988    }
989}
990
991/**
992 * RFC9298 specify connect-udp path should be a template like /.well-known/masque/udp/192.0.2.6/443/
993 */
994fn socks5_addr_to_connect_udp_path(addr: &socks5_proto::Address) -> String {
995    let (host, port) = match addr {
996        socks5_proto::Address::SocketAddress(socket_addr) => {
997            let ip_string = socket_addr.ip().to_string();
998            ip_string.replace(":", "%3A"); // encode ':' in IPv6 address in URI
999            (ip_string, socket_addr.port())
1000        },
1001        socks5_proto::Address::DomainAddress(domain, port) => (domain.to_owned(), port.to_owned()),
1002    };
1003    format!("/.well_known/masque/udp/{}/{}/", host, port)
1004}