mod ffi;
use ffi::*;
use log::trace;
use std::{ffi::CString, num::NonZeroU32, sync::Mutex};
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use crate::{ServiceRegistrationError, TxtRecordValue};
pub struct ServiceRegistration {
reference: Option<(DNSServiceRef, Box<CallbackContext>)>,
}
struct CallbackContext {
tx: Mutex<Option<oneshot::Sender<DNSServiceErrorType>>>,
}
impl ServiceRegistration {
pub(crate) async fn new(
service_type: &String,
port: u16,
name: &Option<String>,
host: &Option<String>,
domain: &Option<String>,
interface_index: Option<NonZeroU32>,
txt_record_values: &[(String, TxtRecordValue)],
) -> Result<ServiceRegistration, ServiceRegistrationError> {
let service_type = CString::new(service_type.as_bytes()).map_err(|e| {
ServiceRegistrationError::ParameterContainsInteriorNulByte(
service_type.clone(),
e.nul_position(),
)
})?;
let name: Option<CString> = name
.as_ref()
.map(|name| {
CString::new(name.as_bytes()).map_err(|e| {
ServiceRegistrationError::ParameterContainsInteriorNulByte(
name.clone(),
e.nul_position(),
)
})
})
.transpose()?;
let domain: Option<CString> = domain
.as_ref()
.map(|domain| {
CString::new(domain.as_bytes()).map_err(|e| {
ServiceRegistrationError::ParameterContainsInteriorNulByte(
domain.clone(),
e.nul_position(),
)
})
})
.transpose()?;
let host: Option<CString> = host
.as_ref()
.map(|host| {
CString::new(host.as_bytes()).map_err(|e| {
ServiceRegistrationError::ParameterContainsInteriorNulByte(
host.clone(),
e.nul_position(),
)
})
})
.transpose()?;
let interface = interface_index.map(|i| i.get()).unwrap_or(0);
let mut txt_record = Vec::new();
for (key, value) in txt_record_values {
match value {
TxtRecordValue::KeyOnly => {
let key_bytes = key.as_bytes();
txt_record.push(key_bytes.len() as u8);
txt_record.extend_from_slice(key_bytes);
}
TxtRecordValue::String(s) => {
let key_bytes = key.as_bytes();
let value_bytes = s.as_bytes();
txt_record.push((key_bytes.len() + 1 + value_bytes.len()) as u8);
txt_record.extend_from_slice(key_bytes);
txt_record.push(b'=');
txt_record.extend_from_slice(value_bytes);
}
TxtRecordValue::Binary(b) => {
let key_bytes = key.as_bytes();
txt_record.push((key_bytes.len() + 1 + b.len()) as u8);
txt_record.extend_from_slice(key_bytes);
txt_record.push(b'=');
txt_record.extend_from_slice(b);
}
}
}
if txt_record_values.is_empty() {
txt_record.push(0);
}
trace!(
"Registering service with type: {:?}, port: {}, name: {:?}, host: {:?}, domain: {:?}, interface_index: {:?}, txt_record: {:?}",
service_type, port, name, host, domain, interface_index, txt_record
);
let registration = tokio::task::spawn_blocking(move || {
let mut service_ref = DNSServiceRef::default();
let (tx, mut rx) = oneshot::channel();
let ctx = Box::new(CallbackContext {
tx: Mutex::new(Some(tx)),
});
let ctx_ptr: *const CallbackContext = &*ctx;
let error = unsafe {
DNSServiceRegister(
&mut service_ref,
0,
interface,
name.as_ref().map_or(std::ptr::null(), |name| name.as_ptr()),
service_type.as_ptr(),
domain
.as_ref()
.map_or(std::ptr::null(), |domain| domain.as_ptr()),
host.as_ref().map_or(std::ptr::null(), |host| host.as_ptr()),
port.to_be(),
txt_record.len() as u16,
txt_record.as_ptr() as *const std::ffi::c_void,
Some(callback),
ctx_ptr as *mut std::ffi::c_void,
)
};
trace!("DNSServiceRegister returned: {error}");
if error != error::NO_ERROR {
if error == error::NAME_CONFLICT {
return Err(ServiceRegistrationError::NameConflict);
}
return Err(ServiceRegistrationError::RegistrationError(format!(
"DNSServiceRegister failed with error code: {error}"
)));
}
let process_error = unsafe { DNSServiceProcessResult(service_ref.0) };
trace!("DNSServiceProcessResult returned: {process_error}");
if process_error != error::NO_ERROR {
unsafe {
DNSServiceRefDeallocate(service_ref);
}
return Err(ServiceRegistrationError::RegistrationError(format!(
"DNSServiceProcessResult failed with error code: {process_error}"
)));
}
let callback_status = match rx.try_recv() {
Ok(status) => status,
Err(_) => {
unsafe {
DNSServiceRefDeallocate(service_ref);
}
return Err(ServiceRegistrationError::RegistrationError(
"DNSServiceRegister callback did not fire".into(),
));
}
};
trace!("Callback status: {callback_status}");
if callback_status != error::NO_ERROR {
unsafe {
DNSServiceRefDeallocate(service_ref);
}
if callback_status == error::NAME_CONFLICT {
return Err(ServiceRegistrationError::NameConflict);
}
return Err(ServiceRegistrationError::RegistrationError(format!(
"DNSServiceRegister callback failed with error code: {callback_status}"
)));
}
Ok(ServiceRegistration {
reference: Some((service_ref, ctx)),
})
})
.await
.map_err(|err| {
ServiceRegistrationError::RegistrationError(format!(
"service registration task failed: {err}"
))
})??;
Ok(registration)
}
pub async fn unregister(mut self) -> Result<(), String> {
if let Some(join_handle) = self.deallocate() {
join_handle
.await
.map_err(|err| format!("service unregistration panicked: {err:?}"))?;
}
Ok(())
}
fn deallocate(&mut self) -> Option<JoinHandle<()>> {
self.reference.take().map(|(service_ref, ctx)| {
tokio::task::spawn_blocking(move || unsafe {
DNSServiceRefDeallocate(service_ref);
drop(ctx);
})
})
}
}
impl Drop for ServiceRegistration {
fn drop(&mut self) {
self.deallocate();
}
}
unsafe extern "C" fn callback(
_service_ref: DNSServiceRef,
_flags: DNSServiceFlags,
error_code: DNSServiceErrorType,
_name: *const ::std::os::raw::c_char,
_regtype: *const ::std::os::raw::c_char,
_domain: *const ::std::os::raw::c_char,
context: *mut ::std::os::raw::c_void,
) {
trace!("Service registration callback with error_code: {error_code}");
let ctx = unsafe { &*(context as *const CallbackContext) };
if let Ok(mut lock) = ctx.tx.lock() {
if let Some(tx) = lock.take() {
let _ = tx.send(error_code);
}
}
}