ruzor 0.1.2

Ruzor, a 1:1-compatible Rust port of the Pyzor UDP client and server
Documentation
use std::collections::HashMap;
use std::net::{SocketAddr, ToSocketAddrs, UdpSocket};
use std::time::Duration;

use crate::account::{Account, now_timestamp, sign_for_account};
use crate::config::Address;
use crate::error::PyzorError;
use crate::logging::Logger;
use crate::message::{self, Message, ThreadId};
use crate::python_repr;
use crate::{MAX_PACKET_SIZE, Result};

#[derive(Clone, Debug)]
pub struct Client {
    accounts: HashMap<Address, Account>,
    timeout: Duration,
    spec: Vec<(usize, usize)>,
    logger: Option<Logger>,
}

impl Default for Client {
    fn default() -> Self {
        Self::new(HashMap::new(), Some(5), crate::digest::DIGEST_SPEC.to_vec())
    }
}

impl Client {
    pub fn new(
        accounts: HashMap<Address, Account>,
        timeout_secs: Option<u64>,
        spec: Vec<(usize, usize)>,
    ) -> Self {
        Self {
            accounts,
            timeout: Duration::from_secs(timeout_secs.unwrap_or(5)),
            spec,
            logger: None,
        }
    }

    pub fn with_logger(mut self, logger: Logger) -> Self {
        self.logger = Some(logger);
        self
    }

    pub fn ping(&self, address: &Address) -> Result<Message> {
        self.round_trip(message::request("ping"), address)
    }

    pub fn pong(&self, digest: &str, address: &Address) -> Result<Message> {
        self.round_trip(message::digest_request("pong", digest), address)
    }

    pub fn check(&self, digest: &str, address: &Address) -> Result<Message> {
        self.round_trip(message::digest_request("check", digest), address)
    }

    pub fn info(&self, digest: &str, address: &Address) -> Result<Message> {
        self.round_trip(message::digest_request("info", digest), address)
    }

    pub fn report(&self, digest: &str, address: &Address) -> Result<Message> {
        self.round_trip(
            message::spec_digest_request("report", digest, &self.spec),
            address,
        )
    }

    pub fn whitelist(&self, digest: &str, address: &Address) -> Result<Message> {
        self.round_trip(
            message::spec_digest_request("whitelist", digest, &self.spec),
            address,
        )
    }

    pub fn send_only(&self, mut msg: Message, address: &Address) -> Result<()> {
        self.sign(&mut msg, address);
        let packet = msg.as_string();
        self.debug(format!("sending: {}", python_repr::string(&packet)));
        let socket = bind_for(address)?;
        let target = resolve(address)?;
        socket
            .send_to(packet.as_bytes(), target)
            .map_err(PyzorError::from)?;
        Ok(())
    }

    fn round_trip(&self, mut msg: Message, address: &Address) -> Result<Message> {
        self.sign(&mut msg, address);
        let expected_id = msg.thread()?;
        let packet = msg.as_string();
        self.debug(format!("sending: {}", python_repr::string(&packet)));
        let socket = bind_for(address)?;
        socket.set_read_timeout(Some(self.timeout))?;
        let target = resolve(address)?;
        socket.send_to(packet.as_bytes(), target).map_err(|error| {
            PyzorError::Comm(format!(
                "Unable to send to {}:{}: {}",
                address.0, address.1, error
            ))
        })?;
        self.read_response(&socket, expected_id)
    }

    fn read_response(&self, socket: &UdpSocket, expected_id: ThreadId) -> Result<Message> {
        let mut buf = [0u8; MAX_PACKET_SIZE];
        let (len, peer) = socket.recv_from(&mut buf).map_err(|error| {
            if error.kind() == std::io::ErrorKind::TimedOut
                || error.kind() == std::io::ErrorKind::WouldBlock
            {
                PyzorError::Timeout("Reading response timed-out.".to_string())
            } else {
                PyzorError::Comm(format!("Socket error while reading response: {}", error))
            }
        })?;
        self.debug(format!(
            "received: {}/{}",
            python_repr::bytes(&buf[..len]),
            python_socket_addr_repr(peer)
        ));
        let response = Message::parse(&buf[..len]);
        response.ensure_response()?;
        let thread_id = response.thread()?;
        if thread_id != expected_id {
            if thread_id.in_ok_range() {
                return Err(PyzorError::Protocol(format!(
                    "received unexpected thread id {} (expected {})",
                    thread_id, expected_id
                )));
            }
            self.warning(format!(
                "received error thread id {} (expected {})",
                thread_id, expected_id
            ));
        }
        Ok(response)
    }

