#![allow(unexpected_cfgs)]
pub mod decoder;
pub mod descriptor;
pub mod lsn;
pub mod stream;
pub mod types;
pub use drasi_mssql_common::config;
pub use drasi_mssql_common::connection;
pub use drasi_mssql_common::error;
pub use drasi_mssql_common::keys;
pub use decoder::CdcOperation;
pub use drasi_mssql_common::{
validate_sql_identifier, AuthMode, ConnectionError, EncryptionMode, LsnError, MsSqlConnection,
MsSqlError, MsSqlErrorKind, MsSqlSourceConfig, PrimaryKeyCache, PrimaryKeyError, StartPosition,
TableKeyConfig,
};
pub use lsn::Lsn;
use anyhow::Result;
use async_trait::async_trait;
use drasi_lib::sources::base::{SourceBase, SourceBaseParams};
use drasi_lib::sources::Source;
use drasi_lib::state_store::StateStoreProvider;
use std::sync::Arc;
use tokio::sync::watch;
use tokio::sync::RwLock;
pub struct MsSqlSource {
source_id: String,
config: MsSqlSourceConfig,
base: SourceBase,
state_store: Arc<RwLock<Option<Arc<dyn StateStoreProvider>>>>,
task_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
shutdown_tx: watch::Sender<bool>,
shutdown_rx: watch::Receiver<bool>,
}
impl MsSqlSource {
pub fn new(id: impl Into<String>, config: MsSqlSourceConfig) -> Result<Self> {
let source_id = id.into();
let params = SourceBaseParams::new(&source_id);
let (shutdown_tx, shutdown_rx) = watch::channel(false);
Ok(Self {
source_id,
config,
base: SourceBase::new(params)?,
state_store: Arc::new(RwLock::new(None)),
task_handle: Arc::new(RwLock::new(None)),
shutdown_tx,
shutdown_rx,
})
}
pub fn builder(id: impl Into<String>) -> MsSqlSourceBuilder {
MsSqlSourceBuilder::new(id)
}
}
#[async_trait]
impl Source for MsSqlSource {
fn id(&self) -> &str {
&self.base.id
}
fn type_name(&self) -> &str {
"mssql"
}
fn properties(&self) -> std::collections::HashMap<String, serde_json::Value> {
use crate::descriptor::{
AuthModeDto, EncryptionModeDto, MsSqlSourceConfigDto, StartPositionDto,
TableKeyConfigDto,
};
use drasi_plugin_sdk::ConfigValue;
let auth_mode_dto = match self.config.auth_mode {
crate::AuthMode::SqlServer => AuthModeDto::SqlServer,
crate::AuthMode::Windows => AuthModeDto::Windows,
crate::AuthMode::AzureAd => AuthModeDto::AzureAd,
};
let encryption_dto = match self.config.encryption {
crate::EncryptionMode::Off => EncryptionModeDto::Off,
crate::EncryptionMode::On => EncryptionModeDto::On,
crate::EncryptionMode::NotSupported => EncryptionModeDto::NotSupported,
};
let start_position_dto = match self.config.start_position {
crate::StartPosition::Beginning => StartPositionDto::Beginning,
crate::StartPosition::Current => StartPositionDto::Current,
};
let table_keys_dto: Vec<TableKeyConfigDto> = self
.config
.table_keys
.iter()
.map(|tk| TableKeyConfigDto {
table: tk.table.clone(),
key_columns: tk.key_columns.clone(),
})
.collect();
let dto = MsSqlSourceConfigDto {
host: ConfigValue::Static(self.config.host.clone()),
port: ConfigValue::Static(self.config.port),
database: ConfigValue::Static(self.config.database.clone()),
user: ConfigValue::Static(self.config.user.clone()),
password: ConfigValue::Static(self.config.password.clone()),
auth_mode: ConfigValue::Static(auth_mode_dto),
tables: self.config.tables.clone(),
poll_interval_ms: ConfigValue::Static(self.config.poll_interval_ms),
encryption: ConfigValue::Static(encryption_dto),
trust_server_certificate: ConfigValue::Static(self.config.trust_server_certificate),
table_keys: table_keys_dto,
start_position: ConfigValue::Static(start_position_dto),
};
match serde_json::to_value(&dto) {
Ok(serde_json::Value::Object(mut map)) => {
map.remove("password");
map.into_iter().collect()
}
_ => std::collections::HashMap::new(),
}
}
fn auto_start(&self) -> bool {
self.base.get_auto_start()
}
async fn status(&self) -> drasi_lib::channels::ComponentStatus {
self.base.get_status().await
}
async fn start(&self) -> Result<()> {
use drasi_lib::channels::ComponentStatus;
if self.base.get_status().await == ComponentStatus::Running {
return Ok(());
}
self.base.set_status(ComponentStatus::Starting, None).await;
log::info!("Starting MS SQL CDC source: {}", self.base.id);
let config = self.config.clone();
let source_id = self.base.id.clone();
let dispatchers = self.base.dispatchers.clone();
let state_store = self.state_store.read().await.clone();
let shutdown_rx = self.shutdown_rx.clone();
let task_handle = tokio::spawn(async move {
if let Err(e) = stream::run_cdc_stream(
source_id.clone(),
config,
dispatchers,
state_store,
shutdown_rx,
)
.await
{
log::error!("CDC stream task failed for {source_id}: {e}");
}
});
*self.task_handle.write().await = Some(task_handle);
self.base.set_status(ComponentStatus::Running, None).await;
log::info!("MS SQL source '{}' started CDC polling", self.base.id);
Ok(())
}
async fn stop(&self) -> Result<()> {
use drasi_lib::channels::ComponentStatus;
log::info!("MS SQL source '{}' stopping", self.base.id);
if let Err(e) = self.shutdown_tx.send(true) {
log::warn!("Failed to send shutdown signal: {e}");
}
if let Some(handle) = self.task_handle.write().await.take() {
match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
Ok(Ok(())) => {
log::debug!("CDC polling task stopped gracefully");
}
Ok(Err(e)) => {
log::warn!("CDC polling task panicked: {e}");
}
Err(_) => {
log::warn!("CDC polling task did not stop within timeout, it will be dropped");
}
}
}
self.base.set_status(ComponentStatus::Stopped, None).await;
Ok(())
}
async fn subscribe(
&self,
settings: drasi_lib::config::SourceSubscriptionSettings,
) -> Result<drasi_lib::channels::SubscriptionResponse> {
self.base
.subscribe_with_bootstrap(&settings, "MS SQL")
.await
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn initialize(&self, context: drasi_lib::context::SourceRuntimeContext) {
self.base.initialize(context.clone()).await;
if let Some(state_store) = context.state_store {
*self.state_store.write().await = Some(state_store);
log::debug!("State store injected into MS SQL source '{}'", self.base.id);
}
}
async fn set_bootstrap_provider(
&self,
provider: Box<dyn drasi_lib::bootstrap::BootstrapProvider + 'static>,
) {
self.base.set_bootstrap_provider(provider).await;
}
}
pub struct MsSqlSourceBuilder {
id: String,
config: MsSqlSourceConfig,
bootstrap_provider: Option<Box<dyn drasi_lib::bootstrap::BootstrapProvider + 'static>>,
}
impl MsSqlSourceBuilder {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
config: MsSqlSourceConfig::default(),
bootstrap_provider: None,
}
}
pub fn with_host(mut self, host: impl Into<String>) -> Self {
self.config.host = host.into();
self
}
pub fn with_port(mut self, port: u16) -> Self {
self.config.port = port;
self
}
pub fn with_database(mut self, database: impl Into<String>) -> Self {
self.config.database = database.into();
self
}
pub fn with_user(mut self, user: impl Into<String>) -> Self {
self.config.user = user.into();
self
}
pub fn with_password(mut self, password: impl Into<String>) -> Self {
self.config.password = password.into();
self
}
pub fn with_auth_mode(mut self, auth_mode: AuthMode) -> Self {
self.config.auth_mode = auth_mode;
self
}
pub fn with_tables(mut self, tables: Vec<String>) -> Self {
self.config.tables = tables;
self
}
pub fn with_table(mut self, table: impl Into<String>) -> Self {
self.config.tables.push(table.into());
self
}
pub fn with_poll_interval_ms(mut self, ms: u64) -> Self {
self.config.poll_interval_ms = ms;
self
}
pub fn with_encryption(mut self, encryption: EncryptionMode) -> Self {
self.config.encryption = encryption;
self
}
pub fn with_trust_server_certificate(mut self, trust: bool) -> Self {
self.config.trust_server_certificate = trust;
self
}
pub fn with_table_key(mut self, table: impl Into<String>, key_columns: Vec<String>) -> Self {
self.config.table_keys.push(TableKeyConfig {
table: table.into(),
key_columns,
});
self
}
pub fn with_start_position(mut self, position: StartPosition) -> Self {
self.config.start_position = position;
self
}
pub fn with_bootstrap_provider(
mut self,
provider: impl drasi_lib::bootstrap::BootstrapProvider + 'static,
) -> Self {
self.bootstrap_provider = Some(Box::new(provider));
self
}
pub fn build(self) -> Result<MsSqlSource> {
if self.config.database.is_empty() {
return Err(anyhow::anyhow!("Database name is required"));
}
if self.config.user.is_empty() {
return Err(anyhow::anyhow!("Database user is required"));
}
let source_id = self.id.clone();
let mut params = SourceBaseParams::new(&source_id);
if let Some(provider) = self.bootstrap_provider {
params = params.with_bootstrap_provider(provider);
}
let (shutdown_tx, shutdown_rx) = watch::channel(false);
Ok(MsSqlSource {
source_id,
config: self.config,
base: SourceBase::new(params)?,
state_store: Arc::new(RwLock::new(None)),
task_handle: Arc::new(RwLock::new(None)),
shutdown_tx,
shutdown_rx,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_basic() {
let source = MsSqlSource::builder("test-source")
.with_host("localhost")
.with_database("testdb")
.with_user("testuser")
.with_password("testpass")
.build()
.unwrap();
assert_eq!(source.id(), "test-source");
assert_eq!(source.type_name(), "mssql");
assert_eq!(source.config.host, "localhost");
assert_eq!(source.config.database, "testdb");
}
#[test]
fn test_builder_with_tables() {
let source = MsSqlSource::builder("test-source")
.with_database("testdb")
.with_user("testuser")
.with_tables(vec!["table1".to_string(), "table2".to_string()])
.build()
.unwrap();
assert_eq!(source.config.tables.len(), 2);
}
#[test]
fn test_builder_missing_required_fields() {
let result = MsSqlSource::builder("test-source")
.with_host("localhost")
.build();
assert!(result.is_err());
}
#[test]
fn test_builder_table_keys() {
let source = MsSqlSource::builder("test-source")
.with_database("testdb")
.with_user("testuser")
.with_table_key("orders", vec!["order_id".to_string()])
.build()
.unwrap();
assert_eq!(source.config.table_keys.len(), 1);
assert_eq!(source.config.table_keys[0].table, "orders");
}
}
#[cfg(feature = "dynamic-plugin")]
drasi_plugin_sdk::export_plugin!(
plugin_id = "mssql-source",
core_version = env!("CARGO_PKG_VERSION"),
lib_version = env!("CARGO_PKG_VERSION"),
plugin_version = env!("CARGO_PKG_VERSION"),
source_descriptors = [descriptor::MsSqlSourceDescriptor],
reaction_descriptors = [],
bootstrap_descriptors = [],
);