use crate::error::Error;
use encoding_rs::WINDOWS_1252;
use sqlx_rt::{timeout, UdpSocket};
use std::time::Duration;
const SSRP_PORT: u16 = 1434;
const CLNT_UCAST_INST: u8 = 0x04;
const SVR_RESP: u8 = 0x05;
const SSRP_TIMEOUT: Duration = Duration::from_secs(1);
struct InstanceInfo<'a> {
server_name: Option<&'a str>,
instance_name: Option<&'a str>,
is_clustered: Option<bool>,
version: Option<&'a str>,
tcp_port: Option<u16>,
}
pub(crate) async fn resolve_instance_port(server: &str, instance: &str) -> Result<u16, Error> {
log::debug!(
"resolving SQL Server instance port for '{}' on server '{}'",
instance,
server
);
let mut request = Vec::with_capacity(1 + instance.len() + 1);
request.push(CLNT_UCAST_INST);
request.extend_from_slice(instance.as_bytes());
request.push(0);
let socket = UdpSocket::bind("0.0.0.0:0")
.await
.map_err(|e| err_protocol!("failed to bind UDP socket for SSRP: {}", e))?;
log::debug!(
"sending SSRP CLNT_UCAST_INST request to {}:{} for instance '{}'",
server,
SSRP_PORT,
instance
);
socket
.send_to(&request, (server, SSRP_PORT))
.await
.map_err(|e| {
err_protocol!(
"failed to send SSRP request to {}:{}: {}",
server,
SSRP_PORT,
e
)
})?;
let mut buffer = [0u8; 1024];
let bytes_read = timeout(SSRP_TIMEOUT, socket.recv(&mut buffer))
.await
.map_err(|_| {
err_protocol!(
"SSRP request to {} for instance {} timed out after {:?}",
server,
instance,
SSRP_TIMEOUT
)
})?
.map_err(|e| {
err_protocol!(
"failed to receive SSRP response from {} for instance {}: {}",
server,
instance,
e
)
})?;
log::debug!(
"received SSRP response from {} ({} bytes)",
server,
bytes_read
);
if bytes_read < 3 {
return Err(err_protocol!(
"SSRP response too short: {} bytes",
bytes_read
));
}
if buffer[0] != SVR_RESP {
return Err(err_protocol!(
"invalid SSRP response type: expected 0x05, got 0x{:02x}",
buffer[0]
));
}
let response_size = u16::from_le_bytes([buffer[1], buffer[2]]) as usize;
if response_size + 3 > bytes_read {
return Err(err_protocol!(
"SSRP response size mismatch: expected {} bytes, got {}",
response_size + 3,
bytes_read
));
}
let response_bytes = &buffer[3..(3 + response_size)];
let (response_str, _encoding_used, had_errors) = WINDOWS_1252.decode(response_bytes);
if had_errors {
log::debug!("SSRP response had MBCS decoding errors, continuing anyway");
}
log::debug!("SSRP response data: {}", response_str);
find_instance_tcp_port(&response_str, instance)
}
fn find_instance_tcp_port(data: &str, instance_name: &str) -> Result<u16, Error> {
for instance_data in data.split(";;") {
if instance_data.is_empty() {
continue;
}
let info = parse_instance_info(instance_data);
if let Some(name) = info.instance_name {
log::debug!("found instance '{}' in SSRP response", name);
if name.eq_ignore_ascii_case(instance_name) {
log::debug!(
"instance '{}' matches requested instance '{}'",
name,
instance_name
);
if let Some(port) = info.tcp_port {
log::debug!("resolved instance '{}' to port {}", instance_name, port);
return Ok(port);
} else {
return Err(err_protocol!(
"instance '{}' found but no TCP port available",
instance_name
));
}
}
}
}
Err(err_protocol!(
"instance '{}' not found in SSRP response",
instance_name
))
}
fn parse_instance_info<'a>(data: &'a str) -> InstanceInfo<'a> {
let mut info = InstanceInfo {
server_name: None,
instance_name: None,
is_clustered: None,
version: None,
tcp_port: None,
};
let mut tokens = data.split(';');
while let Some(key) = tokens.next() {
let value = tokens.next();
match key {
"ServerName" => info.server_name = value,
"InstanceName" => info.instance_name = value,
"IsClustered" => {
info.is_clustered = value.and_then(|v| match v {
"Yes" => Some(true),
"No" => Some(false),
_ => None,
});
}
"Version" => info.version = value,
"tcp" => {
info.tcp_port = value.and_then(|v| v.parse::<u16>().ok());
}
_ => {
if !key.is_empty() {
log::debug!("ignoring unknown SSRP key: '{}'", key);
}
}
}
}
info
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_find_instance_tcp_port_single_instance() {
let data = "ServerName;MYSERVER;InstanceName;SQLEXPRESS;IsClustered;No;Version;15.0.2000.5;tcp;1433;;";
let port = find_instance_tcp_port(data, "SQLEXPRESS").unwrap();
assert_eq!(port, 1433);
}
#[test]
fn test_find_instance_tcp_port_multiple_instances() {
let data = "ServerName;SRV1;InstanceName;INST1;IsClustered;No;Version;15.0.2000.5;tcp;1433;;ServerName;SRV1;InstanceName;INST2;IsClustered;No;Version;16.0.1000.6;tcp;1434;np;\\\\SRV1\\pipe\\MSSQL$INST2\\sql\\query;;";
let port = find_instance_tcp_port(data, "INST2").unwrap();
assert_eq!(port, 1434);
}
#[test]
fn test_find_instance_tcp_port_case_insensitive() {
let data = "ServerName;MYSERVER;InstanceName;SQLExpress;IsClustered;No;Version;15.0.2000.5;tcp;1433;;";
let port = find_instance_tcp_port(data, "sqlexpress").unwrap();
assert_eq!(port, 1433);
}
#[test]
fn test_find_instance_tcp_port_instance_not_found() {
let data = "ServerName;MYSERVER;InstanceName;SQLEXPRESS;IsClustered;No;Version;15.0.2000.5;tcp;1433;;";
let result = find_instance_tcp_port(data, "NOTFOUND");
assert!(result.is_err());
}
#[test]
fn test_find_instance_tcp_port_no_tcp_port() {
let data =
"ServerName;MYSERVER;InstanceName;SQLEXPRESS;IsClustered;No;Version;15.0.2000.5;;";
let result = find_instance_tcp_port(data, "SQLEXPRESS");
assert!(result.is_err());
}
}