tubes 0.6.4

Host/Client protocol based on pipenet
Documentation
use crate::client_id::ClientId;
use crate::prelude::*;

use crate::core::transport::is_aborted_error;
use crate::{core::ReconnectTo, data::MessageDataInternal};
use pipenet::NonBlockStream;
use std::{
    collections::HashMap,
    sync::{
        Mutex,
        mpsc::{Receiver, Sender},
    },
};

type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;

pub(crate) fn server_loop<const MS: usize>(
    server_uuid: ClientId,
    clients: &Mutex<HashMap<ClientId, NonBlockStream<MS>>>,
    reconnect_to: &Mutex<Option<ReconnectTo>>,
    rx_writer: &Receiver<MessageDataInternal>,
    tx_reader: &Sender<MessageData>,
) {
    loop {
        let Ok(mut map) = clients.lock() else {
            continue;
        };

        let (aborted, recto) = match server_read::<MS>(server_uuid, &mut map, tx_reader) {
            Ok(x) => x,
            Err(e) => {
                if is_aborted_error(e) {
                    return;
                }
                continue;
            }
        };

        if let Err(e) = abort_clients::<MS>(&mut map, aborted, tx_reader) {
            if is_aborted_error(e) {
                return;
            }
            continue;
        }

        let aborted = match server_write::<MS>(&mut map, rx_writer) {
            Ok(x) => x,
            Err(e) => {
                if is_aborted_error(e) {
                    return;
                }
                continue;
            }
        };

        if let Err(e) = abort_clients::<MS>(&mut map, aborted, tx_reader) {
            if is_aborted_error(e) {
                return;
            }
            continue;
        }

        if let Some(to) = recto
            && let Ok(mut dst) = reconnect_to.lock()
        {
            *dst = Some(to);
        }
    }
}

// Pass through the clients that have been aborted and send the proper client
// left messages to the rest of the nodes and self.
fn abort_clients<const MS: usize>(
    map: &mut HashMap<ClientId, NonBlockStream<MS>>,
    aborted_clients: Vec<ClientId>,
    tx_reader: &Sender<MessageData>,
) -> Result<()> {
    // Remove clients that have been aborted/ejected, either for them closing
    // of for some protocol reason (security, broken pipe, etc)
    for uuid in &aborted_clients {
        map.remove(uuid);
    }

    // Second pass on aborted clients, sending the ClientLeft message
    // Second pass because the map should had them removed first above.
    for uuid in aborted_clients {
        // Notify the remaining clients of which ones have left.
        for (_, client) in map.iter_mut() {
            let msg = (MessageDataInternal::ClientLeft(uuid)).try_into()?;
            let _ = client.write(msg);
        }
        // Also notify the server itself
        let _ = tx_reader.send(MessageData::ClientLeft(uuid));
    }

    Ok(())
}

#[allow(clippy::type_complexity)]
fn server_read<const MS: usize>(
    server_uuid: ClientId,
    map: &mut HashMap<ClientId, NonBlockStream<MS>>,
    tx_reader: &Sender<MessageData>,
) -> Result<(Vec<ClientId>, Option<ReconnectTo>)> {
    let mut repeats = vec![];
    let mut aborted_clients = vec![];
    let mut reconnect_to = None;
    for (uuid, client) in map.iter_mut() {
        match server_read_one_client::<MS>(server_uuid, *uuid, client, tx_reader) {
            Err(_) => {
                aborted_clients.push(*uuid);
            }
            Ok(ServerSingleClientResult {
                aborted,
                reconnect_to: reconnect_to_detected,
                repeats: mut repeat,
            }) => {
                repeats.append(&mut repeat);
                if aborted {
                    aborted_clients.push(*uuid);
                }
                if let Some(to) = reconnect_to_detected {
                    reconnect_to = Some(to);
                }
            }
        }
    }

    for r in repeats {
        let x: Vec<u8> = r.message.try_into()?;
        for (uuid, client) in map.iter_mut() {
            let mut should_send = r.only_to_uuid.is_none();
            // Target only specific uuid when requested.
            if let Some(dst) = r.only_to_uuid
                && dst == *uuid
            {
                should_send = true;
            }
            // Don't send repeats back to the originator.
            if r.from_uuid == *uuid {
                should_send = false;
            }
            if should_send && let Err(_) = client.write(x.clone()) {
                aborted_clients.push(*uuid);
            }
        }
    }

    Ok((aborted_clients, reconnect_to))
}

