use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::pin::Pin;
use std::sync::Mutex;
use std::time::Duration;
use async_trait::async_trait;
use faucet_core::check::{CheckContext, CheckReport, Probe};
use faucet_core::replication::{filter_incremental, max_replication_value, max_value};
use faucet_core::{FaucetError, Source, StreamPage};
use futures::{Stream, TryStreamExt};
use serde_json::Value;
use tiberius::{QueryItem, ToSql};
use faucet_common_mssql::{MssqlPool, build_pool, with_statement_timeout};
use crate::config::{MssqlReplication, MssqlSourceConfig};
use crate::convert::row_to_json;
pub struct MssqlSource {
config: MssqlSourceConfig,
pool: MssqlPool,
start_bookmark: Mutex<Option<Value>>,
}
impl MssqlSource {
pub async fn new(config: MssqlSourceConfig) -> Result<Self, FaucetError> {
config.validate()?;
let pool = build_pool(&config.connection, config.max_connections).await?;
Ok(Self {
config,
pool,
start_bookmark: Mutex::new(None),
})
}
fn timeout(&self) -> Option<Duration> {
match self.config.statement_timeout_secs {
0 => None,
secs => Some(Duration::from_secs(secs)),
}
}
fn current_start(&self) -> Option<Value> {
self.start_bookmark
.lock()
.expect("start_bookmark mutex poisoned")
.clone()
}
}
#[derive(Debug, Clone, PartialEq)]
struct IncrementalCtx {
column: String,
start: Value,
}
fn build_query_and_params(
config: &MssqlSourceConfig,
context: &HashMap<String, Value>,
start_bookmark: Option<&Value>,
) -> (String, Vec<Value>, Option<IncrementalCtx>) {
let (mut query, mut values) = if context.is_empty() {
(config.query.clone(), config.params.clone())
} else {
let (q, ctx_values) = faucet_core::util::substitute_context_bind_params(
&config.query,
context,
config.params.len() + 1,
|i| format!("@P{i}"),
);
let mut v = config.params.clone();
v.extend(ctx_values);
(q, v)
};
let incremental = match &config.replication {
MssqlReplication::Full => None,
MssqlReplication::Incremental {
column,
initial_value,
} => {
let start = start_bookmark
.cloned()
.unwrap_or_else(|| initial_value.clone());
if query.contains("@bookmark") {
let idx = values.len() + 1;
query = query.replace("@bookmark", &format!("@P{idx}"));
values.push(start.clone());
}
Some(IncrementalCtx {
column: column.clone(),
start,
})
}
};
(query, values, incremental)
}
enum OwnedParam {
I64(i64),
F64(f64),
Bool(bool),
Str(String),
Null(Option<i32>),
}
impl OwnedParam {
fn from_value(v: &Value) -> Self {
match v {
Value::String(s) => OwnedParam::Str(s.clone()),
Value::Number(n) if n.is_i64() => OwnedParam::I64(n.as_i64().unwrap()),
Value::Number(n) if n.is_u64() => OwnedParam::I64(n.as_u64().unwrap() as i64),
Value::Number(n) => OwnedParam::F64(n.as_f64().unwrap_or(0.0)),
Value::Bool(b) => OwnedParam::Bool(*b),
Value::Null => OwnedParam::Null(None),
other => OwnedParam::Str(other.to_string()),
}
}
fn as_tosql(&self) -> &dyn ToSql {
match self {
OwnedParam::I64(v) => v,
OwnedParam::F64(v) => v,
OwnedParam::Bool(v) => v,
OwnedParam::Str(v) => v,
OwnedParam::Null(v) => v,
}
}
}
fn default_state_key(config: &MssqlSourceConfig) -> String {
let host = config
.connection
.connection_url
.as_deref()
.and_then(|u| url::Url::parse(u).ok())
.and_then(|u| u.host_str().map(|h| h.to_string()))
.unwrap_or_else(|| "mssql".to_string());
let mut hasher = std::collections::hash_map::DefaultHasher::new();
config.query.hash(&mut hasher);
let fingerprint = hasher.finish();
let host: String = host
.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || matches!(c, '_' | '-' | '.') {
c
} else {
'_'
}
})
.collect();
format!("mssql:{host}:{fingerprint:016x}")
}
#[async_trait]
impl Source for MssqlSource {
async fn fetch_with_context(
&self,
context: &HashMap<String, Value>,
) -> Result<Vec<Value>, FaucetError> {
Ok(self.collect_all(context).await?.0)
}
async fn fetch_with_context_incremental(
&self,
context: &HashMap<String, Value>,
) -> Result<(Vec<Value>, Option<Value>), FaucetError> {
self.collect_all(context).await
}
fn stream_pages<'a>(
&'a self,
context: &'a HashMap<String, Value>,
_batch_size: usize,
) -> Pin<Box<dyn Stream<Item = Result<StreamPage, FaucetError>> + Send + 'a>> {
let batch_size = self.config.batch_size;
let chunk = if batch_size == 0 {
usize::MAX
} else {
batch_size
};
let cap = if batch_size == 0 { 1024 } else { batch_size };
let start = self.current_start();
let (query, values, incr) = build_query_and_params(&self.config, context, start.as_ref());
Box::pin(async_stream::try_stream! {
let mut conn = self
.pool
.get()
.await
.map_err(|e| FaucetError::Source(format!("MSSQL pool checkout failed: {e}")))?;
let mut stream = {
let owned: Vec<OwnedParam> = values.iter().map(OwnedParam::from_value).collect();
let refs: Vec<&dyn ToSql> = owned.iter().map(OwnedParam::as_tosql).collect();
let query_fut = conn.query(&query, &refs);
match self.timeout() {
Some(t) => {
with_statement_timeout(t, async {
query_fut.await.map_err(|e| {
FaucetError::Source(format!("MSSQL query failed: {e}"))
})
}, || FaucetError::Source("MSSQL query timed out".into()))
.await?
}
None => query_fut
.await
.map_err(|e| FaucetError::Source(format!("MSSQL query failed: {e}")))?,
}
};
let mut buffer: Vec<Value> = Vec::with_capacity(cap);
let mut running_max: Option<Value> = None;
let mut total = 0usize;
while let Some(item) = stream
.try_next()
.await
.map_err(|e| FaucetError::Source(format!("MSSQL row stream failed: {e}")))?
{
let QueryItem::Row(row) = item else { continue };
buffer.push(row_to_json(&row)?);
if buffer.len() >= chunk {
let page = std::mem::replace(&mut buffer, Vec::with_capacity(cap));
let kept = apply_incremental(page, incr.as_ref(), &mut running_max);
total += kept.len();
if !kept.is_empty() {
yield StreamPage { records: kept, bookmark: None };
}
}
}
let kept = apply_incremental(buffer, incr.as_ref(), &mut running_max);
total += kept.len();
let bookmark = if incr.is_some() { running_max.clone() } else { None };
if !kept.is_empty() || bookmark.is_some() {
yield StreamPage { records: kept, bookmark };
}
tracing::info!(rows = total, query = %self.config.query, "MSSQL source stream complete");
})
}
fn config_schema(&self) -> Value {
serde_json::to_value(faucet_core::schema_for!(MssqlSourceConfig))
.expect("schema serialization")
}
fn connector_name(&self) -> &'static str {
"mssql"
}
fn state_key(&self) -> Option<String> {
match &self.config.replication {
MssqlReplication::Full => None,
MssqlReplication::Incremental { .. } => Some(
self.config
.state_key
.clone()
.unwrap_or_else(|| default_state_key(&self.config)),
),
}
}
async fn apply_start_bookmark(&self, bookmark: Value) -> Result<(), FaucetError> {
*self
.start_bookmark
.lock()
.expect("start_bookmark mutex poisoned") = Some(bookmark);
Ok(())
}
async fn check(&self, ctx: &CheckContext) -> Result<CheckReport, FaucetError> {
let started = std::time::Instant::now();
let probe = match tokio::time::timeout(ctx.timeout, self.pool.get()).await {
Ok(Ok(_conn)) => Probe::pass("connect", started.elapsed()),
Ok(Err(e)) => Probe::fail_hint(
"connect",
started.elapsed(),
e.to_string(),
"check connection_url / credentials / TLS / that the server is reachable",
),
Err(_) => Probe::fail_hint(
"connect",
started.elapsed(),
"timed out",
"check connection_url / credentials / TLS / that the server is reachable",
),
};
Ok(CheckReport::single(probe))
}
}
impl MssqlSource {
async fn collect_all(
&self,
context: &HashMap<String, Value>,
) -> Result<(Vec<Value>, Option<Value>), FaucetError> {
let start = self.current_start();
let (query, values, incr) = build_query_and_params(&self.config, context, start.as_ref());
let mut conn = self
.pool
.get()
.await
.map_err(|e| FaucetError::Source(format!("MSSQL pool checkout failed: {e}")))?;
let rows = {
let owned: Vec<OwnedParam> = values.iter().map(OwnedParam::from_value).collect();
let refs: Vec<&dyn ToSql> = owned.iter().map(OwnedParam::as_tosql).collect();
let run = async {
conn.query(&query, &refs)
.await
.map_err(|e| FaucetError::Source(format!("MSSQL query failed: {e}")))?
.into_first_result()
.await
.map_err(|e| FaucetError::Source(format!("MSSQL result read failed: {e}")))
};
match self.timeout() {
Some(t) => {
with_statement_timeout(t, run, || {
FaucetError::Source("MSSQL query timed out".into())
})
.await?
}
None => run.await?,
}
};
let mut records = Vec::with_capacity(rows.len());
for row in &rows {
records.push(row_to_json(row)?);
}
let mut running_max: Option<Value> = None;
let records = apply_incremental(records, incr.as_ref(), &mut running_max);
let bookmark = if incr.is_some() { running_max } else { None };
Ok((records, bookmark))
}
}
fn apply_incremental(
page: Vec<Value>,
incr: Option<&IncrementalCtx>,
running_max: &mut Option<Value>,
) -> Vec<Value> {
match incr {
None => page,
Some(ctx) => {
let kept = filter_incremental(page, &ctx.column, &ctx.start);
if let Some(m) = max_replication_value(&kept, &ctx.column) {
let m = m.clone();
*running_max = Some(match running_max.take() {
Some(prev) => max_value(prev, m),
None => m,
});
}
kept
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn full_cfg() -> MssqlSourceConfig {
MssqlSourceConfig::new("mssql://sa:pw@db.example.com:1433/sales", "SELECT * FROM t")
}
#[test]
fn build_full_returns_query_and_params_unchanged() {
let mut cfg = full_cfg();
cfg.params = vec![json!(1), json!("x")];
let (q, v, incr) = build_query_and_params(&cfg, &HashMap::new(), None);
assert_eq!(q, "SELECT * FROM t");
assert_eq!(v, vec![json!(1), json!("x")]);
assert!(incr.is_none());
}
#[test]
fn build_incremental_binds_bookmark_token() {
let cfg = MssqlSourceConfig {
query: "SELECT * FROM t WHERE updated_at > @bookmark".into(),
replication: MssqlReplication::Incremental {
column: "updated_at".into(),
initial_value: json!("1970-01-01"),
},
..full_cfg()
};
let (q, v, incr) = build_query_and_params(&cfg, &HashMap::new(), None);
assert_eq!(q, "SELECT * FROM t WHERE updated_at > @P1");
assert_eq!(v, vec![json!("1970-01-01")]);
assert_eq!(
incr,
Some(IncrementalCtx {
column: "updated_at".into(),
start: json!("1970-01-01")
})
);
}
#[test]
fn build_incremental_uses_stored_bookmark_over_initial() {
let cfg = MssqlSourceConfig {
query: "SELECT * FROM t WHERE c > @bookmark".into(),
params: vec![json!("p0")],
replication: MssqlReplication::Incremental {
column: "c".into(),
initial_value: json!(0),
},
..full_cfg()
};
let stored = json!(500);
let (q, v, incr) = build_query_and_params(&cfg, &HashMap::new(), Some(&stored));
assert_eq!(q, "SELECT * FROM t WHERE c > @P2");
assert_eq!(v, vec![json!("p0"), json!(500)]);
assert_eq!(incr.unwrap().start, json!(500));
}
#[test]
fn build_incremental_without_token_still_returns_filter_ctx() {
let cfg = MssqlSourceConfig {
query: "SELECT * FROM t".into(),
replication: MssqlReplication::Incremental {
column: "c".into(),
initial_value: json!(0),
},
..full_cfg()
};
let (q, v, incr) = build_query_and_params(&cfg, &HashMap::new(), None);
assert_eq!(q, "SELECT * FROM t");
assert!(v.is_empty());
assert!(incr.is_some(), "client-side filter must still run");
}
#[test]
fn owned_param_classifies_json() {
assert!(matches!(
OwnedParam::from_value(&json!("s")),
OwnedParam::Str(_)
));
assert!(matches!(
OwnedParam::from_value(&json!(7)),
OwnedParam::I64(7)
));
assert!(matches!(
OwnedParam::from_value(&json!(1.5)),
OwnedParam::F64(_)
));
assert!(matches!(
OwnedParam::from_value(&json!(true)),
OwnedParam::Bool(true)
));
assert!(matches!(
OwnedParam::from_value(&Value::Null),
OwnedParam::Null(None)
));
assert!(matches!(
OwnedParam::from_value(&json!({"a":1})),
OwnedParam::Str(_)
));
}
#[test]
fn apply_incremental_filters_and_tracks_max() {
let ctx = IncrementalCtx {
column: "c".into(),
start: json!(10),
};
let mut running = None;
let page = vec![json!({"c": 5}), json!({"c": 15}), json!({"c": 20})];
let kept = apply_incremental(page, Some(&ctx), &mut running);
assert_eq!(kept.len(), 2);
assert_eq!(running, Some(json!(20)));
}
#[test]
fn apply_incremental_full_passes_through() {
let mut running = None;
let page = vec![json!({"c": 1}), json!({"c": 2})];
let kept = apply_incremental(page, None, &mut running);
assert_eq!(kept.len(), 2);
assert_eq!(running, None);
}
#[test]
fn default_state_key_is_stable_and_valid() {
let cfg = full_cfg();
let k1 = default_state_key(&cfg);
let k2 = default_state_key(&cfg);
assert_eq!(k1, k2);
assert!(k1.starts_with("mssql:db.example.com:"));
faucet_core::state::validate_state_key(&k1).expect("derived key must be valid");
}
}