1use std::collections::HashMap;
5use std::hash::{Hash, Hasher};
6use std::pin::Pin;
7use std::sync::Mutex;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use faucet_core::check::{CheckContext, CheckReport, Probe};
12use faucet_core::replication::{filter_incremental, max_replication_value, max_value};
13use faucet_core::{FaucetError, Source, StreamPage};
14use futures::{Stream, TryStreamExt};
15use serde_json::Value;
16use tiberius::{QueryItem, ToSql};
17
18use faucet_common_mssql::{MssqlPool, build_pool, with_statement_timeout};
19
20use crate::config::{MssqlReplication, MssqlSourceConfig};
21use crate::convert::row_to_json;
22
23pub struct MssqlSource {
25 config: MssqlSourceConfig,
26 pool: MssqlPool,
27 start_bookmark: Mutex<Option<Value>>,
30}
31
32impl MssqlSource {
33 pub async fn new(config: MssqlSourceConfig) -> Result<Self, FaucetError> {
35 config.validate()?;
36 let pool = build_pool(&config.connection, config.max_connections).await?;
37 Ok(Self {
38 config,
39 pool,
40 start_bookmark: Mutex::new(None),
41 })
42 }
43
44 fn timeout(&self) -> Option<Duration> {
45 match self.config.statement_timeout_secs {
46 0 => None,
47 secs => Some(Duration::from_secs(secs)),
48 }
49 }
50
51 fn current_start(&self) -> Option<Value> {
52 self.start_bookmark
53 .lock()
54 .expect("start_bookmark mutex poisoned")
55 .clone()
56 }
57}
58
59#[derive(Debug, Clone, PartialEq)]
61struct IncrementalCtx {
62 column: String,
63 start: Value,
64}
65
66fn build_query_and_params(
73 config: &MssqlSourceConfig,
74 context: &HashMap<String, Value>,
75 start_bookmark: Option<&Value>,
76) -> (String, Vec<Value>, Option<IncrementalCtx>) {
77 let (mut query, mut values) = if context.is_empty() {
79 (config.query.clone(), config.params.clone())
80 } else {
81 let (q, ctx_values) = faucet_core::util::substitute_context_bind_params(
82 &config.query,
83 context,
84 config.params.len() + 1,
85 |i| format!("@P{i}"),
86 );
87 let mut v = config.params.clone();
88 v.extend(ctx_values);
89 (q, v)
90 };
91
92 let incremental = match &config.replication {
93 MssqlReplication::Full => None,
94 MssqlReplication::Incremental {
95 column,
96 initial_value,
97 } => {
98 let start = start_bookmark
99 .cloned()
100 .unwrap_or_else(|| initial_value.clone());
101 if query.contains("@bookmark") {
104 let idx = values.len() + 1;
105 query = query.replace("@bookmark", &format!("@P{idx}"));
106 values.push(start.clone());
107 }
108 Some(IncrementalCtx {
109 column: column.clone(),
110 start,
111 })
112 }
113 };
114
115 (query, values, incremental)
116}
117
118enum OwnedParam {
121 I64(i64),
122 F64(f64),
123 Bool(bool),
124 Str(String),
125 Null(Option<i32>),
126}
127
128impl OwnedParam {
129 fn from_value(v: &Value) -> Self {
130 match v {
131 Value::String(s) => OwnedParam::Str(s.clone()),
132 Value::Number(n) if n.is_i64() => OwnedParam::I64(n.as_i64().unwrap()),
133 Value::Number(n) if n.is_u64() => OwnedParam::I64(n.as_u64().unwrap() as i64),
134 Value::Number(n) => OwnedParam::F64(n.as_f64().unwrap_or(0.0)),
135 Value::Bool(b) => OwnedParam::Bool(*b),
136 Value::Null => OwnedParam::Null(None),
137 other => OwnedParam::Str(other.to_string()),
138 }
139 }
140
141 fn as_tosql(&self) -> &dyn ToSql {
142 match self {
143 OwnedParam::I64(v) => v,
144 OwnedParam::F64(v) => v,
145 OwnedParam::Bool(v) => v,
146 OwnedParam::Str(v) => v,
147 OwnedParam::Null(v) => v,
148 }
149 }
150}
151
152fn default_state_key(config: &MssqlSourceConfig) -> String {
155 let host = config
156 .connection
157 .connection_url
158 .as_deref()
159 .and_then(|u| url::Url::parse(u).ok())
160 .and_then(|u| u.host_str().map(|h| h.to_string()))
161 .unwrap_or_else(|| "mssql".to_string());
162
163 let mut hasher = std::collections::hash_map::DefaultHasher::new();
164 config.query.hash(&mut hasher);
165 let fingerprint = hasher.finish();
166 let host: String = host
168 .chars()
169 .map(|c| {
170 if c.is_ascii_alphanumeric() || matches!(c, '_' | '-' | '.') {
171 c
172 } else {
173 '_'
174 }
175 })
176 .collect();
177 format!("mssql:{host}:{fingerprint:016x}")
178}
179
180#[async_trait]
181impl Source for MssqlSource {
182 async fn fetch_with_context(
183 &self,
184 context: &HashMap<String, Value>,
185 ) -> Result<Vec<Value>, FaucetError> {
186 Ok(self.collect_all(context).await?.0)
187 }
188
189 async fn fetch_with_context_incremental(
190 &self,
191 context: &HashMap<String, Value>,
192 ) -> Result<(Vec<Value>, Option<Value>), FaucetError> {
193 self.collect_all(context).await
194 }
195
196 fn stream_pages<'a>(
197 &'a self,
198 context: &'a HashMap<String, Value>,
199 _batch_size: usize,
200 ) -> Pin<Box<dyn Stream<Item = Result<StreamPage, FaucetError>> + Send + 'a>> {
201 let batch_size = self.config.batch_size;
202 let chunk = if batch_size == 0 {
203 usize::MAX
204 } else {
205 batch_size
206 };
207 let cap = if batch_size == 0 { 1024 } else { batch_size };
208 let start = self.current_start();
209 let (query, values, incr) = build_query_and_params(&self.config, context, start.as_ref());
210
211 Box::pin(async_stream::try_stream! {
212 let mut conn = self
213 .pool
214 .get()
215 .await
216 .map_err(|e| FaucetError::Source(format!("MSSQL pool checkout failed: {e}")))?;
217
218 let mut stream = {
221 let owned: Vec<OwnedParam> = values.iter().map(OwnedParam::from_value).collect();
222 let refs: Vec<&dyn ToSql> = owned.iter().map(OwnedParam::as_tosql).collect();
223 let query_fut = conn.query(&query, &refs);
224 match self.timeout() {
225 Some(t) => {
226 with_statement_timeout(t, async {
227 query_fut.await.map_err(|e| {
228 FaucetError::Source(format!("MSSQL query failed: {e}"))
229 })
230 }, || FaucetError::Source("MSSQL query timed out".into()))
231 .await?
232 }
233 None => query_fut
234 .await
235 .map_err(|e| FaucetError::Source(format!("MSSQL query failed: {e}")))?,
236 }
237 };
238
239 let mut buffer: Vec<Value> = Vec::with_capacity(cap);
240 let mut running_max: Option<Value> = None;
241 let mut total = 0usize;
242
243 while let Some(item) = stream
244 .try_next()
245 .await
246 .map_err(|e| FaucetError::Source(format!("MSSQL row stream failed: {e}")))?
247 {
248 let QueryItem::Row(row) = item else { continue };
249 buffer.push(row_to_json(&row)?);
250 if buffer.len() >= chunk {
251 let page = std::mem::replace(&mut buffer, Vec::with_capacity(cap));
252 let kept = apply_incremental(page, incr.as_ref(), &mut running_max);
253 total += kept.len();
254 if !kept.is_empty() {
255 yield StreamPage { records: kept, bookmark: None };
256 }
257 }
258 }
259
260 let kept = apply_incremental(buffer, incr.as_ref(), &mut running_max);
263 total += kept.len();
264 let bookmark = if incr.is_some() { running_max.clone() } else { None };
265 if !kept.is_empty() || bookmark.is_some() {
266 yield StreamPage { records: kept, bookmark };
267 }
268
269 tracing::info!(rows = total, query = %self.config.query, "MSSQL source stream complete");
270 })
271 }
272
273 fn config_schema(&self) -> Value {
274 serde_json::to_value(faucet_core::schema_for!(MssqlSourceConfig))
275 .expect("schema serialization")
276 }
277
278 fn connector_name(&self) -> &'static str {
279 "mssql"
280 }
281
282 fn state_key(&self) -> Option<String> {
283 match &self.config.replication {
284 MssqlReplication::Full => None,
285 MssqlReplication::Incremental { .. } => Some(
286 self.config
287 .state_key
288 .clone()
289 .unwrap_or_else(|| default_state_key(&self.config)),
290 ),
291 }
292 }
293
294 async fn apply_start_bookmark(&self, bookmark: Value) -> Result<(), FaucetError> {
295 *self
296 .start_bookmark
297 .lock()
298 .expect("start_bookmark mutex poisoned") = Some(bookmark);
299 Ok(())
300 }
301
302 async fn check(&self, ctx: &CheckContext) -> Result<CheckReport, FaucetError> {
303 let started = std::time::Instant::now();
304 let probe = match tokio::time::timeout(ctx.timeout, self.pool.get()).await {
305 Ok(Ok(_conn)) => Probe::pass("connect", started.elapsed()),
306 Ok(Err(e)) => Probe::fail_hint(
307 "connect",
308 started.elapsed(),
309 e.to_string(),
310 "check connection_url / credentials / TLS / that the server is reachable",
311 ),
312 Err(_) => Probe::fail_hint(
313 "connect",
314 started.elapsed(),
315 "timed out",
316 "check connection_url / credentials / TLS / that the server is reachable",
317 ),
318 };
319 Ok(CheckReport::single(probe))
320 }
321}
322
323impl MssqlSource {
324 async fn collect_all(
327 &self,
328 context: &HashMap<String, Value>,
329 ) -> Result<(Vec<Value>, Option<Value>), FaucetError> {
330 let start = self.current_start();
331 let (query, values, incr) = build_query_and_params(&self.config, context, start.as_ref());
332
333 let mut conn = self
334 .pool
335 .get()
336 .await
337 .map_err(|e| FaucetError::Source(format!("MSSQL pool checkout failed: {e}")))?;
338
339 let rows = {
340 let owned: Vec<OwnedParam> = values.iter().map(OwnedParam::from_value).collect();
341 let refs: Vec<&dyn ToSql> = owned.iter().map(OwnedParam::as_tosql).collect();
342 let run = async {
343 conn.query(&query, &refs)
344 .await
345 .map_err(|e| FaucetError::Source(format!("MSSQL query failed: {e}")))?
346 .into_first_result()
347 .await
348 .map_err(|e| FaucetError::Source(format!("MSSQL result read failed: {e}")))
349 };
350 match self.timeout() {
351 Some(t) => {
352 with_statement_timeout(t, run, || {
353 FaucetError::Source("MSSQL query timed out".into())
354 })
355 .await?
356 }
357 None => run.await?,
358 }
359 };
360
361 let mut records = Vec::with_capacity(rows.len());
362 for row in &rows {
363 records.push(row_to_json(row)?);
364 }
365
366 let mut running_max: Option<Value> = None;
367 let records = apply_incremental(records, incr.as_ref(), &mut running_max);
368 let bookmark = if incr.is_some() { running_max } else { None };
369 Ok((records, bookmark))
370 }
371}
372
373fn apply_incremental(
376 page: Vec<Value>,
377 incr: Option<&IncrementalCtx>,
378 running_max: &mut Option<Value>,
379) -> Vec<Value> {
380 match incr {
381 None => page,
382 Some(ctx) => {
383 let kept = filter_incremental(page, &ctx.column, &ctx.start);
384 if let Some(m) = max_replication_value(&kept, &ctx.column) {
385 let m = m.clone();
386 *running_max = Some(match running_max.take() {
387 Some(prev) => max_value(prev, m),
388 None => m,
389 });
390 }
391 kept
392 }
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use serde_json::json;
400
401 fn full_cfg() -> MssqlSourceConfig {
402 MssqlSourceConfig::new("mssql://sa:pw@db.example.com:1433/sales", "SELECT * FROM t")
403 }
404
405 #[test]
406 fn build_full_returns_query_and_params_unchanged() {
407 let mut cfg = full_cfg();
408 cfg.params = vec![json!(1), json!("x")];
409 let (q, v, incr) = build_query_and_params(&cfg, &HashMap::new(), None);
410 assert_eq!(q, "SELECT * FROM t");
411 assert_eq!(v, vec![json!(1), json!("x")]);
412 assert!(incr.is_none());
413 }
414
415 #[test]
416 fn build_incremental_binds_bookmark_token() {
417 let cfg = MssqlSourceConfig {
418 query: "SELECT * FROM t WHERE updated_at > @bookmark".into(),
419 replication: MssqlReplication::Incremental {
420 column: "updated_at".into(),
421 initial_value: json!("1970-01-01"),
422 },
423 ..full_cfg()
424 };
425 let (q, v, incr) = build_query_and_params(&cfg, &HashMap::new(), None);
426 assert_eq!(q, "SELECT * FROM t WHERE updated_at > @P1");
427 assert_eq!(v, vec![json!("1970-01-01")]);
428 assert_eq!(
429 incr,
430 Some(IncrementalCtx {
431 column: "updated_at".into(),
432 start: json!("1970-01-01")
433 })
434 );
435 }
436
437 #[test]
438 fn build_incremental_uses_stored_bookmark_over_initial() {
439 let cfg = MssqlSourceConfig {
440 query: "SELECT * FROM t WHERE c > @bookmark".into(),
441 params: vec![json!("p0")],
442 replication: MssqlReplication::Incremental {
443 column: "c".into(),
444 initial_value: json!(0),
445 },
446 ..full_cfg()
447 };
448 let stored = json!(500);
449 let (q, v, incr) = build_query_and_params(&cfg, &HashMap::new(), Some(&stored));
450 assert_eq!(q, "SELECT * FROM t WHERE c > @P2");
452 assert_eq!(v, vec![json!("p0"), json!(500)]);
453 assert_eq!(incr.unwrap().start, json!(500));
454 }
455
456 #[test]
457 fn build_incremental_without_token_still_returns_filter_ctx() {
458 let cfg = MssqlSourceConfig {
459 query: "SELECT * FROM t".into(),
460 replication: MssqlReplication::Incremental {
461 column: "c".into(),
462 initial_value: json!(0),
463 },
464 ..full_cfg()
465 };
466 let (q, v, incr) = build_query_and_params(&cfg, &HashMap::new(), None);
467 assert_eq!(q, "SELECT * FROM t");
468 assert!(v.is_empty());
469 assert!(incr.is_some(), "client-side filter must still run");
470 }
471
472 #[test]
473 fn owned_param_classifies_json() {
474 assert!(matches!(
475 OwnedParam::from_value(&json!("s")),
476 OwnedParam::Str(_)
477 ));
478 assert!(matches!(
479 OwnedParam::from_value(&json!(7)),
480 OwnedParam::I64(7)
481 ));
482 assert!(matches!(
483 OwnedParam::from_value(&json!(1.5)),
484 OwnedParam::F64(_)
485 ));
486 assert!(matches!(
487 OwnedParam::from_value(&json!(true)),
488 OwnedParam::Bool(true)
489 ));
490 assert!(matches!(
491 OwnedParam::from_value(&Value::Null),
492 OwnedParam::Null(None)
493 ));
494 assert!(matches!(
495 OwnedParam::from_value(&json!({"a":1})),
496 OwnedParam::Str(_)
497 ));
498 }
499
500 #[test]
501 fn apply_incremental_filters_and_tracks_max() {
502 let ctx = IncrementalCtx {
503 column: "c".into(),
504 start: json!(10),
505 };
506 let mut running = None;
507 let page = vec![json!({"c": 5}), json!({"c": 15}), json!({"c": 20})];
508 let kept = apply_incremental(page, Some(&ctx), &mut running);
509 assert_eq!(kept.len(), 2);
510 assert_eq!(running, Some(json!(20)));
511 }
512
513 #[test]
514 fn apply_incremental_full_passes_through() {
515 let mut running = None;
516 let page = vec![json!({"c": 1}), json!({"c": 2})];
517 let kept = apply_incremental(page, None, &mut running);
518 assert_eq!(kept.len(), 2);
519 assert_eq!(running, None);
520 }
521
522 #[test]
523 fn default_state_key_is_stable_and_valid() {
524 let cfg = full_cfg();
525 let k1 = default_state_key(&cfg);
526 let k2 = default_state_key(&cfg);
527 assert_eq!(k1, k2);
528 assert!(k1.starts_with("mssql:db.example.com:"));
529 faucet_core::state::validate_state_key(&k1).expect("derived key must be valid");
530 }
531}