sqlx-core-oldapi 0.6.53

Core of SQLx, the rust SQL toolkit. Not intended to be used directly.
Documentation
use crate::error::Error;
use encoding_rs::WINDOWS_1252;
use sqlx_rt::{timeout, UdpSocket};
use std::time::Duration;

const SSRP_PORT: u16 = 1434;
const CLNT_UCAST_INST: u8 = 0x04;
const SVR_RESP: u8 = 0x05;
const SSRP_TIMEOUT: Duration = Duration::from_secs(1);

struct InstanceInfo<'a> {
    server_name: Option<&'a str>,
    instance_name: Option<&'a str>,
    is_clustered: Option<bool>,
    version: Option<&'a str>,
    tcp_port: Option<u16>,
}

pub(crate) async fn resolve_instance_port(server: &str, instance: &str) -> Result<u16, Error> {
    log::debug!(
        "resolving SQL Server instance port for '{}' on server '{}'",
        instance,
        server
    );

    let mut request = Vec::with_capacity(1 + instance.len() + 1);
    request.push(CLNT_UCAST_INST);
    request.extend_from_slice(instance.as_bytes());
    request.push(0);

    let socket = UdpSocket::bind("0.0.0.0:0")
        .await
        .map_err(|e| err_protocol!("failed to bind UDP socket for SSRP: {}", e))?;

    log::debug!(
        "sending SSRP CLNT_UCAST_INST request to {}:{} for instance '{}'",
        server,
        SSRP_PORT,
        instance
    );

    socket
        .send_to(&request, (server, SSRP_PORT))
        .await
        .map_err(|e| {
            err_protocol!(
                "failed to send SSRP request to {}:{}: {}",
                server,
                SSRP_PORT,
                e
            )
        })?;

    let mut buffer = [0u8; 1024];
    let bytes_read = timeout(SSRP_TIMEOUT, socket.recv(&mut buffer))
        .await
        .map_err(|_| {
            err_protocol!(
                "SSRP request to {} for instance {} timed out after {:?}",
                server,
                instance,
                SSRP_TIMEOUT
            )
        })?
        .map_err(|e| {
            err_protocol!(
                "failed to receive SSRP response from {} for instance {}: {}",
                server,
                instance,
                e
            )
        })?;

    log::debug!(
        "received SSRP response from {} ({} bytes)",
        server,
        bytes_read
    );

    if bytes_read < 3 {
        return Err(err_protocol!(
            "SSRP response too short: {} bytes",
            bytes_read
        ));
    }

    if buffer[0] != SVR_RESP {
        return Err(err_protocol!(
            "invalid SSRP response type: expected 0x05, got 0x{:02x}",
            buffer[0]
        ));
    }

    let response_size = u16::from_le_bytes([buffer[1], buffer[2]]) as usize;
    if response_size + 3 > bytes_read {
        return Err(err_protocol!(
            "SSRP response size mismatch: expected {} bytes, got {}",
            response_size + 3,
            bytes_read
        ));
    }

    let response_bytes = &buffer[3..(3 + response_size)];
    let (response_str, _encoding_used, had_errors) = WINDOWS_1252.decode(response_bytes);

    if had_errors {
        log::debug!("SSRP response had MBCS decoding errors, continuing anyway");
    }

    log::debug!("SSRP response data: {}", response_str);

    find_instance_tcp_port(&response_str, instance)
}

fn find_instance_tcp_port(data: &str, instance_name: &str) -> Result<u16, Error> {
    for instance_data in data.split(";;") {
        if instance_data.is_empty() {
            continue;
        }

        let info = parse_instance_info(instance_data);

        if let Some(name) = info.instance_name {
            log::debug!("found instance '{}' in SSRP response", name);

            if name.eq_ignore_ascii_case(instance_name) {
                log::debug!(
                    "instance '{}' matches requested instance '{}'",
                    name,
                    instance_name
                );

                if let Some(port) = info.tcp_port {
                    log::debug!("resolved instance '{}' to port {}", instance_name, port);
                    return Ok(port);
                } else {
                    return Err(err_protocol!(
                        "instance '{}' found but no TCP port available",
                        instance_name
                    ));
                }
            }
        }
    }

    Err(err_protocol!(
        "instance '{}' not found in SSRP response",
        instance_name
    ))
}

fn parse_instance_info<'a>(data: &'a str) -> InstanceInfo<'a> {
    let mut info = InstanceInfo {
        server_name: None,
        instance_name: None,
        is_clustered: None,
        version: None,
        tcp_port: None,
    };

    let mut tokens = data.split(';');
    while let Some(key) = tokens.next() {
        let value = tokens.next();

        match key {
            "ServerName" => info.server_name = value,
            "InstanceName" => info.instance_name = value,
            "IsClustered" => {
                info.is_clustered = value.and_then(|v| match v {
                    "Yes" => Some(true),
                    "No" => Some(false),
                    _ => None,
                });
            }
            "Version" => info.version = value,
            "tcp" => {
                info.tcp_port = value.and_then(|v| v.parse::<u16>().ok());
            }
            _ => {
                if !key.is_empty() {
                    log::debug!("ignoring unknown SSRP key: '{}'", key);
                }
            }
        }
    }

    info
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_find_instance_tcp_port_single_instance() {
        let data = "ServerName;MYSERVER;InstanceName;SQLEXPRESS;IsClustered;No;Version;15.0.2000.5;tcp;1433;;";
        let port = find_instance_tcp_port(data, "SQLEXPRESS").unwrap();
        assert_eq!(port, 1433);
    }

    #[test]
    fn test_find_instance_tcp_port_multiple_instances() {
        let data = "ServerName;SRV1;InstanceName;INST1;IsClustered;No;Version;15.0.2000.5;tcp;1433;;ServerName;SRV1;InstanceName;INST2;IsClustered;No;Version;16.0.1000.6;tcp;1434;np;\\\\SRV1\\pipe\\MSSQL$INST2\\sql\\query;;";
        let port = find_instance_tcp_port(data, "INST2").unwrap();
        assert_eq!(port, 1434);
    }

    #[test]
    fn test_find_instance_tcp_port_case_insensitive() {
        let data = "ServerName;MYSERVER;InstanceName;SQLExpress;IsClustered;No;Version;15.0.2000.5;tcp;1433;;";
        let port = find_instance_tcp_port(data, "sqlexpress").unwrap();
        assert_eq!(port, 1433);
    }

    #[test]
    fn test_find_instance_tcp_port_instance_not_found() {
        let data = "ServerName;MYSERVER;InstanceName;SQLEXPRESS;IsClustered;No;Version;15.0.2000.5;tcp;1433;;";
        let result = find_instance_tcp_port(data, "NOTFOUND");
        assert!(result.is_err());
    }

    #[test]
    fn test_find_instance_tcp_port_no_tcp_port() {
        let data =
            "ServerName;MYSERVER;InstanceName;SQLEXPRESS;IsClustered;No;Version;15.0.2000.5;;";
        let result = find_instance_tcp_port(data, "SQLEXPRESS");
        assert!(result.is_err());
    }
}