lunatic-distributed 0.13.2

Node to node communication
Documentation
use std::{net::SocketAddr, sync::Arc, time::Duration};

use anyhow::{anyhow, Result};
use bytes::Bytes;
use lunatic_process::{env::Environment, state::ProcessState};
use quinn::{ClientConfig, Connecting, ConnectionError, Endpoint, ServerConfig};
use rustls::server::AllowAnyAuthenticatedClient;
use rustls_pemfile::Item;
use wasmtime::ResourceLimiter;

use crate::{distributed, DistributedCtx};

pub struct SendStream {
    pub stream: quinn::SendStream,
}

impl SendStream {
    pub async fn send(&mut self, data: &mut [Bytes]) -> Result<()> {
        self.stream.write_all_chunks(data).await?;
        Ok(())
    }
}

pub struct RecvStream {
    pub stream: quinn::RecvStream,
}

impl RecvStream {
    pub async fn receive(&mut self) -> Result<Bytes> {
        let mut size = [0u8; 4];
        self.stream.read_exact(&mut size).await?;
        let size = u32::from_le_bytes(size);
        let mut buffer = vec![0u8; size as usize];
        self.stream.read_exact(&mut buffer).await?;
        Ok(buffer.into())
    }

    pub fn id(&self) -> quinn::StreamId {
        self.stream.id()
    }
}

#[derive(Clone)]
pub struct Client {
    inner: Endpoint,
}

impl Client {
    pub async fn connect(
        &self,
        addr: SocketAddr,
        name: &str,
        retry: u32,
    ) -> Result<(SendStream, RecvStream)> {
        for try_num in 1..(retry + 1) {
            match self.connect_once(addr, name).await {
                Ok(r) => return Ok(r),
                Err(e) => {
                    log::error!("Error connecting to {name} at {addr}, try {try_num}. Error: {e}")
                }
            }
            tokio::time::sleep(Duration::from_secs(2)).await;
        }
        Err(anyhow!("Failed to connect to {name} at {addr}"))
    }

    async fn connect_once(&self, addr: SocketAddr, name: &str) -> Result<(SendStream, RecvStream)> {
        let conn = self.inner.connect(addr, name)?.await?;
        let (send, recv) = conn.open_bi().await?;
        Ok((SendStream { stream: send }, RecvStream { stream: recv }))
    }
}

pub fn new_quic_client(ca_cert: &str, cert: &str, key: &str) -> Result<Client> {
    let mut ca_cert = ca_cert.as_bytes();
    let ca_cert = rustls_pemfile::read_one(&mut ca_cert)?.unwrap();
    let ca_cert = match ca_cert {
        Item::X509Certificate(ca_cert) => Ok(rustls::Certificate(ca_cert)),
        _ => Err(anyhow!("Not a valid certificate.")),
    }?;
    let mut roots = rustls::RootCertStore::empty();
    roots.add(&ca_cert)?;

    let mut cert = cert.as_bytes();
    let mut key = key.as_bytes();
    let pk = rustls_pemfile::read_one(&mut key)?.unwrap();
    let pk = match pk {
        Item::PKCS8Key(key) => Ok(rustls::PrivateKey(key)),
        _ => Err(anyhow!("Not a valid private key.")),
    }?;
    let cert = rustls_pemfile::read_one(&mut cert)?.unwrap();
    let cert = match cert {
        Item::X509Certificate(cert) => Ok(rustls::Certificate(cert)),
        _ => Err(anyhow!("Not a valid certificate")),
    }?;
    let cert = vec![cert];

    let client_crypto = rustls::ClientConfig::builder()
        .with_safe_defaults()
        .with_root_certificates(roots)
        .with_single_cert(cert, pk)?;

    let client_config = ClientConfig::new(Arc::new(client_crypto));
    let mut endpoint = Endpoint::client("[::]:0".parse().unwrap())?;
    endpoint.set_default_client_config(client_config);
    Ok(Client { inner: endpoint })
}

