use crate::config::{AuthMode, EncryptionMode, MsSqlSourceConfig};
use anyhow::{anyhow, Result};
use log::{debug, info};
use tiberius::{AuthMethod, Client, Config, EncryptionLevel};
use tokio::net::TcpStream;
use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
pub struct MsSqlConnection {
client: Client<Compat<TcpStream>>,
}
impl MsSqlConnection {
pub async fn connect(config: &MsSqlSourceConfig) -> Result<Self> {
info!(
"Connecting to MS SQL Server at {}:{} database '{}'",
config.host, config.port, config.database
);
let mut tiberius_config = Config::new();
tiberius_config.host(&config.host);
tiberius_config.port(config.port);
tiberius_config.database(&config.database);
match config.auth_mode {
AuthMode::SqlServer => {
debug!("Using SQL Server authentication");
tiberius_config
.authentication(AuthMethod::sql_server(&config.user, &config.password));
}
AuthMode::Windows => {
return Err(anyhow!("Windows authentication not yet implemented"));
}
AuthMode::AzureAd => {
return Err(anyhow!("Azure AD authentication not yet implemented"));
}
}
let encryption_level = match config.encryption {
EncryptionMode::Off => EncryptionLevel::Off,
EncryptionMode::On => EncryptionLevel::Required,
EncryptionMode::NotSupported => EncryptionLevel::NotSupported,
};
tiberius_config.encryption(encryption_level);
if config.trust_server_certificate {
tiberius_config.trust_cert();
}
let tcp = TcpStream::connect((config.host.as_str(), config.port))
.await
.map_err(|e| {
anyhow!(
"Failed to connect to {}:{}: {}",
config.host,
config.port,
e
)
})?;
tcp.set_nodelay(true)?;
let client = Client::connect(tiberius_config, tcp.compat_write())
.await
.map_err(|e| anyhow!("Failed to authenticate with MS SQL Server: {e}"))?;
info!("Successfully connected to MS SQL Server");
Ok(Self { client })
}
pub fn client_mut(&mut self) -> &mut Client<Compat<TcpStream>> {
&mut self.client
}
pub fn client(&self) -> &Client<Compat<TcpStream>> {
&self.client
}
pub async fn test_connection(&mut self) -> Result<()> {
debug!("Testing MS SQL connection");
let query = "SELECT @@VERSION AS version";
let stream = self.client.query(query, &[]).await?;
let rows = stream.into_first_result().await?;
if let Some(row) = rows.first() {
let version: &str = row.get(0).ok_or_else(|| anyhow!("No version returned"))?;
info!(
"MS SQL Server version: {}",
version.lines().next().unwrap_or(version)
);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auth_mode_conversion() {
let _sql_auth = AuthMethod::sql_server("user", "password");
}
#[test]
fn test_encryption_level_conversion() {
assert_eq!(
std::mem::discriminant(&EncryptionLevel::Off),
std::mem::discriminant(&match EncryptionMode::Off {
EncryptionMode::Off => EncryptionLevel::Off,
_ => unreachable!(),
})
);
}
}