pushwire-server 0.1.1

Generic multiplexed push server with WebSocket and SSE transports
Documentation
use std::time::Instant;

use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use uuid::Uuid;

use crate::{TransportDispatcher, TransportError, TransportPacket};

/// Relay bandwidth configuration.
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct RelayBandwidth {
    /// Allowed bytes per second per sender.
    pub bytes_per_second: u64,
    /// Maximum burst allowed immediately.
    pub burst_bytes: u64,
}

impl RelayBandwidth {
    pub fn unbounded() -> Self {
        Self {
            bytes_per_second: u64::MAX,
            burst_bytes: u64::MAX / 2,
        }
    }
}

#[derive(Debug)]
struct RelayState {
    allowance: f64,
    last_refill: Instant,
    total_bytes: u64,
}

impl RelayState {
    fn new(now: Instant, burst_bytes: u64) -> Self {
        Self {
            allowance: burst_bytes as f64,
            last_refill: now,
            total_bytes: 0,
        }
    }
}

/// Result of a relay attempt.
#[derive(Debug, Clone)]
pub struct RelayOutcome {
    pub from: Uuid,
    pub to: Uuid,
    pub bytes: u64,
}

/// Performs server-side relay with bandwidth accounting and rate limiting.
pub struct RelayController {
    limits: RelayBandwidth,
    peers: DashMap<Uuid, RelayState>,
}

impl RelayController {
    pub fn new(limits: RelayBandwidth) -> Self {
        Self {
            limits,
            peers: DashMap::new(),
        }
    }

    /// Relay a packet from one peer to another with rate limiting applied to the sender.
    pub fn relay<D: TransportDispatcher>(
        &self,
        from: Uuid,
        to: Uuid,
        packet: TransportPacket,
        dispatcher: &D,
    ) -> Result<RelayOutcome, TransportError> {
        let size = estimate_packet_size(&packet)?;
        let mut state = self.ensure_state(from);
        self.consume_allowance(&mut state, size)?;

        dispatcher.send_relay(to, packet)?;
        state.total_bytes = state.total_bytes.saturating_add(size);

        Ok(RelayOutcome {
            from,
            to,
            bytes: size,
        })
    }

    /// Returns the total number of bytes relayed for a peer (sender-scoped).
    pub fn total_bytes(&self, peer: Uuid) -> u64 {
        self.peers.get(&peer).map(|s| s.total_bytes).unwrap_or(0)
    }

    fn ensure_state(&self, peer: Uuid) -> dashmap::mapref::one::RefMut<'_, Uuid, RelayState> {
        let burst = self.limits.burst_bytes;
        self.peers
            .entry(peer)
            .or_insert_with(|| RelayState::new(Instant::now(), burst))
    }

    fn consume_allowance(
        &self,
        state: &mut dashmap::mapref::one::RefMut<'_, Uuid, RelayState>,
        size: u64,
    ) -> Result<(), TransportError> {
        let now = Instant::now();
        let elapsed = now.saturating_duration_since(state.last_refill);
        let tokens_to_add = (elapsed.as_secs_f64() * self.limits.bytes_per_second as f64)
            .min(self.limits.burst_bytes as f64);
        state.allowance = (state.allowance + tokens_to_add).min(self.limits.burst_bytes as f64);
        state.last_refill = now;

        if state.allowance < size as f64 {
            return Err(TransportError::RateLimited("relay bandwidth exceeded"));
        }

        state.allowance -= size as f64;
        Ok(())
    }
}

fn estimate_packet_size(packet: &TransportPacket) -> Result<u64, TransportError> {
    let payload_len = serde_json::to_vec(&packet.payload)
        .map_err(|_| TransportError::DispatchFailed("payload serialization failed"))?
        .len() as u64;
    let channel_len = packet.channel.len() as u64;
    let cursor_len = if packet.cursor.is_some() { 8 } else { 0 };
    Ok(payload_len + channel_len + cursor_len)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::TransportRoute;
    use serde_json::json;
    use std::sync::Mutex;

    #[derive(Default)]
    struct MockDispatcher {
        calls: Mutex<Vec<(TransportRoute, Uuid, TransportPacket)>>,
    }

    impl TransportDispatcher for MockDispatcher {
        fn send_direct(&self, _peer: Uuid, _packet: TransportPacket) -> Result<(), TransportError> {
            Ok(())
        }

        fn send_p2p(&self, _peer: Uuid, _packet: TransportPacket) -> Result<(), TransportError> {
            Ok(())
        }

        fn send_relay(&self, peer: Uuid, packet: TransportPacket) -> Result<(), TransportError> {
            self.calls
                .lock()
                .unwrap()
                .push((TransportRoute::Relay, peer, packet));
            Ok(())
        }
    }

    #[test]
    fn relays_and_accounts_bytes() {
        let controller = RelayController::new(RelayBandwidth {
            bytes_per_second: 10_000,
            burst_bytes: 10_000,
        });
        let dispatcher = MockDispatcher::default();
        let from = Uuid::new_v4();
        let to = Uuid::new_v4();

        let packet = TransportPacket::new("data", json!({ "hello": "world" }));
        let outcome = controller.relay(from, to, packet, &dispatcher).unwrap();

        assert_eq!(outcome.from, from);
        assert_eq!(outcome.to, to);
        assert!(outcome.bytes > 0);
        assert_eq!(controller.total_bytes(from), outcome.bytes);
    }

    #[test]
    fn rate_limits_when_budget_exceeded() {
        let controller = RelayController::new(RelayBandwidth {
            bytes_per_second: 1,
            burst_bytes: 4,
        });
        let dispatcher = MockDispatcher::default();
        let from = Uuid::new_v4();
        let to = Uuid::new_v4();
        let packet = TransportPacket::new("data", json!({ "blob": "12345" })); // >4 bytes

        let err = controller
            .relay(from, to, packet, &dispatcher)
            .expect_err("should be rate limited");
        assert_eq!(err, TransportError::RateLimited("relay bandwidth exceeded"));
        assert_eq!(controller.total_bytes(from), 0);
    }
}