use std::collections::HashMap;
use std::str::FromStr;
use log::LevelFilter;
use url::Url;
use sqlx_core::connection::{ConnectOptions, LogSettings};
use crate::error::Error;
use crate::FirebirdConnection;
#[derive(Debug, Clone)]
pub struct FirebirdConnectOptions {
pub(crate) host: String,
pub(crate) port: u16,
pub(crate) username: String,
pub(crate) password: String,
pub(crate) database: Option<String>,
pub(crate) conn_options: HashMap<String, String>,
pub(crate) log_settings: LogSettings,
}
impl Default for FirebirdConnectOptions {
fn default() -> Self {
let mut conn_options = HashMap::new();
conn_options.insert("role".to_string(), String::new());
conn_options.insert("timezone".to_string(), String::new());
conn_options.insert("wire_crypt".to_string(), "true".to_string());
conn_options.insert("auth_plugin_name".to_string(), "Srp256".to_string());
conn_options.insert("page_size".to_string(), "4096".to_string());
Self {
host: "localhost".to_string(),
port: 3050,
username: "SYSDBA".to_string(),
password: "masterkey".to_string(),
database: None,
conn_options,
log_settings: LogSettings::default(),
}
}
}
impl FirebirdConnectOptions {
pub fn new() -> Self {
Self::default()
}
pub fn host(mut self, host: &str) -> Self {
self.host = host.to_string();
self
}
pub fn port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn username(mut self, username: &str) -> Self {
self.username = username.to_string();
self
}
pub fn password(mut self, password: &str) -> Self {
self.password = password.to_string();
self
}
pub fn database(mut self, database: &str) -> Self {
self.database = Some(database.to_string());
self
}
pub fn parse_from_url(url: &Url) -> Result<Self, Error> {
Self::from_str(url.as_str())
}
}
impl FromStr for FirebirdConnectOptions {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let url = Url::parse(s).map_err(|e| Error::Configuration(e.to_string().into()))?;
if url.scheme() != "firebird" {
return Err(Error::Configuration(
format!("unsupported scheme: {}", url.scheme()).into(),
));
}
let host = url
.host_str()
.unwrap_or("localhost")
.to_string();
let port = url.port().unwrap_or(3050);
let username = if url.username().is_empty() {
"SYSDBA".to_string()
} else {
percent_decode(url.username())
};
let password = url
.password()
.map(percent_decode)
.unwrap_or_else(|| "masterkey".to_string());
let mut db_name = percent_decode(url.path());
let slash_count = db_name.chars().filter(|&c| c == '/').count();
if slash_count == 1 && db_name.starts_with('/') {
db_name = db_name[1..].to_string();
}
let mut conn_options: HashMap<String, String> =
url.query_pairs().into_owned().collect();
conn_options
.entry("role".to_string())
.or_insert_with(String::new);
conn_options
.entry("timezone".to_string())
.or_insert_with(String::new);
conn_options
.entry("wire_crypt".to_string())
.or_insert_with(|| "true".to_string());
conn_options
.entry("auth_plugin_name".to_string())
.or_insert_with(|| "Srp256".to_string());
conn_options
.entry("page_size".to_string())
.or_insert_with(|| "4096".to_string());
Ok(Self {
host,
port,
username,
password,
database: Some(db_name),
conn_options,
log_settings: LogSettings::default(),
})
}
}
fn percent_decode(s: &str) -> String {
use std::borrow::Cow;
let decoded: Cow<'_, str> = url::form_urlencoded::parse(s.as_bytes())
.next()
.map(|(k, _)| k)
.unwrap_or(Cow::Borrowed(s));
decoded.into_owned()
}
impl ConnectOptions for FirebirdConnectOptions {
type Connection = FirebirdConnection;
fn from_url(url: &Url) -> Result<Self, Error> {
Self::from_str(url.as_str())
}
fn to_url_lossy(&self) -> Url {
let db = self.database.as_deref().unwrap_or("");
let s = format!(
"firebird://{}:{}@{}:{}/{}",
self.username, self.password, self.host, self.port, db
);
Url::parse(&s).expect("BUG: generated URL is invalid")
}
fn connect(
&self,
) -> impl std::future::Future<Output = Result<Self::Connection, Error>> + Send + '_
where
Self::Connection: Sized,
{
crate::connection::AssertSend(FirebirdConnection::establish(self))
}
fn log_statements(mut self, level: LevelFilter) -> Self {
self.log_settings.statements_level = level;
self
}
fn log_slow_statements(mut self, level: LevelFilter, duration: std::time::Duration) -> Self {
self.log_settings.slow_statements_level = level;
self.log_settings.slow_statements_duration = duration;
self
}
}