mssql-browser 0.1.1

Rust implementation of the SQL Server Resolution Protocol
Documentation
use super::error::*;
use super::info::*;
use super::socket::{UdpSocket, UdpSocketFactory};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};

/// The CLNT_BCAST_EX packet is a broadcast or multicast request that is generated by clients that are trying
/// to identify the list of database instances on the network and their network protocol connection information.
const CLNT_BCAST_EX: u8 = 0x02;

/// The server responds to all client requests with an SVR_RESP.
const SVR_RESP: u8 = 0x05;

/// Discovers any SQL Server instances running on hosts reached by
/// the given multicast address.
///
/// # Arguments
/// * `multicast_addr` - A multicast address to which to broadcast the browse datagram.
///                      This can be the Ipv4 BROADCAST address, or a Ipv6 multicast address.
#[cfg(any(feature = "tokio", feature = "async-std"))]
pub async fn browse(
    multicast_addr: IpAddr,
) -> Result<
    AsyncInstanceIterator<<super::socket::DefaultSocketFactory as UdpSocketFactory>::Socket>,
    BrowserError<
        <super::socket::DefaultSocketFactory as UdpSocketFactory>::Error,
        <<super::socket::DefaultSocketFactory as UdpSocketFactory>::Socket as UdpSocket>::Error,
    >,
> {
    let mut factory = super::socket::DefaultSocketFactory::new();
    browse_inner(multicast_addr, &mut factory).await
}

/// Discovers any SQL Server instances running on hosts reached by
/// the given multicast address.
///
/// # Arguments
/// * `multicast_addr` - A multicast address to which to broadcast the browse datagram.
///                      This can be the Ipv4 BROADCAST address, or a Ipv6 multicast address.
pub async fn browse_inner<SF: UdpSocketFactory>(
    multicast_addr: IpAddr,
    socket_factory: &mut SF,
) -> Result<
    AsyncInstanceIterator<SF::Socket>,
    BrowserError<SF::Error, <SF::Socket as UdpSocket>::Error>,
> {
    let local_addr = if multicast_addr.is_ipv4() {
        IpAddr::V4(Ipv4Addr::UNSPECIFIED)
    } else {
        IpAddr::V6(Ipv6Addr::UNSPECIFIED)
    };

    let bind_to = SocketAddr::new(local_addr, 0);
    let mut socket = socket_factory
        .bind(&bind_to)
        .await
        .map_err(BrowserError::BindFailed)?;

    socket
        .enable_broadcast()
        .await
        .map_err(BrowserError::SetBroadcastFailed)?;

    let buffer = [CLNT_BCAST_EX];
    let remote = SocketAddr::new(multicast_addr, 1434);
    socket
        .send_to(&buffer, &remote)
        .await
        .map_err(|e| BrowserError::SendFailed(remote, e))?;

    Ok(AsyncInstanceIterator {
        socket: socket,
        buffer: Vec::new(),
        current_remote_addr: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
        current_offset: 0,
    })
}

/// Iterates over the instances returned by `browse`
pub struct AsyncInstanceIterator<S: UdpSocket> {
    socket: S,
    buffer: Vec<u8>,

    current_remote_addr: IpAddr,
    current_offset: usize,
}

impl<S: UdpSocket> AsyncInstanceIterator<S> {
    /// Gets the next received instance information. You can call this method multiple
    /// times to receive information about multiple instances until it returns Ok(None).
    pub async fn next(
        &mut self,
    ) -> Result<InstanceInfo, BrowserError<std::convert::Infallible, S::Error>> {
        loop {
            if self.current_offset >= self.buffer.len() {
                // Need to receive a new packet
                // TODO: Find a way to determine buffer size based on FIONREAD
                // once/if ever tokio supports it
                self.buffer.resize_with(65535 + 3, Default::default);

                let (bytes_received, remote_addr) = self
                    .socket
                    .recv_from(&mut self.buffer)
                    .await
                    .map_err(BrowserError::ReceiveFailed)?;

                self.current_remote_addr = remote_addr.ip();

                if bytes_received < 3 || self.buffer[0] != SVR_RESP {
                    self.current_offset = std::usize::MAX;
                    continue;
                }

                let resp_data_len = u16::from_le_bytes([self.buffer[1], self.buffer[2]]);
                if resp_data_len as usize != bytes_received - 3 {
                    self.current_offset = std::usize::MAX;
                    continue;
                }

                // Validate that the buffer is valid utf-8
                // TODO: Decode mbcs string
                if std::str::from_utf8(&self.buffer[3..]).is_err() {
                    self.current_offset = std::usize::MAX;
                    continue;
                }

                self.buffer.truncate(bytes_received);
                self.current_offset = 3;
            }

            // UNSAFE: Buffer is already validated to be valid utf-8 when the iterator was created
            let as_str =
                unsafe { std::str::from_utf8_unchecked(&self.buffer[self.current_offset..]) };

            let (instance, consumed) = match parse_instance_info(self.current_remote_addr, as_str) {
                Ok(x) => x,
                Err(_) => {
                    self.current_offset = std::usize::MAX;
                    continue;
                }
            };

            self.current_offset += consumed;
            return Ok(instance);
        }
    }
}