use widestring::U16CString;
use windows::Win32::Foundation::ERROR_SUCCESS;
use windows::Win32::NetworkManagement::Dns;
use windows::Win32::System::SystemInformation::{ComputerNameDnsHostname, GetComputerNameExW};
use windows::core::PCWSTR;
use windows::core::PWSTR;
use windows_strings::*;
use std::ffi::c_void;
use std::ptr::NonNull;
use log::{error, trace};
use std::num::NonZeroU32;
use tokio::sync::oneshot;
use crate::{ServiceRegistrationError, TxtRecordValue};
const DNS_QUERY_REQUEST_VERSION1: u32 = 1;
const DNS_REQUEST_PENDING: u32 = 9506;
struct ServiceInstance(NonNull<Dns::DNS_SERVICE_INSTANCE>);
unsafe impl Send for ServiceInstance {}
unsafe impl Sync for ServiceInstance {}
impl ServiceInstance {
fn new(
instance_name: &str,
host_name: &str,
port: u16,
priority: u16,
weight: u16,
key_value_pairs: &[(U16CString, U16CString)],
) -> Result<Self, ServiceRegistrationError> {
let instance_name = HSTRING::from(instance_name);
let host_name = HSTRING::from(host_name);
let key_ptrs: Vec<PCWSTR> = key_value_pairs
.iter()
.map(|(k, _)| PCWSTR(k.as_ptr()))
.collect();
let value_ptrs: Vec<PCWSTR> = key_value_pairs
.iter()
.map(|(_, v)| PCWSTR(v.as_ptr()))
.collect();
let instance = unsafe {
Dns::DnsServiceConstructInstance(
&instance_name,
&host_name,
None,
None,
port,
priority,
weight,
key_value_pairs.len().try_into().unwrap(),
key_ptrs.as_ptr(),
value_ptrs.as_ptr(),
)
};
NonNull::new(instance).map(Self).ok_or_else(|| {
ServiceRegistrationError::RegistrationFailed(
"DnsServiceConstructInstance returned null".into(),
)
})
}
fn as_ptr(&self) -> *mut Dns::DNS_SERVICE_INSTANCE {
self.0.as_ptr()
}
}
impl Drop for ServiceInstance {
fn drop(&mut self) {
unsafe {
Dns::DnsServiceFreeInstance(self.0.as_ptr());
}
}
}
struct CallbackContext {
status_tx: Option<oneshot::Sender<(u32, ServiceInstance)>>,
instance: ServiceInstance,
}
pub struct ServiceRegistration {
instance: Option<ServiceInstance>,
interface_index: Option<NonZeroU32>,
}
impl ServiceRegistration {
pub(crate) async fn new(
service_type: &String,
port: u16,
name: &Option<String>,
host: &Option<String>,
domain: &Option<String>,
interface_index: Option<NonZeroU32>,
txt_record_values: &[(String, TxtRecordValue)],
) -> Result<ServiceRegistration, ServiceRegistrationError> {
let key_value_pairs = txt_record_values
.iter()
.map(|(key, value)| {
let k = U16CString::from_str(key.as_str()).map_err(|err| {
ServiceRegistrationError::ParameterContainsInteriorNulByte(
key.clone(),
err.nul_position(),
)
})?;
let v = match value {
TxtRecordValue::KeyOnly => U16CString::new(),
TxtRecordValue::String(s) => {
U16CString::from_str(s.as_str()).map_err(|err| {
ServiceRegistrationError::ParameterContainsInteriorNulByte(
s.clone(),
err.nul_position(),
)
})?
}
};
Ok((k, v))
})
.collect::<Result<Vec<_>, ServiceRegistrationError>>()?;
let domain_str = domain.as_deref().unwrap_or("local");
let hostname = if let Some(host) = host {
host
} else {
&get_hostname().map_err(ServiceRegistrationError::HostnameUnavailable)?
};
let instance_name = name.as_deref().unwrap_or(hostname);
let priority = 0;
let weight = 0;
let instance = ServiceInstance::new(
&format!("{instance_name}.{service_type}.{domain_str}"),
&format!("{hostname}.{domain_str}"),
port,
priority,
weight,
&key_value_pairs,
)?;
let (tx, rx) = oneshot::channel::<(u32, ServiceInstance)>();
let request = build_request(instance, interface_index, Some(tx));
let result = unsafe { Dns::DnsServiceRegister(&request, None) };
trace!("DnsServiceRegister result: {result}");
if result != DNS_REQUEST_PENDING {
let context_ptr = request.pQueryContext as *mut CallbackContext;
unsafe {
drop(Box::from_raw(context_ptr));
}
return Err(ServiceRegistrationError::RegistrationError(format!(
"DnsServiceRegister failed with status: {result}"
)));
}
let (status, instance) = rx.await.map_err(|_| {
ServiceRegistrationError::RegistrationError(
"registration callback channel closed unexpectedly".into(),
)
})?;
if status != ERROR_SUCCESS.0 {
return Err(ServiceRegistrationError::RegistrationError(format!(
"registration failed with status: {status}"
)));
}
Ok(ServiceRegistration {
instance: Some(instance),
interface_index,
})
}
pub async fn unregister(mut self) -> Result<(), String> {
let (tx, rx) = oneshot::channel::<(u32, ServiceInstance)>();
self.deregister(Some(tx))?;
let (status, _instance) = rx
.await
.map_err(|_| "deregistration callback channel closed unexpectedly".to_string())?;
if status != ERROR_SUCCESS.0 {
return Err(format!(
"service deregistration failed with status: {status}"
));
}
Ok(())
}
fn deregister(
&mut self,
status_tx: Option<oneshot::Sender<(u32, ServiceInstance)>>,
) -> Result<(), String> {
let Some(instance) = self.instance.take() else {
return Ok(());
};
let request = build_request(instance, self.interface_index, status_tx);
let result = unsafe { Dns::DnsServiceDeRegister(&request, None) };
trace!("DnsServiceDeRegister result: {result}");
if result != DNS_REQUEST_PENDING {
let context_ptr = request.pQueryContext as *mut CallbackContext;
unsafe {
drop(Box::from_raw(context_ptr));
}
return Err(format!(
"deregistration failed to start with status {result}"
));
}
Ok(())
}
}
impl Drop for ServiceRegistration {
fn drop(&mut self) {
self.deregister(None).ok();
}
}
fn build_request(
instance: ServiceInstance,
interface_index: Option<NonZeroU32>,
status_tx: Option<oneshot::Sender<(u32, ServiceInstance)>>,
) -> Dns::DNS_SERVICE_REGISTER_REQUEST {
let service_instance = instance.as_ptr();
let context = Box::new(CallbackContext {
status_tx,
instance,
});
let context_ptr = Box::into_raw(context) as *mut c_void;
Dns::DNS_SERVICE_REGISTER_REQUEST {
Version: DNS_QUERY_REQUEST_VERSION1,
InterfaceIndex: interface_index.map_or(0, |idx| idx.get()),
pServiceInstance: service_instance,
pRegisterCompletionCallback: Some(completion_callback),
pQueryContext: context_ptr,
..Default::default()
}
}
extern "system" fn completion_callback(
status: u32,
context: *const c_void,
_service_instance: *const Dns::DNS_SERVICE_INSTANCE,
) {
let ctx = unsafe { Box::from_raw(context as *mut CallbackContext) };
let CallbackContext {
status_tx,
instance,
} = *ctx;
if status == ERROR_SUCCESS.0 {
trace!("DNS-SD operation completed successfully");
} else {
error!("DNS-SD operation failed with status: {status}");
}
match status_tx {
Some(status_tx) => {
let _ = status_tx.send((status, instance));
}
None => drop(instance),
}
}
fn get_hostname() -> Result<String, String> {
let mut size: u32 = 0;
unsafe {
let _ = GetComputerNameExW(ComputerNameDnsHostname, None, &mut size);
}
if size == 0 {
return Err("hostname empty".to_string());
}
let mut buffer = vec![0u16; size as usize];
unsafe {
if let Err(err) = GetComputerNameExW(
ComputerNameDnsHostname,
Some(PWSTR(buffer.as_mut_ptr())),
&mut size,
) {
Err(format!("GetComputerNameExW failed: {err}"))
} else {
Ok(String::from_utf16_lossy(&buffer[..size as usize]))
}
}
}