Skip to main content

drasi_mssql_common/
connection.rs

1// Copyright 2025 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! MS SQL connection management using Tiberius
16
17use crate::config::{AuthMode, EncryptionMode, MsSqlSourceConfig};
18use anyhow::{anyhow, Result};
19use log::{debug, info};
20use tiberius::{AuthMethod, Client, Config, EncryptionLevel};
21use tokio::net::TcpStream;
22use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
23
24/// MS SQL connection wrapper
25pub struct MsSqlConnection {
26    client: Client<Compat<TcpStream>>,
27}
28
29impl MsSqlConnection {
30    /// Create a new connection to MS SQL Server
31    ///
32    /// # Arguments
33    /// * `config` - MS SQL source configuration
34    ///
35    /// # Errors
36    /// Returns error if connection fails
37    pub async fn connect(config: &MsSqlSourceConfig) -> Result<Self> {
38        info!(
39            "Connecting to MS SQL Server at {}:{} database '{}'",
40            config.host, config.port, config.database
41        );
42
43        // Build Tiberius config
44        let mut tiberius_config = Config::new();
45        tiberius_config.host(&config.host);
46        tiberius_config.port(config.port);
47        tiberius_config.database(&config.database);
48
49        // Set authentication
50        match config.auth_mode {
51            AuthMode::SqlServer => {
52                debug!("Using SQL Server authentication");
53                tiberius_config
54                    .authentication(AuthMethod::sql_server(&config.user, &config.password));
55            }
56            AuthMode::Windows => {
57                // TODO: Implement Windows authentication
58                // Windows integrated authentication not yet supported
59                return Err(anyhow!("Windows authentication not yet implemented"));
60            }
61            AuthMode::AzureAd => {
62                return Err(anyhow!("Azure AD authentication not yet implemented"));
63            }
64        }
65
66        // Set encryption
67        let encryption_level = match config.encryption {
68            EncryptionMode::Off => EncryptionLevel::Off,
69            EncryptionMode::On => EncryptionLevel::Required,
70            EncryptionMode::NotSupported => EncryptionLevel::NotSupported,
71        };
72        tiberius_config.encryption(encryption_level);
73
74        // Trust server certificate if configured
75        if config.trust_server_certificate {
76            tiberius_config.trust_cert();
77        }
78
79        // Connect via TCP
80        let tcp = TcpStream::connect((config.host.as_str(), config.port))
81            .await
82            .map_err(|e| {
83                anyhow!(
84                    "Failed to connect to {}:{}: {}",
85                    config.host,
86                    config.port,
87                    e
88                )
89            })?;
90
91        tcp.set_nodelay(true)?;
92
93        // Create Tiberius client
94        let client = Client::connect(tiberius_config, tcp.compat_write())
95            .await
96            .map_err(|e| anyhow!("Failed to authenticate with MS SQL Server: {e}"))?;
97
98        info!("Successfully connected to MS SQL Server");
99
100        Ok(Self { client })
101    }
102
103    /// Get mutable reference to the underlying Tiberius client
104    pub fn client_mut(&mut self) -> &mut Client<Compat<TcpStream>> {
105        &mut self.client
106    }
107
108    /// Get reference to the underlying Tiberius client
109    pub fn client(&self) -> &Client<Compat<TcpStream>> {
110        &self.client
111    }
112
113    /// Test the connection by running a simple query
114    pub async fn test_connection(&mut self) -> Result<()> {
115        debug!("Testing MS SQL connection");
116
117        let query = "SELECT @@VERSION AS version";
118        let stream = self.client.query(query, &[]).await?;
119        let rows = stream.into_first_result().await?;
120
121        if let Some(row) = rows.first() {
122            let version: &str = row.get(0).ok_or_else(|| anyhow!("No version returned"))?;
123            info!(
124                "MS SQL Server version: {}",
125                version.lines().next().unwrap_or(version)
126            );
127        }
128
129        Ok(())
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn test_auth_mode_conversion() {
139        // Just test that we can create SQL auth methods
140        let _sql_auth = AuthMethod::sql_server("user", "password");
141        // Windows/Integrated auth not yet implemented
142    }
143
144    #[test]
145    fn test_encryption_level_conversion() {
146        assert_eq!(
147            std::mem::discriminant(&EncryptionLevel::Off),
148            std::mem::discriminant(&match EncryptionMode::Off {
149                EncryptionMode::Off => EncryptionLevel::Off,
150                _ => unreachable!(),
151            })
152        );
153    }
154}