portForwarder 0.1.3

a TCP/UDP port multiplexer which forwards connection based on content of traffic
Documentation
use mio::net::{TcpListener, TcpStream};
use mio::{Events, Interest, Poll, Token};
use ntest::timeout;
use portforwarder::forward_config::ForwardSessionConfig;
use portforwarder::tcp_forwarder::TcpForwarder;
use rand::Rng;
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::io::{self, Read, Write};
use std::net::ToSocketAddrs;
use std::rc::Rc;
use std::sync::{Arc, atomic::AtomicBool};
use std::time::Duration;

fn tcp_sender<T: ToSocketAddrs>(addr: T, finished: Arc<AtomicBool>) {
    let server_addr = addr.to_socket_addrs().unwrap().next().unwrap();
    let mut stream = TcpStream::connect(server_addr).expect("Failed to connect");

    let mut poll = Poll::new().unwrap();
    let token = Token(0);
    poll.registry()
        .register(&mut stream, token, Interest::WRITABLE | Interest::READABLE)
        .unwrap();

    let mut rng = rand::thread_rng();
    let mut buf_storage = Vec::new();
    let target_bytes = 1024 * 100;
    let mut send_bytes = 0;
    let mut recv_bytes = 0;
    let mut events = Events::with_capacity(1024);
    let mut buffer = vec![0; 1024];

    while recv_bytes < target_bytes {
        poll.poll(&mut events, Some(Duration::from_secs(1)))
            .unwrap();

        for event in events.iter() {
            if event.token() != Token(0) {
                continue;
            }

            if event.is_writable() && send_bytes < target_bytes {
                let to_send = std::cmp::min(buffer.len(), target_bytes - send_bytes);
                rng.fill(&mut buffer[..to_send]);

                match stream.write(&buffer[..to_send]) {
                    Ok(n) => {
                        buf_storage.extend_from_slice(&buffer[..n]);
                        send_bytes += n;
                        println!("TCP sender sent {} bytes", n);
                    }
                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
                    Err(e) => panic!("Failed to send data: {}", e),
                }

                if send_bytes >= target_bytes {
                    stream.shutdown(std::net::Shutdown::Write).unwrap();
                }
            }

            if event.is_readable() {
                let mut recv_buffer = [0; 1024];
                let old_n = recv_bytes;
                loop {
                    match stream.read(&mut recv_buffer) {
                        Ok(0) => {
                            if old_n == recv_bytes {
                                // panic!("Connection closed prematurely");
                            }
                            break;
                        }
                        Ok(n) => {
                            assert!(buf_storage.len() >= recv_bytes + n);
                            assert_eq!(&buf_storage[recv_bytes..recv_bytes + n], &recv_buffer[..n]);
                            recv_bytes += n;
                            println!("TCP sender received {} bytes (total {})", n, recv_bytes);
                        }
                        Err(e) if e.kind() == io::ErrorKind::WouldBlock => break,
                        Err(e) => panic!("Failed to read: {}", e),
                    }
                }
            }
        }
    }

    finished.store(true, std::sync::atomic::Ordering::SeqCst);
}