    fn sign(&self, msg: &mut Message, address: &Address) {
        message::init_for_sending(msg);
        let account = self
            .accounts
            .get(address)
            .cloned()
            .unwrap_or_else(Account::anonymous);
        sign_for_account(msg, &account, now_timestamp());
    }

    fn debug(&self, message: impl AsRef<str>) {
        if let Some(logger) = &self.logger {
            logger.debug(message);
        }
    }

    fn warning(&self, message: impl AsRef<str>) {
        if let Some(logger) = &self.logger {
            logger.warning(message);
        }
    }
}

#[derive(Clone, Debug)]
pub struct BatchClient {
    client: Client,
    batch_size: usize,
    reports: HashMap<Address, Message>,
    whitelists: HashMap<Address, Message>,
}

impl Drop for BatchClient {
    fn drop(&mut self) {
        self.force();
    }
}

impl BatchClient {
    pub fn new(client: Client, batch_size: usize) -> Self {
        Self {
            client,
            batch_size,
            reports: HashMap::new(),
            whitelists: HashMap::new(),
        }
    }

    pub fn report(&mut self, digest: &str, address: &Address) -> Result<()> {
        Self::add_digest(
            &self.client,
            self.batch_size,
            &mut self.reports,
            "report",
            digest,
            address,
        )
    }

    pub fn whitelist(&mut self, digest: &str, address: &Address) -> Result<()> {
        Self::add_digest(
            &self.client,
            self.batch_size,
            &mut self.whitelists,
            "whitelist",
            digest,
            address,
        )
    }

    pub fn flush(&mut self) {
        self.reports.clear();
        self.whitelists.clear();
    }

    pub fn force(&mut self) {
        for (address, msg) in std::mem::take(&mut self.reports) {
            let _ = self.client.send_only(msg, &address);
        }
        for (address, msg) in std::mem::take(&mut self.whitelists) {
            let _ = self.client.send_only(msg, &address);
        }
    }

    fn add_digest(
        client: &Client,
        batch_size: usize,
        requests: &mut HashMap<Address, Message>,
        op: &str,
        digest: &str,
        address: &Address,
    ) -> Result<()> {
        let msg = requests.entry(address.clone()).or_insert_with(|| {
            let mut msg = message::request(op);
            let flat = client
                .spec
                .iter()
                .flat_map(|(offset, length)| [offset.to_string(), length.to_string()])
                .collect::<Vec<_>>()
                .join(",");
            msg.add_header("Op-Spec", flat);
            msg
        });
        msg.add_header("Op-Digest", digest);
        if msg.get_all("Op-Digest").len() >= batch_size {
            let msg = requests.remove(address).expect("entry just existed");
            client.send_only(msg, address)?;
        }
        Ok(())
    }
}

fn resolve(address: &Address) -> Result<std::net::SocketAddr> {
    (address.0.as_str(), address.1)
        .to_socket_addrs()
        .map_err(PyzorError::from)?
        .next()
        .ok_or_else(|| PyzorError::Comm(format!("Unable to send to {}:{}", address.0, address.1)))
}

fn bind_for(address: &Address) -> Result<UdpSocket> {
    let target = resolve(address)?;
    let bind = if target.is_ipv6() {
        "[::]:0"
    } else {
        "0.0.0.0:0"
    };
    UdpSocket::bind(bind).map_err(PyzorError::from)
}
fn python_socket_addr_repr(address: SocketAddr) -> String {
    match address {
        SocketAddr::V4(address) => format!("('{}', {})", address.ip(), address.port()),
        SocketAddr::V6(address) => format!(
            "('{}', {}, {}, {})",
            address.ip(),
            address.port(),
            address.flowinfo(),
            address.scope_id()
        ),
    }
}