enlace 0.2.2

Encrypted mailbox and latest-value slot fan-out.
Documentation
use std::time::Duration;

use async_trait::async_trait;
use reqwest::header::HeaderMap;
use reqwest::{Client, StatusCode, Url};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;

use crate::config::{BasicAuth, HttpConfig};
use crate::error::TransportError;
use crate::transports::{MailboxTransport, SlotTransport, SlotWatchStream};

const VERSION_HEADER: &str = "x-enlace-version";
const WATCH_BUFFER: usize = 64;

#[derive(Debug, Clone)]
pub struct HttpTransport {
    client: Client,
    base_url: Url,
    auth: Option<BasicAuth>,
    long_poll: Duration,
}

impl HttpTransport {
    pub fn new(config: HttpConfig) -> Result<Self, TransportError> {
        let client = Client::builder()
            .danger_accept_invalid_certs(config.skip_verify)
            .build()
            .map_err(map_reqwest_error)?;
        Ok(Self {
            client,
            base_url: config.url,
            auth: config.auth,
            long_poll: Duration::from_secs(u64::from(config.long_poll_secs)),
        })
    }

    fn request(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
        if let Some(auth) = self.auth.as_ref() {
            builder.basic_auth(&auth.username, Some(&auth.password))
        } else {
            builder
        }
    }

    fn mailbox_url(&self, id: &[u8; 16]) -> Url {
        self.channel_url("m", id)
    }

    fn slot_url(&self, id: &[u8; 16]) -> Url {
        self.channel_url("s", id)
    }

    fn channel_url(&self, prefix: &str, id: &[u8; 16]) -> Url {
        let mut url = self.base_url.clone();
        let base_path = self.base_url.path().trim_end_matches('/');
        let id = hex_id(id);
        let path = if base_path.is_empty() {
            format!("/{prefix}/{id}")
        } else {
            format!("{base_path}/{prefix}/{id}")
        };
        url.set_path(&path);
        url.set_query(None);
        url
    }

    async fn slot_get_since(
        &self,
        id: &[u8; 16],
        since: u64,
        wait: Duration,
    ) -> Result<Option<(u64, Vec<u8>)>, TransportError> {
        let mut url = self.slot_url(id);
        url.query_pairs_mut()
            .append_pair("since", &since.to_string())
            .append_pair("wait", &wait.as_secs().to_string());
        let response = self
            .request(self.client.get(url))
            .send()
            .await
            .map_err(map_reqwest_error)?;

        let status = response.status();
        let headers = response.headers().clone();
        let body = if status == StatusCode::OK {
            response.bytes().await.map_err(map_reqwest_error)?.to_vec()
        } else {
            Vec::new()
        };
        decode_slot_get_response(status, &headers, body)
    }
}

#[async_trait]
impl MailboxTransport for HttpTransport {
    async fn send(&self, id: &[u8], sealed: &[u8]) -> Result<(), TransportError> {
        let id = http_channel_id(id)?;
        let response = self
            .request(
                self.client
                    .post(self.mailbox_url(&id))
                    .header(reqwest::header::CONTENT_TYPE, "application/octet-stream")
                    .body(sealed.to_vec()),
            )
            .send()
            .await
            .map_err(map_reqwest_error)?;

        decode_empty_response(response.status())
    }

    async fn recv(&self, id: &[u8], wait: Duration) -> Result<Option<Vec<u8>>, TransportError> {
        let id = http_channel_id(id)?;
        let mut url = self.mailbox_url(&id);
        url.query_pairs_mut()
            .append_pair("wait", &wait.as_secs().to_string());
        let response = self
            .request(self.client.get(url))
            .send()
            .await
            .map_err(map_reqwest_error)?;

        let status = response.status();
        let body = if status == StatusCode::OK {
            response.bytes().await.map_err(map_reqwest_error)?.to_vec()
        } else {
            Vec::new()
        };
        decode_mailbox_recv_response(status, body)
    }
}

#[async_trait]
impl SlotTransport for HttpTransport {
    async fn put(&self, id: &[u8], version: u64, sealed: &[u8]) -> Result<(), TransportError> {
        let id = http_channel_id(id)?;
        let response = self
            .request(
                self.client
                    .put(self.slot_url(&id))
                    .header(reqwest::header::CONTENT_TYPE, "application/octet-stream")
                    .header(VERSION_HEADER, version.to_string())
                    .body(sealed.to_vec()),
            )
            .send()
            .await
            .map_err(map_reqwest_error)?;

        decode_empty_response(response.status())
    }

