Skip to main content

drasi_mssql_common/
config.rs

1// Copyright 2025 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Configuration for MS SQL CDC source
16
17use anyhow::{anyhow, Result};
18use serde::{Deserialize, Serialize};
19
20/// Maximum length for SQL identifiers (SQL Server limit is 128)
21const MAX_IDENTIFIER_LENGTH: usize = 128;
22
23/// Validate a SQL identifier to prevent SQL injection
24///
25/// Valid identifiers contain only:
26/// - Alphanumeric characters (a-z, A-Z, 0-9)
27/// - Underscores (_)
28/// - Dots (.) for schema.table notation
29///
30/// # Arguments
31/// * `name` - The identifier to validate
32///
33/// # Returns
34/// * `Ok(())` if the identifier is valid
35/// * `Err` if the identifier contains invalid characters or is empty
36///
37/// # Example
38/// ```
39/// use drasi_mssql_common::validate_sql_identifier;
40///
41/// assert!(validate_sql_identifier("orders").is_ok());
42/// assert!(validate_sql_identifier("dbo.orders").is_ok());
43/// assert!(validate_sql_identifier("order_items").is_ok());
44/// assert!(validate_sql_identifier("orders; DROP TABLE users--").is_err());
45/// ```
46pub fn validate_sql_identifier(name: &str) -> Result<()> {
47    if name.is_empty() {
48        return Err(anyhow!("SQL identifier cannot be empty"));
49    }
50
51    if name.len() > MAX_IDENTIFIER_LENGTH {
52        return Err(anyhow!(
53            "SQL identifier exceeds maximum length of {MAX_IDENTIFIER_LENGTH} characters"
54        ));
55    }
56
57    // Check that all characters are valid SQL identifier characters
58    for (i, c) in name.chars().enumerate() {
59        if !c.is_ascii_alphanumeric() && c != '_' && c != '.' {
60            return Err(anyhow!(
61                "Invalid character '{c}' at position {i} in SQL identifier '{name}'. \
62                 Only alphanumeric characters, underscores, and dots are allowed."
63            ));
64        }
65    }
66
67    // Identifier cannot start with a digit
68    if name.chars().next().is_some_and(|c| c.is_ascii_digit()) {
69        return Err(anyhow!("SQL identifier '{name}' cannot start with a digit"));
70    }
71
72    // Check for consecutive dots or leading/trailing dots
73    if name.starts_with('.') || name.ends_with('.') || name.contains("..") {
74        return Err(anyhow!("SQL identifier '{name}' has invalid dot placement"));
75    }
76
77    Ok(())
78}
79
80/// Authentication mode for MS SQL Server
81#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
82#[serde(rename_all = "lowercase")]
83#[derive(Default)]
84pub enum AuthMode {
85    /// SQL Server authentication (username/password)
86    #[default]
87    SqlServer,
88    /// Windows integrated authentication (Kerberos)
89    Windows,
90    /// Azure AD authentication
91    AzureAd,
92}
93
94impl std::fmt::Display for AuthMode {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        match self {
97            Self::SqlServer => write!(f, "sql_server"),
98            Self::Windows => write!(f, "windows"),
99            Self::AzureAd => write!(f, "azure_ad"),
100        }
101    }
102}
103
104/// TLS/SSL encryption mode
105#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
106#[serde(rename_all = "lowercase")]
107#[derive(Default)]
108pub enum EncryptionMode {
109    /// No encryption
110    Off,
111    /// Require encryption
112    On,
113    /// Encrypt if supported, otherwise allow unencrypted
114    #[default]
115    NotSupported,
116}
117
118/// Starting position when no LSN is found in the state store
119#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
120#[serde(rename_all = "lowercase")]
121#[derive(Default)]
122pub enum StartPosition {
123    /// Start from the beginning (earliest available LSN)
124    Beginning,
125    /// Start from the current LSN (now)
126    #[default]
127    Current,
128}
129
130impl std::fmt::Display for EncryptionMode {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        match self {
133            Self::Off => write!(f, "off"),
134            Self::On => write!(f, "on"),
135            Self::NotSupported => write!(f, "not_supported"),
136        }
137    }
138}
139
140/// Table key configuration for custom primary keys
141#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
142pub struct TableKeyConfig {
143    /// Table name
144    pub table: String,
145    /// Column names that form the primary key
146    pub key_columns: Vec<String>,
147}
148
149/// MS SQL CDC source configuration
150#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
151pub struct MsSqlSourceConfig {
152    /// MS SQL server hostname or IP address
153    #[serde(default = "default_host")]
154    pub host: String,
155
156    /// MS SQL server port
157    #[serde(default = "default_port")]
158    pub port: u16,
159
160    /// Database name
161    pub database: String,
162
163    /// Database user
164    pub user: String,
165
166    /// Database password
167    #[serde(default)]
168    pub password: String,
169
170    /// Authentication mode
171    #[serde(default)]
172    pub auth_mode: AuthMode,
173
174    /// Tables to monitor (empty = all CDC-enabled tables)
175    #[serde(default)]
176    pub tables: Vec<String>,
177
178    /// CDC polling interval in milliseconds
179    #[serde(default = "default_poll_interval_ms")]
180    pub poll_interval_ms: u64,
181
182    /// TLS/SSL encryption mode
183    #[serde(default)]
184    pub encryption: EncryptionMode,
185
186    /// Trust server certificate (for self-signed certificates)
187    #[serde(default)]
188    pub trust_server_certificate: bool,
189
190    /// Custom primary key configuration
191    #[serde(default)]
192    pub table_keys: Vec<TableKeyConfig>,
193
194    /// Starting position when no LSN is found in the state store
195    #[serde(default)]
196    pub start_position: StartPosition,
197}
198
199fn default_host() -> String {
200    "localhost".to_string()
201}
202
203fn default_port() -> u16 {
204    1433
205}
206
207fn default_poll_interval_ms() -> u64 {
208    1000 // 1 second
209}
210
211impl Default for MsSqlSourceConfig {
212    fn default() -> Self {
213        Self {
214            host: default_host(),
215            port: default_port(),
216            database: String::new(),
217            user: String::new(),
218            password: String::new(),
219            auth_mode: AuthMode::default(),
220            tables: Vec::new(),
221            poll_interval_ms: default_poll_interval_ms(),
222            encryption: EncryptionMode::default(),
223            trust_server_certificate: false,
224            table_keys: Vec::new(),
225            start_position: StartPosition::default(),
226        }
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    #[test]
235    fn test_default_config() {
236        let config = MsSqlSourceConfig::default();
237        assert_eq!(config.host, "localhost");
238        assert_eq!(config.port, 1433);
239        assert_eq!(config.poll_interval_ms, 1000);
240        assert_eq!(config.auth_mode, AuthMode::SqlServer);
241        assert_eq!(config.encryption, EncryptionMode::NotSupported);
242        assert!(!config.trust_server_certificate);
243    }
244
245    #[test]
246    fn test_config_serialization() {
247        let config = MsSqlSourceConfig {
248            host: "sqlserver.example.com".to_string(),
249            port: 1433,
250            database: "production".to_string(),
251            user: "drasi_user".to_string(),
252            password: "secret".to_string(),
253            auth_mode: AuthMode::SqlServer,
254            tables: vec!["orders".to_string(), "customers".to_string()],
255            poll_interval_ms: 2000,
256            encryption: EncryptionMode::On,
257            trust_server_certificate: true,
258            table_keys: vec![TableKeyConfig {
259                table: "orders".to_string(),
260                key_columns: vec!["order_id".to_string()],
261            }],
262            start_position: StartPosition::Beginning,
263        };
264
265        let json = serde_json::to_string(&config).unwrap();
266        let deserialized: MsSqlSourceConfig = serde_json::from_str(&json).unwrap();
267        assert_eq!(config, deserialized);
268    }
269
270    #[test]
271    fn test_auth_mode_display() {
272        assert_eq!(AuthMode::SqlServer.to_string(), "sql_server");
273        assert_eq!(AuthMode::Windows.to_string(), "windows");
274        assert_eq!(AuthMode::AzureAd.to_string(), "azure_ad");
275    }
276
277    #[test]
278    fn test_encryption_mode_display() {
279        assert_eq!(EncryptionMode::Off.to_string(), "off");
280        assert_eq!(EncryptionMode::On.to_string(), "on");
281        assert_eq!(EncryptionMode::NotSupported.to_string(), "not_supported");
282    }
283
284    #[test]
285    fn test_table_key_config() {
286        let tk = TableKeyConfig {
287            table: "orders".to_string(),
288            key_columns: vec!["order_id".to_string(), "line_item".to_string()],
289        };
290
291        assert_eq!(tk.table, "orders");
292        assert_eq!(tk.key_columns.len(), 2);
293    }
294
295    #[test]
296    fn test_start_position_default() {
297        assert_eq!(StartPosition::default(), StartPosition::Current);
298    }
299
300    #[test]
301    fn test_start_position_serialization() {
302        let json = serde_json::to_string(&StartPosition::Beginning).unwrap();
303        assert_eq!(json, "\"beginning\"");
304
305        let json = serde_json::to_string(&StartPosition::Current).unwrap();
306        assert_eq!(json, "\"current\"");
307    }
308
309    #[test]
310    fn test_validate_sql_identifier_valid() {
311        // Simple table names
312        assert!(validate_sql_identifier("orders").is_ok());
313        assert!(validate_sql_identifier("Orders").is_ok());
314        assert!(validate_sql_identifier("order_items").is_ok());
315        assert!(validate_sql_identifier("Order_Items_2024").is_ok());
316
317        // Schema-qualified names
318        assert!(validate_sql_identifier("dbo.orders").is_ok());
319        assert!(validate_sql_identifier("sales.order_items").is_ok());
320        assert!(validate_sql_identifier("MySchema.MyTable").is_ok());
321    }
322
323    #[test]
324    fn test_validate_sql_identifier_sql_injection() {
325        // SQL injection attempts
326        assert!(validate_sql_identifier("orders; DROP TABLE users--").is_err());
327        assert!(validate_sql_identifier("orders'; DELETE FROM users;--").is_err());
328        assert!(validate_sql_identifier("orders OR 1=1").is_err());
329        assert!(validate_sql_identifier("orders/**/UNION/**/SELECT").is_err());
330        assert!(validate_sql_identifier("orders\n; DROP TABLE").is_err());
331    }
332
333    #[test]
334    fn test_validate_sql_identifier_empty() {
335        assert!(validate_sql_identifier("").is_err());
336    }
337
338    #[test]
339    fn test_validate_sql_identifier_too_long() {
340        let long_name = "a".repeat(129);
341        assert!(validate_sql_identifier(&long_name).is_err());
342
343        let valid_long_name = "a".repeat(128);
344        assert!(validate_sql_identifier(&valid_long_name).is_ok());
345    }
346
347    #[test]
348    fn test_validate_sql_identifier_invalid_start() {
349        assert!(validate_sql_identifier("123table").is_err());
350        assert!(validate_sql_identifier("1orders").is_err());
351    }
352
353    #[test]
354    fn test_validate_sql_identifier_invalid_dots() {
355        assert!(validate_sql_identifier(".orders").is_err());
356        assert!(validate_sql_identifier("orders.").is_err());
357        assert!(validate_sql_identifier("dbo..orders").is_err());
358    }
359}