sqlx-firebirdsql 0.1.0

Firebird SQL driver for SQLx
use std::collections::HashMap;
use std::str::FromStr;

use log::LevelFilter;
use url::Url;

use sqlx_core::connection::{ConnectOptions, LogSettings};

use crate::error::Error;
use crate::FirebirdConnection;

/// Options for connecting to a Firebird database.
#[derive(Debug, Clone)]
pub struct FirebirdConnectOptions {
    pub(crate) host: String,
    pub(crate) port: u16,
    pub(crate) username: String,
    pub(crate) password: String,
    pub(crate) database: Option<String>,
    pub(crate) conn_options: HashMap<String, String>,
    pub(crate) log_settings: LogSettings,
}

impl Default for FirebirdConnectOptions {
    fn default() -> Self {
        let mut conn_options = HashMap::new();
        conn_options.insert("role".to_string(), String::new());
        conn_options.insert("timezone".to_string(), String::new());
        conn_options.insert("wire_crypt".to_string(), "true".to_string());
        conn_options.insert("auth_plugin_name".to_string(), "Srp256".to_string());
        conn_options.insert("page_size".to_string(), "4096".to_string());

        Self {
            host: "localhost".to_string(),
            port: 3050,
            username: "SYSDBA".to_string(),
            password: "masterkey".to_string(),
            database: None,
            conn_options,
            log_settings: LogSettings::default(),
        }
    }
}

impl FirebirdConnectOptions {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn host(mut self, host: &str) -> Self {
        self.host = host.to_string();
        self
    }

    pub fn port(mut self, port: u16) -> Self {
        self.port = port;
        self
    }

    pub fn username(mut self, username: &str) -> Self {
        self.username = username.to_string();
        self
    }

    pub fn password(mut self, password: &str) -> Self {
        self.password = password.to_string();
        self
    }

    pub fn database(mut self, database: &str) -> Self {
        self.database = Some(database.to_string());
        self
    }

    /// Parse from a Firebird URL string.
    pub fn parse_from_url(url: &Url) -> Result<Self, Error> {
        Self::from_str(url.as_str())
    }
}

impl FromStr for FirebirdConnectOptions {
    type Err = Error;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        // Parse the URL manually since firebirust::ConnParams is private
        let url = Url::parse(s).map_err(|e| Error::Configuration(e.to_string().into()))?;

        if url.scheme() != "firebird" {
            return Err(Error::Configuration(
                format!("unsupported scheme: {}", url.scheme()).into(),
            ));
        }

        let host = url
            .host_str()
            .unwrap_or("localhost")
            .to_string();
        let port = url.port().unwrap_or(3050);
        let username = if url.username().is_empty() {
            "SYSDBA".to_string()
        } else {
            percent_decode(url.username())
        };
        let password = url
            .password()
            .map(percent_decode)
            .unwrap_or_else(|| "masterkey".to_string());

        let mut db_name = percent_decode(url.path());
        // "/foo.fdb" -> "foo.fdb" (only one leading slash)
        let slash_count = db_name.chars().filter(|&c| c == '/').count();
        if slash_count == 1 && db_name.starts_with('/') {
            db_name = db_name[1..].to_string();
        }

        let mut conn_options: HashMap<String, String> =
            url.query_pairs().into_owned().collect();
        conn_options
            .entry("role".to_string())
            .or_insert_with(String::new);
        conn_options
            .entry("timezone".to_string())
            .or_insert_with(String::new);
        conn_options
            .entry("wire_crypt".to_string())
            .or_insert_with(|| "true".to_string());
        conn_options
            .entry("auth_plugin_name".to_string())
            .or_insert_with(|| "Srp256".to_string());
        conn_options
            .entry("page_size".to_string())
            .or_insert_with(|| "4096".to_string());

        Ok(Self {
            host,
            port,
            username,
            password,
            database: Some(db_name),
            conn_options,
            log_settings: LogSettings::default(),
        })
    }
}

fn percent_decode(s: &str) -> String {
    use std::borrow::Cow;
    let decoded: Cow<'_, str> = url::form_urlencoded::parse(s.as_bytes())
        .next()
        .map(|(k, _)| k)
        .unwrap_or(Cow::Borrowed(s));
    decoded.into_owned()
}

impl ConnectOptions for FirebirdConnectOptions {
    type Connection = FirebirdConnection;

    fn from_url(url: &Url) -> Result<Self, Error> {
        Self::from_str(url.as_str())
    }

    fn to_url_lossy(&self) -> Url {
        let db = self.database.as_deref().unwrap_or("");
        let s = format!(
            "firebird://{}:{}@{}:{}/{}",
            self.username, self.password, self.host, self.port, db
        );
        Url::parse(&s).expect("BUG: generated URL is invalid")
    }

    fn connect(
        &self,
    ) -> impl std::future::Future<Output = Result<Self::Connection, Error>> + Send + '_
    where
        Self::Connection: Sized,
    {
        crate::connection::AssertSend(FirebirdConnection::establish(self))
    }

    fn log_statements(mut self, level: LevelFilter) -> Self {
        self.log_settings.statements_level = level;
        self
    }

    fn log_slow_statements(mut self, level: LevelFilter, duration: std::time::Duration) -> Self {
        self.log_settings.slow_statements_level = level;
        self.log_settings.slow_statements_duration = duration;
        self
    }
}