    async fn get(&self, id: &[u8]) -> Result<Option<(u64, Vec<u8>)>, TransportError> {
        let id = http_channel_id(id)?;
        self.slot_get_since(&id, 0, Duration::ZERO).await
    }

    fn watch(&self, id: &[u8], since: u64) -> SlotWatchStream {
        let Ok(id) = http_channel_id(id) else {
            return Box::pin(tokio_stream::iter([Err(TransportError::Network(
                "HTTP channel id must be 16 bytes".to_owned(),
            ))]));
        };
        let transport = self.clone();
        let (tx, rx) = mpsc::channel(WATCH_BUFFER);

        tokio::spawn(async move {
            let mut since = since;
            loop {
                match transport
                    .slot_get_since(&id, since, transport.long_poll)
                    .await
                {
                    Ok(Some((version, body))) => {
                        since = version;
                        if tx.send(Ok((version, body))).await.is_err() {
                            break;
                        }
                    }
                    Ok(None) => {}
                    Err(err) => {
                        if tx.send(Err(err)).await.is_err() {
                            break;
                        }
                        tokio::time::sleep(Duration::from_secs(1)).await;
                    }
                }
            }
        });

        Box::pin(ReceiverStream::new(rx))
    }
}

pub(crate) fn decode_empty_response(status: StatusCode) -> Result<(), TransportError> {
    match status {
        StatusCode::NO_CONTENT => Ok(()),
        status => Err(map_status(status)),
    }
}

pub(crate) fn decode_mailbox_recv_response(
    status: StatusCode,
    body: Vec<u8>,
) -> Result<Option<Vec<u8>>, TransportError> {
    match status {
        StatusCode::OK => Ok(Some(body)),
        StatusCode::NO_CONTENT => Ok(None),
        status => Err(map_status(status)),
    }
}

pub(crate) fn decode_slot_get_response(
    status: StatusCode,
    headers: &HeaderMap,
    body: Vec<u8>,
) -> Result<Option<(u64, Vec<u8>)>, TransportError> {
    match status {
        StatusCode::OK => {
            let version = parse_version(headers)?;
            Ok(Some((version, body)))
        }
        StatusCode::NO_CONTENT => Ok(None),
        status => Err(map_status(status)),
    }
}

fn parse_version(headers: &HeaderMap) -> Result<u64, TransportError> {
    headers
        .get(VERSION_HEADER)
        .and_then(|value| value.to_str().ok())
        .and_then(|value| value.parse().ok())
        .ok_or_else(|| TransportError::Network("relay omitted slot version".to_owned()))
}

fn map_status(status: StatusCode) -> TransportError {
    match status {
        StatusCode::UNAUTHORIZED => TransportError::Auth,
        StatusCode::CONFLICT => TransportError::Stale,
        StatusCode::PAYLOAD_TOO_LARGE => TransportError::BodyTooLarge,
        StatusCode::REQUEST_TIMEOUT | StatusCode::GATEWAY_TIMEOUT => TransportError::Timeout,
        _ => TransportError::Network(format!("relay returned status {status}")),
    }
}

fn map_reqwest_error(err: reqwest::Error) -> TransportError {
    let timed_out = err.is_timeout();
    let message = err.to_string();
    drop(err);
    if timed_out {
        TransportError::Timeout
    } else {
        TransportError::Network(message)
    }
}

fn http_channel_id(id: &[u8]) -> Result<[u8; 16], TransportError> {
    id.try_into()
        .map_err(|_| TransportError::Network("HTTP channel id must be 16 bytes".to_owned()))
}

fn hex_id(id: &[u8; 16]) -> String {
    const HEX: &[u8; 16] = b"0123456789abcdef";
    let mut out = String::with_capacity(32);
    for byte in id {
        out.push(char::from(HEX[usize::from(byte >> 4)]));
        out.push(char::from(HEX[usize::from(byte & 0x0f)]));
    }
    out
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn hex_id_is_lowercase_32_chars() {
        let id = [
            0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d,
            0x0e, 0x0f,
        ];
        assert_eq!(hex_id(&id), "000102030405060708090a0b0c0d0e0f");
    }

    #[test]
    fn status_mapping_matches_transport_errors() {
        assert!(matches!(
            map_status(StatusCode::UNAUTHORIZED),
            TransportError::Auth
        ));
        assert!(matches!(
            map_status(StatusCode::CONFLICT),
            TransportError::Stale
        ));
        assert!(matches!(
            map_status(StatusCode::PAYLOAD_TOO_LARGE),
            TransportError::BodyTooLarge
        ));
    }
}