stablessh 0.1.1

Keeps SSH on even when laptop is closed
Documentation
use crate::{pool, queue, utils};
use anyhow::Result;
use clap::Parser;
use core::time;
use std::{net::SocketAddr, sync::Arc, time::Duration};
use tokio::sync::{Mutex, RwLock};

#[derive(Parser, Debug, Clone)]
#[clap(name = "server")]
pub struct Opt {
    #[clap(long = "idle", short = 'i', default_value = "3")]
    idle: u64,

    #[clap(long = "keepalive", short = 'k', default_value = "1")]
    keepalive: u64,

    #[clap(long = "bufsize", short = 'b', default_value = "18")]
    bufsize: u8,

    #[clap(long = "hold-timeout", short = 't', default_value = "604800")]
    hold_timeout: u64,

    #[clap(long = "hold-collect-interval", short = 'c', default_value = "60")]
    hold_collect_interval: u64,

    #[clap(long = "listen", short = 'l', default_value = "0.0.0.0:2222")]
    listen: SocketAddr,

    #[clap(long = "forward", short = 'f', default_value = "localhost:22")]
    forward: String,
}

pub async fn run(opt: Opt) -> Result<()> {
    let (cert_der, priv_key) = utils::gen_cert()?;
    let mut server_crypto = rustls::ServerConfig::builder()
        .with_safe_defaults()
        .with_client_cert_verifier(utils::SkipClientVerification::new())
        .with_single_cert(
            vec![rustls::Certificate(cert_der.clone())],
            rustls::PrivateKey(priv_key),
        )?;
    server_crypto.alpn_protocols = vec![b"stablessh".to_vec()];

    let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_crypto));
    let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
    transport_config.max_concurrent_uni_streams(0_u8.into());
    if opt.idle > 0 {
        transport_config.max_idle_timeout(Some(Duration::from_secs(opt.idle).try_into()?));
    }
    if opt.keepalive > 0 {
        transport_config.keep_alive_interval(Some(Duration::from_secs(opt.keepalive)));
    }

    let endpoint = quinn::Endpoint::server(server_config, opt.listen)?;
    accept_loop(opt, endpoint.clone()).await?;

    endpoint.close(0_u8.into(), b"");
    endpoint.wait_idle().await;

    Ok(())
}

async fn accept_loop(opt: Opt, endpoint: quinn::Endpoint) -> Result<()> {
    let conn_pool = pool::ConnPool::new(opt.hold_timeout);
    pool::collect_loop(
        conn_pool.clone(),
        time::Duration::from_secs(opt.hold_collect_interval),
    );
    tokio::spawn(async move {
        while let Some(conn) = endpoint.accept().await {
            let fut = handle_connection(opt.clone(), conn_pool.clone(), conn);
            tokio::spawn(async move {
                match fut.await {
                    Ok(_) => {}
                    Err(e) => {
                        log::error!("Connection error: {:?}", e);
                    }
                }
            });
        }
    });
    utils::stop_signal_wait().await;
    Ok(())
}

async fn handle_connection(
    opt: Opt,
    mut conn_pool: pool::ConnPool<(
        Arc<Mutex<tokio::net::TcpStream>>,
        Arc<Mutex<queue::Queue>>,
        Arc<RwLock<u32>>,
    )>,
    conn: quinn::Connecting,
) -> Result<()> {
    let conn = conn.await?;
    let pubkey = utils::x509pubkey(
        &conn
            .peer_identity()
            .unwrap()
            .downcast::<Vec<rustls::Certificate>>()
            .unwrap()
            .first()
            .unwrap(),
    )?;
    let (ssh_conn, q, last_ack) = match conn_pool.get(pubkey.clone()).await {
        Some(v) => {
            log::debug!("Reusing connection for {:?}", pubkey);
            v
        }
        None => {
            log::debug!("Creating new connection for {:?}", pubkey);
            let ssh_conn = Arc::new(Mutex::new(
                tokio::net::TcpStream::connect(opt.forward).await?,
            ));
            let q = Arc::new(Mutex::new(queue::Queue::new(opt.bufsize)));
            let last_ack = Arc::new(RwLock::new(0_u32));

            conn_pool
                .insert(pubkey.clone(), (ssh_conn, q, last_ack))
                .await
                .unwrap()
        }
    };

    let mut ssh_conn = ssh_conn.lock().await;
    let (ssh_recv, ssh_send) = ssh_conn.split();
    let _handle = conn_pool.hold(pubkey.clone()).await;
    utils::handle_connection(conn, q, last_ack, ssh_recv, ssh_send).await?;

    Ok(())
}