mm1-multinode 0.7.23

An Erlang-style actor runtime for Rust.
Documentation
use std::pin::pin;
use std::{io, mem};

use eyre::Context;
use futures::future;
use mm1_common::types::AnyError;
use mm1_proto_network_management as nm;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

use super::{config, iostream_util, pdu};

const HELLO: &[u8] = b"MM1!";
const BINCODE_CONFIG: bincode::config::Configuration = bincode::config::standard();

pub(super) async fn run<IO>(io: IO, options: &nm::Options) -> Result<HandshakeDone, AnyError>
where
    IO: AsyncRead + AsyncWrite + Unpin,
{
    let io = pin!(io);
    let (mut io_r, mut io_w) = tokio::io::split(io);

    let config: config::Config = options
        .clone()
        .deserialize_into()
        .wrap_err("options.deserialize (Config)")?;
    let local_capabilities = capabilities(&config).wrap_err("capabilities from config")?;

    let inbound_header = async {
        iostream_util::read_header(&mut io_r)
            .await
            .wrap_err("read_header (Hello)")
    };
    let outbound_header_sent = async {
        let mut outbound_hello = [0u8; 30];
        let mut outbound_hello_w = io::Cursor::new(&mut outbound_hello[..]);
        outbound_hello_w
            .write_all(HELLO)
            .await
            .expect("should be fine");
        bincode::serde::encode_into_std_write(
            &local_capabilities,
            &mut outbound_hello_w,
            BINCODE_CONFIG,
        )
        .wrap_err("capabilities write error")?;

        iostream_util::write_header(&mut io_w, pdu::Hello(outbound_hello))
            .await
            .wrap_err("write_header (Hello)")?;

        io_w.flush().await?;

        Ok(())
    };

    let (inbound_header, ()) = future::try_join(inbound_header, outbound_header_sent).await?;
    let pdu::Header::Hello(pdu::Hello(inbound_hello)) = inbound_header else {
        return Err(eyre::format_err!(
            "unexpected header type: {:?}",
            inbound_header
        ));
    };
    if &inbound_hello[0..HELLO.len()] != HELLO {
        return Err(eyre::format_err!(
            "unexpected header value: {:?}",
            inbound_hello
        ));
    }
    let mut hello_data_r = io::Cursor::new(&inbound_hello[HELLO.len()..]);
    let peer_capabilities: Capabilities =
        bincode::serde::decode_from_std_read(&mut hello_data_r, BINCODE_CONFIG)
            .wrap_err("bincde::decode (Capabilities)")?;

    let mut should_be_zeroes = Vec::with_capacity(inbound_hello.len());
    hello_data_r
        .read_to_end(&mut should_be_zeroes)
        .await
        .wrap_err("hello_data_r.read_to_end")?;
    if !should_be_zeroes.into_iter().all(|b| b == 0) {
        return Err(eyre::format_err!(
            "stray data in HELLO: the peer probably has different view on Capabilities"
        ));
    }

    let done = negotiate(
        &mut io_r,
        &mut io_w,
        &config,
        local_capabilities,
        peer_capabilities,
    )
    .await
    .wrap_err("negotiate")?;

    Ok(done)
}

fn capabilities(config: &config::Config) -> Result<Capabilities, AnyError> {
    let config::Config { authc } = config;
    let authc = {
        use config::AuthcConfig::*;
        match authc {
            Trusted => CapAuthc::Trusted,
            Cookie { .. } => CapAuthc::Cookie,
            SharedSecret { .. } => CapAuthc::SharedSecret,
        }
    };
    let caps = Capabilities { authc };
    Ok(caps)
}

async fn negotiate<R, W>(
    mut io_r: R,
    mut io_w: W,
    config: &config::Config,
    local_capabilities: Capabilities,
    peer_capabilities: Capabilities,
) -> Result<HandshakeDone, AnyError>
where
    R: AsyncRead + Unpin,
    W: AsyncWrite + Unpin,
{
    negotiate_authc(
        &mut io_r,
        &mut io_w,
        config,
        &local_capabilities,
        &peer_capabilities,
    )
    .await
    .wrap_err("negotiate_authc")?;

    let () = negotiate_done(io_r, io_w)
        .await
        .wrap_err("negotiate_done")?;
    Ok(HandshakeDone {})
}

