use crate::browse::{Result, Service, ServiceEventType};
use crate::ffi::windows::{
DNS_FREE_TYPE_DnsFreeRecordList, DnsFree, DnsServiceBrowse, DnsServiceBrowseCancel,
_DNS_SERVICE_BROWSE_REQUEST__bindgen_ty_1 as BrowseCallbackUnion, DNS_QUERY_REQUEST_VERSION1,
DNS_TYPE_A, DNS_TYPE_AAAA, DNS_TYPE_PTR, DNS_TYPE_SRV, DNS_TYPE_TEXT, DWORD, PDNS_RECORD,
PVOID, _DNS_SERVICE_BROWSE_REQUEST, _DNS_SERVICE_CANCEL,
};
use crate::os::windows::to_utf16;
use crate::ServiceBrowserBuilder;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::io::{Error as IoError, ErrorKind};
use std::net::{Ipv4Addr, Ipv6Addr};
use std::ptr::null_mut;
use std::str::Utf8Error;
use std::sync::mpsc::{sync_channel, Receiver, RecvTimeoutError, SyncSender};
use std::time::Duration;
use thiserror::Error;
use widestring::{U16CStr, U16CString};
use winapi::shared::winerror::DNS_REQUEST_PENDING;
#[derive(Debug, Error)]
pub enum BrowseError {
#[error("Timeout waiting for data")]
Timeout,
#[error("IO Error: {0}")]
IoError(#[from] IoError),
#[error("Error from DNS Service APIs: {0}")]
DnsError(DWORD),
#[error("Error creating string from UTF8: {0}")]
Utf8StringError(#[from] Utf8Error),
}
enum DnsRecord {
Ptr(String),
Srv { port: u16, hostname: String },
Txt(HashMap<String, String>),
A(Ipv4Addr),
Aaaa(Ipv6Addr),
}
fn process_name(name: &str) -> Option<(String, String, String)> {
let mut split = name.split('.').collect::<Vec<&str>>().into_iter().rev();
let domain = split.next()?;
let ip_protocol = split.next()?;
let protocol = split.next()?;
let name: String = split.collect::<Vec<&str>>().join(".");
Some((name, format!("{}.{}", protocol, ip_protocol), domain.into()))
}
fn services_from_record_list(start_record: PDNS_RECORD) -> Result<Service> {
let mut service = Service {
name: "".to_string(),
regtype: "".to_string(),
interface_index: None,
domain: "".to_string(),
event_type: ServiceEventType::Added,
hostname: "".to_string(),
port: 0,
txt_record: None,
};
let mut current_record = start_record;
while !current_record.is_null() {
match DnsRecord::try_from(current_record) {
Ok(DnsRecord::Ptr(name)) => {
if let Some((name, regtype, domain)) = process_name(&name) {
service.name = name;
service.regtype = regtype;
service.domain = domain;
}
}
Ok(DnsRecord::Srv { port, hostname }) => {
service.port = port;
service.hostname = hostname;
}
Ok(DnsRecord::Txt(hash)) => {
if !hash.is_empty() {
service.txt_record = Some(hash);
}
}
Ok(DnsRecord::A(_ip)) => {}
Ok(DnsRecord::Aaaa(_ip)) => {}
Err(e) => {
error!("Error processing DNS record, skipping it: {:?}", e);
}
}
unsafe {
current_record = (*current_record).pNext;
}
}
Ok(service)
}
impl TryFrom<PDNS_RECORD> for DnsRecord {
type Error = BrowseError;
fn try_from(record: PDNS_RECORD) -> std::result::Result<Self, Self::Error> {
if record.is_null() {
return Err(IoError::from(ErrorKind::InvalidData).into());
}
let t = unsafe { (*record).wType } as u32; match t {
DNS_TYPE_PTR => unsafe {
let data = (*record).Data.Ptr;
let name = U16CString::from_ptr_str(data.pNameHost);
let name = name.to_string_lossy();
trace!("PTR Name: {}", name);
Ok(DnsRecord::Ptr(name))
},
DNS_TYPE_SRV => unsafe {
let data = (*record).Data.Srv;
let port = data.wPort;
let hostname = U16CString::from_ptr_str(data.pNameTarget).to_string_lossy();
trace!("SRV Port: {}", port);
trace!("SRV Target: {hostname}");
Ok(DnsRecord::Srv { port, hostname })
},
DNS_TYPE_TEXT => unsafe {
let strings = std::slice::from_raw_parts(
(*record).Data.Txt.pStringArray.as_ptr(),
(*record).Data.Txt.dwStringCount as _,
);
let mut hash = HashMap::with_capacity(strings.len());
for str_ptr in strings {
match U16CStr::from_ptr_str(*str_ptr).to_string() {
Ok(s) => {
let mut split = s.split('=');
match (split.next(), split.next()) {
(Some(k), Some(v)) => {
hash.insert(k.to_string(), v.to_string());
}
_ => {
warn!("Failed to get key=value from TXT string: {}", s);
}
}
}
Err(e) => {
error!("Error parsing TXT string: {:?}", e);
}
}
}
Ok(DnsRecord::Txt(hash))
},
DNS_TYPE_A => unsafe {
let data = (*record).Data.A;
let ip = Ipv4Addr::from(data.IpAddress.to_le_bytes());
trace!("IP Address: {}", ip);
Ok(DnsRecord::A(ip))
},
DNS_TYPE_AAAA => unsafe {
let data = (*record).Data.AAAA;
let addr = data.Ip6Address; let ip = Ipv6Addr::from(addr.IP6Word);
trace!("IPv6 Address: {}", ip);
Ok(DnsRecord::Aaaa(ip))
},
_ => {
warn!("Got record: {:?}, unhandled type", t);
Err(IoError::from(ErrorKind::InvalidData).into())
}
}
}
}
pub unsafe extern "C" fn browse_callback(status: DWORD, context: PVOID, record: PDNS_RECORD) {
info!("Browse callback: {}", status);
if status != 0 {
error!("Error in callback: {}", status);
return;
}
if context.is_null() {
error!("Callback has nil context, returning early");
return;
}
let tx_ptr: *mut SyncSender<Service> = context as _;
let tx = &*tx_ptr;
match services_from_record_list(record) {
Ok(service) => {
trace!("{:?}", service);
match tx.send(service) {
Ok(_) => {}
Err(e) => {
error!("Error sending service info: {:?}", e);
}
}
}
Err(e) => {
error!("Error creating services from PDNS_RECORD: {:?}", e);
}
}
DnsFree(record as _, DNS_FREE_TYPE_DnsFreeRecordList);
}
pub struct ServiceBrowser {
cancel: _DNS_SERVICE_CANCEL,
context: *mut SyncSender<Service>,
receiver: Receiver<Service>,
}
impl Drop for ServiceBrowser {
fn drop(&mut self) {
unsafe {
let r = DnsServiceBrowseCancel(&mut self.cancel);
if r != 0 {
error!("Error canceling service browser: {}", r);
}
self.free_context();
}
}
}
impl ServiceBrowser {
fn free_context(&mut self) {
if !self.context.is_null() {
_ = unsafe { Box::from_raw(self.context) };
self.context = null_mut();
}
}
pub fn recv_timeout(&self, timeout: Duration) -> Result<Service> {
match self.receiver.recv_timeout(timeout) {
Ok(service) => Ok(service),
Err(RecvTimeoutError::Timeout) => Err(BrowseError::Timeout),
Err(RecvTimeoutError::Disconnected) => {
Err(BrowseError::IoError(IoError::from(ErrorKind::BrokenPipe)))
}
}
}
}
pub fn browse(builder: ServiceBrowserBuilder) -> Result<ServiceBrowser> {
let name = format!("{}.local", builder.regtype);
let mut name = to_utf16(name);
let callback = BrowseCallbackUnion {
pBrowseCallback: Some(browse_callback),
};
let (tx, rx) = sync_channel::<Service>(10);
let tx = Box::into_raw(Box::new(tx));
let mut request = _DNS_SERVICE_BROWSE_REQUEST {
Version: DNS_QUERY_REQUEST_VERSION1,
InterfaceIndex: 0,
QueryName: name.as_mut_ptr(),
__bindgen_anon_1: callback,
pQueryContext: tx as _,
};
unsafe {
let mut cancel: _DNS_SERVICE_CANCEL = std::mem::zeroed();
let r = DnsServiceBrowse(&mut request, &mut cancel) as u32;
if r != DNS_REQUEST_PENDING {
return Err(BrowseError::DnsError(r));
}
Ok(ServiceBrowser {
cancel,
context: tx,
receiver: rx,
})
}
}