drasi_mssql_common/
connection.rs1use crate::config::{AuthMode, EncryptionMode, MsSqlSourceConfig};
18use anyhow::{anyhow, Result};
19use log::{debug, info};
20use tiberius::{AuthMethod, Client, Config, EncryptionLevel};
21use tokio::net::TcpStream;
22use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
23
24pub struct MsSqlConnection {
26 client: Client<Compat<TcpStream>>,
27}
28
29impl MsSqlConnection {
30 pub async fn connect(config: &MsSqlSourceConfig) -> Result<Self> {
38 info!(
39 "Connecting to MS SQL Server at {}:{} database '{}'",
40 config.host, config.port, config.database
41 );
42
43 let mut tiberius_config = Config::new();
45 tiberius_config.host(&config.host);
46 tiberius_config.port(config.port);
47 tiberius_config.database(&config.database);
48
49 match config.auth_mode {
51 AuthMode::SqlServer => {
52 debug!("Using SQL Server authentication");
53 tiberius_config
54 .authentication(AuthMethod::sql_server(&config.user, &config.password));
55 }
56 AuthMode::Windows => {
57 return Err(anyhow!("Windows authentication not yet implemented"));
60 }
61 AuthMode::AzureAd => {
62 return Err(anyhow!("Azure AD authentication not yet implemented"));
63 }
64 }
65
66 let encryption_level = match config.encryption {
68 EncryptionMode::Off => EncryptionLevel::Off,
69 EncryptionMode::On => EncryptionLevel::Required,
70 EncryptionMode::NotSupported => EncryptionLevel::NotSupported,
71 };
72 tiberius_config.encryption(encryption_level);
73
74 if config.trust_server_certificate {
76 tiberius_config.trust_cert();
77 }
78
79 let tcp = TcpStream::connect((config.host.as_str(), config.port))
81 .await
82 .map_err(|e| {
83 anyhow!(
84 "Failed to connect to {}:{}: {}",
85 config.host,
86 config.port,
87 e
88 )
89 })?;
90
91 tcp.set_nodelay(true)?;
92
93 let client = Client::connect(tiberius_config, tcp.compat_write())
95 .await
96 .map_err(|e| anyhow!("Failed to authenticate with MS SQL Server: {e}"))?;
97
98 info!("Successfully connected to MS SQL Server");
99
100 Ok(Self { client })
101 }
102
103 pub fn client_mut(&mut self) -> &mut Client<Compat<TcpStream>> {
105 &mut self.client
106 }
107
108 pub fn client(&self) -> &Client<Compat<TcpStream>> {
110 &self.client
111 }
112
113 pub async fn test_connection(&mut self) -> Result<()> {
115 debug!("Testing MS SQL connection");
116
117 let query = "SELECT @@VERSION AS version";
118 let stream = self.client.query(query, &[]).await?;
119 let rows = stream.into_first_result().await?;
120
121 if let Some(row) = rows.first() {
122 let version: &str = row.get(0).ok_or_else(|| anyhow!("No version returned"))?;
123 info!(
124 "MS SQL Server version: {}",
125 version.lines().next().unwrap_or(version)
126 );
127 }
128
129 Ok(())
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn test_auth_mode_conversion() {
139 let _sql_auth = AuthMethod::sql_server("user", "password");
141 }
143
144 #[test]
145 fn test_encryption_level_conversion() {
146 assert_eq!(
147 std::mem::discriminant(&EncryptionLevel::Off),
148 std::mem::discriminant(&match EncryptionMode::Off {
149 EncryptionMode::Off => EncryptionLevel::Off,
150 _ => unreachable!(),
151 })
152 );
153 }
154}