Skip to main content

rivet/source/mssql/
proxy.rs

1//! SQL Server connection-proxy classifier — distinguishes a direct connection
2//! from a transaction-mode multiplexer or an Azure SQL gateway in front of the
3//! instance.
4//!
5//! The SQL Server analogue of [`crate::source::mysql::proxy`] and the Postgres
6//! `detect_pg_transaction_pooler`: it runs once at connect time so the operator
7//! gets a one-line warning when a pooler/gateway sits in front of the database
8//! (session `SET` options, `#temp` tables, and any open cursor may not survive
9//! statement-level multiplexing). [`classify_mssql_proxy`] is a pure function,
10//! exhaustively unit-tested in this file; the I/O wrapper
11//! [`detect_mssql_proxy_kind`] collects the live signals and delegates.
12
13use tokio::runtime::Runtime;
14
15use super::{MssqlClient, scalar_to_string};
16
17/// What the SQL Server connection is actually talking to.
18///
19/// Detection happens once at connect time via [`detect_mssql_proxy_kind`].
20///
21/// `pub` for integration-test reachability via `MssqlSource::proxy_kind()`;
22/// same "no external API contract" disclaimer applies as for the rest of
23/// `rivet::source::mssql::*`.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum MssqlProxyKind {
26    /// Direct connection to a SQL Server instance — no proxy detected.
27    Direct,
28    /// A transaction-mode multiplexer / connection router: detected because
29    /// `@@SPID` returned different session ids across two consecutive queries
30    /// on the same connection. The proxy is handing each statement to a
31    /// different backend session, so `SET` options, `#temp` tables, and cursors
32    /// will not persist across statements.
33    ///
34    /// False negatives are possible when the proxy's backend pool size is 1
35    /// (the same physical session is always reused).
36    Multiplexed,
37    /// Azure SQL Database / Managed Instance: detected via
38    /// `SERVERPROPERTY('EngineEdition')` of `5` (Azure SQL DB) or `8` (Managed
39    /// Instance), falling back to the `@@VERSION` banner. These instances are
40    /// always fronted by the Azure SQL gateway — the session is preserved, but
41    /// the gateway may proxy *or* redirect the connection, enforces idle
42    /// timeouts, and injects transient faults.
43    AzureGateway,
44}
45
46impl MssqlProxyKind {
47    /// True for any non-direct connection (`is_proxy() == false` only for
48    /// [`MssqlProxyKind::Direct`]).
49    ///
50    /// `#[allow(dead_code)]` because the binary compilation unit (which
51    /// re-declares `mod source`) does not reference this; the lib + tests do.
52    #[allow(dead_code)]
53    pub fn is_proxy(self) -> bool {
54        !matches!(self, MssqlProxyKind::Direct)
55    }
56
57    /// Stable label for diagnostic logs. Keep terse and stable: external log
58    /// parsers grep on these strings.
59    #[allow(dead_code)]
60    pub fn log_label(self) -> &'static str {
61        match self {
62            MssqlProxyKind::Direct => "direct",
63            MssqlProxyKind::Multiplexed => "mssql-multiplexed",
64            MssqlProxyKind::AzureGateway => "azure-gateway",
65        }
66    }
67
68    /// One-time warning emitted at connect time. Returns `None` for
69    /// [`MssqlProxyKind::Direct`] (the common case, no warning needed).
70    fn warn_message(self) -> Option<&'static str> {
71        match self {
72            MssqlProxyKind::Direct => None,
73            MssqlProxyKind::Multiplexed => Some(
74                "SQL Server connection multiplexing detected (@@SPID differs across queries) \
75                 — session SET options, #temp tables, and cursors may not persist across \
76                 statements; use a direct (non-pooled) connection for production exports",
77            ),
78            MssqlProxyKind::AzureGateway => Some(
79                "Azure SQL gateway detected (EngineEdition Azure) — the connection may be \
80                 proxied or redirected and is subject to gateway idle timeouts and transient \
81                 faults; keep retries enabled and prefer the Redirect connection policy for \
82                 throughput",
83            ),
84        }
85    }
86}
87
88/// Pure classifier for SQL Server proxy detection signals. Kept separate from
89/// [`detect_mssql_proxy_kind`] so it can be exhaustively unit-tested without a
90/// live SQL Server. See [`MssqlProxyKind`] for the meaning of each variant.
91///
92/// Precedence is intentional:
93///
94/// 1. `@@SPID` differing across two queries → [`MssqlProxyKind::Multiplexed`].
95///    This is the strongest *risk* signal: a true statement-level multiplexer
96///    breaks session state even on Azure, so it wins over the edition probe.
97/// 2. `SERVERPROPERTY('EngineEdition')` of `5`/`8`, or an Azure `@@VERSION`
98///    banner → [`MssqlProxyKind::AzureGateway`].
99/// 3. Otherwise → [`MssqlProxyKind::Direct`].
100fn classify_mssql_proxy(
101    spid_pair: Option<(i32, i32)>,
102    engine_edition: Option<i32>,
103    version_banner: Option<&str>,
104) -> MssqlProxyKind {
105    if let Some((a, b)) = spid_pair
106        && a != b
107    {
108        return MssqlProxyKind::Multiplexed;
109    }
110    if matches!(engine_edition, Some(5) | Some(8)) {
111        return MssqlProxyKind::AzureGateway;
112    }
113    if let Some(v) = version_banner {
114        let l = v.to_ascii_lowercase();
115        if l.contains("sql azure") || l.contains("azure sql") {
116            return MssqlProxyKind::AzureGateway;
117        }
118    }
119    MssqlProxyKind::Direct
120}
121
122/// I/O wrapper around [`classify_mssql_proxy`]: collects the detection signals
123/// from a live connection and returns the classification. On any query failure
124/// the missing signal is simply dropped — detection is best-effort and must
125/// never break a real export (worst case it reports [`MssqlProxyKind::Direct`]).
126pub(super) fn detect_mssql_proxy_kind(rt: &Runtime, client: &mut MssqlClient) -> MssqlProxyKind {
127    rt.block_on(async {
128        // `@@SPID` is the session id; comparing two consecutive calls on the
129        // same connection detects transaction-mode multiplexers that hand each
130        // statement to a different backend session. `CAST(... AS INT)` so the
131        // value decodes uniformly (raw `@@SPID` is smallint).
132        let spid1 = scalar_i32(client, "SELECT CAST(@@SPID AS INT)").await;
133        let spid2 = scalar_i32(client, "SELECT CAST(@@SPID AS INT)").await;
134        let edition = scalar_i32(
135            client,
136            "SELECT CAST(SERVERPROPERTY('EngineEdition') AS INT)",
137        )
138        .await;
139        let banner = scalar_string(client, "SELECT @@VERSION").await;
140        let pair = match (spid1, spid2) {
141            (Some(a), Some(b)) => Some((a, b)),
142            _ => None,
143        };
144        classify_mssql_proxy(pair, edition, banner.as_deref())
145    })
146}
147
148/// Fetch a single `INT` scalar; `None` on any driver error or NULL.
149async fn scalar_i32(client: &mut MssqlClient, sql: &str) -> Option<i32> {
150    let row = client.query(sql, &[]).await.ok()?.into_row().await.ok()??;
151    row.get::<i32, _>(0)
152}
153
154/// Fetch a single string scalar via the shared [`scalar_to_string`] decoder;
155/// `None` on any driver error or NULL.
156async fn scalar_string(client: &mut MssqlClient, sql: &str) -> Option<String> {
157    let row = client.query(sql, &[]).await.ok()?.into_row().await.ok()??;
158    scalar_to_string(&row)
159}
160
161/// Emit the one-time connect-time warning for a non-direct proxy kind.
162/// Centralized so the wording stays consistent across connect entry points.
163pub(super) fn warn_proxy_kind(kind: MssqlProxyKind) {
164    if let Some(msg) = kind.warn_message() {
165        log::warn!("{msg}");
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::{MssqlProxyKind, classify_mssql_proxy};
172
173    #[test]
174    fn classify_direct_when_no_signals() {
175        assert_eq!(
176            classify_mssql_proxy(None, None, None),
177            MssqlProxyKind::Direct
178        );
179    }
180
181    #[test]
182    fn classify_direct_when_spids_match_and_on_prem_edition() {
183        // EngineEdition 3 = Enterprise/Developer (the dev container).
184        assert_eq!(
185            classify_mssql_proxy(Some((53, 53)), Some(3), Some("Microsoft SQL Server 2022")),
186            MssqlProxyKind::Direct
187        );
188    }
189
190    #[test]
191    fn classify_multiplexed_when_spid_drifts() {
192        assert_eq!(
193            classify_mssql_proxy(Some((53, 71)), Some(3), None),
194            MssqlProxyKind::Multiplexed
195        );
196    }
197
198    #[test]
199    fn classify_azure_via_engine_edition_5() {
200        assert_eq!(
201            classify_mssql_proxy(Some((100, 100)), Some(5), None),
202            MssqlProxyKind::AzureGateway
203        );
204    }
205
206    #[test]
207    fn classify_azure_via_managed_instance_edition_8() {
208        assert_eq!(
209            classify_mssql_proxy(None, Some(8), None),
210            MssqlProxyKind::AzureGateway
211        );
212    }
213
214    #[test]
215    fn classify_azure_via_version_banner_when_edition_missing() {
216        assert_eq!(
217            classify_mssql_proxy(
218                Some((9, 9)),
219                None,
220                Some("Microsoft SQL Azure (RTM) - 12.0.2000.8")
221            ),
222            MssqlProxyKind::AzureGateway
223        );
224    }
225
226    #[test]
227    fn classify_multiplexed_takes_precedence_over_azure() {
228        // A real statement-level multiplexer breaks session state even on
229        // Azure, so SPID drift must win over the edition probe.
230        assert_eq!(
231            classify_mssql_proxy(Some((1, 2)), Some(5), Some("Microsoft SQL Azure")),
232            MssqlProxyKind::Multiplexed
233        );
234    }
235
236    #[test]
237    fn classify_direct_when_spid_pair_missing_and_on_prem() {
238        assert_eq!(
239            classify_mssql_proxy(None, Some(2), Some("Microsoft SQL Server 2019")),
240            MssqlProxyKind::Direct
241        );
242    }
243
244    #[test]
245    fn classify_azure_banner_case_insensitive() {
246        assert_eq!(
247            classify_mssql_proxy(None, None, Some("...AZURE SQL...")),
248            MssqlProxyKind::AzureGateway
249        );
250    }
251
252    // ── Warning / label contract ────────────────────────────────────────
253
254    #[test]
255    fn is_proxy_helper_matches_variants() {
256        assert!(!MssqlProxyKind::Direct.is_proxy());
257        assert!(MssqlProxyKind::Multiplexed.is_proxy());
258        assert!(MssqlProxyKind::AzureGateway.is_proxy());
259    }
260
261    #[test]
262    fn direct_has_no_warning() {
263        assert!(MssqlProxyKind::Direct.warn_message().is_none());
264    }
265
266    #[test]
267    fn non_direct_variants_have_warnings() {
268        for k in [MssqlProxyKind::Multiplexed, MssqlProxyKind::AzureGateway] {
269            assert!(
270                k.warn_message().is_some(),
271                "{k:?} must emit a warning at connect time"
272            );
273        }
274    }
275
276    #[test]
277    fn log_labels_are_stable_and_distinct() {
278        let labels = [
279            MssqlProxyKind::Direct.log_label(),
280            MssqlProxyKind::Multiplexed.log_label(),
281            MssqlProxyKind::AzureGateway.log_label(),
282        ];
283        assert_eq!(labels, ["direct", "mssql-multiplexed", "azure-gateway"]);
284    }
285}