kcptun-rust 1.1.0

A Rust implementation of kcptun, a fast and reliable tunnel based on KCP protocol
Documentation
use anyhow::Result;
use clap::Parser;
use kcptun_rust::{ClientConfig, CompStream, create_block_crypt, derive_key, wrap_with_qpp};
use rust_tokio_kcp::{KcpConfig, KcpNoDelayConfig, KcpStream};
use smux_rust::{client, Config as SmuxConfig, Session};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex;
use tracing::{error, info};

#[derive(Parser)]
#[command(name = "kcptun-client")]
#[command(about = "KCP tunnel client", version = "0.1.0")]
struct Args {
    #[arg(short = 'l', long, default_value = ":12948")]
    localaddr: String,
    
    #[arg(short = 'r', long, default_value = "vps:29900")]
    remoteaddr: String,
    
    #[arg(long, default_value = "it's a secrect")]
    key: String,
    
    #[arg(long, default_value = "aes")]
    crypt: String,
    
    #[arg(long, default_value = "fast")]
    mode: String,
    
    #[arg(long, default_value = "1")]
    conn: u32,
    
    #[arg(long, default_value = "0")]
    autoexpire: u32,
    
    #[arg(long, default_value = "600")]
    scavengettl: u32,
    
    #[arg(long, default_value = "1350")]
    mtu: u32,
    
    #[arg(long, default_value = "0")]
    ratelimit: u32,
    
    #[arg(long, default_value = "128")]
    sndwnd: u32,
    
    #[arg(long, default_value = "512")]
    rcvwnd: u32,
    
    #[arg(long, default_value = "10")]
    datashard: u32,
    
    #[arg(long, default_value = "3")]
    parityshard: u32,
    
    #[arg(long, default_value = "0")]
    dscp: u32,
    
    #[arg(long)]
    nocomp: bool,
    
    #[arg(long)]
    acknodelay: bool,
    
    #[arg(long, default_value = "2")]
    smuxver: u32,
    
    #[arg(long, default_value = "4194304")]
    smuxbuf: u32,
    
    #[arg(long, default_value = "8192")]
    framesize: u32,
    
    #[arg(long, default_value = "2097152")]
    streambuf: u32,
    
    #[arg(long, default_value = "10")]
    keepalive: u32,
    
    #[arg(long, default_value = "0")]
    closewait: u32,
    
    #[arg(short = 'c', long)]
    config: Option<String>,
    
    #[arg(long)]
    quiet: bool,
    
    #[arg(long)]
    qpp: bool,
    
    #[arg(long, default_value = "61")]
    qpp_count: u32,
}

struct TimedSession {
    session: Arc<Session>,
    expiry_date: Option<Instant>,
}

#[tokio::main]
async fn main() -> Result<()> {
    tracing_subscriber::fmt()
        .with_env_filter(
            tracing_subscriber::EnvFilter::try_from_default_env()
                .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info"))
        )
        .init();
    
    let args = Args::parse();
    
    let config = if let Some(config_path) = args.config {
        ClientConfig::from_json(config_path)?
    } else {
        let mut cfg = ClientConfig {
            localaddr: args.localaddr,
            remoteaddr: args.remoteaddr,
            key: args.key,
            crypt: args.crypt,
            mode: args.mode,
            conn: args.conn,
            autoexpire: args.autoexpire,
            scavengettl: args.scavengettl,
            mtu: args.mtu,
            ratelimit: args.ratelimit,
            sndwnd: args.sndwnd,
            rcvwnd: args.rcvwnd,
            datashard: args.datashard,
            parityshard: args.parityshard,
            dscp: args.dscp,
            nocomp: args.nocomp,
            acknodelay: args.acknodelay,
            nodelay: 0,
            interval: 50,
            resend: 0,
            nc: 0,
            sockbuf: 4194304,
            smuxver: args.smuxver,
            smuxbuf: args.smuxbuf,
            framesize: args.framesize,
            streambuf: args.streambuf,
            keepalive: args.keepalive,
            log: String::new(),
            snmplog: String::new(),
            snmpperiod: 60,
            quiet: args.quiet,
            tcp: false,
            pprof: false,
            qpp: args.qpp,
            qpp_count: args.qpp_count,
            closewait: args.closewait,
        };
        cfg.apply_mode();
        cfg
    };
    
    if config.conn == 0 {
        anyhow::bail!("conn must be greater than 0");
    }
    
    info!("version: 0.1.0");
    info!("listening on: {}", config.localaddr);
    info!("remote address: {}", config.remoteaddr);
    info!("encryption: {}", config.crypt);
    info!("mode: {}", config.mode);
    info!("nodelay parameters: {} {} {} {}", 
          config.nodelay, config.interval, config.resend, config.nc);
    info!("sndwnd: {}, rcvwnd: {}", config.sndwnd, config.rcvwnd);
    info!("compression: {}", !config.nocomp);
    info!("mtu: {}", config.mtu);
    info!("datashard: {}, parityshard: {}", config.datashard, config.parityshard);
    info!("conn: {}", config.conn);
    
    let key = derive_key(&config.key);
    
    let listener = TcpListener::bind(&config.localaddr).await?;
    info!("Listening on {}", listener.local_addr()?);
    
    let num_conn = config.conn as usize;
    let muxes: Vec<Arc<Mutex<Option<TimedSession>>>> = 
        (0..num_conn).map(|_| Arc::new(Mutex::new(None))).collect();
    let mut rr = 0usize;
    
    let config_arc = Arc::new(config);
    
    loop {
        let (stream, _addr) = listener.accept().await?;
        let idx = rr % num_conn;
        rr += 1;
        
        let mux_arc = muxes[idx].clone();
        let config_clone = config_arc.clone();
        let key_clone = key.clone();
        
        tokio::spawn(async move {
            let mut mux_guard = mux_arc.lock().await;
            
            let need_new = mux_guard.is_none() || 
                mux_guard.as_ref().unwrap().expiry_date
                    .map(|e| Instant::now() > e)
                    .unwrap_or(false);
            
            if need_new {
                if !config_clone.quiet {
                    info!("Creating new KCP connection...");
                }
                
                // 添加超时保护
                let create_result = tokio::time::timeout(
                    Duration::from_secs(10),
                    create_session(&config_clone, &key_clone)
                ).await;
                
                match create_result {
                    Ok(Ok(session)) => {
                        let expiry = if config_clone.autoexpire > 0 {
                            Some(Instant::now() + Duration::from_secs(config_clone.autoexpire as u64))
                        } else {
                            None
                        };
                        *mux_guard = Some(TimedSession { session, expiry_date: expiry });
                        if !config_clone.quiet {
                            info!("Created new KCP connection");
                        }
                    }
                    Ok(Err(e)) => {
                        error!("Failed to create session: {}", e);
                        return;
                    }
                    Err(_) => {
                        error!("Timeout creating KCP connection (10s)");
                        return;
                    }
                }
            }
            
            if let Some(ref timed) = *mux_guard {
                if let Err(e) = handle_client(stream, timed.session.clone(), &config_clone).await {
                    if !config_clone.quiet {
                        error!("Error handling client: {}", e);
                    }
                }
            } else {
                error!("No session available after creation attempt");
            }
        });
    }
}