fn tcp_echo<T: ToSocketAddrs>(listen_addr: T, finished: Arc<AtomicBool>) {
    let addr = listen_addr.to_socket_addrs().unwrap().next().unwrap();
    let mut listener = TcpListener::bind(addr).expect("Failed to bind TCP listener");

    let mut poll = Poll::new().unwrap();
    let mut connections = HashMap::new();
    let mut finished_connections = HashSet::new();
    let mut next_token = 1;
    let mut recieved_bytes: u64 = 0;
    let mut send_bytes: u64 = 0;

    poll.registry()
        .register(&mut listener, Token(0), Interest::READABLE)
        .unwrap();

    let mut events = Events::with_capacity(1024);

    while !finished.load(std::sync::atomic::Ordering::SeqCst) {
        poll.poll(&mut events, Some(Duration::from_millis(100)))
            .unwrap();

        for event in events.iter() {
            match event.token() {
                Token(0) => loop {
                    match listener.accept() {
                        Ok((mut stream, _)) => {
                            let token = Token(next_token);
                            next_token += 1;

                            poll.registry()
                                .register(&mut stream, token, Interest::READABLE)
                                .unwrap();

                            connections.insert(token, (Rc::new(RefCell::new(stream)), Vec::new()));
                        }
                        Err(e) if e.kind() == io::ErrorKind::WouldBlock => break,
                        Err(e) => panic!("Failed to accept: {}", e),
                    }
                },
                token => {
                    if event.is_readable() {
                        let (stream, write_buf) = match connections.get_mut(&token) {
                            Some(c) => c,
                            None => continue,
                        };

                        let mut read_buf = [0; 1024];
                        let mut remove_stream = false;
                        loop {
                            match stream.borrow_mut().read(&mut read_buf) {
                                Ok(0) => {
                                    remove_stream = true;
                                    break;
                                }
                                Ok(n) => {
                                    recieved_bytes += n as u64;
                                    write_buf.extend_from_slice(&read_buf[..n]);
                                }
                                Err(e) if e.kind() == io::ErrorKind::WouldBlock => break,
                                Err(e) => panic!("Failed to read: {}", e),
                            }
                        }

                        if !write_buf.is_empty() {
                            poll.registry()
                                .reregister(
                                    &mut *stream.borrow_mut(),
                                    token,
                                    Interest::READABLE | Interest::WRITABLE,
                                )
                                .unwrap();
                        }

                        if remove_stream {
                            finished_connections.insert(token);
                            poll.registry()
                                .reregister(&mut *stream.borrow_mut(), token, Interest::WRITABLE)
                                .unwrap();
                        }
                    }

                    if event.is_writable() {
                        let (stream, write_buf) = match connections.get_mut(&token) {
                            Some(c) => c,
                            None => continue,
                        };

                        let mut written = 0;
                        while written < write_buf.len() {
                            match stream.borrow_mut().write(&write_buf[written..]) {
                                Ok(n) => written += n,
                                Err(e) if e.kind() == io::ErrorKind::WouldBlock => break,
                                Err(e) => panic!("Failed to write: {}", e),
                            }
                        }
                        send_bytes += written as u64;
                        write_buf.drain(0..written);

                        if write_buf.is_empty() {
                            poll.registry()
                                .reregister(&mut *stream.borrow_mut(), token, Interest::READABLE)
                                .unwrap();

                            if finished_connections.contains(&token) {
                                connections.remove(&token);
                                finished_connections.remove(&token);
                            }
                        }
                    }
                }
            }
        }
    }

    assert_eq!(recieved_bytes, send_bytes);
}

fn run_tcp_forwarder(finished: Arc<AtomicBool>) {
    let remote_map = vec![(".*".to_string(), "localhost:32345".to_string())];
    let config = ForwardSessionConfig {
        local: "localhost:33833",
        remoteMap: remote_map,
        enable_tcp: true,
        enable_udp: false,
        conn_bufsize: 1024 * 1024,
        allow_nets: vec!["127.0.0.1/24".to_string(), "::1/128".to_string()],
        max_connections: 10,
    };
    let forwarder = TcpForwarder::from(&config).unwrap();
    forwarder.listen(finished).unwrap();
}

#[test]
#[timeout(8000)]
fn test_tcp_forwarder() {
    env_logger::init_from_env(
        env_logger::Env::default().filter_or(env_logger::DEFAULT_FILTER_ENV, "debug"),
    );

    let finished: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
    let lx1 = finished.clone();
    let forwarder_thread = std::thread::spawn(move || run_tcp_forwarder(lx1));
    std::thread::sleep(Duration::from_millis(200));

    let lx2 = finished.clone();
    let echo_thread = std::thread::spawn(move || tcp_echo("localhost:32345", lx2));
    let lx3 = finished.clone();
    let sender_thread = std::thread::spawn(move || tcp_sender("localhost:33833", lx3));

    sender_thread.join().unwrap();
    echo_thread.join().unwrap();
    forwarder_thread.join().unwrap();

    assert!(finished.load(std::sync::atomic::Ordering::SeqCst));
}