datum-net 0.6.0

Network sources and sinks for Datum streams, built on datum-core
Documentation
//! StreamRefs over QUIC.
//!
//! `datum-core` owns the protobuf protocol and state machine. This module is
//! the carrier: it length-prefixes protobuf frames and pumps them over one
//! reliable, ordered QUIC bidirectional byte stream.

use std::{
    collections::VecDeque,
    sync::{Arc, Mutex, mpsc},
    thread,
};

use datum::{
    NotUsed, Sink, Source, SourceRef, StreamCompletion, StreamError, StreamRefFrame, StreamRefId,
    StreamRefPayload, StreamRefProtoConsumer, StreamRefProtoEndpoint, StreamRefProtoProducer,
    StreamRefSettings, StreamResult,
};

use crate::QuicBidirectionalStream;

const FRAME_LEN_BYTES: usize = 4;
const MAX_STREAM_REF_FRAME_BYTES: usize = 16 * 1024 * 1024;

/// Completion handle for a StreamRefs-over-QUIC carrier.
#[must_use = "wait for the QUIC StreamRefs carrier to observe completion or failure"]
pub struct StreamRefQuicHandle {
    receiver: mpsc::Receiver<StreamResult<NotUsed>>,
}

impl StreamRefQuicHandle {
    pub fn wait(self) -> StreamResult<NotUsed> {
        self.receiver
            .recv()
            .unwrap_or(Err(StreamError::AbruptTermination))
    }

    #[must_use]
    pub fn try_wait(&self) -> Option<StreamResult<NotUsed>> {
        self.receiver.try_recv().ok()
    }
}

/// Serves a local `SourceRef` over an accepted or opened QUIC bidi stream.
pub fn serve_source_ref_over_quic<T>(
    stream: QuicBidirectionalStream,
    source_ref: SourceRef<T>,
    stream_ref_id: StreamRefId,
    settings: StreamRefSettings,
) -> StreamResult<StreamRefQuicHandle>
where
    T: StreamRefPayload,
{
    let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
    Ok(drive_stream_ref_endpoint(stream, producer))
}

/// Serves a local `Source` over an accepted or opened QUIC bidi stream.
pub fn serve_source_over_quic<T, Mat>(
    stream: QuicBidirectionalStream,
    source: Source<T, Mat>,
    stream_ref_id: StreamRefId,
    settings: StreamRefSettings,
) -> StreamResult<StreamRefQuicHandle>
where
    T: StreamRefPayload,
    Mat: Send + 'static,
{
    let producer = StreamRefProtoProducer::from_source(source, stream_ref_id, settings)?;
    Ok(drive_stream_ref_endpoint(stream, producer))
}

/// Creates a local source fed by a remote QUIC StreamRef producer.
pub fn source_ref_over_quic<T>(
    stream: QuicBidirectionalStream,
    stream_ref_id: StreamRefId,
    settings: StreamRefSettings,
) -> (Source<T, NotUsed>, StreamRefQuicHandle)
where
    T: StreamRefPayload,
{
    let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
    let source = consumer.source();
    let handle = drive_stream_ref_endpoint(stream, consumer);
    (source, handle)
}

/// Serves a local `SinkRef` receiver over an accepted or opened QUIC bidi
/// stream, returning a [`Source`] of inbound elements.
///
/// This is the local/receiver side of the SinkRef-over-QUIC pair: the remote
/// sender pushes elements into a [`sink_ref_over_quic`](fn.sink_ref_over_quic)
/// `Sink`, and this side surfaces them as a `Source`. The caller runs the
/// returned source into a local `Sink` (for example `Sink::collect` or a fold).
///
/// The wiring is the transport-agnostic consumer seam
/// ([`StreamRefProtoConsumer`]); the only difference from
/// [`source_ref_over_quic`] is naming and intent — both surface inbound remote
/// elements as a local `Source`, so a SinkRef receiver is symmetric with a
/// SourceRef receiver over the same carrier.
pub fn serve_sink_ref_over_quic<T>(
    stream: QuicBidirectionalStream,
    stream_ref_id: StreamRefId,
    settings: StreamRefSettings,
) -> (Source<T, NotUsed>, StreamRefQuicHandle)
where
    T: StreamRefPayload,
{
    let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
    let source = consumer.source();
    let handle = drive_stream_ref_endpoint(stream, consumer);
    (source, handle)
}

/// Creates a local `Sink` that sends its incoming elements over QUIC to a
/// remote `SinkRef` receiver.
///
/// This is the remote/sender side of the SinkRef-over-QUIC pair: elements
/// pushed into the returned [`Sink`] are framed as `SequencedOnNext` and sent
/// to the [`serve_sink_ref_over_quic`] receiver. The materialized value is a
/// [`StreamCompletion`] that resolves when the producer reaches its terminal
/// state (all elements sent and acknowledged, the receiver cancelled/failed,
/// or the carrier failed).
///
/// The wiring is the transport-agnostic producer seam
/// ([`StreamRefProtoProducer`]) in lazy-input mode: the input stream is
/// attached when the sink is materialized, and the producer waits (without
/// spinning) until both the input is attached and the remote has subscribed
/// with demand.
pub fn sink_ref_over_quic<T>(
    stream: QuicBidirectionalStream,
    stream_ref_id: StreamRefId,
    settings: StreamRefSettings,
) -> (Sink<T, StreamCompletion<NotUsed>>, StreamRefQuicHandle)
where
    T: StreamRefPayload,
{
    let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
    let sink = producer.sink();
    let handle = drive_stream_ref_endpoint(stream, producer);
    (sink, handle)
}

