use super::ExportDiagnostic;
use super::analysis::*;
use super::cursor_expr::incremental_key_expr;
use super::schema_error::PreflightSchemaError;
use crate::config::{ExportConfig, ExportMode, SourceType, TlsConfig};
use crate::error::Result;
use crate::source::Source;
use crate::source::mssql::MssqlSource;
pub(super) fn check_mssql(
url: &str,
tls: Option<&TlsConfig>,
exports: &[&ExportConfig],
silent: bool,
) -> Result<()> {
let mut conn = MssqlSource::connect_with_tls(url, tls)?;
for export in exports {
let diag = diagnose_mssql(&mut conn, export)?;
if !silent {
super::print_diagnostic(&diag);
}
}
Ok(())
}
pub(super) fn diagnose_export_mssql(
url: &str,
tls: Option<&TlsConfig>,
export: &ExportConfig,
) -> Result<super::ExportDiagnostic> {
let mut conn = MssqlSource::connect_with_tls(url, tls)?;
diagnose_mssql(&mut conn, export)
}
fn diagnose_mssql(conn: &mut MssqlSource, export: &ExportConfig) -> Result<ExportDiagnostic> {
let mode_str = match export.mode {
ExportMode::Full => "full".to_string(),
ExportMode::Incremental => format!(
"incremental (cursor: {})",
export.cursor_column.as_deref().unwrap_or("?")
),
ExportMode::Chunked => format!(
"chunked (column: {}, size: {})",
export.chunk_column.as_deref().unwrap_or("?"),
export.chunk_size
),
ExportMode::TimeWindow => format!(
"time_window (column: {}, days: {})",
export.time_column.as_deref().unwrap_or("?"),
export.days_window.unwrap_or(0)
),
};
let base_query: String = match export.resolve_query(std::path::Path::new(""), None) {
Ok(q) => q,
Err(e) => {
log::debug!(
"preflight: base-query resolution failed for export '{}': {e}",
export.name
);
export
.query
.clone()
.unwrap_or_else(|| "SELECT 1".to_string())
}
};
let base_query = base_query.as_str();
if let Some(fail) = schema_fail_mssql(conn, base_query) {
return Err(fail);
}
let range_col = export
.chunk_column
.as_deref()
.or(export.cursor_column.as_deref());
let base_table =
strip_select_star_from(base_query).or_else(|| table_from_simple_query(base_query));
let row_estimate = match base_table {
Some(table) => row_estimate_mssql(conn, table),
None => None,
};
let (range_min, range_max) = if export.mode == ExportMode::Incremental {
match incremental_key_expr(export, SourceType::Mssql) {
Some(expr) => range_min_max_mssql(conn, base_query, base_table, &expr),
None => (None, None),
}
} else if let Some(col) = range_col {
let expr = crate::sql::quote_ident(SourceType::Mssql, col);
range_min_max_mssql(conn, base_query, base_table, &expr)
} else {
(None, None)
};
let scan_type = None;
let uses_index = if matches!(export.mode, ExportMode::Chunked | ExportMode::Incremental)
&& let Some(col) = range_col
&& let Some(table) = base_table
{
column_has_index_mssql(conn, table, col).unwrap_or(false)
} else {
false
};
let strategy = derive_strategy(export);
let verdict = compute_verdict(row_estimate, uses_index, export.cursor_column.is_some());
let recommended_profile = recommend_profile(row_estimate, uses_index, export);
let recommended_parallel = recommend_parallelism(export, row_estimate, uses_index);
let warnings = collect_warnings(
export,
row_estimate,
range_min.as_deref(),
range_max.as_deref(),
None,
);
let suggestion = build_suggestion(&verdict, row_estimate, uses_index, export);
Ok(ExportDiagnostic {
export_name: export.name.clone(),
strategy,
mode: mode_str,
cursor_column: export.cursor_column.clone(),
row_estimate,
cursor_min: range_min,
cursor_max: range_max,
scan_type,
uses_index,
verdict,
recommended_profile,
recommended_parallel,
warnings,
suggestion,
})
}
fn schema_fail_mssql(conn: &mut MssqlSource, base_query: &str) -> Option<anyhow::Error> {
let probe = format!("SELECT TOP 0 1 AS _ok FROM ({base_query}) AS _rivet_probe");
let Err(e) = conn.query_scalar(&probe) else {
return None;
};
let m = format!("{e:#}");
let detail = if m.contains("Invalid object name") {
"a table/view in the export's query does not exist"
} else if m.contains("Invalid column name") {
"a column in the export's query does not exist"
} else {
return None; };
let code_label = mssql_error_code(&m)
.map(|c| format!("error {c}"))
.unwrap_or_else(|| "SQL Server schema error".into());
Some(PreflightSchemaError::new(detail, code_label).into_error())
}
fn mssql_error_code(msg: &str) -> Option<u16> {
msg.split("code: ")
.nth(1)?
.chars()
.take_while(|c| c.is_ascii_digit())
.collect::<String>()
.parse()
.ok()
}
fn row_estimate_mssql(conn: &mut MssqlSource, qualified_table: &str) -> Option<i64> {
let (schema, table) = split_qualified(qualified_table);
let sql = format!(
"SELECT SUM(p.row_count) FROM sys.dm_db_partition_stats p \
JOIN sys.objects o ON o.object_id = p.object_id \
JOIN sys.schemas s ON s.schema_id = o.schema_id \
WHERE s.name = N'{}' AND o.name = N'{}' AND p.index_id IN (0,1)",
schema.replace('\'', "''"),
table.replace('\'', "''"),
);
match conn.query_scalar(&sql) {
Ok(opt) => opt.and_then(|s| s.parse::<i64>().ok()).map(|n| n.max(0)),
Err(e) => {
log::debug!("preflight: row-estimate probe failed for {qualified_table}: {e}");
None
}
}
}
fn range_min_max_mssql(
conn: &mut MssqlSource,
base_query: &str,
base_table: Option<&str>,
expr: &str,
) -> (Option<String>, Option<String>) {
const US: char = '\u{1f}';
let select_list = format!(
"CONCAT(CONVERT(varchar(64), MIN({expr})), CHAR(31), CONVERT(varchar(64), MAX({expr})))"
);
let sql = match base_table {
Some(table) => format!("SELECT {select_list} FROM {table}"),
None => format!("SELECT {select_list} FROM ({base_query}) AS _rivet"),
};
match conn.query_scalar(&sql) {
Ok(Some(agg)) => {
let mut parts = agg.splitn(2, US);
let min_v = parts.next().filter(|s| !s.is_empty()).map(str::to_string);
let max_v = parts.next().filter(|s| !s.is_empty()).map(str::to_string);
(min_v, max_v)
}
Ok(None) => (None, None),
Err(e) => {
log::debug!("preflight: range probe on '{expr}' failed: {e}");
(None, None)
}
}
}
fn column_has_index_mssql(
conn: &mut MssqlSource,
qualified_table: &str,
column: &str,
) -> Option<bool> {
let (schema, table) = split_qualified(qualified_table);
let sql = format!(
"SELECT COUNT(*) FROM sys.indexes i \
JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id \
JOIN sys.columns c ON c.object_id = ic.object_id AND c.column_id = ic.column_id \
JOIN sys.objects o ON o.object_id = i.object_id \
JOIN sys.schemas s ON s.schema_id = o.schema_id \
WHERE ic.key_ordinal = 1 AND i.index_id > 0 \
AND s.name = N'{}' AND o.name = N'{}' AND c.name = N'{}'",
schema.replace('\'', "''"),
table.replace('\'', "''"),
column.replace('\'', "''"),
);
match conn.query_scalar(&sql) {
Ok(opt) => Some(opt.and_then(|s| s.parse::<i64>().ok()).unwrap_or(0) > 0),
Err(e) => {
log::debug!("preflight: index probe failed for {qualified_table}.{column}: {e}");
None
}
}
}
fn split_qualified(qualified_table: &str) -> (&str, &str) {
match qualified_table.split_once('.') {
Some((s, t)) => (s, t),
None => ("dbo", qualified_table),
}
}
use super::postgres::table_from_simple_query;
use crate::sql::strip_select_star_from;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn split_qualified_defaults_schema_to_dbo() {
assert_eq!(split_qualified("orders"), ("dbo", "orders"));
}
#[test]
fn split_qualified_keeps_explicit_schema() {
assert_eq!(split_qualified("dbo.orders"), ("dbo", "orders"));
assert_eq!(split_qualified("sales.line_items"), ("sales", "line_items"));
}
fn base_table_of(q: &str) -> Option<&str> {
strip_select_star_from(q).or_else(|| table_from_simple_query(q))
}
#[test]
fn base_table_from_select_star_shortcut() {
assert_eq!(
base_table_of("SELECT * FROM dbo.orders"),
Some("dbo.orders")
);
assert_eq!(base_table_of("SELECT * FROM orders"), Some("orders"));
}
#[test]
fn base_table_from_init_column_list_query() {
let q = "SELECT id, user_id, product FROM dbo.orders";
assert_eq!(base_table_of(q), Some("dbo.orders"));
}
#[test]
fn base_table_none_for_multi_relation_query() {
assert_eq!(
base_table_of("SELECT * FROM orders JOIN users ON orders.user_id = users.id"),
None
);
assert_eq!(base_table_of("SELECT * FROM orders, users"), None);
assert_eq!(base_table_of("SELECT * FROM (SELECT 1 AS x) AS s"), None);
}
#[test]
fn mssql_error_code_parsed_from_tiberius_display() {
let m = "mssql: scalar query failed: Token error: 'Invalid object name \
'ordrs'.' on server x executing on line 1 (code: 208, state: 1, class: 16)";
assert_eq!(mssql_error_code(m), Some(208));
assert_eq!(
mssql_error_code("Invalid column name 'totl'. (code: 207)"),
Some(207)
);
assert_eq!(mssql_error_code("connection reset; no code here"), None);
}
}