#![feature(try_trait)]
#![feature(never_type)]
#![feature(drain_filter)]
extern crate lib3h_protocol;
use log::{debug, error, trace};
use url::Url;
use zeroize::Zeroize;
use std::{
net::{self, SocketAddr, ToSocketAddrs},
time::Instant,
};
use lib3h_protocol::{
discovery::{error::DiscoveryResult, Discovery},
uri::Lib3hUri,
};
pub mod error;
pub use error::{MulticastDnsError, MulticastDnsResult};
pub mod dns;
pub use dns::*;
pub mod builder;
pub use builder::MulticastDnsBuilder;
pub mod record;
use record::{MapRecord, Record};
const READ_BUF_SIZE: usize = 4_096;
const _PROBE_QUERY_DELAY_MS: u64 = 250;
const SERVICE_LISTENER_PORT: u16 = 8585;
const _FAIL_SAFE_TRESHOLD: u16 = 1_000;
const MDNS_MULCAST_IPV4_ADRESS: &str = "224.0.0.251";
const DEFAULT_BIND_ADRESS: &str = "0.0.0.0";
const DEFAULT_TTL: u32 = 255;
const DEFAULT_QUERY_INTERVAL_MS: u128 = 30_000;
pub struct MulticastDns {
pub(crate) bind_address: String,
pub(crate) bind_port: u16,
pub(crate) multicast_loop: bool,
pub(crate) multicast_ttl: u32,
pub(crate) multicast_address: String,
pub(crate) timestamp: Instant,
pub(crate) query_interval_ms: u128,
pub(crate) send_socket: net::UdpSocket,
pub(crate) recv_socket: net::UdpSocket,
buffer: [u8; READ_BUF_SIZE],
pub(crate) own_map_record: MapRecord,
pub(crate) map_record: MapRecord,
}
impl MulticastDns {
pub fn address(&self) -> &str {
&self.bind_address
}
pub fn port(&self) -> u16 {
self.bind_port
}
pub fn multicast_loop(&self) -> bool {
self.multicast_loop
}
pub fn multicast_ttl(&self) -> u32 {
self.multicast_ttl
}
pub fn multicast_address(&self) -> &str {
&self.multicast_address
}
pub fn records(&self) -> &MapRecord {
&self.map_record
}
pub fn own_urls(&self) -> Vec<String> {
self.own_map_record
.iter()
.flat_map(|(_, v)| v.iter().map(|r| r.url.clone()).collect::<Vec<String>>())
.collect()
}
pub fn urls(&self) -> Vec<Lib3hUri> {
self.map_record
.iter()
.flat_map(|(_, v)| {
v.iter()
.filter_map(|r| match Url::parse(&r.url) {
Ok(url) => Some(url.into()),
Err(_) => None,
})
.collect::<Vec<Lib3hUri>>()
})
.collect()
}
pub fn own_networkids(&self) -> Vec<&str> {
self.own_map_record.keys().map(|k| k.as_str()).collect()
}
pub fn query_interval_ms(&self) -> u128 {
self.query_interval_ms
}
pub fn insert_own_record(&mut self, netid: &str, records: &[&str]) {
let records: Vec<Record> = records
.iter()
.map(|rec| Record::new(netid, rec, 255))
.collect();
self.own_map_record.insert(netid.to_string(), records);
}
pub fn insert_record(&mut self, netid: &str, records: &[&str]) {
let records: Vec<Record> = records
.iter()
.map(|rec| Record::new(netid, rec, 255))
.collect();
self.map_record.insert(netid.to_string(), records);
}
pub fn update_cache(&mut self, other_map_record: &MapRecord) {
self.map_record.update(other_map_record);
}
pub fn broadcast_message(&self, dmesg: &DnsMessage) -> Result<usize, MulticastDnsError> {
let addr = (self.multicast_address.as_ref(), self.bind_port)
.to_socket_addrs()?
.next()?;
let data = dmesg.to_raw()?;
Ok(self.send_socket.send_to(&data, &addr)?)
}
pub fn broadcast(&self, data: &[u8]) -> Result<usize, MulticastDnsError> {
let addr = (self.multicast_address.as_ref(), self.bind_port)
.to_socket_addrs()?
.next()?;
Ok(self.send_socket.send_to(&data, &addr)?)
}
pub fn recv(&mut self) -> MulticastDnsResult<Option<(Vec<u8>, SocketAddr)>> {
self.clear_buffer();
match self.recv_socket.recv_from(&mut self.buffer) {
Ok((0, _)) => Ok(None),
Ok((num_bytes, addr)) => {
debug!(
"Received '{}' bytes: {:?}",
num_bytes,
&self.buffer.to_vec()[..num_bytes]
);
let packet = self.buffer[..num_bytes].to_vec();
Ok(Some((packet, addr)))
}
Err(e) => {
if e.kind() == std::io::ErrorKind::WouldBlock {
Ok(None)
} else {
Err(e.into())
}
}
}
}
pub fn clear_buffer(&mut self) {
self.buffer.zeroize();
}
pub fn prune_cache(&mut self) {
for (_, records) in self.map_record.iter_mut() {
let _: Vec<Record> = records.drain_filter(|r| r.ttl == 0).collect();
}
}
pub fn update_ttl(&mut self) {
let own_urls = self.own_urls();
for (_netid, records) in self.map_record.iter_mut() {
for record in records {
if !own_urls.contains(&record.url) && record.ttl > 0 {
record.ttl -= 1;
}
}
}
}
pub fn query(&mut self) -> MulticastDnsResult<()> {
if let Some(query_message) = self.build_query_message() {
self.broadcast_message(&query_message)?;
self.broadcast_message(&query_message)?;
self.broadcast_message(&query_message)?;
}
Ok(())
}
fn responder(&mut self) -> MulticastDnsResult<()> {
loop {
match self.recv() {
Ok(Some((packet, sender_addr))) => {
let dmesg = DnsMessage::from_raw(&packet)?;
if dmesg.nb_answers > 0 {
if let Some(new_map_record) = MapRecord::from_dns_message(&dmesg) {
let own_networkids: Vec<String> = self
.own_networkids()
.iter()
.map(|v| (*v).to_string())
.collect();
for (netid, new_records) in new_map_record.iter() {
if own_networkids.contains(netid) {
let tmp_new_map_record =
MapRecord::with_record(netid, new_records);
self.update_cache(&tmp_new_map_record);
}
}
}
}
else if dmesg.nb_questions > 0 {
let question_list: Vec<&str> = dmesg
.questions
.iter()
.filter_map(|q| {
if q.query_class == 1 && q.query_type == 5 {
Some(q.domain_name.as_str())
} else {
None
}
})
.collect();
if let Some(response) =
self.own_map_record.to_dns_response_message(&question_list)
{
self.send_socket.send_to(&response.to_raw()?, sender_addr)?;
self.broadcast_message(&response)?;
}
}
}
Ok(None) => {
trace!(">> Nothing on the UDP stack");
break;
}
Err(e) => {
error!(
"Something went wrong while processing the UDP stack during update: '{}'",
e
);
break;
}
}
}
Ok(())
}
fn _build_probe_packet(&self) -> DnsMessage {
let questions: Vec<QuerySection> = self
.own_map_record
.keys()
.map(|k| QuerySection::new(k))
.collect();
DnsMessage {
nb_questions: questions.len() as u16,
questions,
..Default::default()
}
}
pub fn build_query_message(&self) -> Option<DnsMessage> {
if self.own_map_record.is_empty() {
None
} else {
let mut questions = Vec::new();
for (_netid, records) in self.own_map_record.iter() {
for rec in records {
questions.push(rec.to_question_section());
}
}
Some(DnsMessage {
nb_questions: questions.len() as u16,
questions,
..Default::default()
})
}
}
fn announcing(&mut self) -> MulticastDnsResult<()> {
let own_net_id_list = self.own_networkids();
if let Some(dmesg) = self
.own_map_record
.to_dns_response_message(&own_net_id_list)
{
self.broadcast_message(&dmesg)?;
self.broadcast_message(&dmesg)?;
self.broadcast_message(&dmesg)?;
}
Ok(())
}
}
impl Discovery for MulticastDns {
fn advertise(&mut self) -> DiscoveryResult<()> {
self.query()?;
self.announcing()?;
Ok(())
}
fn discover(&mut self) -> DiscoveryResult<Vec<Lib3hUri>> {
self.responder()?;
if self.timestamp.elapsed().as_millis() > self.query_interval_ms {
self.query()?;
self.timestamp = Instant::now();
}
self.update_ttl();
self.prune_cache();
Ok(self.urls())
}
fn release(&mut self) -> DiscoveryResult<()> {
for (_netid, records) in self.own_map_record.iter_mut() {
for rec in records.iter_mut() {
rec.ttl = 0;
}
}
let net_ids = self.own_networkids();
if let Some(release_dmesg) = self.own_map_record.to_dns_response_message(&net_ids) {
self.broadcast_message(&release_dmesg)?;
self.broadcast_message(&release_dmesg)?;
self.broadcast_message(&release_dmesg)?;
self.broadcast_message(&release_dmesg)?;
}
Ok(())
}
fn flush(&mut self) -> DiscoveryResult<()> {
self.map_record.clear();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use nanoid;
#[test]
fn it_should_loop_question() {
let mut mdns = MulticastDnsBuilder::new()
.bind_address("0.0.0.0")
.multicast_address("224.0.0.247")
.bind_port(55247)
.multicast_loop(true)
.multicast_ttl(255)
.build()
.expect("build fail");
let mut dmesg = DnsMessage::new();
dmesg.nb_questions = 1;
dmesg.questions = vec![QuerySection::new("lib3h.test.service")];
let _ = mdns.recv().expect("Fail to receive from the UDP socket.");
let _ = mdns.recv().expect("Fail to receive from the UDP socket.");
mdns.broadcast_message(&dmesg)
.expect("Fail to broadcast DNS Message.");
if let Some((resp, _addr)) = mdns.recv().expect("Fail to receive from the UDP socket.") {
let dmesg_from_resp = DnsMessage::from_raw(&resp).unwrap();
assert_eq!(
&dmesg_from_resp.questions[0].domain_name,
"lib3h.test.service"
);
}
}
#[test]
fn it_should_loop_answer() {
let mut mdns = MulticastDnsBuilder::new()
.bind_address("0.0.0.0")
.multicast_address("224.0.0.248")
.bind_port(56248)
.multicast_loop(true)
.multicast_ttl(255)
.build()
.expect("build fail");
let mut dmesg = DnsMessage::new();
let answers = vec![
AnswerSection::new("holonaute.local.", &Target::new("wss://192.168.0.88")),
AnswerSection::new("mistral.local.", &Target::new("wss://192.168.0.77")),
];
dmesg.nb_answers = answers.len() as u16;
dmesg.answers = answers;
let _ = mdns.recv().expect("Fail to receive from the UDP socket.");
let _ = mdns.recv().expect("Fail to receive from the UDP socket.");
let _ = mdns.recv().expect("Fail to receive from the UDP socket.");
mdns.broadcast_message(&dmesg)
.expect("Fail to broadcast DNS Message.");
if let Some((resp, _addr)) = mdns.recv().expect("Fail to receive from the UDP socket.") {
let dmesg_from_resp = DnsMessage::from_raw(&resp).unwrap();
println!("dmesg = {:#?}", &dmesg);
println!("dmesg_from_resp = {:#?}", &dmesg_from_resp);
assert_eq!(dmesg, dmesg_from_resp);
}
}
#[test]
fn release_test() {
let networkid = format!("holonaute-release-{}.holo.host", nanoid::simple());
let mut mdns = MulticastDnsBuilder::new()
.own_record(&networkid, &["wss://192.168.0.88:88088?a=to-keep"])
.multicast_address("224.0.0.251")
.bind_port(8251)
.build()
.expect("Fail to build mDNS.");
let mut mdns_releaser = MulticastDnsBuilder::new()
.own_record(&networkid, &["wss://192.168.0.87:88088?a=to-release"])
.multicast_address("224.0.0.251")
.bind_port(8251)
.build()
.expect("Fail to build mDNS.");
mdns_releaser
.advertise()
.expect("Fail to advertise my existence during release test.");
::std::thread::sleep(::std::time::Duration::from_millis(100));
mdns.discover().expect("Fail to discover.");
::std::thread::sleep(::std::time::Duration::from_millis(100));
println!("mdns = {:#?}", &mdns.map_record);
{
let records = mdns
.map_record
.get(&networkid)
.expect("Fail to get records from the networkid after 'Advertising'.");
assert_eq!(records.len(), 2);
}
mdns_releaser
.release()
.expect("Fail to release myself from the participants on the network.");
::std::thread::sleep(::std::time::Duration::from_millis(100));
mdns.discover().expect("Fail to discover.");
::std::thread::sleep(::std::time::Duration::from_millis(100));
println!("mdns = {:#?}", &mdns.map_record);
{
let records = mdns
.map_record
.get(&networkid)
.expect("Fail to get records from the networkid after 'Releasing'.");
assert_eq!(records.len(), 1);
}
}
#[test]
fn query_test() -> MulticastDnsResult<()> {
let networkid = format!("holonaute-query-{}.holo.host", nanoid::simple());
let mut mdns_actor1 = MulticastDnsBuilder::new()
.own_record(&networkid, &["wss://192.168.0.88:88088?a=hc-actor1"])
.multicast_address("224.0.0.223")
.bind_port(8223)
.query_interval_ms(1)
.build()
.expect("Fail to build mDNS.");
let mut mdns_actor2 = MulticastDnsBuilder::new()
.own_record(&networkid, &["wss://192.168.0.87:88088?a=hc-actor2"])
.multicast_address("224.0.0.223")
.bind_port(8223)
.query_interval_ms(1)
.build()
.expect("Fail to build mDNS.");
mdns_actor1.query()?;
::std::thread::sleep(::std::time::Duration::from_millis(10));
mdns_actor1.query()?;
::std::thread::sleep(::std::time::Duration::from_millis(10));
mdns_actor1.discover()?;
::std::thread::sleep(::std::time::Duration::from_millis(10));
mdns_actor1.discover()?;
let records = mdns_actor1
.map_record
.get(&networkid)
.expect("Fail to get records from the networkid during Query test on mdns_actor1")
.to_vec();
assert_eq!(records.len(), 1);
eprintln!("mdns_actor1 = {:#?}", &mdns_actor1.map_record);
mdns_actor2.query()?;
::std::thread::sleep(::std::time::Duration::from_millis(10));
mdns_actor2.discover()?;
let mut records = mdns_actor2
.map_record
.get(&networkid)
.expect("Fail to get records from the networkid during Query test on mdns_actor2")
.to_vec();
eprintln!("mdns_actor2 = {:#?}", &mdns_actor2.map_record);
records.sort_by(|a, b| a.url.cmp(&b.url));
assert_eq!(records.len(), 2);
assert_eq!(records[0].url, "wss://192.168.0.87:88088?a=hc-actor2");
assert_eq!(records[1].url, "wss://192.168.0.88:88088?a=hc-actor1");
Ok(())
}
#[test]
fn advertise_test() {
let networkid = format!("holonaute-advertise-{}.holo.host", nanoid::simple());
let mut mdns_actor1 = MulticastDnsBuilder::new()
.own_record(&networkid, &["wss://192.168.0.88:88088?a=hc-actor1"])
.multicast_address("224.0.0.252")
.bind_port(8252)
.build()
.expect("Fail to build mDNS.");
eprintln!("bind addr = {}", mdns_actor1.multicast_address());
let mut mdns_actor2 = MulticastDnsBuilder::new()
.own_record(&networkid, &["wss://192.168.0.88:88088?a=hc-actor2"])
.multicast_address("224.0.0.252")
.bind_port(8252)
.build()
.expect("Fail to build mDNS.");
mdns_actor2
.advertise()
.expect("Fail to advertise mdns_actor1 existence during release test.");
::std::thread::sleep(::std::time::Duration::from_millis(10));
mdns_actor1
.advertise()
.expect("Fail to advertise mdns_actor2 existence during release test.");
::std::thread::sleep(::std::time::Duration::from_millis(10));
mdns_actor1.discover().expect("Fail to discover.");
eprintln!("mdns = {:#?}", &mdns_actor1.map_record);
let records = mdns_actor1
.map_record
.get(&networkid)
.expect("Fail to get records from the networkid");
assert_eq!(records.len(), 2);
}
}