ruzor 0.1.2

Ruzor, a 1:1-compatible Rust port of the Pyzor UDP client and server
Documentation
use std::process::Command;

use crate::Result;
use crate::account::now_timestamp;
use crate::engines::Record;
use crate::error::PyzorError;
use crate::local_time;
use crate::mysql_engine::{MySqlDatabase, MySqlDsn, MySqlExecutor};

const MYSQL_UNAVAILABLE: &str = "Database temporarily unavailable.";

#[derive(Clone, Debug)]
pub struct MySqlCommandExecutor {
    dsn: MySqlDsn,
    program: String,
}

impl MySqlCommandExecutor {
    pub fn new(dsn: MySqlDsn) -> Self {
        let program = std::env::var("PYZOR_MYSQL_BIN").unwrap_or_else(|_| "mysql".to_string());
        Self::with_program(dsn, program)
    }

    pub fn with_program(dsn: MySqlDsn, program: impl Into<String>) -> Self {
        Self {
            dsn,
            program: program.into(),
        }
    }

    pub fn command_args(&self, sql: &str) -> Vec<String> {
        let mut args = vec![
            "--batch".to_string(),
            "--raw".to_string(),
            "--skip-column-names".to_string(),
        ];
        if self.dsn.host.starts_with('/') {
            args.push(format!("--socket={}", self.dsn.host));
        } else if !self.dsn.host.is_empty() {
            args.push(format!("--host={}", self.dsn.host));
        }
        if !self.dsn.user.is_empty() {
            args.push(format!("--user={}", self.dsn.user));
        }
        if !self.dsn.password.is_empty() {
            args.push(format!("--password={}", self.dsn.password));
        }
        if !self.dsn.database.is_empty() {
            args.push(format!("--database={}", self.dsn.database));
        }
        args.push("--execute".to_string());
        args.push(sql.to_string());
        args
    }

    pub fn execute_statement(&self, sql: &str) -> Result<String> {
        let output = Command::new(&self.program)
            .args(self.command_args(sql))
            .output()
            .map_err(|error| {
                PyzorError::Comm(format!(
                    "Unable to run mysql client '{}': {}",
                    self.program, error
                ))
            })?;
        if !output.status.success() {
            return Err(database_unavailable());
        }
        String::from_utf8(output.stdout)
            .map_err(|error| PyzorError::Comm(format!("Invalid mysql output: {error}")))
    }
}

impl MySqlDatabase<MySqlCommandExecutor> {
    pub fn connect(dsn: impl AsRef<str>) -> Result<Self> {
        Self::connect_with_max_age(dsn, None)
    }

    pub fn connect_with_max_age(dsn: impl AsRef<str>, max_age: Option<i64>) -> Result<Self> {
        Self::connect_with_max_age_and_db_connections(dsn, max_age, 0)
    }

    pub fn connect_with_max_age_and_db_connections(
        dsn: impl AsRef<str>,
        max_age: Option<i64>,
        _db_connections: usize,
    ) -> Result<Self> {
        let dsn = MySqlDsn::parse(dsn.as_ref())?;
        let statements = dsn.statements();
        let executor = MySqlCommandExecutor::new(dsn.clone());
        executor.execute_statement("SELECT 1")?;
        if let Some(max_age) = max_age.filter(|age| *age != 0) {
            let cutoff = now_timestamp() - max_age;
            let statement = mysql_prepared_placeholders(&statements.reorganize());
            let sql = bind_statement(&statement, &[SqlValue::DateTime(Some(cutoff))])?;
            executor.execute_statement(&sql)?;
        }
        Ok(MySqlDatabase::with_executor(dsn, executor))
    }
}

impl MySqlExecutor for MySqlCommandExecutor {
    fn fetch_record(&mut self, statement: &str, digest: &str) -> Result<Option<Record>> {
        let sql = bind_statement(statement, &[SqlValue::Text(digest)])?;
        let output = self.execute_statement(&sql)?;
        output
            .lines()
            .find(|line| !line.trim().is_empty())
            .map(parse_record_row)
            .transpose()
    }

