Skip to main content

faucet_source_mssql/
stream.rs

1//! The MSSQL [`Source`] implementation — connection pool, query execution,
2//! streaming, and incremental-replication bookkeeping.
3
4use std::collections::HashMap;
5use std::hash::{Hash, Hasher};
6use std::pin::Pin;
7use std::sync::Mutex;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use faucet_core::check::{CheckContext, CheckReport, Probe};
12use faucet_core::replication::{filter_incremental, max_replication_value, max_value};
13use faucet_core::{FaucetError, Source, StreamPage};
14use futures::{Stream, TryStreamExt};
15use serde_json::Value;
16use tiberius::{QueryItem, ToSql};
17
18use faucet_common_mssql::{MssqlPool, build_pool, with_statement_timeout};
19
20use crate::config::{MssqlReplication, MssqlSourceConfig};
21use crate::convert::row_to_json;
22
23/// Microsoft SQL Server query source.
24pub struct MssqlSource {
25    config: MssqlSourceConfig,
26    pool: MssqlPool,
27    /// Bookmark loaded via [`Source::apply_start_bookmark`]; overrides the
28    /// configured `initial_value` for incremental runs.
29    start_bookmark: Mutex<Option<Value>>,
30}
31
32impl MssqlSource {
33    /// Connect, validate the config, and build the connection pool.
34    pub async fn new(config: MssqlSourceConfig) -> Result<Self, FaucetError> {
35        config.validate()?;
36        let pool = build_pool(&config.connection, config.max_connections).await?;
37        Ok(Self {
38            config,
39            pool,
40            start_bookmark: Mutex::new(None),
41        })
42    }
43
44    fn timeout(&self) -> Option<Duration> {
45        match self.config.statement_timeout_secs {
46            0 => None,
47            secs => Some(Duration::from_secs(secs)),
48        }
49    }
50
51    fn current_start(&self) -> Option<Value> {
52        self.start_bookmark
53            .lock()
54            .expect("start_bookmark mutex poisoned")
55            .clone()
56    }
57}
58
59/// Incremental-replication context resolved for one run.
60#[derive(Debug, Clone, PartialEq)]
61struct IncrementalCtx {
62    column: String,
63    start: Value,
64}
65
66/// Build the final query string, the ordered bind values, and (for incremental
67/// runs) the client-side filter context.
68///
69/// Pure function (no pool) so it is unit-testable. Param order is:
70/// `config.params` → context-substituted values → the incremental bookmark
71/// (only when the query contains the `@bookmark` token).
72fn build_query_and_params(
73    config: &MssqlSourceConfig,
74    context: &HashMap<String, Value>,
75    start_bookmark: Option<&Value>,
76) -> (String, Vec<Value>, Option<IncrementalCtx>) {
77    // Resolve parent-context placeholders to positional @P markers.
78    let (mut query, mut values) = if context.is_empty() {
79        (config.query.clone(), config.params.clone())
80    } else {
81        let (q, ctx_values) = faucet_core::util::substitute_context_bind_params(
82            &config.query,
83            context,
84            config.params.len() + 1,
85            |i| format!("@P{i}"),
86        );
87        let mut v = config.params.clone();
88        v.extend(ctx_values);
89        (q, v)
90    };
91
92    let incremental = match &config.replication {
93        MssqlReplication::Full => None,
94        MssqlReplication::Incremental {
95            column,
96            initial_value,
97        } => {
98            let start = start_bookmark
99                .cloned()
100                .unwrap_or_else(|| initial_value.clone());
101            // Server-side pushdown: bind the cursor where the user wrote
102            // `@bookmark`. If absent, only the client-side filter applies.
103            if query.contains("@bookmark") {
104                let idx = values.len() + 1;
105                query = query.replace("@bookmark", &format!("@P{idx}"));
106                values.push(start.clone());
107            }
108            Some(IncrementalCtx {
109                column: column.clone(),
110                start,
111            })
112        }
113    };
114
115    (query, values, incremental)
116}
117
118/// Owned bind parameter, so the borrowed `&dyn ToSql` slice handed to
119/// `tiberius` outlives nothing it shouldn't.
120enum OwnedParam {
121    I64(i64),
122    F64(f64),
123    Bool(bool),
124    Str(String),
125    Null(Option<i32>),
126}
127
128impl OwnedParam {
129    fn from_value(v: &Value) -> Self {
130        match v {
131            Value::String(s) => OwnedParam::Str(s.clone()),
132            Value::Number(n) if n.is_i64() => OwnedParam::I64(n.as_i64().unwrap()),
133            Value::Number(n) if n.is_u64() => OwnedParam::I64(n.as_u64().unwrap() as i64),
134            Value::Number(n) => OwnedParam::F64(n.as_f64().unwrap_or(0.0)),
135            Value::Bool(b) => OwnedParam::Bool(*b),
136            Value::Null => OwnedParam::Null(None),
137            other => OwnedParam::Str(other.to_string()),
138        }
139    }
140
141    fn as_tosql(&self) -> &dyn ToSql {
142        match self {
143            OwnedParam::I64(v) => v,
144            OwnedParam::F64(v) => v,
145            OwnedParam::Bool(v) => v,
146            OwnedParam::Str(v) => v,
147            OwnedParam::Null(v) => v,
148        }
149    }
150}
151
152/// Derive a default state-store key from the connection host + a query
153/// fingerprint, stable across runs.
154fn default_state_key(config: &MssqlSourceConfig) -> String {
155    let host = config
156        .connection
157        .connection_url
158        .as_deref()
159        .and_then(|u| url::Url::parse(u).ok())
160        .and_then(|u| u.host_str().map(|h| h.to_string()))
161        .unwrap_or_else(|| "mssql".to_string());
162
163    let mut hasher = std::collections::hash_map::DefaultHasher::new();
164    config.query.hash(&mut hasher);
165    let fingerprint = hasher.finish();
166    // Host may contain dots (allowed mid-key); sanitise anything else.
167    let host: String = host
168        .chars()
169        .map(|c| {
170            if c.is_ascii_alphanumeric() || matches!(c, '_' | '-' | '.') {
171                c
172            } else {
173                '_'
174            }
175        })
176        .collect();
177    format!("mssql:{host}:{fingerprint:016x}")
178}
179
180#[async_trait]
181impl Source for MssqlSource {
182    async fn fetch_with_context(
183        &self,
184        context: &HashMap<String, Value>,
185    ) -> Result<Vec<Value>, FaucetError> {
186        Ok(self.collect_all(context).await?.0)
187    }
188
189    async fn fetch_with_context_incremental(
190        &self,
191        context: &HashMap<String, Value>,
192    ) -> Result<(Vec<Value>, Option<Value>), FaucetError> {
193        self.collect_all(context).await
194    }
195
196    fn stream_pages<'a>(
197        &'a self,
198        context: &'a HashMap<String, Value>,
199        _batch_size: usize,
200    ) -> Pin<Box<dyn Stream<Item = Result<StreamPage, FaucetError>> + Send + 'a>> {
201        let batch_size = self.config.batch_size;
202        let chunk = if batch_size == 0 {
203            usize::MAX
204        } else {
205            batch_size
206        };
207        let cap = if batch_size == 0 { 1024 } else { batch_size };
208        let start = self.current_start();
209        let (query, values, incr) = build_query_and_params(&self.config, context, start.as_ref());
210
211        Box::pin(async_stream::try_stream! {
212            let mut conn = self
213                .pool
214                .get()
215                .await
216                .map_err(|e| FaucetError::Source(format!("MSSQL pool checkout failed: {e}")))?;
217
218            // Scope the borrowed param slice to the query() call — the
219            // QueryStream borrows the connection, not the params.
220            let mut stream = {
221                let owned: Vec<OwnedParam> = values.iter().map(OwnedParam::from_value).collect();
222                let refs: Vec<&dyn ToSql> = owned.iter().map(OwnedParam::as_tosql).collect();
223                let query_fut = conn.query(&query, &refs);
224                match self.timeout() {
225                    Some(t) => {
226                        with_statement_timeout(t, async {
227                            query_fut.await.map_err(|e| {
228                                FaucetError::Source(format!("MSSQL query failed: {e}"))
229                            })
230                        }, || FaucetError::Source("MSSQL query timed out".into()))
231                        .await?
232                    }
233                    None => query_fut
234                        .await
235                        .map_err(|e| FaucetError::Source(format!("MSSQL query failed: {e}")))?,
236                }
237            };
238
239            let mut buffer: Vec<Value> = Vec::with_capacity(cap);
240            let mut running_max: Option<Value> = None;
241            let mut total = 0usize;
242
243            while let Some(item) = stream
244                .try_next()
245                .await
246                .map_err(|e| FaucetError::Source(format!("MSSQL row stream failed: {e}")))?
247            {
248                let QueryItem::Row(row) = item else { continue };
249                buffer.push(row_to_json(&row)?);
250                if buffer.len() >= chunk {
251                    let page = std::mem::replace(&mut buffer, Vec::with_capacity(cap));
252                    let kept = apply_incremental(page, incr.as_ref(), &mut running_max);
253                    total += kept.len();
254                    if !kept.is_empty() {
255                        yield StreamPage { records: kept, bookmark: None };
256                    }
257                }
258            }
259
260            // Final page carries the bookmark so the pipeline persists only
261            // after everything before it has been written.
262            let kept = apply_incremental(buffer, incr.as_ref(), &mut running_max);
263            total += kept.len();
264            let bookmark = if incr.is_some() { running_max.clone() } else { None };
265            if !kept.is_empty() || bookmark.is_some() {
266                yield StreamPage { records: kept, bookmark };
267            }
268
269            tracing::info!(rows = total, query = %self.config.query, "MSSQL source stream complete");
270        })
271    }
272
273    fn config_schema(&self) -> Value {
274        serde_json::to_value(faucet_core::schema_for!(MssqlSourceConfig))
275            .expect("schema serialization")
276    }
277
278    fn connector_name(&self) -> &'static str {
279        "mssql"
280    }
281
282    fn state_key(&self) -> Option<String> {
283        match &self.config.replication {
284            MssqlReplication::Full => None,
285            MssqlReplication::Incremental { .. } => Some(
286                self.config
287                    .state_key
288                    .clone()
289                    .unwrap_or_else(|| default_state_key(&self.config)),
290            ),
291        }
292    }
293
294    async fn apply_start_bookmark(&self, bookmark: Value) -> Result<(), FaucetError> {
295        *self
296            .start_bookmark
297            .lock()
298            .expect("start_bookmark mutex poisoned") = Some(bookmark);
299        Ok(())
300    }
301
302    async fn check(&self, ctx: &CheckContext) -> Result<CheckReport, FaucetError> {
303        let started = std::time::Instant::now();
304        let probe = match tokio::time::timeout(ctx.timeout, self.pool.get()).await {
305            Ok(Ok(_conn)) => Probe::pass("connect", started.elapsed()),
306            Ok(Err(e)) => Probe::fail_hint(
307                "connect",
308                started.elapsed(),
309                e.to_string(),
310                "check connection_url / credentials / TLS / that the server is reachable",
311            ),
312            Err(_) => Probe::fail_hint(
313                "connect",
314                started.elapsed(),
315                "timed out",
316                "check connection_url / credentials / TLS / that the server is reachable",
317            ),
318        };
319        Ok(CheckReport::single(probe))
320    }
321}
322
323impl MssqlSource {
324    /// Run the query and return all decoded rows plus (for incremental) the new
325    /// bookmark. Used by the non-streaming convenience methods.
326    async fn collect_all(
327        &self,
328        context: &HashMap<String, Value>,
329    ) -> Result<(Vec<Value>, Option<Value>), FaucetError> {
330        let start = self.current_start();
331        let (query, values, incr) = build_query_and_params(&self.config, context, start.as_ref());
332
333        let mut conn = self
334            .pool
335            .get()
336            .await
337            .map_err(|e| FaucetError::Source(format!("MSSQL pool checkout failed: {e}")))?;
338
339        let rows = {
340            let owned: Vec<OwnedParam> = values.iter().map(OwnedParam::from_value).collect();
341            let refs: Vec<&dyn ToSql> = owned.iter().map(OwnedParam::as_tosql).collect();
342            let run = async {
343                conn.query(&query, &refs)
344                    .await
345                    .map_err(|e| FaucetError::Source(format!("MSSQL query failed: {e}")))?
346                    .into_first_result()
347                    .await
348                    .map_err(|e| FaucetError::Source(format!("MSSQL result read failed: {e}")))
349            };
350            match self.timeout() {
351                Some(t) => {
352                    with_statement_timeout(t, run, || {
353                        FaucetError::Source("MSSQL query timed out".into())
354                    })
355                    .await?
356                }
357                None => run.await?,
358            }
359        };
360
361        let mut records = Vec::with_capacity(rows.len());
362        for row in &rows {
363            records.push(row_to_json(row)?);
364        }
365
366        let mut running_max: Option<Value> = None;
367        let records = apply_incremental(records, incr.as_ref(), &mut running_max);
368        let bookmark = if incr.is_some() { running_max } else { None };
369        Ok((records, bookmark))
370    }
371}
372
373/// Filter a page for incremental replication and advance `running_max`.
374/// For full replication the page passes through unchanged.
375fn apply_incremental(
376    page: Vec<Value>,
377    incr: Option<&IncrementalCtx>,
378    running_max: &mut Option<Value>,
379) -> Vec<Value> {
380    match incr {
381        None => page,
382        Some(ctx) => {
383            let kept = filter_incremental(page, &ctx.column, &ctx.start);
384            if let Some(m) = max_replication_value(&kept, &ctx.column) {
385                let m = m.clone();
386                *running_max = Some(match running_max.take() {
387                    Some(prev) => max_value(prev, m),
388                    None => m,
389                });
390            }
391            kept
392        }
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use serde_json::json;
400
401    fn full_cfg() -> MssqlSourceConfig {
402        MssqlSourceConfig::new("mssql://sa:pw@db.example.com:1433/sales", "SELECT * FROM t")
403    }
404
405    #[test]
406    fn build_full_returns_query_and_params_unchanged() {
407        let mut cfg = full_cfg();
408        cfg.params = vec![json!(1), json!("x")];
409        let (q, v, incr) = build_query_and_params(&cfg, &HashMap::new(), None);
410        assert_eq!(q, "SELECT * FROM t");
411        assert_eq!(v, vec![json!(1), json!("x")]);
412        assert!(incr.is_none());
413    }
414
415    #[test]
416    fn build_incremental_binds_bookmark_token() {
417        let cfg = MssqlSourceConfig {
418            query: "SELECT * FROM t WHERE updated_at > @bookmark".into(),
419            replication: MssqlReplication::Incremental {
420                column: "updated_at".into(),
421                initial_value: json!("1970-01-01"),
422            },
423            ..full_cfg()
424        };
425        let (q, v, incr) = build_query_and_params(&cfg, &HashMap::new(), None);
426        assert_eq!(q, "SELECT * FROM t WHERE updated_at > @P1");
427        assert_eq!(v, vec![json!("1970-01-01")]);
428        assert_eq!(
429            incr,
430            Some(IncrementalCtx {
431                column: "updated_at".into(),
432                start: json!("1970-01-01")
433            })
434        );
435    }
436
437    #[test]
438    fn build_incremental_uses_stored_bookmark_over_initial() {
439        let cfg = MssqlSourceConfig {
440            query: "SELECT * FROM t WHERE c > @bookmark".into(),
441            params: vec![json!("p0")],
442            replication: MssqlReplication::Incremental {
443                column: "c".into(),
444                initial_value: json!(0),
445            },
446            ..full_cfg()
447        };
448        let stored = json!(500);
449        let (q, v, incr) = build_query_and_params(&cfg, &HashMap::new(), Some(&stored));
450        // bookmark bound after the one configured param → @P2
451        assert_eq!(q, "SELECT * FROM t WHERE c > @P2");
452        assert_eq!(v, vec![json!("p0"), json!(500)]);
453        assert_eq!(incr.unwrap().start, json!(500));
454    }
455
456    #[test]
457    fn build_incremental_without_token_still_returns_filter_ctx() {
458        let cfg = MssqlSourceConfig {
459            query: "SELECT * FROM t".into(),
460            replication: MssqlReplication::Incremental {
461                column: "c".into(),
462                initial_value: json!(0),
463            },
464            ..full_cfg()
465        };
466        let (q, v, incr) = build_query_and_params(&cfg, &HashMap::new(), None);
467        assert_eq!(q, "SELECT * FROM t");
468        assert!(v.is_empty());
469        assert!(incr.is_some(), "client-side filter must still run");
470    }
471
472    #[test]
473    fn owned_param_classifies_json() {
474        assert!(matches!(
475            OwnedParam::from_value(&json!("s")),
476            OwnedParam::Str(_)
477        ));
478        assert!(matches!(
479            OwnedParam::from_value(&json!(7)),
480            OwnedParam::I64(7)
481        ));
482        assert!(matches!(
483            OwnedParam::from_value(&json!(1.5)),
484            OwnedParam::F64(_)
485        ));
486        assert!(matches!(
487            OwnedParam::from_value(&json!(true)),
488            OwnedParam::Bool(true)
489        ));
490        assert!(matches!(
491            OwnedParam::from_value(&Value::Null),
492            OwnedParam::Null(None)
493        ));
494        assert!(matches!(
495            OwnedParam::from_value(&json!({"a":1})),
496            OwnedParam::Str(_)
497        ));
498    }
499
500    #[test]
501    fn apply_incremental_filters_and_tracks_max() {
502        let ctx = IncrementalCtx {
503            column: "c".into(),
504            start: json!(10),
505        };
506        let mut running = None;
507        let page = vec![json!({"c": 5}), json!({"c": 15}), json!({"c": 20})];
508        let kept = apply_incremental(page, Some(&ctx), &mut running);
509        assert_eq!(kept.len(), 2);
510        assert_eq!(running, Some(json!(20)));
511    }
512
513    #[test]
514    fn apply_incremental_full_passes_through() {
515        let mut running = None;
516        let page = vec![json!({"c": 1}), json!({"c": 2})];
517        let kept = apply_incremental(page, None, &mut running);
518        assert_eq!(kept.len(), 2);
519        assert_eq!(running, None);
520    }
521
522    #[test]
523    fn default_state_key_is_stable_and_valid() {
524        let cfg = full_cfg();
525        let k1 = default_state_key(&cfg);
526        let k2 = default_state_key(&cfg);
527        assert_eq!(k1, k2);
528        assert!(k1.starts_with("mssql:db.example.com:"));
529        faucet_core::state::validate_state_key(&k1).expect("derived key must be valid");
530    }
531}