toolforge 5.8.0

Small library for common tasks on Wikimedia Toolforge
Documentation
/*
Copyright (C) 2022 Kunal Mehta <legoktm@debian.org>

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

use crate::db::Cluster;
pub use mysql_async;
use mysql_async::{prelude::*, Conn, Pool};
use std::collections::{BTreeSet, HashMap};
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::{OnceCell, RwLock};

use crate::db::ReplicaMyCnf;
pub use crate::db::{ANALYTICS_CLUSTER, WEB_CLUSTER};
use crate::{Error, Result};

/// By default Toolforge tools have a connection limit of 10
const DEFAULT_POOL_MAX: usize = 10;

pub struct Builder {
    cluster: Cluster,
    my_cnf_path: Option<PathBuf>,
    pool_max: Option<usize>,
}

impl Builder {
    /// Build the `WikiPool` instance
    pub fn build(self) -> Result<WikiPool> {
        let my_cnf = crate::db::load_replica_cnf(self.my_cnf_path)?;
        let inner = Inner {
            cluster: self.cluster,
            my_cnf,
            pool_max: self.pool_max.unwrap_or(DEFAULT_POOL_MAX),
            mapping: RwLock::new(HashMap::new()),
            domains: RwLock::new(HashMap::new()),
            // TODO: Can we macro this repetition?
            pool_s1: OnceCell::new(),
            pool_s2: OnceCell::new(),
            pool_s3: OnceCell::new(),
            pool_s4: OnceCell::new(),
            pool_s5: OnceCell::new(),
            pool_s6: OnceCell::new(),
            pool_s7: OnceCell::new(),
            pool_s8: OnceCell::new(),
            pool_s9: OnceCell::new(),
            pool_s10: OnceCell::new(),
            pool_s11: OnceCell::new(),
            pool_s12: OnceCell::new(),
            pool_s13: OnceCell::new(),
            pool_s14: OnceCell::new(),
            pool_s15: OnceCell::new(),
        };
        Ok(WikiPool {
            inner: Arc::new(inner),
        })
    }

    /// Change the database cluster to be used
    pub fn cluster(mut self, cluster: Cluster) -> Self {
        self.cluster = cluster;
        self
    }

    /// Change the path where .my.cnf credentials should be loaded from.
    /// Primarily useful for testing. Defaults to `~/replica.my.cnf`.
    pub fn my_cnf_path(mut self, path: PathBuf) -> Self {
        self.my_cnf_path = Some(path);
        self
    }

    /// Increase the number of connections the each server's pool can hold.
    /// By default tools are allowed 10 concurrent connections.
    pub fn pool_max(mut self, pool_max: usize) -> Self {
        self.pool_max = Some(pool_max);
        self
    }
}

/// WikiPool is a wrapper around `mysql_async::Pool` that supports
/// connecting to all wiki replica databases. It specifically knows
/// which databases are on which backend servers, allowing for further
/// optimization.
///
/// ```rust
/// use toolforge::pool::{WikiPool, WEB_CLUSTER};
/// # async fn demo() -> toolforge::Result<()> {
/// // Or use ANALYTICS_CLUSTER
/// let pool = WikiPool::new(WEB_CLUSTER).expect("failed to load db config");
/// let mut conn = pool.connect("enwiki").await?;
/// // conn.query(...);
/// # Ok(())
/// # }
/// ```
///
/// If you need to e.g. increase the pool size (because your tool was given
/// an increased quota limit), you can use the builder.
///
/// ```rust
/// use toolforge::pool::{WikiPool, WEB_CLUSTER};
/// # async fn demo() -> toolforge::Result<()> {
/// // Or use ANALYTICS_CLUSTER
/// let pool = WikiPool::builder(WEB_CLUSTER).pool_max(20).build().expect("failed to load db config");
/// let mut conn = pool.connect("enwiki").await?;
/// // conn.query(...);
/// # Ok(())
/// # }
/// ```
#[derive(Clone)]
pub struct WikiPool {
    inner: Arc<Inner>,
}

struct Inner {
    cluster: Cluster,
    my_cnf: ReplicaMyCnf,
    pool_max: usize,
    mapping: RwLock<HashMap<String, usize>>,
    domains: RwLock<HashMap<String, String>>,
    // As of 2022-12, covering s1 to s15 should cover us for at least
    // a decade. See T325951#8491619.
    // TODO: Can we macro this repetition?
    pool_s1: OnceCell<Pool>,
    pool_s2: OnceCell<Pool>,
    pool_s3: OnceCell<Pool>,
    pool_s4: OnceCell<Pool>,
    pool_s5: OnceCell<Pool>,
    pool_s6: OnceCell<Pool>,
    pool_s7: OnceCell<Pool>,
    pool_s8: OnceCell<Pool>,
    pool_s9: OnceCell<Pool>,
    pool_s10: OnceCell<Pool>,
    pool_s11: OnceCell<Pool>,
    pool_s12: OnceCell<Pool>,
    pool_s13: OnceCell<Pool>,
    pool_s14: OnceCell<Pool>,
    pool_s15: OnceCell<Pool>,
}

impl WikiPool {
    pub fn builder(cluster: Cluster) -> Builder {
        Builder {
            cluster,
            my_cnf_path: None,
            pool_max: None,
        }
    }

    /// Create a new WikiPool for the given cluster, either
    /// `WEB_CLUSTER` or `ANALYTICS_CLUSTER`
    pub fn new(cluster: Cluster) -> Result<Self> {
        Self::builder(cluster).build()
    }

    /// Build the MySQL connection string for the given slice
    fn connection_string(&self, slice: usize) -> String {
        let cluster = match self.inner.cluster {
            Cluster::ANALYTICS => "analytics",
            Cluster::WEB => "web",
        };
        let mut port = 3306;
        let base_host = if self.inner.my_cnf.local {
            // local.wmftest.net has a wildcard DNS that points to localhost
            port += slice;
            "local.wmftest.net"
        } else {
            "svc.wikimedia.cloud"
        };

        format!(
            "mysql://{}:{}@s{slice}.{cluster}.db.{base_host}:{port}/?pool_min=0&pool_max={}&inactive_connection_ttl=1&ttl_check_interval=30",
            self.inner.my_cnf.user, self.inner.my_cnf.password, self.inner.pool_max
        )
    }

    /// Get the pool for the corresponding slice, creating it if necessary.
    async fn pool(&self, slice: usize) -> &Pool {
        // TODO: Can we macro this repetition?
        let cell = match slice {
            1 => &self.inner.pool_s1,
            2 => &self.inner.pool_s2,
            3 => &self.inner.pool_s3,
            4 => &self.inner.pool_s4,
            5 => &self.inner.pool_s5,
            6 => &self.inner.pool_s6,
            7 => &self.inner.pool_s7,
            8 => &self.inner.pool_s8,
            9 => &self.inner.pool_s9,
            10 => &self.inner.pool_s10,
            11 => &self.inner.pool_s11,
            12 => &self.inner.pool_s12,
            13 => &self.inner.pool_s13,
            14 => &self.inner.pool_s14,
            15 => &self.inner.pool_s15,
            num => panic!("Unknown slice s{num}, please update or report a bug in the toolforge crate")
        };
        cell.get_or_init(|| async {
            Pool::new(self.connection_string(slice).as_str())
        })
        .await
    }

    /// Load the mapping of database servers to their slice by querying the
    /// meta_p.wiki table. As an optimization, return the slice of the
    /// database server we're looking for, if we found it.
    async fn load_mapping(&self, wanted: &str) -> Result<Option<usize>> {
        // meta_p database is on s7
        let pool = self.pool(7).await;
        let mut conn = pool.get_conn().await?;
        conn.query_drop("USE meta_p").await?;
        let rows: Vec<(String, String, String)> =
            conn.query("SELECT dbname, url, slice FROM wiki").await?;
        let mut ret = None;
        // Hardcode meta_p and centralauth_p to s7
        let mut mapping = HashMap::from([
            ("meta".to_string(), 7),
            ("centralauth".to_string(), 7),
        ]);
        if wanted == "meta" || wanted == "centralauth" {
            ret = Some(7);
        }
        let mut domains = HashMap::new();
        for (dbname, url, hostname) in rows {
            // e.g. "s7.labsdb" -> 7
            let (s_name, _) = hostname
                .split_once('.')
                .expect("invalid value in meta_p.wiki database");
            let slice: usize = s_name
                .trim_start_matches('s')
                .parse()
                .expect("invalid value in meta_p.wiki database");
            if dbname == wanted {
                ret = Some(slice);
            }
            domains.insert(
                url.trim_start_matches("https://").to_string(),
                dbname.to_string(),
            );
            mapping.insert(dbname, slice);
        }
        let mut w = self.inner.mapping.write().await;
        *w = mapping;
        let mut w = self.inner.domains.write().await;
        *w = domains;
        Ok(ret)
    }

    /// Open a database connection to the specified database.
    /// The database name will be validate against the list of known databases,
    /// so it is safe to supply arbitrary user input to this function, but it is
    /// recommended you do your own validation for performance.
    pub async fn connect(&self, dbname: &str) -> Result<Conn> {
        let dbname = dbname.trim_end_matches("_p");
        let slice = self.inner.mapping.read().await.get(dbname).copied();
        let slice = match slice {
            Some(slice) => slice,
            // Not present in our possibly loaded mapping. Let's force
            // reload the mapping, maybe it's a newly added wiki?
            None => match self.load_mapping(dbname).await? {
                Some(slice) => slice,
                None => {
                    // Not present again, return an error.
                    return Err(Error::UnknownDatabase(dbname.to_string()));
                }
            },
        };
        let pool = self.pool(slice).await;
        let mut conn = pool.get_conn().await?;
        // This is safe against SQL injection because:
        // * We validated the database name against the meta_p.wiki mapping
        // * The replicas are read-only, so no stored SQLi is possible
        // * We drop the result, so messing with the query results is useless
        conn.query_drop(format!("USE {dbname}_p")).await?;
        Ok(conn)
    }

    /// Open a database connection to the specified domain.
    /// The domain name will be validated against the list of known databases,
    /// so it is safe to supply arbitrary user input to this function.
    pub async fn connect_by_domain(&self, domain: &str) -> Result<Conn> {
        if let Some(dbname) = self.inner.domains.read().await.get(domain) {
            return self.connect(dbname).await;
        }
        // Force reload the mapping
        let _ = self.load_mapping("").await?;
        match self.inner.domains.read().await.get(domain) {
            Some(dbname) => self.connect(dbname).await,
            None => Err(Error::UnknownDomain(domain.to_string())),
        }
    }

    /// Open a database connection to the specified slice.
    ///
    /// No database will be selected.
    pub async fn connect_to_slice(&self, slice: usize) -> Result<Conn> {
        let pool = self.pool(slice).await;
        Ok(pool.get_conn().await?)
    }

    /// Get a list of all the database names (e.g. "enwiki", "dewiktionary")
    /// from the meta_p.wiki table. This will include closed wikis and
    /// is unsorted.
    pub async fn list_dbnames(&self) -> Result<BTreeSet<String>> {
        let dbnames = self
            .connect("meta")
            .await?
            .query("SELECT dbname FROM wiki")
            .await?
            .into_iter()
            .collect();
        Ok(dbnames)
    }

    /// Get a list of all the wiki domains (e.g. "en.wikipedia.org") from the
    /// meta_p.wiki table. This will include closed wikis and is unsorted.
    pub async fn list_domains(&self) -> Result<BTreeSet<String>> {
        let domains = self
            .connect("meta")
            .await?
            .query_map("SELECT url FROM wiki", |url: String| {
                url.trim_start_matches("https://").to_string()
            })
            .await?
            .into_iter()
            .collect();
        Ok(domains)
    }
}