simengine 0.2.7

A plugin-based simulation engine runtime and plugin API
Documentation
use serde::{Deserialize, Serialize};
use std::{
    collections::HashMap,
    io::{self, Read, Write},
    net::{TcpListener, TcpStream},
    sync::Arc,
    thread,
    time::Duration,
};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkMessage {
    pub input: String,
    pub payload: Vec<u8>,
}

#[derive(Default)]
pub struct NetworkSender {
    streams: HashMap<String, TcpStream>,
}

impl NetworkSender {
    pub fn send(&mut self, endpoint: &str, message: &NetworkMessage) -> io::Result<()> {
        match self.write_to_endpoint(endpoint, message) {
            Ok(()) => Ok(()),
            Err(err) => {
                self.streams.remove(endpoint);
                self.write_to_endpoint(endpoint, message).map_err(|_| err)
            }
        }
    }

    fn write_to_endpoint(&mut self, endpoint: &str, message: &NetworkMessage) -> io::Result<()> {
        if !self.streams.contains_key(endpoint) {
            self.streams
                .insert(endpoint.to_string(), connect_stream(endpoint)?);
        }

        let stream = self
            .streams
            .get_mut(endpoint)
            .expect("stream was just inserted");
        write_message(stream, message)
    }
}

pub fn send(endpoint: &str, message: &NetworkMessage) -> io::Result<()> {
    let mut stream = connect_stream(endpoint)?;
    write_message(&mut stream, message)
}

pub fn start_listener<F>(endpoint: String, on_message: F) -> io::Result<thread::JoinHandle<()>>
where
    F: Fn(NetworkMessage) + Send + Sync + 'static,
{
    let listener = TcpListener::bind(&endpoint)?;
    let on_message = Arc::new(on_message);

    let handle = thread::spawn(move || {
        println!("[network] listening on {endpoint}");

        for stream in listener.incoming() {
            let Ok(mut stream) = stream else {
                continue;
            };

            let on_message = Arc::clone(&on_message);
            thread::spawn(move || {
                loop {
                    match read_message(&mut stream) {
                        Ok(message) => on_message(message),
                        Err(err) if is_connection_closed(&err) => break,
                        Err(err) => {
                            eprintln!("[network] failed to read message: {err}");
                            break;
                        }
                    }
                }
            });
        }
    });

    Ok(handle)
}

fn connect_stream(endpoint: &str) -> io::Result<TcpStream> {
    let stream = TcpStream::connect(endpoint)?;
    stream.set_write_timeout(Some(Duration::from_secs(2)))?;
    stream.set_nodelay(true)?;
    Ok(stream)
}

fn is_connection_closed(err: &io::Error) -> bool {
    matches!(
        err.kind(),
        io::ErrorKind::UnexpectedEof
            | io::ErrorKind::ConnectionReset
            | io::ErrorKind::ConnectionAborted
            | io::ErrorKind::BrokenPipe
    )
}

fn write_message(stream: &mut TcpStream, message: &NetworkMessage) -> io::Result<()> {
    let body = serde_json::to_vec(message).map_err(io::Error::other)?;
    let len = u32::try_from(body.len())
        .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "network message too large"))?;

    stream.write_all(&len.to_be_bytes())?;
    stream.write_all(&body)?;
    stream.flush()?;
    Ok(())
}

fn read_message(stream: &mut TcpStream) -> io::Result<NetworkMessage> {
    let mut len_bytes = [0u8; 4];
    stream.read_exact(&mut len_bytes)?;
    let len = u32::from_be_bytes(len_bytes) as usize;

    let mut body = vec![0u8; len];
    stream.read_exact(&mut body)?;

    serde_json::from_slice(&body).map_err(io::Error::other)
}