toolforge 5.8.0

Small library for common tasks on Wikimedia Toolforge
Documentation
/*
Copyright (C) 2020-2022 Kunal Mehta <legoktm@debian.org>

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

use ini::Ini;
use std::{fmt, fs, path::PathBuf};

/// Deprecated type alias
pub type ToolsDBCluster = Cluster;
// Convenience alias
pub use Cluster::ANALYTICS as ANALYTICS_CLUSTER;
pub use Cluster::WEB as WEB_CLUSTER;

/// By default Toolforge tools have a connection limit of 10
const DEFAULT_POOL_MAX: usize = 10;

/// Fields needed to connect to the MariaDB databases on Toolforge
///
/// Calling `.to_string()` will get the connection info in the format of the
/// [`mysql`](https://docs.rs/mysql) crate.
pub struct DBConnectionInfo {
    pub database: String,
    pub host: String,
    pub user: String,
    pub password: String,
    pool_max: usize,
}

impl DBConnectionInfo {
    /// Increase the number of connections the database pool can hold. By
    /// default tools are allowed 10 concurrent connections.
    pub fn pool_max(mut self, pool_max: usize) -> Self {
        self.pool_max = pool_max;
        self
    }
}

impl fmt::Display for DBConnectionInfo {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        // pool_min=0 means the connection pool will hold 0 active connections at minimum
        // pool_max=? means the max number of connections the pool will hold (should be no more than
        //            the max_connections_limit for your user (default 10)
        // inactive_connection_ttl=1 means inactive connections will be dropped after 1 second.
        //                           We can't do 0 because that effectively disables connection reuse.
        // ttl_check_interval=30 means it will check for inactive connections every 30sec
        write!(
            f,
            "mysql://{}:{}@{}:3306/{}?pool_min=0&pool_max={}&inactive_connection_ttl=1&ttl_check_interval=30",
            self.user, self.password, self.host, self.database, self.pool_max
        )
    }
}

pub(crate) struct ReplicaMyCnf {
    pub(crate) user: String,
    pub(crate) password: String,
    pub(crate) local: bool,
}

/// Database clusters to connect to
///
/// See [Connecting to the database replicas](https://wikitech.wikimedia.org/wiki/Help:Toolforge/Database#Connecting_to_the_database_replicas)
/// for information about the differences between the web and analytics
/// clusters.
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum Cluster {
    ANALYTICS,
    WEB,
}

/// Get database connection info. Note that the `connection_info!()` macro
/// is a simpler way to call this.
#[doc(hidden)]
pub fn get_db_connection_info(
    dbname: &str,
    cluster: Cluster,
    my_cnf_path: Option<PathBuf>,
) -> crate::Result<DBConnectionInfo> {
    let my_cnf = load_replica_cnf(my_cnf_path)?;

    let domain = format!(
        "{}.db.svc.{}",
        match cluster {
            Cluster::ANALYTICS => "analytics",
            Cluster::WEB => "web",
        },
        if my_cnf.local {
            "local.wmftest.net"
        } else {
            "wikimedia.cloud"
        }
    );
    let normalized_dbname = dbname.trim_end_matches("_p");
    let host = if normalized_dbname == "meta" {
        format!("s7.{domain}")
    } else {
        format!("{normalized_dbname}.{domain}")
    };

    Ok(DBConnectionInfo {
        database: format!("{normalized_dbname}_p"),
        host,
        user: my_cnf.user,
        password: my_cnf.password,
        pool_max: DEFAULT_POOL_MAX,
    })
}

/// Get a connection string for a database on toolsdb. The full name of
/// database must be provided (e.g. `u12345__cooltool_p`).
pub fn toolsdb(database: String) -> crate::Result<DBConnectionInfo> {
    let my_cnf = load_replica_cnf(None)?;
    let domain = if my_cnf.local {
        "tools.db.svc.local.wmftest.net"
    } else {
        "tools.db.svc.wikimedia.cloud"
    };
    Ok(DBConnectionInfo {
        database,
        host: domain.to_string(),
        user: my_cnf.user,
        password: my_cnf.password,
        pool_max: DEFAULT_POOL_MAX,
    })
}

pub(crate) fn load_replica_cnf(
    path: Option<PathBuf>,
) -> crate::Result<ReplicaMyCnf> {
    let path = path.unwrap_or_else(|| {
        dirs::home_dir()
            // Just panic if this happens, seriously
            .expect("Couldn't find home directory")
            .join("replica.my.cnf")
    });
    if !path.exists() {
        return Err(crate::Error::NotToolforge("replica.my.cnf".to_string()));
    }
    let contents = fs::read_to_string(path)?;
    let conf = Ini::load_from_str(&contents)?;
    // TODO: avoid unwrap
    let section = conf.section(Some("client")).unwrap();
    Ok(ReplicaMyCnf {
        user: section.get("user").unwrap().to_string(),
        password: section.get("password").unwrap().to_string(),
        local: section.get("local").is_some(),
    })
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::db::get_db_connection_info;

    #[test]
    fn test_to_string() {
        let info = DBConnectionInfo {
            database: "meta_p".to_string(),
            host: "hostname".to_string(),
            user: "u12345".to_string(),
            password: "correcthorsebatterystaple".to_string(),
            pool_max: DEFAULT_POOL_MAX,
        };
        assert_eq!(
            "mysql://u12345:correcthorsebatterystaple@hostname:3306/meta_p?pool_min=0&pool_max=10&inactive_connection_ttl=1&ttl_check_interval=30"
                .to_string(),
            info.to_string()
        )
    }

    #[test]
    fn test_connection_info() -> crate::Result<()> {
        let path = Some(PathBuf::new().join("./tests/replica.my.cnf"));

        // basic
        assert_eq!(
            "mysql://u12345:correcthorsebatterystaple@enwiki.web.db.svc.wikimedia.cloud:3306/enwiki_p?pool_min=0&pool_max=10&inactive_connection_ttl=1&ttl_check_interval=30".to_string(),
            get_db_connection_info(
                "enwiki", Cluster::WEB, path.clone())?.to_string()
        );

        // analytics cluster
        assert_eq!(
            "mysql://u12345:correcthorsebatterystaple@enwiki.analytics.db.svc.wikimedia.cloud:3306/enwiki_p?pool_min=0&pool_max=10&inactive_connection_ttl=1&ttl_check_interval=30"
                .to_string(),
            get_db_connection_info(
                "enwiki_p", Cluster::ANALYTICS, path.clone())?.to_string()
        );

        // meta_p database
        assert_eq!(
            "mysql://u12345:correcthorsebatterystaple@s7.web.db.svc.wikimedia.cloud:3306/meta_p?pool_min=0&pool_max=10&inactive_connection_ttl=1&ttl_check_interval=30".to_string(),
            get_db_connection_info(
                "meta_p", Cluster::WEB, path.clone())?.to_string()
        );

        // my_cnf with double quotes
        assert_eq!(
            "mysql://u23456:correcthorsebatterystaples@enwiki.web.db.svc.wikimedia.cloud:3306/enwiki_p?pool_min=0&pool_max=10&inactive_connection_ttl=1&ttl_check_interval=30".to_string(),
            get_db_connection_info(
                "enwiki",
                Cluster::WEB,
                Some(PathBuf::new().join("./tests/double_quote.my.cnf"))
            )?.to_string()
        );

        // local=true
        assert_eq!(
            "mysql://u34567:correcthorsebatterystapler@enwiki.web.db.svc.local.wmftest.net:3306/enwiki_p?pool_min=0&pool_max=10&inactive_connection_ttl=1&ttl_check_interval=30".to_string(),
            get_db_connection_info(
                "enwiki",
                Cluster::WEB,
                Some(PathBuf::new().join("./tests/local.my.cnf"))
            )?.to_string()
        );

        // pool_max(20)
        assert_eq!(
            "mysql://u12345:correcthorsebatterystaple@enwiki.web.db.svc.wikimedia.cloud:3306/enwiki_p?pool_min=0&pool_max=20&inactive_connection_ttl=1&ttl_check_interval=30".to_string(),
            get_db_connection_info(
                "enwiki",
                Cluster::WEB,
                path
            )?.pool_max(20).to_string()
        );

        Ok(())
    }
}