mssql_browser/
browse.rs

1use super::error::*;
2use super::info::*;
3use super::socket::{UdpSocket, UdpSocketFactory};
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
5
6/// The CLNT_BCAST_EX packet is a broadcast or multicast request that is generated by clients that are trying
7/// to identify the list of database instances on the network and their network protocol connection information.
8const CLNT_BCAST_EX: u8 = 0x02;
9
10/// The server responds to all client requests with an SVR_RESP.
11const SVR_RESP: u8 = 0x05;
12
13/// Discovers any SQL Server instances running on hosts reached by
14/// the given multicast address.
15///
16/// # Arguments
17/// * `multicast_addr` - A multicast address to which to broadcast the browse datagram.
18///                      This can be the Ipv4 BROADCAST address, or a Ipv6 multicast address.
19#[cfg(any(feature = "tokio", feature = "async-std"))]
20pub async fn browse(
21    multicast_addr: IpAddr,
22) -> Result<
23    AsyncInstanceIterator<<super::socket::DefaultSocketFactory as UdpSocketFactory>::Socket>,
24    BrowserError<
25        <super::socket::DefaultSocketFactory as UdpSocketFactory>::Error,
26        <<super::socket::DefaultSocketFactory as UdpSocketFactory>::Socket as UdpSocket>::Error,
27    >,
28> {
29    let mut factory = super::socket::DefaultSocketFactory::new();
30    browse_inner(multicast_addr, &mut factory).await
31}
32
33/// Discovers any SQL Server instances running on hosts reached by
34/// the given multicast address.
35///
36/// # Arguments
37/// * `multicast_addr` - A multicast address to which to broadcast the browse datagram.
38///                      This can be the Ipv4 BROADCAST address, or a Ipv6 multicast address.
39pub async fn browse_inner<SF: UdpSocketFactory>(
40    multicast_addr: IpAddr,
41    socket_factory: &mut SF,
42) -> Result<
43    AsyncInstanceIterator<SF::Socket>,
44    BrowserError<SF::Error, <SF::Socket as UdpSocket>::Error>,
45> {
46    let local_addr = if multicast_addr.is_ipv4() {
47        IpAddr::V4(Ipv4Addr::UNSPECIFIED)
48    } else {
49        IpAddr::V6(Ipv6Addr::UNSPECIFIED)
50    };
51
52    let bind_to = SocketAddr::new(local_addr, 0);
53    let mut socket = socket_factory
54        .bind(&bind_to)
55        .await
56        .map_err(BrowserError::BindFailed)?;
57
58    socket
59        .enable_broadcast()
60        .await
61        .map_err(BrowserError::SetBroadcastFailed)?;
62
63    let buffer = [CLNT_BCAST_EX];
64    let remote = SocketAddr::new(multicast_addr, 1434);
65    socket
66        .send_to(&buffer, &remote)
67        .await
68        .map_err(|e| BrowserError::SendFailed(remote, e))?;
69
70    Ok(AsyncInstanceIterator {
71        socket: socket,
72        buffer: Vec::new(),
73        current_remote_addr: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
74        current_offset: 0,
75    })
76}
77
78/// Iterates over the instances returned by `browse`
79pub struct AsyncInstanceIterator<S: UdpSocket> {
80    socket: S,
81    buffer: Vec<u8>,
82
83    current_remote_addr: IpAddr,
84    current_offset: usize,
85}
86
87impl<S: UdpSocket> AsyncInstanceIterator<S> {
88    /// Gets the next received instance information. You can call this method multiple
89    /// times to receive information about multiple instances until it returns Ok(None).
90    pub async fn next(
91        &mut self,
92    ) -> Result<InstanceInfo, BrowserError<std::convert::Infallible, S::Error>> {
93        loop {
94            if self.current_offset >= self.buffer.len() {
95                // Need to receive a new packet
96                // TODO: Find a way to determine buffer size based on FIONREAD
97                // once/if ever tokio supports it
98                self.buffer.resize_with(65535 + 3, Default::default);
99
100                let (bytes_received, remote_addr) = self
101                    .socket
102                    .recv_from(&mut self.buffer)
103                    .await
104                    .map_err(BrowserError::ReceiveFailed)?;
105
106                self.current_remote_addr = remote_addr.ip();
107
108                if bytes_received < 3 || self.buffer[0] != SVR_RESP {
109                    self.current_offset = std::usize::MAX;
110                    continue;
111                }
112
113                let resp_data_len = u16::from_le_bytes([self.buffer[1], self.buffer[2]]);
114                if resp_data_len as usize != bytes_received - 3 {
115                    self.current_offset = std::usize::MAX;
116                    continue;
117                }
118
119                // Validate that the buffer is valid utf-8
120                // TODO: Decode mbcs string
121                if std::str::from_utf8(&self.buffer[3..]).is_err() {
122                    self.current_offset = std::usize::MAX;
123                    continue;
124                }
125
126                self.buffer.truncate(bytes_received);
127                self.current_offset = 3;
128            }
129
130            // UNSAFE: Buffer is already validated to be valid utf-8 when the iterator was created
131            let as_str =
132                unsafe { std::str::from_utf8_unchecked(&self.buffer[self.current_offset..]) };
133
134            let (instance, consumed) = match parse_instance_info(self.current_remote_addr, as_str) {
135                Ok(x) => x,
136                Err(_) => {
137                    self.current_offset = std::usize::MAX;
138                    continue;
139                }
140            };
141
142            self.current_offset += consumed;
143            return Ok(instance);
144        }
145    }
146}