use err_derive::Error;
use mdns::RecordKind;
use std::time::Duration;
use tokio_stream::{Stream, StreamExt};
use std::sync::atomic::{AtomicBool, Ordering};
const HC_SERVICE_PROTOCOL: &str = "._udp";
const BROADCAST_INTERVAL_SEC: u64 = 8;
const QUERY_INTERVAL_SEC: u64 = 5;
const MAX_TXT_SIZE: usize = 192;
#[derive(Debug, Error)]
pub enum MdnsError {
#[error(display = "Regular Mdns error {}", _0)]
Mdns(#[error(source)] mdns::Error),
#[error(display = "Base64 decoding error {}", _0)]
Base64(#[error(source)] base64::DecodeError),
}
pub fn mdns_kill_thread(can_run: ::std::sync::Arc<AtomicBool>) {
can_run.store(false, Ordering::Relaxed);
}
pub fn mdns_create_broadcast_thread(
service_type: String,
service_name: String,
buffer: &[u8],
) -> ::std::sync::Arc<AtomicBool> {
let svc_type = format!("_{}{}", service_type, HC_SERVICE_PROTOCOL);
assert!(
svc_type.len() < 63,
"len = {} ({}) ; {}",
svc_type.len(),
service_type.len(),
service_type
);
assert!(service_name.len() < 63);
let can_run = ::std::sync::Arc::new(AtomicBool::new(true));
let can_run_clone = can_run.clone();
let mut b64 = format!(
"u{}",
base64::encode_config(buffer, base64::URL_SAFE_NO_PAD)
);
let _handle = tokio::task::spawn(async move {
let mut substrs = Vec::new();
while b64.len() > MAX_TXT_SIZE {
let start: String = b64.drain(..MAX_TXT_SIZE).collect();
substrs.push(start);
}
substrs.push(b64);
let txts: Vec<_> = substrs.iter().map(AsRef::as_ref).collect();
let responder = libmdns::Responder::new().unwrap();
let _svc = responder.register(svc_type, service_name, 0, &txts);
loop {
tokio::time::sleep(::std::time::Duration::from_secs(BROADCAST_INTERVAL_SEC)).await;
if !can_run_clone.load(Ordering::Relaxed) {
break;
}
}
});
can_run
}
#[derive(Debug, Clone)]
pub struct MdnsResponse {
pub service_type: String,
pub service_name: String,
pub addr: std::net::IpAddr,
pub buffer: Vec<u8>,
}
#[allow(clippy::let_and_return)]
pub fn mdns_listen(service_type: String) -> impl Stream<Item = Result<MdnsResponse, MdnsError>> {
let svc_type = format!("_{}{}.local", service_type, HC_SERVICE_PROTOCOL);
let query = mdns::discover::all(svc_type, Duration::from_secs(QUERY_INTERVAL_SEC))
.expect("mdns Discover failed");
let response_stream = query.listen();
let mdns_stream = response_stream
.filter(move |res| {
match res {
Ok(response) => !response.is_empty() && response.ip_addr().is_some(),
Err(_) => true, }
})
.map(|maybe_response| {
if let Err(e) = maybe_response {
return Err(MdnsError::Mdns(e));
}
let response = maybe_response.unwrap();
let addr = response.ip_addr().unwrap(); let mut buffer = Vec::new();
let mut service_name = String::new();
let mut service_type = String::new();
for answer in response.answers {
match answer.kind {
RecordKind::TXT(txts) => {
let mut b64 = String::new();
for txt in txts {
b64.push_str(&txt);
}
buffer = match base64::decode_config(&b64[1..], base64::URL_SAFE_NO_PAD) {
Err(e) => return Err(MdnsError::Base64(e)),
Ok(s) => s,
};
}
RecordKind::PTR(ptr) => {
service_name = ptr
.split('.')
.into_iter()
.next()
.expect("Found service without a name")
.to_string();
let names: Vec<&str> = answer.name.split("._").collect();
service_type = names[0][1..].to_string();
}
_ => {}
}
}
Ok(MdnsResponse {
service_type,
service_name,
addr,
buffer,
})
});
mdns_stream
}