1#![allow(unexpected_cfgs)]
2pub mod decoder;
72pub mod descriptor;
73pub mod lsn;
74pub mod stream;
75pub mod types;
76
77pub use drasi_mssql_common::config;
79pub use drasi_mssql_common::connection;
80pub use drasi_mssql_common::error;
81pub use drasi_mssql_common::keys;
82
83pub use decoder::CdcOperation;
85pub use drasi_mssql_common::{
86 validate_sql_identifier, AuthMode, ConnectionError, EncryptionMode, LsnError, MsSqlConnection,
87 MsSqlError, MsSqlErrorKind, MsSqlSourceConfig, PrimaryKeyCache, PrimaryKeyError, StartPosition,
88 TableKeyConfig,
89};
90pub use lsn::Lsn;
91
92use anyhow::Result;
93use async_trait::async_trait;
94use drasi_lib::sources::base::{SourceBase, SourceBaseParams};
95use drasi_lib::sources::Source;
96use drasi_lib::state_store::StateStoreProvider;
97use std::sync::Arc;
98use tokio::sync::watch;
99use tokio::sync::RwLock;
100
101pub struct MsSqlSource {
105 source_id: String,
107
108 config: MsSqlSourceConfig,
110
111 base: SourceBase,
113
114 state_store: Arc<RwLock<Option<Arc<dyn StateStoreProvider>>>>,
116
117 task_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
119
120 shutdown_tx: watch::Sender<bool>,
122
123 shutdown_rx: watch::Receiver<bool>,
125}
126
127impl MsSqlSource {
128 pub fn new(id: impl Into<String>, config: MsSqlSourceConfig) -> Result<Self> {
134 let source_id = id.into();
135
136 let params = SourceBaseParams::new(&source_id);
138
139 let (shutdown_tx, shutdown_rx) = watch::channel(false);
141
142 Ok(Self {
143 source_id,
144 config,
145 base: SourceBase::new(params)?,
146 state_store: Arc::new(RwLock::new(None)),
147 task_handle: Arc::new(RwLock::new(None)),
148 shutdown_tx,
149 shutdown_rx,
150 })
151 }
152
153 pub fn builder(id: impl Into<String>) -> MsSqlSourceBuilder {
173 MsSqlSourceBuilder::new(id)
174 }
175}
176
177#[async_trait]
178impl Source for MsSqlSource {
179 fn id(&self) -> &str {
180 &self.base.id
181 }
182
183 fn type_name(&self) -> &str {
184 "mssql"
185 }
186
187 fn properties(&self) -> std::collections::HashMap<String, serde_json::Value> {
188 use crate::descriptor::{
189 AuthModeDto, EncryptionModeDto, MsSqlSourceConfigDto, StartPositionDto,
190 TableKeyConfigDto,
191 };
192 use drasi_plugin_sdk::ConfigValue;
193
194 let auth_mode_dto = match self.config.auth_mode {
195 crate::AuthMode::SqlServer => AuthModeDto::SqlServer,
196 crate::AuthMode::Windows => AuthModeDto::Windows,
197 crate::AuthMode::AzureAd => AuthModeDto::AzureAd,
198 };
199
200 let encryption_dto = match self.config.encryption {
201 crate::EncryptionMode::Off => EncryptionModeDto::Off,
202 crate::EncryptionMode::On => EncryptionModeDto::On,
203 crate::EncryptionMode::NotSupported => EncryptionModeDto::NotSupported,
204 };
205
206 let start_position_dto = match self.config.start_position {
207 crate::StartPosition::Beginning => StartPositionDto::Beginning,
208 crate::StartPosition::Current => StartPositionDto::Current,
209 };
210
211 let table_keys_dto: Vec<TableKeyConfigDto> = self
212 .config
213 .table_keys
214 .iter()
215 .map(|tk| TableKeyConfigDto {
216 table: tk.table.clone(),
217 key_columns: tk.key_columns.clone(),
218 })
219 .collect();
220
221 let dto = MsSqlSourceConfigDto {
222 host: ConfigValue::Static(self.config.host.clone()),
223 port: ConfigValue::Static(self.config.port),
224 database: ConfigValue::Static(self.config.database.clone()),
225 user: ConfigValue::Static(self.config.user.clone()),
226 password: ConfigValue::Static(self.config.password.clone()),
227 auth_mode: ConfigValue::Static(auth_mode_dto),
228 tables: self.config.tables.clone(),
229 poll_interval_ms: ConfigValue::Static(self.config.poll_interval_ms),
230 encryption: ConfigValue::Static(encryption_dto),
231 trust_server_certificate: ConfigValue::Static(self.config.trust_server_certificate),
232 table_keys: table_keys_dto,
233 start_position: ConfigValue::Static(start_position_dto),
234 };
235
236 match serde_json::to_value(&dto) {
237 Ok(serde_json::Value::Object(mut map)) => {
238 map.remove("password");
240 map.into_iter().collect()
241 }
242 _ => std::collections::HashMap::new(),
243 }
244 }
245
246 fn auto_start(&self) -> bool {
247 self.base.get_auto_start()
248 }
249
250 async fn status(&self) -> drasi_lib::channels::ComponentStatus {
251 self.base.get_status().await
252 }
253
254 async fn start(&self) -> Result<()> {
255 use drasi_lib::channels::ComponentStatus;
256
257 if self.base.get_status().await == ComponentStatus::Running {
258 return Ok(());
259 }
260
261 self.base.set_status(ComponentStatus::Starting, None).await;
262 log::info!("Starting MS SQL CDC source: {}", self.base.id);
263
264 let config = self.config.clone();
265 let source_id = self.base.id.clone();
266 let dispatchers = self.base.dispatchers.clone();
267 let state_store = self.state_store.read().await.clone();
268 let shutdown_rx = self.shutdown_rx.clone();
269
270 let task_handle = tokio::spawn(async move {
272 if let Err(e) = stream::run_cdc_stream(
273 source_id.clone(),
274 config,
275 dispatchers,
276 state_store,
277 shutdown_rx,
278 )
279 .await
280 {
281 log::error!("CDC stream task failed for {source_id}: {e}");
282 }
283 });
284
285 *self.task_handle.write().await = Some(task_handle);
287
288 self.base.set_status(ComponentStatus::Running, None).await;
289
290 log::info!("MS SQL source '{}' started CDC polling", self.base.id);
291 Ok(())
292 }
293
294 async fn stop(&self) -> Result<()> {
295 use drasi_lib::channels::ComponentStatus;
296
297 log::info!("MS SQL source '{}' stopping", self.base.id);
298
299 if let Err(e) = self.shutdown_tx.send(true) {
301 log::warn!("Failed to send shutdown signal: {e}");
302 }
303
304 if let Some(handle) = self.task_handle.write().await.take() {
306 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
308 Ok(Ok(())) => {
309 log::debug!("CDC polling task stopped gracefully");
310 }
311 Ok(Err(e)) => {
312 log::warn!("CDC polling task panicked: {e}");
313 }
314 Err(_) => {
315 log::warn!("CDC polling task did not stop within timeout, it will be dropped");
316 }
317 }
318 }
319
320 self.base.set_status(ComponentStatus::Stopped, None).await;
321
322 Ok(())
323 }
324
325 async fn subscribe(
326 &self,
327 settings: drasi_lib::config::SourceSubscriptionSettings,
328 ) -> Result<drasi_lib::channels::SubscriptionResponse> {
329 self.base
330 .subscribe_with_bootstrap(&settings, "MS SQL")
331 .await
332 }
333
334 fn as_any(&self) -> &dyn std::any::Any {
335 self
336 }
337
338 async fn initialize(&self, context: drasi_lib::context::SourceRuntimeContext) {
339 self.base.initialize(context.clone()).await;
340
341 if let Some(state_store) = context.state_store {
343 *self.state_store.write().await = Some(state_store);
344 log::debug!("State store injected into MS SQL source '{}'", self.base.id);
345 }
346 }
347
348 async fn set_bootstrap_provider(
349 &self,
350 provider: Box<dyn drasi_lib::bootstrap::BootstrapProvider + 'static>,
351 ) {
352 self.base.set_bootstrap_provider(provider).await;
353 }
354}
355
356pub struct MsSqlSourceBuilder {
360 id: String,
361 config: MsSqlSourceConfig,
362 bootstrap_provider: Option<Box<dyn drasi_lib::bootstrap::BootstrapProvider + 'static>>,
363}
364
365impl MsSqlSourceBuilder {
366 pub fn new(id: impl Into<String>) -> Self {
368 Self {
369 id: id.into(),
370 config: MsSqlSourceConfig::default(),
371 bootstrap_provider: None,
372 }
373 }
374
375 pub fn with_host(mut self, host: impl Into<String>) -> Self {
377 self.config.host = host.into();
378 self
379 }
380
381 pub fn with_port(mut self, port: u16) -> Self {
383 self.config.port = port;
384 self
385 }
386
387 pub fn with_database(mut self, database: impl Into<String>) -> Self {
389 self.config.database = database.into();
390 self
391 }
392
393 pub fn with_user(mut self, user: impl Into<String>) -> Self {
395 self.config.user = user.into();
396 self
397 }
398
399 pub fn with_password(mut self, password: impl Into<String>) -> Self {
401 self.config.password = password.into();
402 self
403 }
404
405 pub fn with_auth_mode(mut self, auth_mode: AuthMode) -> Self {
407 self.config.auth_mode = auth_mode;
408 self
409 }
410
411 pub fn with_tables(mut self, tables: Vec<String>) -> Self {
413 self.config.tables = tables;
414 self
415 }
416
417 pub fn with_table(mut self, table: impl Into<String>) -> Self {
419 self.config.tables.push(table.into());
420 self
421 }
422
423 pub fn with_poll_interval_ms(mut self, ms: u64) -> Self {
425 self.config.poll_interval_ms = ms;
426 self
427 }
428
429 pub fn with_encryption(mut self, encryption: EncryptionMode) -> Self {
431 self.config.encryption = encryption;
432 self
433 }
434
435 pub fn with_trust_server_certificate(mut self, trust: bool) -> Self {
437 self.config.trust_server_certificate = trust;
438 self
439 }
440
441 pub fn with_table_key(mut self, table: impl Into<String>, key_columns: Vec<String>) -> Self {
443 self.config.table_keys.push(TableKeyConfig {
444 table: table.into(),
445 key_columns,
446 });
447 self
448 }
449
450 pub fn with_start_position(mut self, position: StartPosition) -> Self {
452 self.config.start_position = position;
453 self
454 }
455
456 pub fn with_bootstrap_provider(
458 mut self,
459 provider: impl drasi_lib::bootstrap::BootstrapProvider + 'static,
460 ) -> Self {
461 self.bootstrap_provider = Some(Box::new(provider));
462 self
463 }
464
465 pub fn build(self) -> Result<MsSqlSource> {
470 if self.config.database.is_empty() {
472 return Err(anyhow::anyhow!("Database name is required"));
473 }
474 if self.config.user.is_empty() {
475 return Err(anyhow::anyhow!("Database user is required"));
476 }
477
478 let source_id = self.id.clone();
479
480 let mut params = SourceBaseParams::new(&source_id);
482
483 if let Some(provider) = self.bootstrap_provider {
485 params = params.with_bootstrap_provider(provider);
486 }
487
488 let (shutdown_tx, shutdown_rx) = watch::channel(false);
490
491 Ok(MsSqlSource {
492 source_id,
493 config: self.config,
494 base: SourceBase::new(params)?,
495 state_store: Arc::new(RwLock::new(None)),
496 task_handle: Arc::new(RwLock::new(None)),
497 shutdown_tx,
498 shutdown_rx,
499 })
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[test]
508 fn test_builder_basic() {
509 let source = MsSqlSource::builder("test-source")
510 .with_host("localhost")
511 .with_database("testdb")
512 .with_user("testuser")
513 .with_password("testpass")
514 .build()
515 .unwrap();
516
517 assert_eq!(source.id(), "test-source");
518 assert_eq!(source.type_name(), "mssql");
519 assert_eq!(source.config.host, "localhost");
520 assert_eq!(source.config.database, "testdb");
521 }
522
523 #[test]
524 fn test_builder_with_tables() {
525 let source = MsSqlSource::builder("test-source")
526 .with_database("testdb")
527 .with_user("testuser")
528 .with_tables(vec!["table1".to_string(), "table2".to_string()])
529 .build()
530 .unwrap();
531
532 assert_eq!(source.config.tables.len(), 2);
533 }
534
535 #[test]
536 fn test_builder_missing_required_fields() {
537 let result = MsSqlSource::builder("test-source")
538 .with_host("localhost")
539 .build();
540
541 assert!(result.is_err());
542 }
543
544 #[test]
545 fn test_builder_table_keys() {
546 let source = MsSqlSource::builder("test-source")
547 .with_database("testdb")
548 .with_user("testuser")
549 .with_table_key("orders", vec!["order_id".to_string()])
550 .build()
551 .unwrap();
552
553 assert_eq!(source.config.table_keys.len(), 1);
554 assert_eq!(source.config.table_keys[0].table, "orders");
555 }
556}
557
558#[cfg(feature = "dynamic-plugin")]
562drasi_plugin_sdk::export_plugin!(
563 plugin_id = "mssql-source",
564 core_version = env!("CARGO_PKG_VERSION"),
565 lib_version = env!("CARGO_PKG_VERSION"),
566 plugin_version = env!("CARGO_PKG_VERSION"),
567 source_descriptors = [descriptor::MsSqlSourceDescriptor],
568 reaction_descriptors = [],
569 bootstrap_descriptors = [],
570);