use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
use bytes::BytesMut;
use clap::Parser;
use env_logger::Target;
use rtc_mdns::{Mdns, MdnsConfig, MdnsEvent, MulticastSocket};
use sansio::Protocol;
use shared::{TaggedBytesMut, TransportContext, TransportProtocol};
use std::fs::OpenOptions;
use std::{io::Write, str::FromStr};
use tokio::net::UdpSocket;
#[derive(Parser, Debug)]
#[command(name = "mDNS Server + Query")]
#[command(version = "0.1.0")]
#[command(author = "Rain Liu <yliu@webrtc.rs>")]
#[command(about = "An example of mDNS Server + Query using sans-I/O rtc-mdns")]
struct Cli {
#[arg(short, long)]
debug: bool,
#[arg(short, long, default_value_t = format!("INFO"))]
log_level: String,
#[arg(short, long, default_value_t = format!(""))]
output_log_file: String,
#[arg(long, default_value = "10")]
timeout: u64,
#[arg(long, default_value = "1000")]
interval: u64,
}
fn get_local_ip() -> IpAddr {
if let Ok(socket) = std::net::UdpSocket::bind("0.0.0.0:0") {
if socket.connect("8.8.8.8:80").is_ok() {
if let Ok(addr) = socket.local_addr() {
if let IpAddr::V4(ip) = addr.ip() {
return ip.into();
}
}
}
}
Ipv4Addr::new(127, 0, 0, 1).into()
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Cli::parse();
let log_level = log::LevelFilter::from_str(&args.log_level)?;
let output_log_file = args.output_log_file;
if args.debug {
env_logger::Builder::new()
.target(if !output_log_file.is_empty() {
Target::Pipe(Box::new(
OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(output_log_file)?,
))
} else {
Target::Stdout
})
.format(|buf, record| {
writeln!(
buf,
"{}:{} [{}] {} - {}",
record.file().unwrap_or("unknown"),
record.line().unwrap_or(0),
record.level(),
chrono::Local::now().format("%H:%M:%S.%6f"),
record.args()
)
})
.filter(None, log_level)
.init();
}
log::info!("Creating mDNS server with local names and local ip");
let config_server = MdnsConfig::default()
.with_local_names(vec![
"webrtc-rs-mdns-1.local".to_string(),
"webrtc-rs-mdns-2.local".to_string(),
])
.with_local_ip(get_local_ip());
let mut mdns_server = Mdns::new(config_server);
let config_client = MdnsConfig::default()
.with_query_interval(Duration::from_millis(args.interval))
.with_query_timeout(Duration::from_secs(args.timeout));
let mut mdns_client = Mdns::new(config_client);
let multicast_udp_socket = UdpSocket::from_std(MulticastSocket::new().into_std()?)?;
let query_id_1 = mdns_client.query("webrtc-rs-mdns-1.local");
log::info!(
"Started query for webrtc-rs-mdns-1.local (query_id={}, timeout={}s, interval={}ms)",
query_id_1,
args.timeout,
args.interval
);
let mut query_1_answered = false;
let mut query_2_answered = false;
let mut query_id_2: Option<u64> = None;
let mut buf = vec![0u8; 1500];
loop {
while let Some(packet) = mdns_server.poll_write() {
log::trace!(
"mdns_server sending {} bytes from {} to {}",
packet.message.len(),
packet.transport.local_addr,
packet.transport.peer_addr,
);
multicast_udp_socket
.send_to(&packet.message, packet.transport.peer_addr)
.await?;
}
while let Some(packet) = mdns_client.poll_write() {
log::trace!(
"mdns_client sending {} bytes from {} to {}",
packet.message.len(),
packet.transport.local_addr,
packet.transport.peer_addr,
);
multicast_udp_socket
.send_to(&packet.message, packet.transport.peer_addr)
.await?;
}
if mdns_client.pending_query_count() == 0 {
if query_1_answered && query_2_answered {
log::info!("All queries answered successfully");
} else {
log::debug!("No more pending queries, exiting");
}
break;
}
let wait_duration = mdns_client
.poll_timeout()
.map(|t| t.saturating_duration_since(Instant::now()))
.unwrap_or(Duration::from_millis(100));
tokio::select! {
result = multicast_udp_socket.recv_from(&mut buf) => {
match result {
Ok((len, peer_addr)) => {
log::trace!("Received {} bytes from {} to {}", len, peer_addr, multicast_udp_socket.local_addr()?);
let now = Instant::now();
let msg = TaggedBytesMut {
now,
transport: TransportContext {
local_addr: multicast_udp_socket.local_addr()?,
peer_addr,
transport_protocol: TransportProtocol::UDP,
ecn: None,
},
message: BytesMut::from(&buf[..len]),
};
let msg_clone = TaggedBytesMut {
now,
transport: msg.transport.clone(),
message: msg.message.clone(),
};
if let Err(e) = mdns_server.handle_read(msg) {
log::trace!("server_a handle_read: {}", e);
}
if let Err(e) = mdns_client.handle_read(msg_clone) {
log::trace!("server_b handle_read: {}", e);
}
}
Err(e) => {
log::warn!("Socket recv error: {}", e);
}
}
}
_ = tokio::time::sleep(wait_duration) => {
let now = Instant::now();
let _ = mdns_server.handle_timeout(now);
if let Err(e) = mdns_client.handle_timeout(now) {
log::warn!("Failed to handle timeout: {}", e);
}
}
}
while let Some(event) = mdns_client.poll_event() {
match event {
MdnsEvent::QueryAnswered(id, addr) => {
if id == query_id_1 {
println!("query_id = {}, addr = {}", id, addr);
query_1_answered = true;
if query_id_2.is_none() {
let id = mdns_client.query("webrtc-rs-mdns-2.local");
query_id_2 = Some(id);
log::info!(
"Started query for webrtc-rs-mdns-2.local (query_id={}, timeout={}s, interval={}ms)",
id,
args.timeout,
args.interval
);
}
} else if query_id_2 == Some(id) {
println!("query_id = {}, addr = {}", id, addr);
query_2_answered = true;
}
}
MdnsEvent::QueryTimeout(id) => {
log::error!("Query {} timed out after {} seconds", id, args.timeout);
mdns_server.close()?;
mdns_client.close()?;
return Err(
format!("Query {} timed out after {} seconds", id, args.timeout).into(),
);
}
}
}
}
mdns_server.close()?;
mdns_client.close()?;
Ok(())
}