use std::{
collections::VecDeque,
sync::{Arc, Mutex, mpsc},
thread,
};
use datum::{
NotUsed, Sink, Source, SourceRef, StreamCompletion, StreamError, StreamRefFrame, StreamRefId,
StreamRefPayload, StreamRefProtoConsumer, StreamRefProtoEndpoint, StreamRefProtoProducer,
StreamRefSettings, StreamResult,
};
use crate::QuicBidirectionalStream;
const FRAME_LEN_BYTES: usize = 4;
const MAX_STREAM_REF_FRAME_BYTES: usize = 16 * 1024 * 1024;
#[must_use = "wait for the QUIC StreamRefs carrier to observe completion or failure"]
pub struct StreamRefQuicHandle {
receiver: mpsc::Receiver<StreamResult<NotUsed>>,
}
impl StreamRefQuicHandle {
pub fn wait(self) -> StreamResult<NotUsed> {
self.receiver
.recv()
.unwrap_or(Err(StreamError::AbruptTermination))
}
#[must_use]
pub fn try_wait(&self) -> Option<StreamResult<NotUsed>> {
self.receiver.try_recv().ok()
}
}
pub fn serve_source_ref_over_quic<T>(
stream: QuicBidirectionalStream,
source_ref: SourceRef<T>,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> StreamResult<StreamRefQuicHandle>
where
T: StreamRefPayload,
{
let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
Ok(drive_stream_ref_endpoint(stream, producer))
}
pub fn serve_source_over_quic<T, Mat>(
stream: QuicBidirectionalStream,
source: Source<T, Mat>,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> StreamResult<StreamRefQuicHandle>
where
T: StreamRefPayload,
Mat: Send + 'static,
{
let producer = StreamRefProtoProducer::from_source(source, stream_ref_id, settings)?;
Ok(drive_stream_ref_endpoint(stream, producer))
}
pub fn source_ref_over_quic<T>(
stream: QuicBidirectionalStream,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> (Source<T, NotUsed>, StreamRefQuicHandle)
where
T: StreamRefPayload,
{
let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
let source = consumer.source();
let handle = drive_stream_ref_endpoint(stream, consumer);
(source, handle)
}
pub fn serve_sink_ref_over_quic<T>(
stream: QuicBidirectionalStream,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> (Source<T, NotUsed>, StreamRefQuicHandle)
where
T: StreamRefPayload,
{
let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
let source = consumer.source();
let handle = drive_stream_ref_endpoint(stream, consumer);
(source, handle)
}
pub fn sink_ref_over_quic<T>(
stream: QuicBidirectionalStream,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> (Sink<T, StreamCompletion<NotUsed>>, StreamRefQuicHandle)
where
T: StreamRefPayload,
{
let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
let sink = producer.sink();
let handle = drive_stream_ref_endpoint(stream, producer);
(sink, handle)
}
fn drive_stream_ref_endpoint<E>(stream: QuicBidirectionalStream, endpoint: E) -> StreamRefQuicHandle
where
E: StreamRefProtoEndpoint,
{
let (byte_source, byte_sink) = stream.into_parts();
let (sender, receiver) = mpsc::channel();
let outbound_endpoint = endpoint.clone();
let outbound_thread = thread::spawn(move || {
let result = outbound_frames(outbound_endpoint.clone())
.run_with(byte_sink)
.and_then(|completion| completion.wait());
if let Err(error) = &result {
outbound_endpoint.fail_connection(error.clone());
}
result
});
let inbound_endpoint = endpoint.clone();
let inbound_thread = thread::spawn(move || {
let result = inbound_frames(byte_source)
.run_with(Sink::foreach_result({
let inbound_endpoint = inbound_endpoint.clone();
move |frame| inbound_endpoint.handle_frame(frame)
}))
.and_then(|completion| completion.wait());
if let Err(error) = &result {
inbound_endpoint.fail_connection(error.clone());
}
result
});
thread::spawn(move || {
let outbound = join_carrier_thread(outbound_thread);
let inbound = join_carrier_thread(inbound_thread);
let result = match (outbound, inbound) {
(Err(error), _) => Err(error),
(_, Err(error)) => Err(error),
(Ok(()), Ok(())) => Ok(NotUsed),
};
let _ = sender.send(result);
});
StreamRefQuicHandle { receiver }
}
fn outbound_frames<E>(endpoint: E) -> Source<Vec<u8>, NotUsed>
where
E: StreamRefProtoEndpoint,
{
Source::unfold_resource(
move || Ok(endpoint.clone()),
|endpoint| match endpoint.next_frame() {
Some(Ok(frame)) => Ok(Some(encode_carrier_frame(frame)?)),
Some(Err(error)) => Err(error),
None => Ok(None),
},
|_endpoint| Ok(()),
)
}
fn inbound_frames(byte_source: Source<Vec<u8>, NotUsed>) -> Source<StreamRefFrame, NotUsed> {
let decoder = Arc::new(Mutex::new(FrameDecoder::default()));
byte_source.map_concat_result(move |chunk| {
decoder
.lock()
.expect("stream ref frame decoder poisoned")
.push_chunk(chunk)
})
}
fn encode_carrier_frame(frame: StreamRefFrame) -> StreamResult<Vec<u8>> {
let payload = frame.encode_to_vec();
let len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
max: u32::MAX as u64,
})?;
let mut bytes = Vec::with_capacity(FRAME_LEN_BYTES + payload.len());
bytes.extend(len.to_be_bytes());
bytes.extend(payload);
Ok(bytes)
}
#[derive(Default)]
struct FrameDecoder {
buffer: VecDeque<u8>,
}
impl FrameDecoder {
fn push_chunk(&mut self, chunk: Vec<u8>) -> StreamResult<Vec<StreamRefFrame>> {
self.buffer.extend(chunk);
let mut frames = Vec::new();
while let Some(len) = self.peek_len()? {
if self.buffer.len() < FRAME_LEN_BYTES + len {
break;
}
self.buffer.drain(..FRAME_LEN_BYTES);
let payload = self.buffer.drain(..len).collect::<Vec<_>>();
frames.push(StreamRefFrame::decode(&payload)?);
}
Ok(frames)
}
fn peek_len(&self) -> StreamResult<Option<usize>> {
if self.buffer.len() < FRAME_LEN_BYTES {
return Ok(None);
}
let mut len = [0_u8; FRAME_LEN_BYTES];
for (target, source) in len.iter_mut().zip(self.buffer.iter().take(FRAME_LEN_BYTES)) {
*target = *source;
}
let len = u32::from_be_bytes(len) as usize;
if len > MAX_STREAM_REF_FRAME_BYTES {
return Err(StreamError::LimitExceeded {
max: MAX_STREAM_REF_FRAME_BYTES as u64,
});
}
Ok(Some(len))
}
}
fn join_carrier_thread(handle: thread::JoinHandle<StreamResult<NotUsed>>) -> StreamResult<()> {
match handle.join() {
Ok(Ok(NotUsed)) => Ok(()),
Ok(Err(error)) => Err(error),
Err(_) => Err(StreamError::Failed(
"StreamRefs QUIC carrier thread panicked".to_owned(),
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn carrier_frame_decoder_reassembles_split_frames() {
let frame = StreamRefFrame::new(
StreamRefId::from_u128(1),
datum::StreamRefMessage::CumulativeDemand { seq_nr: 32 },
);
let bytes = encode_carrier_frame(frame.clone()).unwrap();
let split = bytes.len() / 2;
let mut decoder = FrameDecoder::default();
assert!(
decoder
.push_chunk(bytes[..split].to_vec())
.unwrap()
.is_empty()
);
assert_eq!(
decoder.push_chunk(bytes[split..].to_vec()).unwrap(),
vec![frame]
);
}
}