fn server_read_one_client<const MS: usize>(
    server_uuid: ClientId,
    client_uuid: ClientId,
    socket: &mut NonBlockStream<MS>,
    tx_reader: &Sender<MessageData>,
) -> Result<ServerSingleClientResult> {
    let mut aborted = false;
    let mut reconnec_to = None;
    let mut repeats = vec![];
    while let Some(msg) = socket.read()? {
        // Vec is cloneable, but deserialized MessageData is not.
        // Hold 2 copies from this point forward so it can be added into the
        // repeats too.
        let msg_orig = msg.clone();
        let msg: MessageDataInternal = (*msg).try_into()?;
        let msg_orig: MessageDataInternal = (*msg_orig).try_into()?;
        match msg {
            MessageDataInternal::Broadcast(sender, m) => {
                let _ = tx_reader.send(MessageData::Broadcast {
                    from: sender,
                    data: m,
                });
                repeats.push((sender, msg_orig).into());
            }
            // Checks if destination is the server, if so just add it to the
            // channel, otherwise forward it to the correct client.
            MessageDataInternal::Send(sender, dst, m) => {
                if dst == server_uuid {
                    let _ = tx_reader.send(MessageData::Send {
                        from: sender,
                        to: dst,
                        data: m,
                    });
                } else {
                    let mut rep: ServerRepeats = (sender, msg_orig).into();
                    rep.only_to_uuid = Some(dst);
                    repeats.push(rep);
                }
            }
            MessageDataInternal::ClientLeft(_) => {
                aborted = true;
                break;
            }
            MessageDataInternal::NewHost(ip, port) => {
                reconnec_to = Some(ReconnectTo {
                    become_server: false,
                    address: ip,
                    port,
                });
                repeats.push((client_uuid, msg).into());
            }
            MessageDataInternal::ClientJoined(_)
            | MessageDataInternal::ServerUuid(_)
            | MessageDataInternal::PromoteToHost(..) => {}
        }
    }
    Ok(ServerSingleClientResult {
        aborted,
        reconnect_to: reconnec_to,
        repeats,
    })
}

#[allow(clippy::type_complexity)]
fn server_write<const MS: usize>(
    map: &mut HashMap<ClientId, NonBlockStream<MS>>,
    rx_writer: &Receiver<MessageDataInternal>,
) -> Result<Vec<ClientId>> {
    let mut aborted_clients = vec![];
    while let Ok(msg) = rx_writer.try_recv() {
        match msg {
            // Send to all clients
            MessageDataInternal::Broadcast(..) => {
                let x: Vec<u8> = msg.try_into()?;
                for (uuid, client) in map.iter_mut() {
                    if client.write(x.clone()).is_err() {
                        aborted_clients.push(*uuid);
                    }
                }
            }
            MessageDataInternal::PromoteToHost(dst, ..) | MessageDataInternal::Send(_, dst, _) => {
                if let Some(client) = map.get_mut(&dst)
                    && let Err(_) = client.write(msg.try_into()?)
                {
                    aborted_clients.push(dst);
                }
            }
            MessageDataInternal::ServerUuid(..)
            | MessageDataInternal::ClientJoined(..)
            | MessageDataInternal::ClientLeft(..)
            | MessageDataInternal::NewHost(..) => {}
        }
    }

    Ok(aborted_clients)
}

struct ServerSingleClientResult {
    aborted: bool,
    reconnect_to: Option<ReconnectTo>,
    repeats: Vec<ServerRepeats>,
}

struct ServerRepeats {
    message: MessageDataInternal,
    from_uuid: ClientId,
    only_to_uuid: Option<ClientId>,
}

impl From<(ClientId, MessageDataInternal)> for ServerRepeats {
    fn from(value: (ClientId, MessageDataInternal)) -> Self {
        Self {
            message: value.1,
            from_uuid: value.0,
            only_to_uuid: None,
        }
    }
}