use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
const MAX_IDENTIFIER_LENGTH: usize = 128;
pub fn validate_sql_identifier(name: &str) -> Result<()> {
if name.is_empty() {
return Err(anyhow!("SQL identifier cannot be empty"));
}
if name.len() > MAX_IDENTIFIER_LENGTH {
return Err(anyhow!(
"SQL identifier exceeds maximum length of {MAX_IDENTIFIER_LENGTH} characters"
));
}
for (i, c) in name.chars().enumerate() {
if !c.is_ascii_alphanumeric() && c != '_' && c != '.' {
return Err(anyhow!(
"Invalid character '{c}' at position {i} in SQL identifier '{name}'. \
Only alphanumeric characters, underscores, and dots are allowed."
));
}
}
if name.chars().next().is_some_and(|c| c.is_ascii_digit()) {
return Err(anyhow!("SQL identifier '{name}' cannot start with a digit"));
}
if name.starts_with('.') || name.ends_with('.') || name.contains("..") {
return Err(anyhow!("SQL identifier '{name}' has invalid dot placement"));
}
Ok(())
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum AuthMode {
#[default]
SqlServer,
Windows,
AzureAd,
}
impl std::fmt::Display for AuthMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SqlServer => write!(f, "sql_server"),
Self::Windows => write!(f, "windows"),
Self::AzureAd => write!(f, "azure_ad"),
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum EncryptionMode {
Off,
On,
#[default]
NotSupported,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum StartPosition {
Beginning,
#[default]
Current,
}
impl std::fmt::Display for EncryptionMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Off => write!(f, "off"),
Self::On => write!(f, "on"),
Self::NotSupported => write!(f, "not_supported"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct TableKeyConfig {
pub table: String,
pub key_columns: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct MsSqlSourceConfig {
#[serde(default = "default_host")]
pub host: String,
#[serde(default = "default_port")]
pub port: u16,
pub database: String,
pub user: String,
#[serde(default)]
pub password: String,
#[serde(default)]
pub auth_mode: AuthMode,
#[serde(default)]
pub tables: Vec<String>,
#[serde(default = "default_poll_interval_ms")]
pub poll_interval_ms: u64,
#[serde(default)]
pub encryption: EncryptionMode,
#[serde(default)]
pub trust_server_certificate: bool,
#[serde(default)]
pub table_keys: Vec<TableKeyConfig>,
#[serde(default)]
pub start_position: StartPosition,
}
fn default_host() -> String {
"localhost".to_string()
}
fn default_port() -> u16 {
1433
}
fn default_poll_interval_ms() -> u64 {
1000 }
impl Default for MsSqlSourceConfig {
fn default() -> Self {
Self {
host: default_host(),
port: default_port(),
database: String::new(),
user: String::new(),
password: String::new(),
auth_mode: AuthMode::default(),
tables: Vec::new(),
poll_interval_ms: default_poll_interval_ms(),
encryption: EncryptionMode::default(),
trust_server_certificate: false,
table_keys: Vec::new(),
start_position: StartPosition::default(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = MsSqlSourceConfig::default();
assert_eq!(config.host, "localhost");
assert_eq!(config.port, 1433);
assert_eq!(config.poll_interval_ms, 1000);
assert_eq!(config.auth_mode, AuthMode::SqlServer);
assert_eq!(config.encryption, EncryptionMode::NotSupported);
assert!(!config.trust_server_certificate);
}
#[test]
fn test_config_serialization() {
let config = MsSqlSourceConfig {
host: "sqlserver.example.com".to_string(),
port: 1433,
database: "production".to_string(),
user: "drasi_user".to_string(),
password: "secret".to_string(),
auth_mode: AuthMode::SqlServer,
tables: vec!["orders".to_string(), "customers".to_string()],
poll_interval_ms: 2000,
encryption: EncryptionMode::On,
trust_server_certificate: true,
table_keys: vec![TableKeyConfig {
table: "orders".to_string(),
key_columns: vec!["order_id".to_string()],
}],
start_position: StartPosition::Beginning,
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: MsSqlSourceConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config, deserialized);
}
#[test]
fn test_auth_mode_display() {
assert_eq!(AuthMode::SqlServer.to_string(), "sql_server");
assert_eq!(AuthMode::Windows.to_string(), "windows");
assert_eq!(AuthMode::AzureAd.to_string(), "azure_ad");
}
#[test]
fn test_encryption_mode_display() {
assert_eq!(EncryptionMode::Off.to_string(), "off");
assert_eq!(EncryptionMode::On.to_string(), "on");
assert_eq!(EncryptionMode::NotSupported.to_string(), "not_supported");
}
#[test]
fn test_table_key_config() {
let tk = TableKeyConfig {
table: "orders".to_string(),
key_columns: vec!["order_id".to_string(), "line_item".to_string()],
};
assert_eq!(tk.table, "orders");
assert_eq!(tk.key_columns.len(), 2);
}
#[test]
fn test_start_position_default() {
assert_eq!(StartPosition::default(), StartPosition::Current);
}
#[test]
fn test_start_position_serialization() {
let json = serde_json::to_string(&StartPosition::Beginning).unwrap();
assert_eq!(json, "\"beginning\"");
let json = serde_json::to_string(&StartPosition::Current).unwrap();
assert_eq!(json, "\"current\"");
}
#[test]
fn test_validate_sql_identifier_valid() {
assert!(validate_sql_identifier("orders").is_ok());
assert!(validate_sql_identifier("Orders").is_ok());
assert!(validate_sql_identifier("order_items").is_ok());
assert!(validate_sql_identifier("Order_Items_2024").is_ok());
assert!(validate_sql_identifier("dbo.orders").is_ok());
assert!(validate_sql_identifier("sales.order_items").is_ok());
assert!(validate_sql_identifier("MySchema.MyTable").is_ok());
}
#[test]
fn test_validate_sql_identifier_sql_injection() {
assert!(validate_sql_identifier("orders; DROP TABLE users--").is_err());
assert!(validate_sql_identifier("orders'; DELETE FROM users;--").is_err());
assert!(validate_sql_identifier("orders OR 1=1").is_err());
assert!(validate_sql_identifier("orders/**/UNION/**/SELECT").is_err());
assert!(validate_sql_identifier("orders\n; DROP TABLE").is_err());
}
#[test]
fn test_validate_sql_identifier_empty() {
assert!(validate_sql_identifier("").is_err());
}
#[test]
fn test_validate_sql_identifier_too_long() {
let long_name = "a".repeat(129);
assert!(validate_sql_identifier(&long_name).is_err());
let valid_long_name = "a".repeat(128);
assert!(validate_sql_identifier(&valid_long_name).is_ok());
}
#[test]
fn test_validate_sql_identifier_invalid_start() {
assert!(validate_sql_identifier("123table").is_err());
assert!(validate_sql_identifier("1orders").is_err());
}
#[test]
fn test_validate_sql_identifier_invalid_dots() {
assert!(validate_sql_identifier(".orders").is_err());
assert!(validate_sql_identifier("orders.").is_err());
assert!(validate_sql_identifier("dbo..orders").is_err());
}
}