1use anyhow::{anyhow, Result};
18use serde::{Deserialize, Serialize};
19
20const MAX_IDENTIFIER_LENGTH: usize = 128;
22
23pub fn validate_sql_identifier(name: &str) -> Result<()> {
47 if name.is_empty() {
48 return Err(anyhow!("SQL identifier cannot be empty"));
49 }
50
51 if name.len() > MAX_IDENTIFIER_LENGTH {
52 return Err(anyhow!(
53 "SQL identifier exceeds maximum length of {MAX_IDENTIFIER_LENGTH} characters"
54 ));
55 }
56
57 for (i, c) in name.chars().enumerate() {
59 if !c.is_ascii_alphanumeric() && c != '_' && c != '.' {
60 return Err(anyhow!(
61 "Invalid character '{c}' at position {i} in SQL identifier '{name}'. \
62 Only alphanumeric characters, underscores, and dots are allowed."
63 ));
64 }
65 }
66
67 if name.chars().next().is_some_and(|c| c.is_ascii_digit()) {
69 return Err(anyhow!("SQL identifier '{name}' cannot start with a digit"));
70 }
71
72 if name.starts_with('.') || name.ends_with('.') || name.contains("..") {
74 return Err(anyhow!("SQL identifier '{name}' has invalid dot placement"));
75 }
76
77 Ok(())
78}
79
80#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
82#[serde(rename_all = "lowercase")]
83#[derive(Default)]
84pub enum AuthMode {
85 #[default]
87 SqlServer,
88 Windows,
90 AzureAd,
92}
93
94impl std::fmt::Display for AuthMode {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 match self {
97 Self::SqlServer => write!(f, "sql_server"),
98 Self::Windows => write!(f, "windows"),
99 Self::AzureAd => write!(f, "azure_ad"),
100 }
101 }
102}
103
104#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
106#[serde(rename_all = "lowercase")]
107#[derive(Default)]
108pub enum EncryptionMode {
109 Off,
111 On,
113 #[default]
115 NotSupported,
116}
117
118#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
120#[serde(rename_all = "lowercase")]
121#[derive(Default)]
122pub enum StartPosition {
123 Beginning,
125 #[default]
127 Current,
128}
129
130impl std::fmt::Display for EncryptionMode {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 match self {
133 Self::Off => write!(f, "off"),
134 Self::On => write!(f, "on"),
135 Self::NotSupported => write!(f, "not_supported"),
136 }
137 }
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
142pub struct TableKeyConfig {
143 pub table: String,
145 pub key_columns: Vec<String>,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
151pub struct MsSqlSourceConfig {
152 #[serde(default = "default_host")]
154 pub host: String,
155
156 #[serde(default = "default_port")]
158 pub port: u16,
159
160 pub database: String,
162
163 pub user: String,
165
166 #[serde(default)]
168 pub password: String,
169
170 #[serde(default)]
172 pub auth_mode: AuthMode,
173
174 #[serde(default)]
176 pub tables: Vec<String>,
177
178 #[serde(default = "default_poll_interval_ms")]
180 pub poll_interval_ms: u64,
181
182 #[serde(default)]
184 pub encryption: EncryptionMode,
185
186 #[serde(default)]
188 pub trust_server_certificate: bool,
189
190 #[serde(default)]
192 pub table_keys: Vec<TableKeyConfig>,
193
194 #[serde(default)]
196 pub start_position: StartPosition,
197}
198
199fn default_host() -> String {
200 "localhost".to_string()
201}
202
203fn default_port() -> u16 {
204 1433
205}
206
207fn default_poll_interval_ms() -> u64 {
208 1000 }
210
211impl Default for MsSqlSourceConfig {
212 fn default() -> Self {
213 Self {
214 host: default_host(),
215 port: default_port(),
216 database: String::new(),
217 user: String::new(),
218 password: String::new(),
219 auth_mode: AuthMode::default(),
220 tables: Vec::new(),
221 poll_interval_ms: default_poll_interval_ms(),
222 encryption: EncryptionMode::default(),
223 trust_server_certificate: false,
224 table_keys: Vec::new(),
225 start_position: StartPosition::default(),
226 }
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233
234 #[test]
235 fn test_default_config() {
236 let config = MsSqlSourceConfig::default();
237 assert_eq!(config.host, "localhost");
238 assert_eq!(config.port, 1433);
239 assert_eq!(config.poll_interval_ms, 1000);
240 assert_eq!(config.auth_mode, AuthMode::SqlServer);
241 assert_eq!(config.encryption, EncryptionMode::NotSupported);
242 assert!(!config.trust_server_certificate);
243 }
244
245 #[test]
246 fn test_config_serialization() {
247 let config = MsSqlSourceConfig {
248 host: "sqlserver.example.com".to_string(),
249 port: 1433,
250 database: "production".to_string(),
251 user: "drasi_user".to_string(),
252 password: "secret".to_string(),
253 auth_mode: AuthMode::SqlServer,
254 tables: vec!["orders".to_string(), "customers".to_string()],
255 poll_interval_ms: 2000,
256 encryption: EncryptionMode::On,
257 trust_server_certificate: true,
258 table_keys: vec![TableKeyConfig {
259 table: "orders".to_string(),
260 key_columns: vec!["order_id".to_string()],
261 }],
262 start_position: StartPosition::Beginning,
263 };
264
265 let json = serde_json::to_string(&config).unwrap();
266 let deserialized: MsSqlSourceConfig = serde_json::from_str(&json).unwrap();
267 assert_eq!(config, deserialized);
268 }
269
270 #[test]
271 fn test_auth_mode_display() {
272 assert_eq!(AuthMode::SqlServer.to_string(), "sql_server");
273 assert_eq!(AuthMode::Windows.to_string(), "windows");
274 assert_eq!(AuthMode::AzureAd.to_string(), "azure_ad");
275 }
276
277 #[test]
278 fn test_encryption_mode_display() {
279 assert_eq!(EncryptionMode::Off.to_string(), "off");
280 assert_eq!(EncryptionMode::On.to_string(), "on");
281 assert_eq!(EncryptionMode::NotSupported.to_string(), "not_supported");
282 }
283
284 #[test]
285 fn test_table_key_config() {
286 let tk = TableKeyConfig {
287 table: "orders".to_string(),
288 key_columns: vec!["order_id".to_string(), "line_item".to_string()],
289 };
290
291 assert_eq!(tk.table, "orders");
292 assert_eq!(tk.key_columns.len(), 2);
293 }
294
295 #[test]
296 fn test_start_position_default() {
297 assert_eq!(StartPosition::default(), StartPosition::Current);
298 }
299
300 #[test]
301 fn test_start_position_serialization() {
302 let json = serde_json::to_string(&StartPosition::Beginning).unwrap();
303 assert_eq!(json, "\"beginning\"");
304
305 let json = serde_json::to_string(&StartPosition::Current).unwrap();
306 assert_eq!(json, "\"current\"");
307 }
308
309 #[test]
310 fn test_validate_sql_identifier_valid() {
311 assert!(validate_sql_identifier("orders").is_ok());
313 assert!(validate_sql_identifier("Orders").is_ok());
314 assert!(validate_sql_identifier("order_items").is_ok());
315 assert!(validate_sql_identifier("Order_Items_2024").is_ok());
316
317 assert!(validate_sql_identifier("dbo.orders").is_ok());
319 assert!(validate_sql_identifier("sales.order_items").is_ok());
320 assert!(validate_sql_identifier("MySchema.MyTable").is_ok());
321 }
322
323 #[test]
324 fn test_validate_sql_identifier_sql_injection() {
325 assert!(validate_sql_identifier("orders; DROP TABLE users--").is_err());
327 assert!(validate_sql_identifier("orders'; DELETE FROM users;--").is_err());
328 assert!(validate_sql_identifier("orders OR 1=1").is_err());
329 assert!(validate_sql_identifier("orders/**/UNION/**/SELECT").is_err());
330 assert!(validate_sql_identifier("orders\n; DROP TABLE").is_err());
331 }
332
333 #[test]
334 fn test_validate_sql_identifier_empty() {
335 assert!(validate_sql_identifier("").is_err());
336 }
337
338 #[test]
339 fn test_validate_sql_identifier_too_long() {
340 let long_name = "a".repeat(129);
341 assert!(validate_sql_identifier(&long_name).is_err());
342
343 let valid_long_name = "a".repeat(128);
344 assert!(validate_sql_identifier(&valid_long_name).is_ok());
345 }
346
347 #[test]
348 fn test_validate_sql_identifier_invalid_start() {
349 assert!(validate_sql_identifier("123table").is_err());
350 assert!(validate_sql_identifier("1orders").is_err());
351 }
352
353 #[test]
354 fn test_validate_sql_identifier_invalid_dots() {
355 assert!(validate_sql_identifier(".orders").is_err());
356 assert!(validate_sql_identifier("orders.").is_err());
357 assert!(validate_sql_identifier("dbo..orders").is_err());
358 }
359}