assistant_daemon 0.1.0

Daemon program for providing many features.
use crate::{
    config::{self, get_feature_dir, ConfigManager, FsConfigManager},
    feature::FeatureControl,
    generated::{SshLocalPortMapping, SshpmSetting},
};
use anyhow::{bail, Result};
use async_trait::async_trait;
use russh::{
    client::{self, Session},
    ChannelId,
};
use russh_keys::key;
use serde::{Deserialize, Serialize};
use std::{
    cell::RefCell, collections::HashMap, future::Future, pin::Pin, process::exit, sync::Arc,
};
use std::{error::Error, net::SocketAddr};
use tokio::{
    io::{AsyncReadExt, AsyncWriteExt},
    net::{tcp::WriteHalf, TcpListener, ToSocketAddrs},
    sync::oneshot,
};
use tokio::{sync::RwLock, task::JoinHandle};
use tracing::{debug, error, info};

/// Generate an unique ID for a local port mapping task
fn get_lmp_id(
    jump_host: &String,
    remote_host: &String,
    remote_port: u32,
    local_port: u32,
) -> String {
    format!("{}{}{}{}", local_port, remote_host, remote_port, jump_host)
}

struct Client {}

#[async_trait]
impl client::Handler for Client {
    type Error = russh::Error;

    async fn check_server_key(
        self,
        _server_public_key: &key::PublicKey,
    ) -> Result<(Self, bool), Self::Error> {
        Ok((self, true))
    }
}

/// Spawn an blocking task to ssh local port mapping
async fn spawn_sshlmp(
    id: &String,
    jump_host: &String,
    jump_host_port: u32,
    jump_host_user: Option<&String>,
    jump_host_pwd: Option<&String>,
    target_host: &String,
    target_host_user: Option<&String>,
    target_host_pwd: Option<&String>,
    target_host_port: u32,
    local_port: u32,
) -> Result<oneshot::Sender<()>> {
    let (stop_sender, mut stop_receiver) = oneshot::channel::<()>();

    let jump_host = jump_host.clone();
    let target_host = target_host.clone();
    let jump_host_user = jump_host_user.cloned();
    let jump_host_pwd = jump_host_pwd.cloned();
    let id: String = id.clone();

    // abort() does not work for block
    tokio::task::spawn_blocking(move || {
        let jump_host = jump_host.clone();
        let target_host = target_host.clone();
        let jump_host_user = jump_host_user.clone();
        let jump_host_pwd = jump_host_pwd.clone();
        let rt = tokio::runtime::Handle::current();

        // infinit loop until stop_receiver receive the exit signal
        rt.block_on(async move {
            let listener = TcpListener::bind(format!("0.0.0.0:{}", local_port)).await?;

            loop {
                // check whether should we stop
                match stop_receiver.try_recv() {
                    Ok(_) => {
                        // if receive stop signal, terminate it
                        break;
                    }
                    Err(err) => match err {
                        oneshot::error::TryRecvError::Empty => {}
                        oneshot::error::TryRecvError::Closed => {
                            // if channel is closed, terminate it
                            break;
                        }
                    },
                }

                // Accept incoming connections
                let (mut socket, addr) = listener.accept().await?;
                let jump_host = jump_host.clone();
                let target_host = target_host.clone();
                let jump_host_user = jump_host_user.clone();
                let jump_host_pwd = jump_host_pwd.clone();
                // Spawn a new task to handle the connection
                tokio::spawn(async move {
                    // Configure the SSH client
                    let mut config = client::Config::default();
                    config.inactivity_timeout = Some(std::time::Duration::from_secs(3600 * 60));
                    let config = Arc::new(config);
                    let sh = Client {};
                    // Connect to the SSH server
                    let addr: Vec<SocketAddr> =
                        tokio::net::lookup_host(format!("{}:{}", jump_host, jump_host_port))
                            .await?
                            .collect();
                    let mut jump_handle =
                        client::connect(config.clone(), addr.as_slice(), sh).await?;

                    if jump_host_user.is_some() && jump_host_pwd.is_some() {
                        // Authenticate
                        let is_authenticated = jump_handle
                            .authenticate_password(jump_host_user.unwrap(), jump_host_pwd.unwrap())
                            .await?;

                        if !is_authenticated {
                            bail!("failed to authenticate");
                        }
                    } else {
                        bail!("TODO: authentication other than password");
                    }
                    let (mut reader, mut writer) = socket.split();
                    let mut channel = jump_handle
                        .channel_open_direct_tcpip(
                            target_host,
                            target_host_port,
                            "localhost",
                            local_port,
                        )
                        .await?;
                    let mut ch_writer = channel.make_writer();
                    let mut ch_reader = channel.make_reader();
                    loop {
                        tokio::select! {
                            _ = tokio::io::copy(&mut reader, &mut ch_writer) => {},
                            _ = tokio::io::copy(&mut ch_reader, &mut writer) => {},
                        };
                    }

                    Ok(())
                });
            }

            Ok::<(), anyhow::Error>(())
        });

        info!("lmp {} thread exited", id);
    });

    Ok(stop_sender)
}