async fn create_session(
    config: &ClientConfig,
    key: &[u8],
) -> Result<Arc<Session>> {
    // Parse remote address
    let addr: std::net::SocketAddr = config.remoteaddr.parse()?;
    
    // Create block cipher
    let block_crypt = create_block_crypt(&config.crypt, key)?;
    
    // Configure KCP
    let kcp_config = KcpConfig {
        mtu: config.mtu as usize,
        nodelay: KcpNoDelayConfig {
            nodelay: config.nodelay != 0,
            interval: config.interval as i32,
            resend: config.resend as i32,
            nc: config.nc != 0,
        },
        wnd_size: (config.sndwnd as u16, config.rcvwnd as u16),
        stream: true,
        flush_write: false,
        flush_acks_input: config.acknodelay,
        fec_data_shards: config.datashard as usize,
        fec_parity_shards: config.parityshard as usize,
        crypt: block_crypt,
        ..Default::default()
    };
    
    // Connect to server
    let kcp_stream = KcpStream::connect(&kcp_config, addr).await?;
    
    // Create SMUX session
    let smux_config = SmuxConfig {
        version: config.smuxver as u8,
        keep_alive_disabled: false,
        keep_alive_interval: Duration::from_secs(config.keepalive as u64),
        keep_alive_timeout: Duration::from_secs(config.keepalive as u64 * 3),
        max_frame_size: config.framesize as usize,
        max_receive_buffer: config.smuxbuf as usize,
        max_stream_buffer: config.streambuf as usize,
    };
    
    // 根据 nocomp 配置决定是否使用压缩层
    let session = if config.nocomp {
        // 不使用压缩,直接使用 KCP stream
        client(Box::new(kcp_stream), Some(smux_config)).await?
    } else {
        // 使用 Snappy 压缩
        let comp_stream = CompStream::new(kcp_stream);
        client(Box::new(comp_stream), Some(smux_config)).await?
    };
    Ok(session)
}

async fn handle_client(
    local: TcpStream,
    session: Arc<Session>,
    config: &ClientConfig,
) -> Result<()> {
    let remote = session.open_stream().await?;
    
    if !config.quiet {
        info!("Stream opened");
    }
    
    let (mut lr, mut lw) = tokio::io::split(local);
    let (rr, rw) = tokio::io::split(remote);
    let (mut rr, mut rw): (
        Box<dyn tokio::io::AsyncRead + Unpin + Send>,
        Box<dyn tokio::io::AsyncWrite + Unpin + Send>,
    ) = if config.qpp {
        let seed = config.key.as_bytes();
        let (r, w) = wrap_with_qpp(rr, rw, seed, config.qpp_count);
        (Box::new(r), Box::new(w))
    } else {
        (Box::new(rr), Box::new(rw))
    };

    let t1 = tokio::spawn(async move {
        let mut buf = vec![0u8; 8192];
        loop {
            match lr.read(&mut buf).await {
                Ok(0) => break,
                Ok(n) => {
                    if rw.write_all(&buf[..n]).await.is_err() {
                        break;
                    }
                    if rw.flush().await.is_err() {
                        break;
                    }
                }
                Err(_) => break,
            }
        }
    });
    
    let t2 = tokio::spawn(async move {
        let mut buf = vec![0u8; 8192];
        loop {
            match rr.read(&mut buf).await {
                Ok(0) => break,
                Ok(n) => {
                    if lw.write_all(&buf[..n]).await.is_err() {
                        break;
                    }
                    if lw.flush().await.is_err() {
                        break;
                    }
                }
                Err(_) => break,
            }
        }
    });
    
    let _ = tokio::join!(t1, t2);
    
    if !config.quiet {
        info!("Stream closed");
    }
    
    Ok(())
}