1mod arrow_convert;
19mod proxy;
20
21pub use proxy::MssqlProxyKind;
22
23use std::sync::Arc;
24
25use arrow::datatypes::SchemaRef;
26use tiberius::{AuthMethod, Client, Config, EncryptionLevel};
27use tokio::net::TcpStream;
28use tokio::runtime::Runtime;
29use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
30
31use proxy::{detect_mssql_proxy_kind, warn_proxy_kind};
32
33use crate::config::TlsConfig;
34use crate::error::Result;
35use crate::source::batch_controller::{
36 AdaptiveBatchController, DEFAULT_BATCH_TARGET_MB, PROBE_BATCH_SIZE,
37};
38use crate::source::query::build_export_query;
39use crate::source::{BatchSink, ExportRequest, Source, TableIntrospection};
40use crate::types::{ColumnOverrides, TypeMapping};
41
42type MssqlClient = Client<Compat<TcpStream>>;
43
44pub struct MssqlSource {
50 rt: Runtime,
51 client: MssqlClient,
52 proxy_kind: MssqlProxyKind,
54 lock_timeout_applied: bool,
57}
58
59impl Drop for MssqlSource {
60 fn drop(&mut self) {
74 if !self.lock_timeout_applied {
75 return;
76 }
77 let Self { rt, client, .. } = self;
78 let _ = rt.block_on(async {
79 tokio::time::timeout(
80 std::time::Duration::from_secs(2),
81 client.execute("SET LOCK_TIMEOUT -1", &[]),
82 )
83 .await
84 });
85 }
86}
87
88struct MssqlUrl {
90 host: String,
91 port: u16,
92 user: String,
93 password: String,
94 database: String,
95}
96
97fn parse_mssql_url(url: &str) -> Result<MssqlUrl> {
98 let rest = url
99 .strip_prefix("sqlserver://")
100 .or_else(|| url.strip_prefix("mssql://"))
101 .ok_or_else(|| anyhow::anyhow!("mssql url must start with sqlserver:// — got {url}"))?;
102 let (userinfo, hostpart) = rest
105 .rsplit_once('@')
106 .ok_or_else(|| anyhow::anyhow!("mssql url missing user@host: {url}"))?;
107 let (user, password) = match userinfo.split_once(':') {
108 Some((u, p)) => (u.to_string(), p.to_string()),
109 None => (userinfo.to_string(), String::new()),
110 };
111 let (hostport, database) = hostpart
112 .split_once('/')
113 .map(|(h, d)| (h, d.to_string()))
114 .unwrap_or((hostpart, String::new()));
115 let (host, port) = match hostport.rsplit_once(':') {
116 Some((h, p)) => (
117 h.to_string(),
118 p.parse::<u16>()
119 .map_err(|_| anyhow::anyhow!("mssql url port not a number: {p}"))?,
120 ),
121 None => (hostport.to_string(), 1433),
122 };
123 if database.is_empty() {
124 anyhow::bail!("mssql url must include a database: sqlserver://user:pass@host:port/<db>");
125 }
126 Ok(MssqlUrl {
127 host,
128 port,
129 user,
130 password,
131 database,
132 })
133}
134
135impl MssqlSource {
136 pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
140 let parts = parse_mssql_url(url)?;
141 let mut config = Config::new();
142 config.host(&parts.host);
143 config.port(parts.port);
144 config.database(&parts.database);
145 config.authentication(AuthMethod::sql_server(&parts.user, &parts.password));
146
147 config.encryption(EncryptionLevel::Required);
152 match tls {
153 Some(cfg) if cfg.accept_invalid_certs => config.trust_cert(),
154 Some(cfg) => {
155 if let Some(ca) = cfg.ca_file.as_deref() {
156 config.trust_cert_ca(ca);
157 }
158 }
159 None => config.trust_cert(),
160 }
161
162 let rt = tokio::runtime::Builder::new_current_thread()
163 .enable_all()
164 .build()
165 .map_err(|e| anyhow::anyhow!("mssql: tokio runtime build failed: {e}"))?;
166
167 let client = rt.block_on(async {
168 let tcp = TcpStream::connect(config.get_addr())
169 .await
170 .map_err(|e| anyhow::anyhow!("mssql: TCP connect failed: {e}"))?;
171 tcp.set_nodelay(true).ok();
172 Client::connect(config, tcp.compat_write())
173 .await
174 .map_err(|e| anyhow::anyhow!("mssql: login failed: {e}"))
175 })?;
176
177 let mut src = Self {
178 rt,
179 client,
180 proxy_kind: MssqlProxyKind::Direct,
181 lock_timeout_applied: false,
182 };
183 src.query_scalar("SELECT 1")?;
186 let kind = detect_mssql_proxy_kind(&src.rt, &mut src.client);
191 warn_proxy_kind(kind);
192 src.proxy_kind = kind;
193 Ok(src)
194 }
195
196 #[allow(dead_code)]
200 pub fn proxy_kind(&self) -> MssqlProxyKind {
201 self.proxy_kind
202 }
203}
204
205impl Source for MssqlSource {
206 fn export(&mut self, request: &ExportRequest<'_>, sink: &mut dyn BatchSink) -> Result<()> {
207 let built = build_export_query(request, crate::config::SourceType::Mssql);
210 let sql = built.sql.clone();
211 let overrides = request.column_overrides.clone();
212 let mut ctl =
221 AdaptiveBatchController::new(request.tuning, request.tuning.batch_size.max(1));
222 let mut cap_applied = false;
223 let lock_timeout_ms = request.tuning.lock_timeout_s.saturating_mul(1000);
233 let stmt_timeout = (request.tuning.statement_timeout_s > 0)
234 .then(|| std::time::Duration::from_secs(request.tuning.statement_timeout_s));
235
236 if lock_timeout_ms > 0 {
239 self.lock_timeout_applied = true;
240 }
241
242 let Self { rt, client, .. } = self;
243 rt.block_on(async {
244 use futures_util::stream::TryStreamExt;
245 use tiberius::QueryItem;
246
247 if lock_timeout_ms > 0 {
248 client
249 .execute(format!("SET LOCK_TIMEOUT {lock_timeout_ms}"), &[])
250 .await
251 .map_err(|e| anyhow::anyhow!("mssql: SET LOCK_TIMEOUT failed: {e}"))?;
252 }
253
254 let started = std::time::Instant::now();
255 let mut stream = client
256 .query(sql.as_str(), &[])
257 .await
258 .map_err(|e| anyhow::anyhow!("mssql: query failed: {e}"))?;
259
260 let mut columns: Vec<tiberius::Column> = Vec::new();
261 let mut buf: Vec<tiberius::Row> = Vec::with_capacity(ctl.target());
262 let mut schema: Option<SchemaRef> = None;
263
264 while let Some(item) = stream
265 .try_next()
266 .await
267 .map_err(|e| anyhow::anyhow!("mssql: streaming rows failed: {e}"))?
268 {
269 if let Some(budget) = stmt_timeout
270 && started.elapsed() > budget
271 {
272 anyhow::bail!(
273 "mssql: statement timeout after {}s (tuning.statement_timeout_s) — \
274 this query cannot finish within the budget; split it with \
275 `mode: chunked` (per-chunk statements stay under the limit) or \
276 raise `tuning.statement_timeout_s`",
277 budget.as_secs()
278 );
279 }
280 match item {
281 QueryItem::Metadata(meta) if columns.is_empty() => {
284 columns = meta.columns().to_vec();
285 }
286 QueryItem::Metadata(_) => {}
287 QueryItem::Row(row) => {
288 buf.push(row);
289 if buf.len() >= ctl.target() {
290 let arrow_bytes =
291 emit_mssql_batch(&columns, &overrides, &mut schema, &buf, sink)?;
292 let n = buf.len();
293 buf.clear();
294 if !cap_applied && n > 0 {
299 let arrow_per_row = (arrow_bytes / n).max(64);
300 let target_mb = request
301 .tuning
302 .batch_size_memory_mb
303 .unwrap_or(DEFAULT_BATCH_TARGET_MB);
304 let safe = ((target_mb * 1024 * 1024) / arrow_per_row)
305 .max(PROBE_BATCH_SIZE);
306 if let Some(new) = ctl.apply_memory_cap(safe) {
307 log::info!(
308 "MSSQL batch cap: arrow≈{} B/row, target={} MB → batch_size → {}",
309 arrow_per_row,
310 target_mb,
311 new
312 );
313 buf.reserve(new.saturating_sub(buf.capacity()));
314 }
315 cap_applied = true;
316 }
317 ctl.after_batch(|| None);
319 ctl.throttle();
320 }
321 }
322 }
323 }
324 if !buf.is_empty() || schema.is_none() {
330 emit_mssql_batch(&columns, &overrides, &mut schema, &buf, sink)?;
331 }
332 Ok::<_, anyhow::Error>(())
333 })?;
334 Ok(())
335 }
336
337 fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
338 let Self { rt, client, .. } = self;
339 rt.block_on(async {
340 let row = client
341 .query(sql, &[])
342 .await
343 .map_err(|e| anyhow::anyhow!("mssql: scalar query failed: {e}"))?
344 .into_row()
345 .await
346 .map_err(|e| anyhow::anyhow!("mssql: reading scalar row failed: {e}"))?;
347 Ok(row.and_then(|r| scalar_to_string(&r)))
348 })
349 }
350
351 fn type_mappings(
352 &mut self,
353 query: &str,
354 column_overrides: &ColumnOverrides,
355 ) -> Result<Vec<TypeMapping>> {
356 let wrapped = format!("SELECT * FROM ({query}) AS _rivet_q WHERE 1 = 0");
358 let overrides = column_overrides.clone();
359 let Self { rt, client, .. } = self;
360 rt.block_on(async {
361 let mut stream = client
362 .query(wrapped.as_str(), &[])
363 .await
364 .map_err(|e| anyhow::anyhow!("mssql: type-probe query failed: {e}"))?;
365 let columns = stream
366 .columns()
367 .await
368 .map_err(|e| anyhow::anyhow!("mssql: type-probe metadata failed: {e}"))?
369 .map(<[_]>::to_vec)
370 .unwrap_or_default();
371 let _ = stream.into_first_result().await;
373 Ok(arrow_convert::mssql_type_mappings(&columns, &overrides))
374 })
375 }
376
377 fn sample_pressure(&mut self) -> Option<u64> {
378 let Self { rt, client, .. } = self;
379 let sql = "SELECT SUM(cntr_value) FROM sys.dm_os_performance_counters \
389 WHERE counter_name IN ('Workfiles Created/sec', 'Worktables Created/sec')";
390 rt.block_on(async {
391 let row = client.query(sql, &[]).await.ok()?.into_row().await.ok()??;
392 row.get::<i64, _>(0).map(|v| v.max(0) as u64)
393 })
394 }
395}
396
397fn emit_mssql_batch(
407 columns: &[tiberius::Column],
408 overrides: &ColumnOverrides,
409 schema: &mut Option<SchemaRef>,
410 rows: &[tiberius::Row],
411 sink: &mut dyn BatchSink,
412) -> Result<usize> {
413 let schema_ref = match schema {
414 Some(s) => s.clone(),
415 None => {
416 let (built, _decoders) =
417 arrow_convert::mssql_columns_to_schema(columns, overrides, rows)?;
418 let s: SchemaRef = Arc::new(built);
419 sink.on_schema(s.clone())?;
420 *schema = Some(s.clone());
421 s
422 }
423 };
424 if !rows.is_empty() {
425 let batch = arrow_convert::mssql_rows_to_record_batch(&schema_ref, rows)?;
426 let bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
427 sink.on_batch(&batch)?;
428 return Ok(bytes);
429 }
430 Ok(0)
431}
432
433fn scalar_to_string(row: &tiberius::Row) -> Option<String> {
437 use tiberius::ColumnData;
438 let cell = row.cells().next().map(|(_, d)| d)?;
439 match cell {
440 ColumnData::U8(v) => v.map(|x| x.to_string()),
441 ColumnData::I16(v) => v.map(|x| x.to_string()),
442 ColumnData::I32(v) => v.map(|x| x.to_string()),
443 ColumnData::I64(v) => v.map(|x| x.to_string()),
444 ColumnData::F32(v) => v.map(|x| x.to_string()),
445 ColumnData::F64(v) => v.map(|x| x.to_string()),
446 ColumnData::Bit(v) => v.map(|x| x.to_string()),
447 ColumnData::String(v) => v.as_ref().map(|s| s.to_string()),
448 ColumnData::Numeric(v) => v.map(|n| {
449 let raw = n.value();
451 let scale = n.scale() as usize;
452 if scale == 0 {
453 raw.to_string()
454 } else {
455 let neg = raw < 0;
456 let digits = raw.unsigned_abs().to_string();
457 let digits = format!("{digits:0>width$}", width = scale + 1);
458 let (int, frac) = digits.split_at(digits.len() - scale);
459 format!("{}{int}.{frac}", if neg { "-" } else { "" })
460 }
461 }),
462 ColumnData::Guid(v) => v.map(|g| g.to_string()),
463 other => Some(format!("{other:?}")),
464 }
465}
466
467pub(crate) fn introspect_mssql_table_for_chunking(
470 url: &str,
471 tls: Option<&TlsConfig>,
472 qualified_table: &str,
473) -> Result<TableIntrospection> {
474 let (schema, table) = match qualified_table.split_once('.') {
475 Some((s, t)) => (s.to_string(), t.to_string()),
476 None => ("dbo".to_string(), qualified_table.to_string()),
477 };
478 let mut src = MssqlSource::connect_with_tls(url, tls)?;
479
480 let count_sql = format!(
483 "SELECT SUM(p.row_count) FROM sys.dm_db_partition_stats p \
484 JOIN sys.objects o ON o.object_id = p.object_id \
485 JOIN sys.schemas s ON s.schema_id = o.schema_id \
486 WHERE s.name = N'{}' AND o.name = N'{}' AND p.index_id IN (0,1)",
487 schema.replace('\'', "''"),
488 table.replace('\'', "''")
489 );
490 let row_estimate = src
491 .query_scalar(&count_sql)?
492 .and_then(|s| s.parse::<i64>().ok())
493 .unwrap_or(0);
494
495 let pk_sql = format!(
498 "SELECT TOP 1 c.name, t.name FROM sys.indexes i \
499 JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id \
500 JOIN sys.columns c ON c.object_id = ic.object_id AND c.column_id = ic.column_id \
501 JOIN sys.types t ON t.user_type_id = c.user_type_id \
502 JOIN sys.objects o ON o.object_id = i.object_id \
503 JOIN sys.schemas s ON s.schema_id = o.schema_id \
504 WHERE i.is_primary_key = 1 AND s.name = N'{}' AND o.name = N'{}' \
505 GROUP BY c.name, t.name HAVING COUNT(*) = 1",
506 schema.replace('\'', "''"),
507 table.replace('\'', "''")
508 );
509 let keyset_sql = format!(
519 "SELECT STRING_AGG(col, CHAR(31)) WITHIN GROUP (ORDER BY is_pk DESC, col) FROM ( \
520 SELECT col, MAX(is_pk) AS is_pk FROM ( \
521 SELECT MIN(c.name) AS col, MAX(CONVERT(int, i.is_primary_key)) AS is_pk \
522 FROM sys.indexes i \
523 JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.key_ordinal > 0 \
524 JOIN sys.columns c ON c.object_id = ic.object_id AND c.column_id = ic.column_id \
525 JOIN sys.objects o ON o.object_id = i.object_id \
526 JOIN sys.schemas s ON s.schema_id = o.schema_id \
527 WHERE i.is_unique = 1 AND c.is_nullable = 0 AND s.name = N'{}' AND o.name = N'{}' \
528 GROUP BY i.object_id, i.index_id HAVING COUNT(*) = 1 \
529 ) per_index GROUP BY col \
530 ) deduped",
531 schema.replace('\'', "''"),
532 table.replace('\'', "''")
533 );
534 let keyset_keys: Vec<String> = src
535 .query_scalar(&keyset_sql)?
536 .map(|s| {
537 s.split('\u{1f}')
538 .filter(|c| !c.is_empty())
539 .map(str::to_string)
540 .collect()
541 })
542 .unwrap_or_default();
543
544 let mut single_int_pk = None;
547 if let Some(pk_col) = src.query_scalar(&pk_sql)? {
548 let type_sql = format!(
551 "SELECT t.name FROM sys.columns c \
552 JOIN sys.types t ON t.user_type_id = c.user_type_id \
553 JOIN sys.objects o ON o.object_id = c.object_id \
554 JOIN sys.schemas s ON s.schema_id = o.schema_id \
555 WHERE s.name = N'{}' AND o.name = N'{}' AND c.name = N'{}'",
556 schema.replace('\'', "''"),
557 table.replace('\'', "''"),
558 pk_col.replace('\'', "''")
559 );
560 if let Some(ty) = src.query_scalar(&type_sql)?
561 && matches!(ty.as_str(), "tinyint" | "smallint" | "int" | "bigint")
562 {
563 single_int_pk = Some(pk_col);
564 }
565 }
566
567 Ok(TableIntrospection {
568 single_int_pk,
569 keyset_keys,
570 row_estimate,
571 avg_row_bytes: None,
572 })
573}