fn drive_stream_ref_endpoint<E>(stream: QuicBidirectionalStream, endpoint: E) -> StreamRefQuicHandle
where
    E: StreamRefProtoEndpoint,
{
    let (byte_source, byte_sink) = stream.into_parts();
    let (sender, receiver) = mpsc::channel();

    let outbound_endpoint = endpoint.clone();
    let outbound_thread = thread::spawn(move || {
        let result = outbound_frames(outbound_endpoint.clone())
            .run_with(byte_sink)
            .and_then(|completion| completion.wait());
        if let Err(error) = &result {
            outbound_endpoint.fail_connection(error.clone());
        }
        result
    });

    let inbound_endpoint = endpoint.clone();
    let inbound_thread = thread::spawn(move || {
        let result = inbound_frames(byte_source)
            .run_with(Sink::foreach_result({
                let inbound_endpoint = inbound_endpoint.clone();
                move |frame| inbound_endpoint.handle_frame(frame)
            }))
            .and_then(|completion| completion.wait());
        if let Err(error) = &result {
            inbound_endpoint.fail_connection(error.clone());
        }
        result
    });

    thread::spawn(move || {
        let outbound = join_carrier_thread(outbound_thread);
        let inbound = join_carrier_thread(inbound_thread);
        let result = match (outbound, inbound) {
            (Err(error), _) => Err(error),
            (_, Err(error)) => Err(error),
            (Ok(()), Ok(())) => Ok(NotUsed),
        };
        let _ = sender.send(result);
    });

    StreamRefQuicHandle { receiver }
}

fn outbound_frames<E>(endpoint: E) -> Source<Vec<u8>, NotUsed>
where
    E: StreamRefProtoEndpoint,
{
    Source::unfold_resource(
        move || Ok(endpoint.clone()),
        |endpoint| match endpoint.next_frame() {
            Some(Ok(frame)) => Ok(Some(encode_carrier_frame(frame)?)),
            Some(Err(error)) => Err(error),
            None => Ok(None),
        },
        |_endpoint| Ok(()),
    )
}

fn inbound_frames(byte_source: Source<Vec<u8>, NotUsed>) -> Source<StreamRefFrame, NotUsed> {
    let decoder = Arc::new(Mutex::new(FrameDecoder::default()));
    byte_source.map_concat_result(move |chunk| {
        decoder
            .lock()
            .expect("stream ref frame decoder poisoned")
            .push_chunk(chunk)
    })
}

fn encode_carrier_frame(frame: StreamRefFrame) -> StreamResult<Vec<u8>> {
    let payload = frame.encode_to_vec();
    let len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
        max: u32::MAX as u64,
    })?;
    let mut bytes = Vec::with_capacity(FRAME_LEN_BYTES + payload.len());
    bytes.extend(len.to_be_bytes());
    bytes.extend(payload);
    Ok(bytes)
}

#[derive(Default)]
struct FrameDecoder {
    buffer: VecDeque<u8>,
}

impl FrameDecoder {
    fn push_chunk(&mut self, chunk: Vec<u8>) -> StreamResult<Vec<StreamRefFrame>> {
        self.buffer.extend(chunk);
        let mut frames = Vec::new();
        while let Some(len) = self.peek_len()? {
            if self.buffer.len() < FRAME_LEN_BYTES + len {
                break;
            }
            self.buffer.drain(..FRAME_LEN_BYTES);
            let payload = self.buffer.drain(..len).collect::<Vec<_>>();
            frames.push(StreamRefFrame::decode(&payload)?);
        }
        Ok(frames)
    }

    fn peek_len(&self) -> StreamResult<Option<usize>> {
        if self.buffer.len() < FRAME_LEN_BYTES {
            return Ok(None);
        }
        let mut len = [0_u8; FRAME_LEN_BYTES];
        for (target, source) in len.iter_mut().zip(self.buffer.iter().take(FRAME_LEN_BYTES)) {
            *target = *source;
        }
        let len = u32::from_be_bytes(len) as usize;
        if len > MAX_STREAM_REF_FRAME_BYTES {
            return Err(StreamError::LimitExceeded {
                max: MAX_STREAM_REF_FRAME_BYTES as u64,
            });
        }
        Ok(Some(len))
    }
}

fn join_carrier_thread(handle: thread::JoinHandle<StreamResult<NotUsed>>) -> StreamResult<()> {
    match handle.join() {
        Ok(Ok(NotUsed)) => Ok(()),
        Ok(Err(error)) => Err(error),
        Err(_) => Err(StreamError::Failed(
            "StreamRefs QUIC carrier thread panicked".to_owned(),
        )),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn carrier_frame_decoder_reassembles_split_frames() {
        let frame = StreamRefFrame::new(
            StreamRefId::from_u128(1),
            datum::StreamRefMessage::CumulativeDemand { seq_nr: 32 },
        );
        let bytes = encode_carrier_frame(frame.clone()).unwrap();
        let split = bytes.len() / 2;
        let mut decoder = FrameDecoder::default();

        assert!(
            decoder
                .push_chunk(bytes[..split].to_vec())
                .unwrap()
                .is_empty()
        );
        assert_eq!(
            decoder.push_chunk(bytes[split..].to_vec()).unwrap(),
            vec![frame]
        );
    }
}