#[derive(Serialize, Deserialize, Default, Clone, PartialEq, Debug)]
pub struct SshpmStatus {
    pub lmp_status: HashMap<String, bool>,
}

pub struct SshpmUsecaseImpl {
    status_mngr: RwLock<Option<Arc<dyn ConfigManager<SshpmStatus>>>>,
    lmp_exit: RwLock<HashMap<String, oneshot::Sender<()>>>,
}

impl SshpmUsecaseImpl {
    pub fn new() -> Self {
        Self {
            status_mngr: Default::default(),
            lmp_exit: Default::default(),
        }
    }

    async fn deactivate_lmp(&self, mp: &SshLocalPortMapping) -> Result<()> {
        let lmp_id = get_lmp_id(
            &mp.jump_host,
            &mp.target_host,
            mp.target_host_port.unwrap_or(22),
            mp.local_port,
        );

        Ok(())
    }

    async fn activate_lmp(&self, mp: &SshLocalPortMapping) -> Result<()> {
        let lmp_id = get_lmp_id(
            &mp.jump_host,
            &mp.target_host,
            mp.target_host_port.unwrap_or(22),
            mp.local_port,
        );

        let exit_sender = spawn_sshlmp(
            &lmp_id,
            &mp.jump_host,
            mp.jump_host_port.unwrap_or(22),
            mp.jump_host_user.as_ref().or_else(|| {
                if mp.same_pwd.is_some_and(|s| s) {
                    mp.target_host_user.as_ref()
                } else {
                    None
                }
            }),
            mp.jump_host_pwd.as_ref().or_else(|| {
                if mp.same_pwd.is_some_and(|s| s) {
                    mp.target_host_pwd.as_ref()
                } else {
                    None
                }
            }),
            &mp.target_host,
            mp.target_host_user.as_ref().or_else(|| {
                if mp.same_pwd.is_some_and(|s| s) {
                    mp.jump_host_pwd.as_ref()
                } else {
                    None
                }
            }),
            mp.target_host_pwd.as_ref().or_else(|| {
                if mp.same_pwd.is_some_and(|s| s) {
                    mp.jump_host_pwd.as_ref()
                } else {
                    None
                }
            }),
            mp.target_host_port.unwrap_or(22),
            mp.local_port,
        )
        .await?;

        self.lmp_exit
            .write()
            .await
            .insert(lmp_id.clone(), exit_sender);
        info!("lmp '{}' is spawned", lmp_id);

        match &mut *self.status_mngr.write().await {
            Some(status_mngr) => {
                let status = status_mngr.get().await?;
                if let Some(mut status) = status {
                    status.lmp_status.insert(lmp_id, true);
                    status_mngr.update(status).await?;
                }
            }
            None => {}
        }

        Ok(())
    }
}

#[async_trait]
impl FeatureControl for SshpmUsecaseImpl {
    fn name(&self) -> &str {
        "sshpm"
    }
    async fn enable(&self, settings: Option<serde_json::Value>) -> Result<()> {
        if self.status_mngr.read().await.is_none() {
            let feature_dir = get_feature_dir(self.name());
            config::init_dir(&feature_dir)?;

            let status = Arc::new(FsConfigManager::new(feature_dir, "status.json".to_owned()));
            status.init(None).await?;
            *self.status_mngr.write().await = Some(status);
        }

        match settings {
            Some(settings) => {
                let settings: SshpmSetting = serde_json::from_value(settings)?;
                for lmp in settings.lpmappings.iter().flatten() {
                    match self.activate_lmp(lmp).await {
                        Ok(_) => {
                            info!("spawn ssh local port mapping '{:?}'", lmp)
                        }
                        Err(err) => {
                            error!(
                                "failed to spawn ssh local port mapping '{:?}': {}",
                                lmp, err
                            )
                        }
                    }
                }
            }
            None => {}
        }

        Ok(())
    }
    async fn disable(&self) -> Result<()> {
        for (lmpid, exit) in self.lmp_exit.write().await.drain().into_iter() {
            match exit.send(()) {
                Ok(_) => {}
                Err(err) => {
                    error!("failed to send exit to {}", lmpid);
                }
            }
        }
        self.lmp_exit.write().await.clear();
        *self.status_mngr.write().await = None;
        Ok(())
    }
    async fn update(
        &self,
        old_settings: Option<serde_json::Value>,
        setting: Option<serde_json::Value>,
    ) -> Result<()> {
        // loop seeting to see whether any lmp setting is changed
        Ok(())
    }
    async fn lid_change(&self, open: bool) {}
}