1mod arrow_convert;
19mod proxy;
20
21pub use proxy::MssqlProxyKind;
22
23use std::sync::Arc;
24
25use arrow::datatypes::SchemaRef;
26use tiberius::{AuthMethod, Client, Config, EncryptionLevel};
27use tokio::net::TcpStream;
28use tokio::runtime::Runtime;
29use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
30
31use proxy::{detect_mssql_proxy_kind, warn_proxy_kind};
32
33use crate::config::TlsConfig;
34use crate::error::Result;
35use crate::source::batch_controller::{
36 AdaptiveBatchController, DEFAULT_BATCH_TARGET_MB, PROBE_BATCH_SIZE,
37};
38use crate::source::query::build_export_query;
39use crate::source::{BatchSink, ExportRequest, Source, TableIntrospection};
40use crate::types::{ColumnOverrides, TypeMapping};
41
42type MssqlClient = Client<Compat<TcpStream>>;
43
44pub struct MssqlSource {
50 rt: Runtime,
51 client: MssqlClient,
52 database: String,
54 proxy_kind: MssqlProxyKind,
56}
57
58struct MssqlUrl {
60 host: String,
61 port: u16,
62 user: String,
63 password: String,
64 database: String,
65}
66
67fn parse_mssql_url(url: &str) -> Result<MssqlUrl> {
68 let rest = url
69 .strip_prefix("sqlserver://")
70 .or_else(|| url.strip_prefix("mssql://"))
71 .ok_or_else(|| anyhow::anyhow!("mssql url must start with sqlserver:// — got {url}"))?;
72 let (userinfo, hostpart) = rest
75 .rsplit_once('@')
76 .ok_or_else(|| anyhow::anyhow!("mssql url missing user@host: {url}"))?;
77 let (user, password) = match userinfo.split_once(':') {
78 Some((u, p)) => (u.to_string(), p.to_string()),
79 None => (userinfo.to_string(), String::new()),
80 };
81 let (hostport, database) = hostpart
82 .split_once('/')
83 .map(|(h, d)| (h, d.to_string()))
84 .unwrap_or((hostpart, String::new()));
85 let (host, port) = match hostport.rsplit_once(':') {
86 Some((h, p)) => (
87 h.to_string(),
88 p.parse::<u16>()
89 .map_err(|_| anyhow::anyhow!("mssql url port not a number: {p}"))?,
90 ),
91 None => (hostport.to_string(), 1433),
92 };
93 if database.is_empty() {
94 anyhow::bail!("mssql url must include a database: sqlserver://user:pass@host:port/<db>");
95 }
96 Ok(MssqlUrl {
97 host,
98 port,
99 user,
100 password,
101 database,
102 })
103}
104
105impl MssqlSource {
106 pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
110 let parts = parse_mssql_url(url)?;
111 let mut config = Config::new();
112 config.host(&parts.host);
113 config.port(parts.port);
114 config.database(&parts.database);
115 config.authentication(AuthMethod::sql_server(&parts.user, &parts.password));
116
117 config.encryption(EncryptionLevel::Required);
122 match tls {
123 Some(cfg) if cfg.accept_invalid_certs => config.trust_cert(),
124 Some(cfg) => {
125 if let Some(ca) = cfg.ca_file.as_deref() {
126 config.trust_cert_ca(ca);
127 }
128 }
129 None => config.trust_cert(),
130 }
131
132 let rt = tokio::runtime::Builder::new_current_thread()
133 .enable_all()
134 .build()
135 .map_err(|e| anyhow::anyhow!("mssql: tokio runtime build failed: {e}"))?;
136
137 let client = rt.block_on(async {
138 let tcp = TcpStream::connect(config.get_addr())
139 .await
140 .map_err(|e| anyhow::anyhow!("mssql: TCP connect failed: {e}"))?;
141 tcp.set_nodelay(true).ok();
142 Client::connect(config, tcp.compat_write())
143 .await
144 .map_err(|e| anyhow::anyhow!("mssql: login failed: {e}"))
145 })?;
146
147 let mut src = Self {
148 rt,
149 client,
150 database: parts.database,
151 proxy_kind: MssqlProxyKind::Direct,
152 };
153 src.query_scalar("SELECT 1")?;
156 let kind = detect_mssql_proxy_kind(&src.rt, &mut src.client);
161 warn_proxy_kind(kind);
162 src.proxy_kind = kind;
163 Ok(src)
164 }
165
166 #[allow(dead_code)]
170 pub fn proxy_kind(&self) -> MssqlProxyKind {
171 self.proxy_kind
172 }
173}
174
175impl Source for MssqlSource {
176 fn export(&mut self, request: &ExportRequest<'_>, sink: &mut dyn BatchSink) -> Result<()> {
177 let built = build_export_query(request, crate::config::SourceType::Mssql);
180 let sql = built.sql.clone();
181 let overrides = request.column_overrides.clone();
182 let mut ctl =
191 AdaptiveBatchController::new(request.tuning, request.tuning.batch_size.max(1));
192 let mut cap_applied = false;
193 let lock_timeout_ms = request.tuning.lock_timeout_s.saturating_mul(1000);
203 let stmt_timeout = (request.tuning.statement_timeout_s > 0)
204 .then(|| std::time::Duration::from_secs(request.tuning.statement_timeout_s));
205
206 let Self { rt, client, .. } = self;
207 rt.block_on(async {
208 use futures_util::stream::TryStreamExt;
209 use tiberius::QueryItem;
210
211 if lock_timeout_ms > 0 {
212 client
213 .execute(format!("SET LOCK_TIMEOUT {lock_timeout_ms}"), &[])
214 .await
215 .map_err(|e| anyhow::anyhow!("mssql: SET LOCK_TIMEOUT failed: {e}"))?;
216 }
217
218 let started = std::time::Instant::now();
219 let mut stream = client
220 .query(sql.as_str(), &[])
221 .await
222 .map_err(|e| anyhow::anyhow!("mssql: query failed: {e}"))?;
223
224 let mut columns: Vec<tiberius::Column> = Vec::new();
225 let mut buf: Vec<tiberius::Row> = Vec::with_capacity(ctl.target());
226 let mut schema: Option<SchemaRef> = None;
227
228 while let Some(item) = stream
229 .try_next()
230 .await
231 .map_err(|e| anyhow::anyhow!("mssql: streaming rows failed: {e}"))?
232 {
233 if let Some(budget) = stmt_timeout
234 && started.elapsed() > budget
235 {
236 anyhow::bail!(
237 "mssql: statement timeout after {}s (tuning.statement_timeout_s)",
238 budget.as_secs()
239 );
240 }
241 match item {
242 QueryItem::Metadata(meta) if columns.is_empty() => {
245 columns = meta.columns().to_vec();
246 }
247 QueryItem::Metadata(_) => {}
248 QueryItem::Row(row) => {
249 buf.push(row);
250 if buf.len() >= ctl.target() {
251 let arrow_bytes =
252 emit_mssql_batch(&columns, &overrides, &mut schema, &buf, sink)?;
253 let n = buf.len();
254 buf.clear();
255 if !cap_applied && n > 0 {
260 let arrow_per_row = (arrow_bytes / n).max(64);
261 let target_mb = request
262 .tuning
263 .batch_size_memory_mb
264 .unwrap_or(DEFAULT_BATCH_TARGET_MB);
265 let safe = ((target_mb * 1024 * 1024) / arrow_per_row)
266 .max(PROBE_BATCH_SIZE);
267 if let Some(new) = ctl.apply_memory_cap(safe) {
268 log::info!(
269 "MSSQL batch cap: arrow≈{} B/row, target={} MB → batch_size → {}",
270 arrow_per_row,
271 target_mb,
272 new
273 );
274 buf.reserve(new.saturating_sub(buf.capacity()));
275 }
276 cap_applied = true;
277 }
278 ctl.after_batch(|| None);
280 ctl.throttle();
281 }
282 }
283 }
284 }
285 if !buf.is_empty() || schema.is_none() {
291 emit_mssql_batch(&columns, &overrides, &mut schema, &buf, sink)?;
292 }
293 Ok::<_, anyhow::Error>(())
294 })?;
295 Ok(())
296 }
297
298 fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
299 let Self { rt, client, .. } = self;
300 rt.block_on(async {
301 let row = client
302 .query(sql, &[])
303 .await
304 .map_err(|e| anyhow::anyhow!("mssql: scalar query failed: {e}"))?
305 .into_row()
306 .await
307 .map_err(|e| anyhow::anyhow!("mssql: reading scalar row failed: {e}"))?;
308 Ok(row.and_then(|r| scalar_to_string(&r)))
309 })
310 }
311
312 fn type_mappings(
313 &mut self,
314 query: &str,
315 column_overrides: &ColumnOverrides,
316 ) -> Result<Vec<TypeMapping>> {
317 let wrapped = format!("SELECT * FROM ({query}) AS _rivet_q WHERE 1 = 0");
319 let overrides = column_overrides.clone();
320 let Self { rt, client, .. } = self;
321 rt.block_on(async {
322 let mut stream = client
323 .query(wrapped.as_str(), &[])
324 .await
325 .map_err(|e| anyhow::anyhow!("mssql: type-probe query failed: {e}"))?;
326 let columns = stream
327 .columns()
328 .await
329 .map_err(|e| anyhow::anyhow!("mssql: type-probe metadata failed: {e}"))?
330 .map(<[_]>::to_vec)
331 .unwrap_or_default();
332 let _ = stream.into_first_result().await;
334 Ok(arrow_convert::mssql_type_mappings(&columns, &overrides))
335 })
336 }
337
338 fn sample_pressure(&mut self) -> Option<u64> {
339 let db = self.database.clone();
340 let Self { rt, client, .. } = self;
341 let sql = "SELECT cntr_value FROM sys.dm_os_performance_counters \
346 WHERE counter_name LIKE 'Log Flush Wait%' AND instance_name = @P1";
347 rt.block_on(async {
348 let row = client
349 .query(sql, &[&db])
350 .await
351 .ok()?
352 .into_row()
353 .await
354 .ok()??;
355 row.get::<i64, _>(0).map(|v| v.max(0) as u64)
356 })
357 }
358}
359
360fn emit_mssql_batch(
370 columns: &[tiberius::Column],
371 overrides: &ColumnOverrides,
372 schema: &mut Option<SchemaRef>,
373 rows: &[tiberius::Row],
374 sink: &mut dyn BatchSink,
375) -> Result<usize> {
376 let schema_ref = match schema {
377 Some(s) => s.clone(),
378 None => {
379 let (built, _decoders) =
380 arrow_convert::mssql_columns_to_schema(columns, overrides, rows)?;
381 let s: SchemaRef = Arc::new(built);
382 sink.on_schema(s.clone())?;
383 *schema = Some(s.clone());
384 s
385 }
386 };
387 if !rows.is_empty() {
388 let batch = arrow_convert::mssql_rows_to_record_batch(&schema_ref, rows)?;
389 let bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
390 sink.on_batch(&batch)?;
391 return Ok(bytes);
392 }
393 Ok(0)
394}
395
396fn scalar_to_string(row: &tiberius::Row) -> Option<String> {
400 use tiberius::ColumnData;
401 let cell = row.cells().next().map(|(_, d)| d)?;
402 match cell {
403 ColumnData::U8(v) => v.map(|x| x.to_string()),
404 ColumnData::I16(v) => v.map(|x| x.to_string()),
405 ColumnData::I32(v) => v.map(|x| x.to_string()),
406 ColumnData::I64(v) => v.map(|x| x.to_string()),
407 ColumnData::F32(v) => v.map(|x| x.to_string()),
408 ColumnData::F64(v) => v.map(|x| x.to_string()),
409 ColumnData::Bit(v) => v.map(|x| x.to_string()),
410 ColumnData::String(v) => v.as_ref().map(|s| s.to_string()),
411 ColumnData::Numeric(v) => v.map(|n| {
412 let raw = n.value();
414 let scale = n.scale() as usize;
415 if scale == 0 {
416 raw.to_string()
417 } else {
418 let neg = raw < 0;
419 let digits = raw.unsigned_abs().to_string();
420 let digits = format!("{digits:0>width$}", width = scale + 1);
421 let (int, frac) = digits.split_at(digits.len() - scale);
422 format!("{}{int}.{frac}", if neg { "-" } else { "" })
423 }
424 }),
425 ColumnData::Guid(v) => v.map(|g| g.to_string()),
426 other => Some(format!("{other:?}")),
427 }
428}
429
430pub(crate) fn introspect_mssql_table_for_chunking(
433 url: &str,
434 tls: Option<&TlsConfig>,
435 qualified_table: &str,
436) -> Result<TableIntrospection> {
437 let (schema, table) = match qualified_table.split_once('.') {
438 Some((s, t)) => (s.to_string(), t.to_string()),
439 None => ("dbo".to_string(), qualified_table.to_string()),
440 };
441 let mut src = MssqlSource::connect_with_tls(url, tls)?;
442
443 let count_sql = format!(
446 "SELECT SUM(p.row_count) FROM sys.dm_db_partition_stats p \
447 JOIN sys.objects o ON o.object_id = p.object_id \
448 JOIN sys.schemas s ON s.schema_id = o.schema_id \
449 WHERE s.name = N'{}' AND o.name = N'{}' AND p.index_id IN (0,1)",
450 schema.replace('\'', "''"),
451 table.replace('\'', "''")
452 );
453 let row_estimate = src
454 .query_scalar(&count_sql)?
455 .and_then(|s| s.parse::<i64>().ok())
456 .unwrap_or(0);
457
458 let pk_sql = format!(
461 "SELECT TOP 1 c.name, t.name FROM sys.indexes i \
462 JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id \
463 JOIN sys.columns c ON c.object_id = ic.object_id AND c.column_id = ic.column_id \
464 JOIN sys.types t ON t.user_type_id = c.user_type_id \
465 JOIN sys.objects o ON o.object_id = i.object_id \
466 JOIN sys.schemas s ON s.schema_id = o.schema_id \
467 WHERE i.is_primary_key = 1 AND s.name = N'{}' AND o.name = N'{}' \
468 GROUP BY c.name, t.name HAVING COUNT(*) = 1",
469 schema.replace('\'', "''"),
470 table.replace('\'', "''")
471 );
472 let mut single_int_pk = None;
473 let mut keyset_keys = Vec::new();
474 if let Some(pk_col) = src.query_scalar(&pk_sql)? {
475 keyset_keys.push(pk_col.clone());
477 let type_sql = format!(
480 "SELECT t.name FROM sys.columns c \
481 JOIN sys.types t ON t.user_type_id = c.user_type_id \
482 JOIN sys.objects o ON o.object_id = c.object_id \
483 JOIN sys.schemas s ON s.schema_id = o.schema_id \
484 WHERE s.name = N'{}' AND o.name = N'{}' AND c.name = N'{}'",
485 schema.replace('\'', "''"),
486 table.replace('\'', "''"),
487 pk_col.replace('\'', "''")
488 );
489 if let Some(ty) = src.query_scalar(&type_sql)?
490 && matches!(ty.as_str(), "tinyint" | "smallint" | "int" | "bigint")
491 {
492 single_int_pk = Some(pk_col);
493 }
494 }
495
496 Ok(TableIntrospection {
497 single_int_pk,
498 keyset_keys,
499 row_estimate,
500 avg_row_bytes: None,
501 })
502}