use crate::{OdbcConnection, Result};
use log::LevelFilter;
use std::fmt::{self, Debug, Formatter};
use std::str::FromStr;
use std::time::Duration;
use url::Url;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct OdbcBufferSettings {
pub batch_size: usize,
pub max_column_size: Option<usize>,
}
impl Default for OdbcBufferSettings {
fn default() -> Self {
Self {
batch_size: 64,
max_column_size: None,
}
}
}
#[derive(Clone)]
pub struct OdbcConnectOptions {
pub(crate) conn_str: String,
pub(crate) buffer_settings: OdbcBufferSettings,
pub(crate) log_statements: LevelFilter,
pub(crate) log_slow_statements: LevelFilter,
pub(crate) log_slow_statement_duration: Duration,
}
impl OdbcConnectOptions {
pub fn connection_string(&self) -> &str {
&self.conn_str
}
pub fn buffer_settings(&mut self, settings: OdbcBufferSettings) -> &mut Self {
assert!(settings.batch_size > 0, "batch_size must be greater than 0");
if let Some(size) = settings.max_column_size {
assert!(size > 0, "max_column_size must be greater than 0");
}
self.buffer_settings = settings;
self
}
pub fn buffer_settings_ref(&self) -> &OdbcBufferSettings {
&self.buffer_settings
}
pub fn batch_size(&mut self, batch_size: usize) -> &mut Self {
assert!(batch_size > 0, "batch_size must be greater than 0");
self.buffer_settings.batch_size = batch_size;
self
}
pub fn max_column_size(&mut self, max_column_size: Option<usize>) -> &mut Self {
if let Some(size) = max_column_size {
assert!(size > 0, "max_column_size must be greater than 0");
}
self.buffer_settings.max_column_size = max_column_size;
self
}
pub fn log_statements(&mut self, level: LevelFilter) -> &mut Self {
self.log_statements = level;
self
}
pub fn log_slow_statements(&mut self, level: LevelFilter, duration: Duration) -> &mut Self {
self.log_slow_statements = level;
self.log_slow_statement_duration = duration;
self
}
pub fn connect_blocking(&self) -> Result<OdbcConnection> {
OdbcConnection::connect_blocking(self)
}
}
impl Debug for OdbcConnectOptions {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("OdbcConnectOptions")
.field("conn_str", &"<redacted>")
.field("buffer_settings", &self.buffer_settings)
.field("log_statements", &self.log_statements)
.field("log_slow_statements", &self.log_slow_statements)
.field(
"log_slow_statement_duration",
&self.log_slow_statement_duration,
)
.finish()
}
}
impl FromStr for OdbcConnectOptions {
type Err = sqlx_core::Error;
fn from_str(input: &str) -> std::result::Result<Self, Self::Err> {
let mut trimmed = input.trim();
if let Some(rest) = trimmed.strip_prefix("odbc:") {
trimmed = rest;
}
let conn_str = if trimmed.contains('=') {
trimmed.to_owned()
} else {
format!("DSN={trimmed}")
};
Ok(Self {
conn_str,
buffer_settings: OdbcBufferSettings::default(),
log_statements: LevelFilter::Debug,
log_slow_statements: LevelFilter::Warn,
log_slow_statement_duration: Duration::from_secs(1),
})
}
}
impl sqlx_core::connection::ConnectOptions for OdbcConnectOptions {
type Connection = OdbcConnection;
fn from_url(url: &Url) -> std::result::Result<Self, sqlx_core::Error> {
Self::from_str(url.as_str())
}
async fn connect(&self) -> std::result::Result<Self::Connection, sqlx_core::Error> {
self.connect_blocking().map_err(Into::into)
}
fn log_statements(mut self, level: LevelFilter) -> Self {
self.log_statements = level;
self
}
fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self {
self.log_slow_statements = level;
self.log_slow_statement_duration = duration;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_bare_dsn_as_dsn_connection_string() {
let options = OdbcConnectOptions::from_str("Warehouse").unwrap();
assert_eq!(options.connection_string(), "DSN=Warehouse");
}
#[test]
fn preserves_standard_connection_strings() {
let input = "Driver={ODBC Driver 17 for SQL Server};Server=localhost;Database=test";
let options = OdbcConnectOptions::from_str(input).unwrap();
assert_eq!(options.connection_string(), input);
}
#[test]
fn strips_legacy_odbc_prefix() {
let options = OdbcConnectOptions::from_str("odbc:DSN=Warehouse").unwrap();
assert_eq!(options.connection_string(), "DSN=Warehouse");
}
#[test]
fn updates_buffer_settings_incrementally() {
let mut options = OdbcConnectOptions::from_str("Warehouse").unwrap();
options.batch_size(128).max_column_size(Some(2048));
assert_eq!(
*options.buffer_settings_ref(),
OdbcBufferSettings {
batch_size: 128,
max_column_size: Some(2048)
}
);
}
}