rivet/source/mssql/
proxy.rs1use tokio::runtime::Runtime;
14
15use super::{MssqlClient, scalar_to_string};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum MssqlProxyKind {
26 Direct,
28 Multiplexed,
37 AzureGateway,
44}
45
46impl MssqlProxyKind {
47 #[allow(dead_code)]
53 pub fn is_proxy(self) -> bool {
54 !matches!(self, MssqlProxyKind::Direct)
55 }
56
57 #[allow(dead_code)]
60 pub fn log_label(self) -> &'static str {
61 match self {
62 MssqlProxyKind::Direct => "direct",
63 MssqlProxyKind::Multiplexed => "mssql-multiplexed",
64 MssqlProxyKind::AzureGateway => "azure-gateway",
65 }
66 }
67
68 fn warn_message(self) -> Option<&'static str> {
71 match self {
72 MssqlProxyKind::Direct => None,
73 MssqlProxyKind::Multiplexed => Some(
74 "SQL Server connection multiplexing detected (@@SPID differs across queries) \
75 — session SET options, #temp tables, and cursors may not persist across \
76 statements; use a direct (non-pooled) connection for production exports",
77 ),
78 MssqlProxyKind::AzureGateway => Some(
79 "Azure SQL gateway detected (EngineEdition Azure) — the connection may be \
80 proxied or redirected and is subject to gateway idle timeouts and transient \
81 faults; keep retries enabled and prefer the Redirect connection policy for \
82 throughput",
83 ),
84 }
85 }
86}
87
88fn classify_mssql_proxy(
101 spid_pair: Option<(i32, i32)>,
102 engine_edition: Option<i32>,
103 version_banner: Option<&str>,
104) -> MssqlProxyKind {
105 if let Some((a, b)) = spid_pair
106 && a != b
107 {
108 return MssqlProxyKind::Multiplexed;
109 }
110 if matches!(engine_edition, Some(5) | Some(8)) {
111 return MssqlProxyKind::AzureGateway;
112 }
113 if let Some(v) = version_banner {
114 let l = v.to_ascii_lowercase();
115 if l.contains("sql azure") || l.contains("azure sql") {
116 return MssqlProxyKind::AzureGateway;
117 }
118 }
119 MssqlProxyKind::Direct
120}
121
122pub(super) fn detect_mssql_proxy_kind(rt: &Runtime, client: &mut MssqlClient) -> MssqlProxyKind {
127 rt.block_on(async {
128 let spid1 = scalar_i32(client, "SELECT CAST(@@SPID AS INT)").await;
133 let spid2 = scalar_i32(client, "SELECT CAST(@@SPID AS INT)").await;
134 let edition = scalar_i32(
135 client,
136 "SELECT CAST(SERVERPROPERTY('EngineEdition') AS INT)",
137 )
138 .await;
139 let banner = scalar_string(client, "SELECT @@VERSION").await;
140 let pair = match (spid1, spid2) {
141 (Some(a), Some(b)) => Some((a, b)),
142 _ => None,
143 };
144 classify_mssql_proxy(pair, edition, banner.as_deref())
145 })
146}
147
148async fn scalar_i32(client: &mut MssqlClient, sql: &str) -> Option<i32> {
150 let row = client.query(sql, &[]).await.ok()?.into_row().await.ok()??;
151 row.get::<i32, _>(0)
152}
153
154async fn scalar_string(client: &mut MssqlClient, sql: &str) -> Option<String> {
157 let row = client.query(sql, &[]).await.ok()?.into_row().await.ok()??;
158 scalar_to_string(&row)
159}
160
161pub(super) fn warn_proxy_kind(kind: MssqlProxyKind) {
164 if let Some(msg) = kind.warn_message() {
165 log::warn!("{msg}");
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::{MssqlProxyKind, classify_mssql_proxy};
172
173 #[test]
174 fn classify_direct_when_no_signals() {
175 assert_eq!(
176 classify_mssql_proxy(None, None, None),
177 MssqlProxyKind::Direct
178 );
179 }
180
181 #[test]
182 fn classify_direct_when_spids_match_and_on_prem_edition() {
183 assert_eq!(
185 classify_mssql_proxy(Some((53, 53)), Some(3), Some("Microsoft SQL Server 2022")),
186 MssqlProxyKind::Direct
187 );
188 }
189
190 #[test]
191 fn classify_multiplexed_when_spid_drifts() {
192 assert_eq!(
193 classify_mssql_proxy(Some((53, 71)), Some(3), None),
194 MssqlProxyKind::Multiplexed
195 );
196 }
197
198 #[test]
199 fn classify_azure_via_engine_edition_5() {
200 assert_eq!(
201 classify_mssql_proxy(Some((100, 100)), Some(5), None),
202 MssqlProxyKind::AzureGateway
203 );
204 }
205
206 #[test]
207 fn classify_azure_via_managed_instance_edition_8() {
208 assert_eq!(
209 classify_mssql_proxy(None, Some(8), None),
210 MssqlProxyKind::AzureGateway
211 );
212 }
213
214 #[test]
215 fn classify_azure_via_version_banner_when_edition_missing() {
216 assert_eq!(
217 classify_mssql_proxy(
218 Some((9, 9)),
219 None,
220 Some("Microsoft SQL Azure (RTM) - 12.0.2000.8")
221 ),
222 MssqlProxyKind::AzureGateway
223 );
224 }
225
226 #[test]
227 fn classify_multiplexed_takes_precedence_over_azure() {
228 assert_eq!(
231 classify_mssql_proxy(Some((1, 2)), Some(5), Some("Microsoft SQL Azure")),
232 MssqlProxyKind::Multiplexed
233 );
234 }
235
236 #[test]
237 fn classify_direct_when_spid_pair_missing_and_on_prem() {
238 assert_eq!(
239 classify_mssql_proxy(None, Some(2), Some("Microsoft SQL Server 2019")),
240 MssqlProxyKind::Direct
241 );
242 }
243
244 #[test]
245 fn classify_azure_banner_case_insensitive() {
246 assert_eq!(
247 classify_mssql_proxy(None, None, Some("...AZURE SQL...")),
248 MssqlProxyKind::AzureGateway
249 );
250 }
251
252 #[test]
255 fn is_proxy_helper_matches_variants() {
256 assert!(!MssqlProxyKind::Direct.is_proxy());
257 assert!(MssqlProxyKind::Multiplexed.is_proxy());
258 assert!(MssqlProxyKind::AzureGateway.is_proxy());
259 }
260
261 #[test]
262 fn direct_has_no_warning() {
263 assert!(MssqlProxyKind::Direct.warn_message().is_none());
264 }
265
266 #[test]
267 fn non_direct_variants_have_warnings() {
268 for k in [MssqlProxyKind::Multiplexed, MssqlProxyKind::AzureGateway] {
269 assert!(
270 k.warn_message().is_some(),
271 "{k:?} must emit a warning at connect time"
272 );
273 }
274 }
275
276 #[test]
277 fn log_labels_are_stable_and_distinct() {
278 let labels = [
279 MssqlProxyKind::Direct.log_label(),
280 MssqlProxyKind::Multiplexed.log_label(),
281 MssqlProxyKind::AzureGateway.log_label(),
282 ];
283 assert_eq!(labels, ["direct", "mssql-multiplexed", "azure-gateway"]);
284 }
285}