mssql-browser 0.1.1

Rust implementation of the SQL Server Resolution Protocol
Documentation
use super::error::*;
use super::info::*;
use super::socket::{UdpSocket, UdpSocketFactory};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};

/// The CLNT_UCAST_EX packet is a unicast request that is generated by clients that are trying to identify
/// the list of database instances and their network protocol connection information installed on a single machine.
const CLNT_UCAST_EX: u8 = 0x03;

/// The server responds to all client requests with an SVR_RESP.
const SVR_RESP: u8 = 0x05;

/// Discovers any SQL Server instances running on the given host
///
/// # Arguments
/// * `remote_addr` - The address of the remote host of which to retrieve information
///                   about the instances running on it.
#[cfg(any(feature = "tokio", feature = "async-std"))]
pub async fn browse_host(
    remote_addr: IpAddr,
) -> Result<
    InstanceIterator,
    BrowserError<
        <super::socket::DefaultSocketFactory as UdpSocketFactory>::Error,
        <<super::socket::DefaultSocketFactory as UdpSocketFactory>::Socket as UdpSocket>::Error,
    >,
> {
    let mut factory = super::socket::DefaultSocketFactory::new();
    browse_host_inner(remote_addr, &mut factory).await
}

/// Discovers any SQL Server instances running on the given host
///
/// # Arguments
/// * `remote_addr` - The address of the remote host of which to retrieve information
///                   about the instances running on it.
pub async fn browse_host_inner<SF: UdpSocketFactory>(
    remote_addr: IpAddr,
    socket_factory: &mut SF,
) -> Result<InstanceIterator, BrowserError<SF::Error, <SF::Socket as UdpSocket>::Error>> {
    let local_addr = if remote_addr.is_ipv4() {
        IpAddr::V4(Ipv4Addr::UNSPECIFIED)
    } else {
        IpAddr::V6(Ipv6Addr::UNSPECIFIED)
    };

    let bind_to = SocketAddr::new(local_addr, 0);
    let mut socket = socket_factory
        .bind(&bind_to)
        .await
        .map_err(BrowserError::BindFailed)?;

    let remote = SocketAddr::new(remote_addr, 1434);
    socket
        .connect(&remote)
        .await
        .map_err(|e| BrowserError::ConnectFailed(remote, e))?;

    let buffer = [CLNT_UCAST_EX];
    socket
        .send_to(&buffer, &remote)
        .await
        .map_err(|e| BrowserError::SendFailed(remote, e))?;

    let mut buffer = Vec::with_capacity(65535 + 3);

    buffer.resize_with(buffer.capacity(), Default::default);

    let bytes_received = socket
        .recv(&mut buffer)
        .await
        .map_err(BrowserError::ReceiveFailed)?;

    if bytes_received < 1 {
        return Err(BrowserError::ProtocolError(
            BrowserProtocolError::UnexpectedToken {
                expected: BrowserProtocolToken::MessageIdentifier(SVR_RESP),
                found: BrowserProtocolToken::EndOfMessage,
            },
        ));
    }

    if buffer[0] != SVR_RESP {
        return Err(BrowserError::ProtocolError(
            BrowserProtocolError::UnexpectedToken {
                expected: BrowserProtocolToken::MessageIdentifier(SVR_RESP),
                found: BrowserProtocolToken::MessageIdentifier(buffer[0]),
            },
        ));
    }

    if bytes_received < 3 {
        return Err(BrowserError::ProtocolError(
            BrowserProtocolError::UnexpectedToken {
                expected: BrowserProtocolToken::MessageLength,
                found: BrowserProtocolToken::EndOfMessage,
            },
        ));
    }

    let resp_data_len = u16::from_le_bytes([buffer[1], buffer[2]]);
    if resp_data_len as usize != bytes_received - 3 {
        return Err(BrowserError::ProtocolError(
            BrowserProtocolError::LengthMismatch {
                datagram: bytes_received,
                header: (resp_data_len + 3) as usize,
            },
        ));
    }

    buffer.truncate(bytes_received);

    // Validate that the buffer is valid utf-8
    // TODO: Decode mbcs string
    std::str::from_utf8(&buffer[3..])
        .map_err(|e| BrowserError::ProtocolError(BrowserProtocolError::InvalidUtf8(e)))?;

    Ok(InstanceIterator {
        remote_addr,
        buffer,
        offset: 3,
    })
}

/// Iterates over the instances returned by `browse_host`
pub struct InstanceIterator {
    remote_addr: IpAddr,
    buffer: Vec<u8>,
    offset: usize,
}

impl InstanceIterator {
    /// Gets the next received instance information. You can call this method multiple
    /// times to receive information about multiple instances until it returns Ok(None).
    pub fn next(
        &mut self,
    ) -> Result<
        Option<InstanceInfo>,
        BrowserError<std::convert::Infallible, std::convert::Infallible>,
    > {
        if self.offset == self.buffer.len() {
            return Ok(None);
        }

        // UNSAFE: Buffer is already validated to be valid utf-8 when the iterator was created
        let as_str = unsafe { std::str::from_utf8_unchecked(&self.buffer[self.offset..]) };
        let (instance, consumed) = parse_instance_info(self.remote_addr, as_str)
            .map_err(|e| BrowserError::ProtocolError(e))?;

        self.offset += consumed;
        Ok(Some(instance))
    }
}