async fn negotiate_authc<R, W>(
    io_r: R,
    io_w: W,
    config: &config::Config,
    local_capabilities: &Capabilities,
    peer_capabilities: &Capabilities,
) -> Result<(), AnyError>
where
    R: AsyncRead + Unpin,
    W: AsyncWrite + Unpin,
{
    let mut io_r = pin!(io_r);
    let mut io_w = pin!(io_w);

    let Capabilities { authc: local, .. } = local_capabilities;
    let Capabilities { authc: peer, .. } = peer_capabilities;
    let config::Config { authc } = config;

    if local != peer {
        return Err(eyre::format_err!(
            "incompatible authc-methods [local: {:?}; peer: {:?}]",
            local,
            peer
        ))
    }

    assert_eq!(local, peer);
    {
        use config::AuthcConfig::*;
        match authc {
            Trusted => Ok(()),
            Cookie(cookie) => {
                let cookie = cookie.as_ref();
                let r = async {
                    let negotiate = read_negotiate(&mut io_r).await.wrap_err("read_negotiate")?;
                    let Negotiate::Cookie(cookie_len) = negotiate else {
                        return Err(eyre::format_err!(
                            "expected Cookie; received: {:?}",
                            negotiate
                        ))
                    };
                    let mut buf = vec![0u8; cookie_len as usize];
                    io_r.read_exact(&mut buf[..])
                        .await
                        .wrap_err("io_r.read_exact (cookie)")?;
                    if &buf[..] != cookie.as_bytes() {
                        return Err(eyre::format_err!("cookie mismatch"))
                    }
                    Ok(())
                };
                let w = async {
                    let negotiate =
                        Negotiate::Cookie(cookie.len().try_into().wrap_err("cookie too long?")?);
                    write_negotiate(&mut io_w, negotiate)
                        .await
                        .wrap_err("write_negotiate")?;
                    io_w.write_all(cookie.as_bytes())
                        .await
                        .wrap_err("write cookie")?;
                    Ok(())
                };

                future::try_join(r, w).await.map(|((), ())| ())
            },
            SharedSecret { .. } => Err(eyre::format_err!("not implemented: shared-secret")),
        }
    }
}

async fn negotiate_done<R, W>(io_r: R, io_w: W) -> Result<(), AnyError>
where
    R: AsyncRead + Unpin,
    W: AsyncWrite + Unpin,
{
    let w = async {
        write_negotiate(io_w, Negotiate::Done)
            .await
            .wrap_err("write_negotiate")?;
        Ok(())
    };
    let r = async {
        match read_negotiate(io_r).await.wrap_err("read_negotiate")? {
            Negotiate::Done => Ok(()),
            unexpected => {
                Err(eyre::format_err!(
                    "expected Done, received: {:?}",
                    unexpected
                ))
            },
        }
    };

    future::try_join(w, r).await.map(|((), ())| ())
}

async fn read_negotiate<R>(io_r: R) -> Result<Negotiate, AnyError>
where
    R: AsyncRead + Unpin,
{
    let mut io_r = pin!(io_r);
    let mut buf = [0u8; mem::size_of::<Negotiate>() * 2];
    io_r.read_exact(&mut buf[..])
        .await
        .wrap_err("io_r.read_exact")?;
    let mut buf_r = io::Cursor::new(&buf);
    let negotiate = bincode::serde::decode_from_std_read(&mut buf_r, BINCODE_CONFIG)
        .wrap_err("bincode::decode (Negotiate)")?;
    Ok(negotiate)
}

async fn write_negotiate<W>(io_w: W, negotiate: Negotiate) -> Result<(), AnyError>
where
    W: AsyncWrite + Unpin,
{
    let mut io_w = pin!(io_w);
    let mut buf = [0u8; mem::size_of::<Negotiate>() * 2];
    let mut buf_w = io::Cursor::new(&mut buf[..]);
    bincode::serde::encode_into_std_write(negotiate, &mut buf_w, BINCODE_CONFIG)
        .wrap_err("bincode::encode (Negotiate)")?;
    io_w.write_all(&buf[..]).await.wrap_err("io_w.write_all")?;
    Ok(())
}

#[derive(Debug, Serialize, Deserialize)]
#[repr(C)]
struct Capabilities {
    authc: CapAuthc,
}

#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
#[repr(C)]
pub enum CapAuthc {
    #[default]
    Trusted      = 0,
    Cookie       = 1,
    SharedSecret = 2,
}

#[must_use]
pub(super) struct HandshakeDone {}

#[derive(Debug, Serialize, Deserialize)]
#[repr(C)]
enum Negotiate {
    Cookie(u16),
    Done,
}