pub fn new_quic_server(
    addr: SocketAddr,
    certs: Vec<String>,
    key: &str,
    ca_cert: &str,
) -> Result<Endpoint> {
    let mut ca_cert = ca_cert.as_bytes();
    let ca_cert = rustls_pemfile::read_one(&mut ca_cert)?.unwrap();
    let ca_cert = match ca_cert {
        Item::X509Certificate(ca_cert) => Ok(rustls::Certificate(ca_cert)),
        _ => Err(anyhow!("Not a valid certificate.")),
    }?;
    let mut roots = rustls::RootCertStore::empty();
    roots.add(&ca_cert)?;

    let mut key = key.as_bytes();
    let pk = rustls_pemfile::read_one(&mut key)?.unwrap();
    let pk = match pk {
        Item::PKCS8Key(key) => Ok(rustls::PrivateKey(key)),
        _ => Err(anyhow!("Not a valid private key.")),
    }?;

    let mut cert_chain = Vec::new();
    for cert in certs {
        let mut cert = cert.as_bytes();
        let cert = rustls_pemfile::read_one(&mut cert)?.unwrap();
        let cert = match cert {
            Item::X509Certificate(cert) => Ok(rustls::Certificate(cert)),
            _ => Err(anyhow!("Not a valid certificate")),
        }?;
        cert_chain.push(cert);
    }

    let server_crypto = rustls::ServerConfig::builder()
        .with_safe_defaults()
        .with_client_cert_verifier(AllowAnyAuthenticatedClient::new(roots))
        .with_single_cert(cert_chain, pk)?;
    let mut server_config = ServerConfig::with_crypto(Arc::new(server_crypto));
    Arc::get_mut(&mut server_config.transport)
        .unwrap()
        .max_concurrent_uni_streams(0_u8.into());

    Ok(quinn::Endpoint::server(server_config, addr)?)
}

pub async fn handle_node_server<T, E>(
    quic_server: &mut Endpoint,
    ctx: distributed::server::ServerCtx<T, E>,
) -> Result<()>
where
    T: ProcessState + ResourceLimiter + DistributedCtx<E> + Send + Sync + 'static,
    E: Environment + 'static,
{
    while let Some(conn) = quic_server.accept().await {
        tokio::spawn(handle_quic_connection_node(ctx.clone(), conn));
    }
    Err(anyhow!("Node server exited"))
}

async fn handle_quic_connection_node<T, E>(
    ctx: distributed::server::ServerCtx<T, E>,
    conn: Connecting,
) -> Result<()>
where
    T: ProcessState + ResourceLimiter + DistributedCtx<E> + Send + Sync + 'static,
    E: Environment + 'static,
{
    log::info!("New node connection");
    let conn = conn.await?;
    log::info!("Remote {} connected", conn.remote_address());
    loop {
        if let Some(reason) = conn.close_reason() {
            log::info!("Connection {} is closed: {reason}", conn.remote_address());
            break;
        }
        let stream = conn.accept_bi().await;
        log::info!("Stream from remote {} accepted", conn.remote_address());
        match stream {
            Ok((s, r)) => {
                let send = SendStream { stream: s };
                let recv = RecvStream { stream: r };
                tokio::spawn(handle_quic_stream_node(ctx.clone(), send, recv));
            }
            Err(ConnectionError::LocallyClosed) => break,
            Err(_) => {}
        }
    }
    log::info!("Connection from remote {} closed", conn.remote_address());
    Ok(())
}

async fn handle_quic_stream_node<T, E>(
    ctx: distributed::server::ServerCtx<T, E>,
    mut send: SendStream,
    mut recv: RecvStream,
) where
    T: ProcessState + ResourceLimiter + DistributedCtx<E> + Send + Sync + 'static,
    E: Environment + 'static,
{
    while let Ok(bytes) = recv.receive().await {
        if let Ok((msg_id, request)) =
            rmp_serde::from_slice::<(u64, distributed::message::Request)>(&bytes)
        {
            distributed::server::handle_message(ctx.clone(), &mut send, msg_id, request).await;
        } else {
            log::debug!("Error deserializing request");
        }
    }
}