use crate::{MssqlConnection, Result};
use log::LevelFilter;
use std::fmt::{self, Debug, Formatter};
use std::str::FromStr;
use std::time::Duration;
use url::Url;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MssqlBufferSettings {
pub batch_size: usize,
pub max_column_size: Option<usize>,
}
impl Default for MssqlBufferSettings {
fn default() -> Self {
Self {
batch_size: 64,
max_column_size: None,
}
}
}
#[derive(Clone)]
pub struct MssqlConnectOptions {
pub(crate) conn_str: String,
pub(crate) buffer_settings: MssqlBufferSettings,
pub(crate) statement_cache_capacity: usize,
pub(crate) log_statements: LevelFilter,
pub(crate) log_slow_statements: LevelFilter,
pub(crate) log_slow_statement_duration: Duration,
}
impl MssqlConnectOptions {
pub fn connection_string(&self) -> &str {
&self.conn_str
}
pub fn buffer_settings(&mut self, settings: MssqlBufferSettings) -> &mut Self {
assert!(settings.batch_size > 0, "batch_size must be greater than 0");
if let Some(size) = settings.max_column_size {
assert!(size > 0, "max_column_size must be greater than 0");
}
self.buffer_settings = settings;
self
}
pub fn buffer_settings_ref(&self) -> &MssqlBufferSettings {
&self.buffer_settings
}
pub fn batch_size(&mut self, batch_size: usize) -> &mut Self {
assert!(batch_size > 0, "batch_size must be greater than 0");
self.buffer_settings.batch_size = batch_size;
self
}
pub fn max_column_size(&mut self, max_column_size: Option<usize>) -> &mut Self {
if let Some(size) = max_column_size {
assert!(size > 0, "max_column_size must be greater than 0");
}
self.buffer_settings.max_column_size = max_column_size;
self
}
pub fn statement_cache_capacity(&mut self, capacity: usize) -> &mut Self {
self.statement_cache_capacity = capacity;
self
}
pub fn log_statements(&mut self, level: LevelFilter) -> &mut Self {
self.log_statements = level;
self
}
pub fn log_slow_statements(&mut self, level: LevelFilter, duration: Duration) -> &mut Self {
self.log_slow_statements = level;
self.log_slow_statement_duration = duration;
self
}
pub fn encrypt(&mut self, enable: bool) -> &mut Self {
if enable && !self.conn_str.contains("Encrypt=") {
self.conn_str.push_str(";Encrypt=yes");
}
self
}
pub fn trust_certificate(&mut self, enable: bool) -> &mut Self {
if enable && !self.conn_str.contains("TrustServerCertificate=") {
self.conn_str.push_str(";TrustServerCertificate=yes");
}
self
}
pub fn connect_blocking(&self) -> Result<MssqlConnection> {
MssqlConnection::connect_blocking(self)
}
#[cfg(feature = "migrate")]
pub(crate) fn with_database(&self, database: &str) -> Self {
let mut new = self.clone();
let escaped = escape_odbc_value(database);
let upper = new.conn_str.to_uppercase();
let search = "DATABASE=";
if let Some(pos) = upper.find(search) {
let start = pos;
let end = new.conn_str[start..]
.find(';')
.map(|i| start + i)
.unwrap_or(new.conn_str.len());
let before = new.conn_str[..start].to_owned();
let after = new.conn_str[end..].to_owned();
new.conn_str = format!("{before}Database={escaped}{after}");
} else {
new.conn_str.push_str(&format!(";Database={escaped}"));
}
new
}
}
impl Debug for MssqlConnectOptions {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("MssqlConnectOptions")
.field("conn_str", &"<redacted>")
.field("buffer_settings", &self.buffer_settings)
.field("statement_cache_capacity", &self.statement_cache_capacity)
.field("log_statements", &self.log_statements)
.field("log_slow_statements", &self.log_slow_statements)
.field(
"log_slow_statement_duration",
&self.log_slow_statement_duration,
)
.finish()
}
}
fn escape_odbc_value(value: &str) -> String {
if value.contains(';') || value.contains('{') || value.contains('}') || value.contains('=') {
format!("{{{}}}", value.replace('}', "}}"))
} else {
value.to_owned()
}
}
fn mssql_url_to_connection_string(url: &Url) -> String {
let scheme = url.scheme();
let is_mssql = scheme.eq_ignore_ascii_case("mssql");
if !is_mssql && !scheme.eq_ignore_ascii_case("odbc") {
return url.as_str().to_owned();
}
let host = url.host_str().unwrap_or("localhost");
let port = url.port().unwrap_or(1433);
let database = url.path().trim_start_matches('/');
let username = url.username();
let password = url.password().unwrap_or_default();
let mut conn_str = format!(
"Driver={{ODBC Driver 18 for SQL Server}};Server={host},{port}"
);
if !database.is_empty() {
conn_str.push_str(&format!(";Database={}", escape_odbc_value(database)));
}
if !username.is_empty() {
conn_str.push_str(&format!(";UID={}", escape_odbc_value(username)));
}
if !password.is_empty() {
conn_str.push_str(&format!(";PWD={}", escape_odbc_value(password)));
}
for (key, value) in url.query_pairs() {
match key.as_ref() {
"trust_certificate" if value == "true" => {
if !conn_str.contains("TrustServerCertificate=") {
conn_str.push_str(";TrustServerCertificate=yes");
}
}
"encrypt" if value == "true" => {
if !conn_str.contains("Encrypt=") {
conn_str.push_str(";Encrypt=yes");
}
}
"driver" => {
let driver_val = format!("Driver={value}");
if let Some(pos) = conn_str.find("Driver=") {
let end = conn_str[pos..].find(';').map(|i| pos + i).unwrap_or(conn_str.len());
conn_str.replace_range(pos..end, &driver_val);
}
}
_ => {}
}
}
conn_str
}
impl FromStr for MssqlConnectOptions {
type Err = sqlx_core::Error;
fn from_str(input: &str) -> std::result::Result<Self, Self::Err> {
let trimmed = input.trim();
let (trimmed, _had_odbc_prefix) = if let Some(rest) = trimmed.strip_prefix("odbc:") {
(rest, true)
} else {
(trimmed, false)
};
if trimmed.starts_with("mssql://") || trimmed.starts_with("mssql:") {
if let Ok(url) = Url::parse(trimmed) {
let scheme = url.scheme();
if scheme.eq_ignore_ascii_case("mssql") {
let conn_str = mssql_url_to_connection_string(&url);
return Ok(Self {
conn_str,
buffer_settings: MssqlBufferSettings::default(),
statement_cache_capacity: 100,
log_statements: LevelFilter::Debug,
log_slow_statements: LevelFilter::Warn,
log_slow_statement_duration: Duration::from_secs(1),
});
}
}
}
let conn_str = if trimmed.contains('=') {
trimmed.to_owned()
} else {
format!("DSN={trimmed}")
};
Ok(Self {
conn_str,
buffer_settings: MssqlBufferSettings::default(),
statement_cache_capacity: 100,
log_statements: LevelFilter::Debug,
log_slow_statements: LevelFilter::Warn,
log_slow_statement_duration: Duration::from_secs(1),
})
}
}
impl sqlx_core::connection::ConnectOptions for MssqlConnectOptions {
type Connection = MssqlConnection;
fn from_url(url: &Url) -> std::result::Result<Self, sqlx_core::Error> {
Self::from_str(url.as_str())
}
async fn connect(&self) -> std::result::Result<Self::Connection, sqlx_core::Error> {
self.connect_blocking().map_err(Into::into)
}
fn log_statements(mut self, level: LevelFilter) -> Self {
self.log_statements = level;
self
}
fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self {
self.log_slow_statements = level;
self.log_slow_statement_duration = duration;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_mssql_url_with_all_components() {
let url = "mssql://sa:Password1!@server.example.com:1433/testdb";
let options = MssqlConnectOptions::from_str(url).unwrap();
let cs = options.connection_string();
assert!(cs.contains("Driver={ODBC Driver 18 for SQL Server}"));
assert!(cs.contains("Server=server.example.com,1433"));
assert!(cs.contains("Database=testdb"));
assert!(cs.contains("UID=sa"));
assert!(cs.contains("PWD=Password1!"));
}
#[test]
fn parses_mssql_url_with_default_port() {
let url = "mssql://user:pass@localhost/mydb";
let options = MssqlConnectOptions::from_str(url).unwrap();
let cs = options.connection_string();
assert!(cs.contains("Server=localhost,1433"));
}
#[test]
fn parses_mssql_url_without_credentials() {
let url = "mssql://localhost/mydb";
let options = MssqlConnectOptions::from_str(url).unwrap();
let cs = options.connection_string();
assert!(cs.contains("Server=localhost,1433"));
assert!(cs.contains("Database=mydb"));
assert!(!cs.contains("UID="));
assert!(!cs.contains("PWD="));
}
#[test]
fn parses_mssql_url_with_trust_certificate() {
let url = "mssql://localhost/mydb?trust_certificate=true";
let options = MssqlConnectOptions::from_str(url).unwrap();
let cs = options.connection_string();
assert!(cs.contains("TrustServerCertificate=yes"));
}
#[test]
fn parses_mssql_url_with_encrypt() {
let url = "mssql://localhost/mydb?encrypt=true";
let options = MssqlConnectOptions::from_str(url).unwrap();
let cs = options.connection_string();
assert!(cs.contains("Encrypt=yes"));
}
#[test]
fn parses_mssql_url_with_custom_driver() {
let url = "mssql://localhost/mydb?driver={ODBC Driver 17 for SQL Server}";
let options = MssqlConnectOptions::from_str(url).unwrap();
let cs = options.connection_string();
assert!(cs.contains("Driver={ODBC Driver 17 for SQL Server}"));
}
#[test]
fn preserves_raw_odbc_connection_strings() {
let input = "Driver={ODBC Driver 17 for SQL Server};Server=localhost;Database=test";
let options = MssqlConnectOptions::from_str(input).unwrap();
assert_eq!(options.connection_string(), input);
}
#[test]
fn supports_dsn_format() {
let options = MssqlConnectOptions::from_str("MyMssqlDSN").unwrap();
assert_eq!(options.connection_string(), "DSN=MyMssqlDSN");
}
#[test]
fn strips_legacy_odbc_prefix() {
let options = MssqlConnectOptions::from_str("odbc:DSN=Warehouse").unwrap();
assert_eq!(options.connection_string(), "DSN=Warehouse");
}
#[test]
fn encrypt_method_adds_encrypt() {
let mut options = MssqlConnectOptions::from_str("DSN=Test").unwrap();
options.encrypt(true);
assert!(options.connection_string().contains("Encrypt=yes"));
}
#[test]
fn trust_certificate_method_adds_flag() {
let mut options = MssqlConnectOptions::from_str("DSN=Test").unwrap();
options.trust_certificate(true);
assert!(options.connection_string().contains("TrustServerCertificate=yes"));
}
#[test]
fn updates_buffer_settings_incrementally() {
let mut options = MssqlConnectOptions::from_str("DSN=Test").unwrap();
options.batch_size(128).max_column_size(Some(2048));
assert_eq!(options.buffer_settings.batch_size, 128);
assert_eq!(options.buffer_settings.max_column_size, Some(2048));
}
#[test]
fn escape_odbc_value_preserves_safe_values() {
assert_eq!(escape_odbc_value("simple"), "simple");
assert_eq!(escape_odbc_value(""), "");
assert_eq!(escape_odbc_value("abc123"), "abc123");
}
#[test]
fn escape_odbc_value_wraps_values_with_special_chars() {
assert_eq!(escape_odbc_value("pass;word"), "{pass;word}");
assert_eq!(escape_odbc_value("pass=word"), "{pass=word}");
assert_eq!(escape_odbc_value("pass{word"), "{pass{word}");
assert_eq!(escape_odbc_value("pass}word"), "{pass}}word}");
assert_eq!(escape_odbc_value("a}b}c"), "{a}}b}}c}");
}
#[test]
fn parses_mssql_url_with_special_chars_in_password() {
let url = "mssql://user:a%3Bb%3Dc%7Dd@localhost/mydb";
let options = MssqlConnectOptions::from_str(url).unwrap();
let cs = options.connection_string();
assert!(
cs.contains("PWD=a%3Bb%3Dc%7Dd"),
"password not included correctly; got: {cs}"
);
}
}