use std::collections::HashSet;
use crate::{error::Error, net::build_multicast_socket, Service, ServiceInfo, Udis};
use log::{error, trace};
use tokio::{
sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
task::JoinHandle,
};
#[derive(Debug)]
pub struct AsyncUdis {
_udis: Udis,
bg_task_jh: JoinHandle<Result<(), Error>>,
cmd_tx: UnboundedSender<Cmd>,
serv_info_rx: UnboundedReceiver<ServiceInfo>,
}
enum Cmd {
Shutdown,
}
impl AsyncUdis {
pub(crate) fn build(udis: Udis) -> Self {
let (cmd_tx, cmd_rx) = unbounded_channel();
let (serv_info_tx, serv_info_rx) = unbounded_channel();
let udis_bg = udis.clone();
let bg_task_jh =
tokio::task::spawn(async move { async_task(udis_bg, cmd_rx, serv_info_tx).await });
Self {
_udis: udis,
bg_task_jh,
cmd_tx,
serv_info_rx,
}
}
pub async fn find_service(&mut self) -> Result<ServiceInfo, Error> {
if let Some(serv_info) = self.serv_info_rx.recv().await {
Ok(serv_info)
} else {
Err(Error::ServiceInfoChannelClosed)
}
}
pub async fn shutdown(self) -> Result<(), Error> {
self.cmd_tx
.send(Cmd::Shutdown)
.map_err(|_| Error::FailedToShutdownUdisTask)?;
self.bg_task_jh.await??;
Ok(())
}
}
async fn async_task(
udis: Udis,
mut cmd_rx: UnboundedReceiver<Cmd>,
serv_info_tx: UnboundedSender<ServiceInfo>,
) -> Result<(), Error> {
let (disc_addr, socket) = build_multicast_socket()?;
trace!("joined udis notify network on {disc_addr}");
for service in &udis.services {
match service {
Service::Host { kind, port } => {
trace!("hosting service `{}` on port {}", kind, port);
}
Service::Search { kind } => {
trace!("searching for service `{}`", kind);
}
}
}
let socket: tokio::net::UdpSocket = tokio::net::UdpSocket::from_std(socket.into())?;
let mut registry = HashSet::<Udis>::new();
let notify_message = serde_json::to_vec(&udis).map_err(Error::FailedToSerialiseNotifyMsg)?;
socket.send_to(¬ify_message[..], &disc_addr).await?;
let mut buf = [0; 1024];
loop {
tokio::select! {
cmd = cmd_rx.recv() => {
match cmd {
Some(cmd) => match cmd {
Cmd::Shutdown => break,
}
None => break,
}
},
recv_res = socket.recv(&mut buf) => {
let received = match recv_res {
Ok(r) => r,
Err(e) => {
error!("Error while receiving udis notify messages (will continue): {e}");
continue;
}
};
let peer: Udis =
serde_json::from_slice(&buf[..received])
.map_err(Error::FailedToDeserialiseNotifyMsg)?;
if peer == udis {
continue;
}
if registry.contains(&peer) {
continue;
}
registry.insert(peer.clone());
if udis.get_wanted_services(&peer).count() > 0 {
trace!(
"notified of peer `{}` that wants one of our services",
peer.name
);
socket.send_to(¬ify_message[..], &disc_addr).await?;
}
for service in peer.get_wanted_services(&udis) {
let Service::Host { kind, port } = service else {
trace!("Non-host service returned by get_wanted_services, skipping");
continue;
};
trace!(
"found peer `{}` that hosts a service we want `{}` at {}:{}",
peer.name,
kind,
peer.addr,
port
);
let serv_info = ServiceInfo {
name: peer.name.clone(),
kind: kind.clone(),
addr: peer.addr,
port: *port,
};
serv_info_tx.send(serv_info)?;
}
}
}
}
trace!("udis background task shutting down");
Ok(())
}