Skip to main content

drasi_source_mssql/
lib.rs

1#![allow(unexpected_cfgs)]
2// Copyright 2025 The Drasi Authors.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Microsoft SQL Server CDC Source for Drasi
17//!
18//! This source plugin captures data changes from MS SQL Server databases using
19//! Change Data Capture (CDC). It monitors CDC change tables and converts changes
20//! to Drasi source events.
21//!
22//! # Features
23//!
24//! - Real-time change capture via CDC polling
25//! - LSN-based progress tracking with StateStore persistence
26//! - Automatic LSN validation and recovery
27//! - Support for INSERT, UPDATE, DELETE operations
28//! - Transaction grouping
29//! - Primary key discovery and custom key configuration
30//!
31//! # Prerequisites
32//!
33//! Before using this source, you must:
34//!
35//! 1. Enable CDC on the database:
36//!    ```sql
37//!    EXEC sys.sp_cdc_enable_db;
38//!    ```
39//!
40//! 2. Enable CDC on tables you want to monitor:
41//!    ```sql
42//!    EXEC sys.sp_cdc_enable_table
43//!        @source_schema = N'dbo',
44//!        @source_name = N'YourTable',
45//!        @role_name = NULL;
46//!    ```
47//!
48//! 3. Ensure SQL Server Agent is running (required for CDC)
49//!
50//! # Example
51//!
52//! ```no_run
53//! use drasi_source_mssql::{MsSqlSource, StartPosition};
54//! use drasi_lib::Source;
55//!
56//! # async fn example() -> anyhow::Result<()> {
57//! let source = MsSqlSource::builder("mssql-source")
58//!     .with_host("localhost")
59//!     .with_database("production")
60//!     .with_user("drasi_user")
61//!     .with_password("secure_password")
62//!     .with_tables(vec!["orders".to_string(), "customers".to_string()])
63//!     .with_start_position(StartPosition::Beginning)  // Capture all retained changes
64//!     .build()?;
65//!
66//! source.start().await?;
67//! # Ok(())
68//! # }
69//! ```
70
71pub mod decoder;
72pub mod descriptor;
73pub mod lsn;
74pub mod stream;
75pub mod types;
76
77// Re-export from drasi-mssql-common
78pub use drasi_mssql_common::config;
79pub use drasi_mssql_common::connection;
80pub use drasi_mssql_common::error;
81pub use drasi_mssql_common::keys;
82
83// Re-export main types
84pub use decoder::CdcOperation;
85pub use drasi_mssql_common::{
86    validate_sql_identifier, AuthMode, ConnectionError, EncryptionMode, LsnError, MsSqlConnection,
87    MsSqlError, MsSqlErrorKind, MsSqlSourceConfig, PrimaryKeyCache, PrimaryKeyError, StartPosition,
88    TableKeyConfig,
89};
90pub use lsn::Lsn;
91
92use anyhow::Result;
93use async_trait::async_trait;
94use drasi_lib::sources::base::{SourceBase, SourceBaseParams};
95use drasi_lib::sources::Source;
96use drasi_lib::state_store::StateStoreProvider;
97use std::sync::Arc;
98use tokio::sync::watch;
99use tokio::sync::RwLock;
100
101/// MS SQL CDC Source
102///
103/// Monitors MS SQL Server CDC change tables and emits source change events.
104pub struct MsSqlSource {
105    /// Source identifier
106    source_id: String,
107
108    /// Configuration
109    config: MsSqlSourceConfig,
110
111    /// Base source implementation (handles dispatching, status, etc.)
112    base: SourceBase,
113
114    /// State store for LSN persistence
115    state_store: Arc<RwLock<Option<Arc<dyn StateStoreProvider>>>>,
116
117    /// CDC polling task handle
118    task_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
119
120    /// Shutdown signal sender for graceful shutdown
121    shutdown_tx: watch::Sender<bool>,
122
123    /// Shutdown signal receiver (cloned for each task)
124    shutdown_rx: watch::Receiver<bool>,
125}
126
127impl MsSqlSource {
128    /// Create a new MS SQL source with configuration
129    ///
130    /// # Arguments
131    /// * `id` - Unique identifier for this source
132    /// * `config` - MS SQL source configuration
133    pub fn new(id: impl Into<String>, config: MsSqlSourceConfig) -> Result<Self> {
134        let source_id = id.into();
135
136        // Create base source parameters
137        let params = SourceBaseParams::new(&source_id);
138
139        // Create shutdown channel (false = running, true = shutdown)
140        let (shutdown_tx, shutdown_rx) = watch::channel(false);
141
142        Ok(Self {
143            source_id,
144            config,
145            base: SourceBase::new(params)?,
146            state_store: Arc::new(RwLock::new(None)),
147            task_handle: Arc::new(RwLock::new(None)),
148            shutdown_tx,
149            shutdown_rx,
150        })
151    }
152
153    /// Create a builder for configuring the MS SQL source
154    ///
155    /// # Arguments
156    /// * `id` - Unique identifier for this source
157    ///
158    /// # Example
159    /// ```no_run
160    /// use drasi_source_mssql::MsSqlSource;
161    ///
162    /// # fn example() -> anyhow::Result<()> {
163    /// let source = MsSqlSource::builder("my-source")
164    ///     .with_host("localhost")
165    ///     .with_database("mydb")
166    ///     .with_user("user")
167    ///     .with_password("password")
168    ///     .build()?;
169    /// # Ok(())
170    /// # }
171    /// ```
172    pub fn builder(id: impl Into<String>) -> MsSqlSourceBuilder {
173        MsSqlSourceBuilder::new(id)
174    }
175}
176
177#[async_trait]
178impl Source for MsSqlSource {
179    fn id(&self) -> &str {
180        &self.base.id
181    }
182
183    fn type_name(&self) -> &str {
184        "mssql"
185    }
186
187    fn properties(&self) -> std::collections::HashMap<String, serde_json::Value> {
188        use crate::descriptor::{
189            AuthModeDto, EncryptionModeDto, MsSqlSourceConfigDto, StartPositionDto,
190            TableKeyConfigDto,
191        };
192        use drasi_plugin_sdk::ConfigValue;
193
194        let auth_mode_dto = match self.config.auth_mode {
195            crate::AuthMode::SqlServer => AuthModeDto::SqlServer,
196            crate::AuthMode::Windows => AuthModeDto::Windows,
197            crate::AuthMode::AzureAd => AuthModeDto::AzureAd,
198        };
199
200        let encryption_dto = match self.config.encryption {
201            crate::EncryptionMode::Off => EncryptionModeDto::Off,
202            crate::EncryptionMode::On => EncryptionModeDto::On,
203            crate::EncryptionMode::NotSupported => EncryptionModeDto::NotSupported,
204        };
205
206        let start_position_dto = match self.config.start_position {
207            crate::StartPosition::Beginning => StartPositionDto::Beginning,
208            crate::StartPosition::Current => StartPositionDto::Current,
209        };
210
211        let table_keys_dto: Vec<TableKeyConfigDto> = self
212            .config
213            .table_keys
214            .iter()
215            .map(|tk| TableKeyConfigDto {
216                table: tk.table.clone(),
217                key_columns: tk.key_columns.clone(),
218            })
219            .collect();
220
221        let dto = MsSqlSourceConfigDto {
222            host: ConfigValue::Static(self.config.host.clone()),
223            port: ConfigValue::Static(self.config.port),
224            database: ConfigValue::Static(self.config.database.clone()),
225            user: ConfigValue::Static(self.config.user.clone()),
226            password: ConfigValue::Static(self.config.password.clone()),
227            auth_mode: ConfigValue::Static(auth_mode_dto),
228            tables: self.config.tables.clone(),
229            poll_interval_ms: ConfigValue::Static(self.config.poll_interval_ms),
230            encryption: ConfigValue::Static(encryption_dto),
231            trust_server_certificate: ConfigValue::Static(self.config.trust_server_certificate),
232            table_keys: table_keys_dto,
233            start_position: ConfigValue::Static(start_position_dto),
234        };
235
236        match serde_json::to_value(&dto) {
237            Ok(serde_json::Value::Object(mut map)) => {
238                // Don't expose password
239                map.remove("password");
240                map.into_iter().collect()
241            }
242            _ => std::collections::HashMap::new(),
243        }
244    }
245
246    fn auto_start(&self) -> bool {
247        self.base.get_auto_start()
248    }
249
250    async fn status(&self) -> drasi_lib::channels::ComponentStatus {
251        self.base.get_status().await
252    }
253
254    async fn start(&self) -> Result<()> {
255        use drasi_lib::channels::ComponentStatus;
256
257        if self.base.get_status().await == ComponentStatus::Running {
258            return Ok(());
259        }
260
261        self.base.set_status(ComponentStatus::Starting, None).await;
262        log::info!("Starting MS SQL CDC source: {}", self.base.id);
263
264        let config = self.config.clone();
265        let source_id = self.base.id.clone();
266        let dispatchers = self.base.dispatchers.clone();
267        let state_store = self.state_store.read().await.clone();
268        let shutdown_rx = self.shutdown_rx.clone();
269
270        // Spawn CDC polling task
271        let task_handle = tokio::spawn(async move {
272            if let Err(e) = stream::run_cdc_stream(
273                source_id.clone(),
274                config,
275                dispatchers,
276                state_store,
277                shutdown_rx,
278            )
279            .await
280            {
281                log::error!("CDC stream task failed for {source_id}: {e}");
282            }
283        });
284
285        // Store task handle for shutdown
286        *self.task_handle.write().await = Some(task_handle);
287
288        self.base.set_status(ComponentStatus::Running, None).await;
289
290        log::info!("MS SQL source '{}' started CDC polling", self.base.id);
291        Ok(())
292    }
293
294    async fn stop(&self) -> Result<()> {
295        use drasi_lib::channels::ComponentStatus;
296
297        log::info!("MS SQL source '{}' stopping", self.base.id);
298
299        // Signal the CDC polling loop to stop gracefully
300        if let Err(e) = self.shutdown_tx.send(true) {
301            log::warn!("Failed to send shutdown signal: {e}");
302        }
303
304        // Wait for the task to complete (with timeout)
305        if let Some(handle) = self.task_handle.write().await.take() {
306            // Give the task a chance to shut down gracefully
307            match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
308                Ok(Ok(())) => {
309                    log::debug!("CDC polling task stopped gracefully");
310                }
311                Ok(Err(e)) => {
312                    log::warn!("CDC polling task panicked: {e}");
313                }
314                Err(_) => {
315                    log::warn!("CDC polling task did not stop within timeout, it will be dropped");
316                }
317            }
318        }
319
320        self.base.set_status(ComponentStatus::Stopped, None).await;
321
322        Ok(())
323    }
324
325    async fn subscribe(
326        &self,
327        settings: drasi_lib::config::SourceSubscriptionSettings,
328    ) -> Result<drasi_lib::channels::SubscriptionResponse> {
329        self.base
330            .subscribe_with_bootstrap(&settings, "MS SQL")
331            .await
332    }
333
334    fn as_any(&self) -> &dyn std::any::Any {
335        self
336    }
337
338    async fn initialize(&self, context: drasi_lib::context::SourceRuntimeContext) {
339        self.base.initialize(context.clone()).await;
340
341        // Store state store if provided
342        if let Some(state_store) = context.state_store {
343            *self.state_store.write().await = Some(state_store);
344            log::debug!("State store injected into MS SQL source '{}'", self.base.id);
345        }
346    }
347
348    async fn set_bootstrap_provider(
349        &self,
350        provider: Box<dyn drasi_lib::bootstrap::BootstrapProvider + 'static>,
351    ) {
352        self.base.set_bootstrap_provider(provider).await;
353    }
354}
355
356/// Builder for MS SQL source
357///
358/// Provides a fluent API for constructing an MS SQL source with validation.
359pub struct MsSqlSourceBuilder {
360    id: String,
361    config: MsSqlSourceConfig,
362    bootstrap_provider: Option<Box<dyn drasi_lib::bootstrap::BootstrapProvider + 'static>>,
363}
364
365impl MsSqlSourceBuilder {
366    /// Create a new builder
367    pub fn new(id: impl Into<String>) -> Self {
368        Self {
369            id: id.into(),
370            config: MsSqlSourceConfig::default(),
371            bootstrap_provider: None,
372        }
373    }
374
375    /// Set the MS SQL server hostname
376    pub fn with_host(mut self, host: impl Into<String>) -> Self {
377        self.config.host = host.into();
378        self
379    }
380
381    /// Set the MS SQL server port
382    pub fn with_port(mut self, port: u16) -> Self {
383        self.config.port = port;
384        self
385    }
386
387    /// Set the database name
388    pub fn with_database(mut self, database: impl Into<String>) -> Self {
389        self.config.database = database.into();
390        self
391    }
392
393    /// Set the database user
394    pub fn with_user(mut self, user: impl Into<String>) -> Self {
395        self.config.user = user.into();
396        self
397    }
398
399    /// Set the database password
400    pub fn with_password(mut self, password: impl Into<String>) -> Self {
401        self.config.password = password.into();
402        self
403    }
404
405    /// Set the authentication mode
406    pub fn with_auth_mode(mut self, auth_mode: AuthMode) -> Self {
407        self.config.auth_mode = auth_mode;
408        self
409    }
410
411    /// Set the list of tables to monitor
412    pub fn with_tables(mut self, tables: Vec<String>) -> Self {
413        self.config.tables = tables;
414        self
415    }
416
417    /// Add a single table to monitor
418    pub fn with_table(mut self, table: impl Into<String>) -> Self {
419        self.config.tables.push(table.into());
420        self
421    }
422
423    /// Set the CDC polling interval in milliseconds
424    pub fn with_poll_interval_ms(mut self, ms: u64) -> Self {
425        self.config.poll_interval_ms = ms;
426        self
427    }
428
429    /// Set the encryption mode
430    pub fn with_encryption(mut self, encryption: EncryptionMode) -> Self {
431        self.config.encryption = encryption;
432        self
433    }
434
435    /// Set whether to trust the server certificate
436    pub fn with_trust_server_certificate(mut self, trust: bool) -> Self {
437        self.config.trust_server_certificate = trust;
438        self
439    }
440
441    /// Add a table key configuration
442    pub fn with_table_key(mut self, table: impl Into<String>, key_columns: Vec<String>) -> Self {
443        self.config.table_keys.push(TableKeyConfig {
444            table: table.into(),
445            key_columns,
446        });
447        self
448    }
449
450    /// Set the starting position when no LSN is found in state store
451    pub fn with_start_position(mut self, position: StartPosition) -> Self {
452        self.config.start_position = position;
453        self
454    }
455
456    /// Set the bootstrap provider for initial data delivery
457    pub fn with_bootstrap_provider(
458        mut self,
459        provider: impl drasi_lib::bootstrap::BootstrapProvider + 'static,
460    ) -> Self {
461        self.bootstrap_provider = Some(Box::new(provider));
462        self
463    }
464
465    /// Build the MS SQL source
466    ///
467    /// # Errors
468    /// Returns error if required configuration is missing
469    pub fn build(self) -> Result<MsSqlSource> {
470        // Validate required fields
471        if self.config.database.is_empty() {
472            return Err(anyhow::anyhow!("Database name is required"));
473        }
474        if self.config.user.is_empty() {
475            return Err(anyhow::anyhow!("Database user is required"));
476        }
477
478        let source_id = self.id.clone();
479
480        // Create base source parameters
481        let mut params = SourceBaseParams::new(&source_id);
482
483        // Add bootstrap provider if configured
484        if let Some(provider) = self.bootstrap_provider {
485            params = params.with_bootstrap_provider(provider);
486        }
487
488        // Create shutdown channel (false = running, true = shutdown)
489        let (shutdown_tx, shutdown_rx) = watch::channel(false);
490
491        Ok(MsSqlSource {
492            source_id,
493            config: self.config,
494            base: SourceBase::new(params)?,
495            state_store: Arc::new(RwLock::new(None)),
496            task_handle: Arc::new(RwLock::new(None)),
497            shutdown_tx,
498            shutdown_rx,
499        })
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506
507    #[test]
508    fn test_builder_basic() {
509        let source = MsSqlSource::builder("test-source")
510            .with_host("localhost")
511            .with_database("testdb")
512            .with_user("testuser")
513            .with_password("testpass")
514            .build()
515            .unwrap();
516
517        assert_eq!(source.id(), "test-source");
518        assert_eq!(source.type_name(), "mssql");
519        assert_eq!(source.config.host, "localhost");
520        assert_eq!(source.config.database, "testdb");
521    }
522
523    #[test]
524    fn test_builder_with_tables() {
525        let source = MsSqlSource::builder("test-source")
526            .with_database("testdb")
527            .with_user("testuser")
528            .with_tables(vec!["table1".to_string(), "table2".to_string()])
529            .build()
530            .unwrap();
531
532        assert_eq!(source.config.tables.len(), 2);
533    }
534
535    #[test]
536    fn test_builder_missing_required_fields() {
537        let result = MsSqlSource::builder("test-source")
538            .with_host("localhost")
539            .build();
540
541        assert!(result.is_err());
542    }
543
544    #[test]
545    fn test_builder_table_keys() {
546        let source = MsSqlSource::builder("test-source")
547            .with_database("testdb")
548            .with_user("testuser")
549            .with_table_key("orders", vec!["order_id".to_string()])
550            .build()
551            .unwrap();
552
553        assert_eq!(source.config.table_keys.len(), 1);
554        assert_eq!(source.config.table_keys[0].table, "orders");
555    }
556}
557
558/// Dynamic plugin entry point.
559///
560/// Dynamic plugin entry point.
561#[cfg(feature = "dynamic-plugin")]
562drasi_plugin_sdk::export_plugin!(
563    plugin_id = "mssql-source",
564    core_version = env!("CARGO_PKG_VERSION"),
565    lib_version = env!("CARGO_PKG_VERSION"),
566    plugin_version = env!("CARGO_PKG_VERSION"),
567    source_descriptors = [descriptor::MsSqlSourceDescriptor],
568    reaction_descriptors = [],
569    bootstrap_descriptors = [],
570);