use std::collections::HashMap;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::net::{Ipv4Addr, SocketAddr};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use crate::Error;
use super::{interval, Chart, Id};
use rand::rngs::OsRng;
use rand::RngCore;
use serde::Serialize;
use tokio::net::UdpSocket;
use tokio::sync::broadcast;
use tracing::info;
#[derive(Debug, Default)]
pub struct Yes;
#[derive(Debug, Default)]
pub struct No;
pub trait ToAssign: core::fmt::Debug {}
pub trait Assigned: ToAssign {}
pub trait NotAssigned: ToAssign {}
impl ToAssign for Yes {}
impl ToAssign for No {}
impl Assigned for Yes {}
impl NotAssigned for No {}
const DEFAULT_HEADER: u64 = 6_687_164_552_036_412_667;
const DEFAULT_PORT: u16 = 8080;
pub type Port = u16;
#[allow(clippy::pedantic)]
pub struct ChartBuilder<const N: usize, IdSet, PortSet, PortsSet>
where
IdSet: ToAssign,
PortSet: ToAssign,
PortsSet: ToAssign,
{
header: u64,
service_id: Option<Id>,
discovery_port: u16,
service_port: Option<u16>,
service_ports: [u16; N],
rampdown: interval::Params,
local: bool,
id_set: PhantomData<IdSet>,
port_set: PhantomData<PortSet>,
ports_set: PhantomData<PortsSet>,
}
impl<const N: usize> ChartBuilder<N, No, No, No> {
#[allow(clippy::new_without_default)] #[must_use]
pub fn new() -> ChartBuilder<N, No, No, No> {
ChartBuilder {
header: DEFAULT_HEADER,
service_id: None,
discovery_port: DEFAULT_PORT,
service_ports: [0u16; N],
service_port: None,
rampdown: interval::Params::default(),
local: false,
id_set: PhantomData {},
port_set: PhantomData {},
ports_set: PhantomData {},
}
}
}
impl<const N: usize, IdSet, PortSet, PortsSet> ChartBuilder<N, IdSet, PortSet, PortsSet>
where
IdSet: ToAssign,
PortSet: ToAssign,
PortsSet: ToAssign,
{
#[must_use]
pub fn with_id(self, id: Id) -> ChartBuilder<N, Yes, PortSet, PortsSet> {
ChartBuilder {
header: self.header,
discovery_port: self.discovery_port,
service_id: Some(id),
service_port: self.service_port,
service_ports: self.service_ports,
rampdown: self.rampdown,
local: self.local,
id_set: PhantomData {},
port_set: PhantomData {},
ports_set: PhantomData {},
}
}
#[must_use]
pub fn with_random_id(self) -> ChartBuilder<N, Yes, PortSet, PortsSet> {
let mut rng = OsRng::default();
let id = rng.next_u64();
info!("Using random id: {id}");
ChartBuilder {
header: self.header,
discovery_port: self.discovery_port,
service_id: Some(id),
service_port: self.service_port,
service_ports: self.service_ports,
rampdown: self.rampdown,
local: self.local,
id_set: PhantomData {},
port_set: PhantomData {},
ports_set: PhantomData {},
}
}
#[must_use]
pub fn with_service_port(self, port: u16) -> ChartBuilder<N, IdSet, Yes, No> {
ChartBuilder {
header: self.header,
discovery_port: self.discovery_port,
service_id: self.service_id,
service_port: Some(port),
service_ports: self.service_ports,
rampdown: self.rampdown,
local: self.local,
id_set: PhantomData {},
port_set: PhantomData {},
ports_set: PhantomData {},
}
}
#[must_use]
pub fn with_service_ports(self, ports: [u16; N]) -> ChartBuilder<N, IdSet, No, Yes> {
ChartBuilder {
header: self.header,
discovery_port: self.discovery_port,
service_id: self.service_id,
service_port: None,
service_ports: ports,
rampdown: self.rampdown,
local: self.local,
id_set: PhantomData {},
port_set: PhantomData {},
ports_set: PhantomData {},
}
}
#[must_use]
pub fn with_header(mut self, header: u64) -> ChartBuilder<N, IdSet, PortSet, PortsSet> {
self.header = header;
self
}
#[must_use]
pub fn with_discovery_port(mut self, port: u16) -> ChartBuilder<N, IdSet, PortSet, PortsSet> {
self.discovery_port = port;
self
}
#[must_use]
pub fn with_rampdown(
mut self,
min: Duration,
max: Duration,
rampdown: Duration,
) -> ChartBuilder<N, IdSet, PortSet, PortsSet> {
assert!(
min <= max,
"minimum duration: {min:?} must be smaller or equal to the maximum: {max:?}"
);
self.rampdown = interval::Params { rampdown, min, max };
self
}
#[must_use]
pub fn local_discovery(
mut self,
is_enabled: bool,
) -> ChartBuilder<N, IdSet, PortSet, PortsSet> {
self.local = is_enabled;
self
}
}
impl ChartBuilder<1, Yes, No, No> {
#[allow(clippy::missing_panics_doc)] pub fn custom_msg<Msg>(self, msg: Msg) -> Result<Chart<1, Msg>, Error>
where
Msg: Debug + Serialize + Clone,
{
let sock = open_socket(self.discovery_port, self.local)?;
Ok(Chart {
header: self.header,
service_id: self.service_id.unwrap(),
msg: [msg],
sock: Arc::new(sock),
map: Arc::new(Mutex::new(HashMap::new())),
interval: self.rampdown.into(),
broadcast: broadcast::channel(256).0,
})
}
}
impl ChartBuilder<1, Yes, Yes, No> {
#[allow(clippy::missing_panics_doc)]
pub fn finish(self) -> Result<Chart<1, Port>, Error> {
let sock = open_socket(self.discovery_port, self.local)?;
Ok(Chart {
header: self.header,
service_id: self.service_id.unwrap(),
msg: [self.service_port.unwrap()],
sock: Arc::new(sock),
map: Arc::new(Mutex::new(HashMap::new())),
interval: self.rampdown.into(),
broadcast: broadcast::channel(256).0,
})
}
}
impl<const N: usize> ChartBuilder<N, Yes, No, Yes> {
#[allow(clippy::missing_panics_doc)]
pub fn finish(self) -> Result<Chart<N, Port>, Error> {
let sock = open_socket(self.discovery_port, self.local)?;
Ok(Chart {
header: self.header,
service_id: self.service_id.unwrap(),
msg: self.service_ports,
sock: Arc::new(sock),
map: Arc::new(Mutex::new(HashMap::new())),
interval: self.rampdown.into(),
broadcast: broadcast::channel(256).0,
})
}
}
fn open_socket(port: u16, local_discovery: bool) -> Result<UdpSocket, Error> {
use socket2::{Domain, SockAddr, Socket, Type};
use Error::{
Bind, Construct, JoinMulticast, SetBroadcast, SetMulticast, SetNonBlocking, SetReuse,
SetTTL, ToTokio,
};
assert_ne!(port, 0);
let interface = Ipv4Addr::from([0, 0, 0, 0]);
let multiaddr = Ipv4Addr::from([224, 0, 0, 251]);
let sock = Socket::new(Domain::IPV4, Type::DGRAM, None).map_err(Construct)?;
if local_discovery {
sock.set_reuse_port(true).map_err(SetReuse)?; }
sock.set_broadcast(true).map_err(SetBroadcast)?; sock.set_multicast_loop_v4(true).map_err(SetMulticast)?; sock.set_ttl(4).map_err(SetTTL)?;
let address = SocketAddr::from((interface, port));
let address = SockAddr::from(address);
sock.bind(&address).map_err(|error| Bind { error, port })?;
sock.join_multicast_v4(&multiaddr, &interface)
.map_err(JoinMulticast)?;
let sock = std::net::UdpSocket::from(sock);
sock.set_nonblocking(true).map_err(SetNonBlocking)?;
let sock = UdpSocket::from_std(sock).map_err(ToTokio)?;
Ok(sock)
}
#[cfg(test)]
mod compiles {
use super::*;
#[tokio::test]
async fn with_service_port() {
let chart = ChartBuilder::new()
.with_id(0)
.with_service_port(15)
.local_discovery(true)
.finish()
.unwrap();
let _ = chart.our_service_port();
}
#[tokio::test]
async fn with_service_ports() {
let chart = ChartBuilder::new()
.with_id(0)
.with_service_ports([1, 2])
.local_discovery(true)
.finish()
.unwrap();
let _ = chart.our_service_ports();
}
#[tokio::test]
async fn custom_msg() {
let chart = ChartBuilder::new()
.with_id(0)
.local_discovery(true)
.custom_msg("hi")
.unwrap();
let _ = chart.our_msg();
}
}