    fn execute_digest_batch(&mut self, statement: &str, digests: &[String]) -> Result<()> {
        let mut sql = String::new();
        for digest in digests {
            if !sql.is_empty() {
                sql.push(';');
            }
            sql.push_str(&bind_statement(statement, &[SqlValue::Text(digest)])?);
        }
        if sql.is_empty() {
            return Ok(());
        }
        self.execute_statement(&sql).map(|_| ())
    }

    fn execute_set_record(&mut self, statement: &str, digest: &str, record: &Record) -> Result<()> {
        let values = [
            SqlValue::Text(digest),
            SqlValue::Integer(record.r_count),
            SqlValue::Integer(record.wl_count),
            SqlValue::DateTime(record.r_entered),
            SqlValue::DateTime(record.r_updated),
            SqlValue::DateTime(record.wl_entered),
            SqlValue::DateTime(record.wl_updated),
            SqlValue::Integer(record.r_count),
            SqlValue::Integer(record.wl_count),
            SqlValue::DateTime(record.r_entered),
            SqlValue::DateTime(record.r_updated),
            SqlValue::DateTime(record.wl_entered),
            SqlValue::DateTime(record.wl_updated),
        ];
        let sql = bind_statement(statement, &values)?;
        self.execute_statement(&sql).map(|_| ())
    }
}

#[derive(Clone, Copy, Debug)]
enum SqlValue<'a> {
    Text(&'a str),
    Integer(i64),
    DateTime(Option<i64>),
}

impl SqlValue<'_> {
    fn to_sql(self) -> String {
        match self {
            Self::Text(value) => quote_sql_string(value),
            Self::Integer(value) => value.to_string(),
            Self::DateTime(Some(value)) => quote_sql_string(&format_mysql_datetime(value)),
            Self::DateTime(None) => "NULL".to_string(),
        }
    }
}

