shvbroker 3.26.6

Rust implementation of the SHV broker
Documentation
use std::{collections::BTreeMap, fs, path::{Path, PathBuf}};

use log::{debug, info};
use async_sqlite::{ClientBuilder, Client};
use shvproto::RpcValue;

use crate::{brokerimpl::LastLogin, config::{AccessConfig, UpdateSqlOperation}};

pub const TBL_MOUNTS: &str = "mounts";
pub const TBL_USERS: &str = "users";
pub const TBL_ROLES: &str = "roles";
pub const TBL_ALLOWED_IPS: &str = "allowed_ips";
pub const TBL_LAST_LOGIN: &str = "last_login";

pub async fn migrate_sqlite_connection(sql_config_file: &PathBuf, access: &AccessConfig) -> shvrpc::Result<(Client, AccessConfig, LastLogin)> {
    info!("Opening SQLite access db file: {}", sql_config_file.to_str().expect("Valid path"));
    let (sql_connection, db_is_empty) = if sql_config_file == ":memory:" {
        // In memoty database is the default.
        (ClientBuilder::new().open().await?, true)
    } else {
        if let Some(path) = sql_config_file.parent() {
            fs::create_dir_all(path)?;
        }
        let db_file_exists = Path::new(&sql_config_file).exists();
        if !db_file_exists {
            info!("Creating new db file: {}", sql_config_file.to_str().expect("Valid path"));
        }

        let sql_connection = ClientBuilder::new()
            .path(sql_config_file)
            .open()
            .await?;

        sql_connection.conn_and_then(|sql_connection| {
            if sql_connection.is_readonly(sql_connection.db_name(0)?.as_str())? {
                return shvrpc::Result::Err("Couldn't open SQLite database as read-write".into());
            }
            Ok(())
        }).await?;

        (sql_connection, !db_file_exists)
    };
    let (access_config, last_login) = init_access_db(&sql_connection, db_is_empty, access).await?;

    Ok((sql_connection, access_config, last_login))
}

async fn init_access_db(sql_connection: &Client, db_is_empty: bool, access: &AccessConfig) -> shvrpc::Result<(AccessConfig, LastLogin)> {
    let loaded_data = if db_is_empty {
        let last_login = LastLogin::default();
        create_access_sqlite(sql_connection, access, &last_login).await?;
        (access.clone(), last_login)
    } else {
        load_access_sqlite(sql_connection).await?
    };
    Ok(loaded_data)
}

async fn create_access_sqlite(sql_conn: &Client, access: &AccessConfig, last_login: &LastLogin) -> shvrpc::Result<()> {
    async fn save_table<TableElementType: serde::Serialize + Send + 'static>(sql_conn: &Client, tbl_name: &'static str, items: BTreeMap<String, TableElementType>) -> shvrpc::Result<()> {
        sql_conn.conn_and_then(move |sql_conn| {
            sql_conn.execute(&format!(r#"
                CREATE TABLE {tbl_name} (
                    id character varying PRIMARY KEY,
                    def character varying
                );
            "#), [])?;
            let query = format!(r#"INSERT INTO {tbl_name} (id, def) VALUES (?1, ?2);"#);
            let mut stmt = sql_conn.prepare(&query)?;
            for (id, def) in items {
                debug!("Inserting {id} into {tbl_name}");
                stmt.execute((id, serde_json::to_string(&def)?))?;
            }
            Ok(())
        }).await
    }

    info!("Creating SQLite access db");
    save_table(sql_conn, TBL_MOUNTS, access.mounts().clone()).await?;
    save_table(sql_conn, TBL_USERS, access.users().clone()).await?;
    save_table(sql_conn, TBL_ROLES, access.roles().clone()).await?;
    save_table(sql_conn, TBL_ALLOWED_IPS, access.allowed_ips().clone()).await?;
    save_table(sql_conn, TBL_LAST_LOGIN, last_login.0.clone()).await?;

    Ok(())
}

pub(crate) async fn update_sql(oper: Vec<UpdateSqlOperation<'_>>, sql_connection: &async_sqlite::Client) -> shvrpc::Result<RpcValue> {
    let query = oper.into_iter().fold(String::new(), |mut acc, oper| {
        match oper {
            UpdateSqlOperation::Insert { table, id, json } => {
                acc += &format!("INSERT INTO {table} (id, def) VALUES ('{id}', '{json}');");
            }
            UpdateSqlOperation::Update { table, id, json } => {
                acc += &format!("UPDATE {table} SET def = '{json}' WHERE id = '{id}';");
            }
            UpdateSqlOperation::Delete { table, id } => {
                acc += &format!("DELETE FROM {table} WHERE id = '{id}';");
            }
        };
        acc
    });

    sql_connection.conn(move |sql_connection| {
        sql_connection.execute(&query, ())
    }).await
    .map(|v| RpcValue::from(v as i64))
        .map_err(|err| shvrpc::rpcmessage::RpcError::new(shvrpc::rpcmessage::RpcErrorCode::MethodCallException, err.to_string()).into())
}

async fn load_access_sqlite(sql_conn: &Client) -> shvrpc::Result<(AccessConfig, LastLogin)> {
    async fn load_table<TableElementType: for <'a> serde::Deserialize<'a> + 'static + Send>(sql_conn: &Client, table_name: &'static str) -> shvrpc::Result<BTreeMap<String, TableElementType>> {
        sql_conn.conn_and_then(move |sql_conn| {
            sql_conn.execute(&format!(r#"
                CREATE TABLE IF NOT EXISTS {table_name} (
                    id character varying PRIMARY KEY,
                    def character varying
                );
            "#), [])
                .map_err(|err| err.to_string())?;

            let mut stmt = sql_conn.prepare(&format!("SELECT id, def FROM {table_name}"))?;
            let rows = stmt.query([])?;
            let first_two_columns = rows.mapped(|row| {
                let id: String = row.get(0)?;
                let def: String = row.get(1)?;
                Ok((id, def))
            }).collect::<Result<Vec<_>,_>>()?;

            let parsed_rows = first_two_columns
                .into_iter()
                .map(|(id, def)| serde_json::from_str(&def).map(|parsed| (id, parsed)))
                .collect::<Result<BTreeMap<_,_>,_>>()?;

            Ok(parsed_rows)
        }).await
    }

    let access = AccessConfig::new(
        load_table(sql_conn, TBL_USERS).await?,
        load_table(sql_conn, TBL_ROLES).await?,
        load_table(sql_conn, TBL_MOUNTS).await?,
        load_table(sql_conn, TBL_ALLOWED_IPS).await?,
    );

    let last_login = load_table(sql_conn, TBL_LAST_LOGIN).await?
        .into_iter()
        .filter_map(|(key, val): (String, String)|{
            let date_time = shvproto::DateTime::from_iso_str(&val).inspect_err(|err| log::error!("Couldn't parse last_login datetime string: {err}")).ok();
            date_time.map(|date_time| (key, date_time))
        }).collect();

    Ok((access, LastLogin(last_login)))
}