use std::num::NonZeroU32;
use anyhow::{bail, ensure, Result};
use bytes::Bytes;
use postcard::experimental::max_size::MaxSize;
use serde::{Deserialize, Serialize};
use super::{client_conn::ClientConnBuilder, codec::PROTOCOL_VERSION};
use crate::key::PublicKey;
pub type MeshKey = [u8; 32];
pub(crate) struct RateLimiter {
inner: governor::RateLimiter<
governor::state::direct::NotKeyed,
governor::state::InMemoryState,
governor::clock::DefaultClock,
governor::middleware::NoOpMiddleware,
>,
}
impl RateLimiter {
pub(crate) fn new(bytes_per_second: usize, bytes_burst: usize) -> Result<Option<Self>> {
if bytes_per_second == 0 || bytes_burst == 0 {
return Ok(None);
}
let bytes_per_second = NonZeroU32::new(u32::try_from(bytes_per_second)?).unwrap();
let bytes_burst = NonZeroU32::new(u32::try_from(bytes_burst)?).unwrap();
Ok(Some(Self {
inner: governor::RateLimiter::direct(
governor::Quota::per_second(bytes_per_second).allow_burst(bytes_burst),
),
}))
}
pub(crate) fn check_n(&self, n: usize) -> Result<()> {
ensure!(n != 0);
let n = NonZeroU32::new(u32::try_from(n)?).unwrap();
match self.inner.check_n(n) {
Ok(_) => Ok(()),
Err(_) => bail!("batch cannot go through"),
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct Packet {
pub(crate) src: PublicKey,
pub(crate) bytes: Bytes,
}
#[derive(Debug, Clone)]
pub(crate) struct PeerConnState {
pub(crate) peer: PublicKey,
pub(crate) present: bool,
}
#[derive(Debug, Serialize, Deserialize, MaxSize, PartialEq, Eq)]
pub(crate) struct ClientInfo {
pub(crate) version: usize,
pub(crate) mesh_key: Option<MeshKey>,
pub(crate) can_ack_pings: bool,
pub(crate) is_prober: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, MaxSize)]
pub(crate) struct ServerInfo {
pub(crate) version: usize,
pub(crate) token_bucket_bytes_per_second: usize,
pub(crate) token_bucket_bytes_burst: usize,
}
impl ServerInfo {
pub fn no_rate_limit() -> Self {
Self {
version: PROTOCOL_VERSION,
token_bucket_bytes_burst: 0,
token_bucket_bytes_per_second: 0,
}
}
}
pub trait PacketForwarder: Send + Sync + 'static {
fn forward_packet(&mut self, srckey: PublicKey, dstkey: PublicKey, packet: Bytes);
}
#[derive(derive_more::Debug)]
pub(crate) enum ServerMessage<P>
where
P: PacketForwarder,
{
AddWatcher(PublicKey),
ClosePeer(PublicKey),
SendPacket((PublicKey, Packet)),
SendDiscoPacket((PublicKey, Packet)),
#[debug("CreateClient")]
CreateClient(ClientConnBuilder<P>),
RemoveClient((PublicKey, usize)),
AddPacketForwarder {
key: PublicKey,
#[debug("PacketForwarder")]
forwarder: P,
},
RemovePacketForwarder(PublicKey),
Shutdown,
}