fn bind_statement(statement: &str, values: &[SqlValue<'_>]) -> Result<String> {
    let mut bound = String::with_capacity(statement.len() + values.len() * 16);
    let mut values = values.iter();
    for ch in statement.chars() {
        if ch == '?' {
            let value = values.next().ok_or_else(|| {
                PyzorError::Comm("Missing MySQL statement parameter.".to_string())
            })?;
            bound.push_str(&value.to_sql());
        } else {
            bound.push(ch);
        }
    }
    if values.next().is_some() {
        return Err(PyzorError::Comm(
            "Too many MySQL statement parameters.".to_string(),
        ));
    }
    Ok(bound)
}

fn mysql_prepared_placeholders(statement: &str) -> String {
    statement.replace("%s", "?")
}

fn quote_sql_string(value: &str) -> String {
    let mut quoted = String::with_capacity(value.len() + 2);
    quoted.push('\'');
    for ch in value.chars() {
        match ch {
            '\0' => quoted.push_str("\\0"),
            '\n' => quoted.push_str("\\n"),
            '\r' => quoted.push_str("\\r"),
            '\\' => quoted.push_str("\\\\"),
            '\'' => quoted.push_str("\\'"),
            '"' => quoted.push_str("\\\""),
            '\x1a' => quoted.push_str("\\Z"),
            _ => quoted.push(ch),
        }
    }
    quoted.push('\'');
    quoted
}

fn parse_record_row(row: &str) -> Result<Record> {
    let fields = row.split('\t').collect::<Vec<_>>();
    if fields.len() != 6 {
        return Err(database_unavailable());
    }
    Ok(Record {
        r_count: parse_nullable_i64(fields[0])?.unwrap_or(0),
        wl_count: parse_nullable_i64(fields[1])?.unwrap_or(0),
        r_entered: parse_mysql_datetime(fields[2])?,
        r_updated: parse_mysql_datetime(fields[3])?,
        wl_entered: parse_mysql_datetime(fields[4])?,
        wl_updated: parse_mysql_datetime(fields[5])?,
    })
}

fn parse_nullable_i64(value: &str) -> Result<Option<i64>> {
    if is_mysql_null(value) {
        return Ok(None);
    }
    value.parse().map(Some).map_err(|_| database_unavailable())
}

fn parse_mysql_datetime(value: &str) -> Result<Option<i64>> {
    if is_mysql_null(value) {
        return Ok(None);
    }
    local_time::parse_datetime(value)
        .map(Some)
        .ok_or_else(database_unavailable)
}

fn is_mysql_null(value: &str) -> bool {
    value.eq_ignore_ascii_case("NULL") || value == "\\N" || value.is_empty()
}

fn format_mysql_datetime(timestamp: i64) -> String {
    local_time::format_timestamp(timestamp)
}

fn database_unavailable() -> PyzorError {
    PyzorError::Comm(MYSQL_UNAVAILABLE.to_string())
}

#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn mysql_command_args_use_reference_dsn_fields() {
        let dsn = MySqlDsn::parse("localhost,pyzor,secret,pyzord,public").unwrap();
        let executor = MySqlCommandExecutor::with_program(dsn, "mysql");

        assert_eq!(
            executor.command_args("SELECT 1"),
            vec![
                "--batch",
                "--raw",
                "--skip-column-names",
                "--host=localhost",
                "--user=pyzor",
                "--password=secret",
                "--database=pyzord",
                "--execute",
                "SELECT 1"
            ]
        );
    }

    #[test]
    fn mysql_command_args_support_unix_socket_hosts() {
        let dsn = MySqlDsn::parse("/tmp/mysql.sock,pyzor,,pyzord,public").unwrap();
        let executor = MySqlCommandExecutor::with_program(dsn, "mysql");

        assert!(
            executor
                .command_args("SELECT 1")
                .contains(&"--socket=/tmp/mysql.sock".to_string())
        );
    }

    #[test]
    fn mysql_statement_binding_quotes_values_and_times() {
        crate::local_time::with_timezone_for_tests("UTC", || {
            let statement = "INSERT INTO public VALUES (?, ?, ?, ?)";
            let sql = bind_statement(
                statement,
                &[
                    SqlValue::Text("abc'd\\e"),
                    SqlValue::Integer(24),
                    SqlValue::DateTime(Some(1_400_221_786)),
                    SqlValue::DateTime(None),
                ],
            )
            .unwrap();

            assert_eq!(
                sql,
                "INSERT INTO public VALUES ('abc\\'d\\\\e', 24, '2014-05-16 06:29:46', NULL)"
            );
        });
    }

    #[test]
    fn mysql_datetime_binding_and_parsing_use_python_local_time() {
        crate::local_time::with_timezone_for_tests("Europe/Paris", || {
            let sql = bind_statement(
                "DELETE FROM public WHERE r_updated<?",
                &[SqlValue::DateTime(Some(1_400_221_786))],
            )
            .unwrap();

            assert_eq!(
                sql,
                "DELETE FROM public WHERE r_updated<'2014-05-16 08:29:46'"
            );
            assert_eq!(
                parse_mysql_datetime("2014-05-16 08:29:46").unwrap(),
                Some(1_400_221_786)
            );
        });
    }

    #[test]
    fn mysql_statement_binding_rejects_placeholder_mismatch() {
        assert!(bind_statement("SELECT ?", &[]).is_err());
        assert!(bind_statement("SELECT 1", &[SqlValue::Integer(1)]).is_err());
    }

    #[test]
    fn mysql_row_parser_matches_reference_column_order() {
        crate::local_time::with_timezone_for_tests("UTC", || {
            let row = "24\t42\t2014-05-16 06:29:46\t2014-05-16 06:29:54\tNULL\t\\N";

            assert_eq!(
                parse_record_row(row).unwrap(),
                Record {
                    r_count: 24,
                    wl_count: 42,
                    r_entered: Some(1_400_221_786),
                    r_updated: Some(1_400_221_794),
                    wl_entered: None,
                    wl_updated: None,
                }
            );
        });
    }
}