1use std::{
8 collections::VecDeque,
9 sync::{Arc, Mutex, mpsc},
10 thread,
11};
12
13use datum::{
14 NotUsed, Sink, Source, SourceRef, StreamCompletion, StreamError, StreamRefFrame, StreamRefId,
15 StreamRefPayload, StreamRefProtoConsumer, StreamRefProtoEndpoint, StreamRefProtoProducer,
16 StreamRefSettings, StreamResult,
17};
18
19use crate::QuicBidirectionalStream;
20
21const FRAME_LEN_BYTES: usize = 4;
22const MAX_STREAM_REF_FRAME_BYTES: usize = 16 * 1024 * 1024;
23
24#[must_use = "wait for the QUIC StreamRefs carrier to observe completion or failure"]
26pub struct StreamRefQuicHandle {
27 receiver: mpsc::Receiver<StreamResult<NotUsed>>,
28}
29
30impl StreamRefQuicHandle {
31 pub fn wait(self) -> StreamResult<NotUsed> {
32 self.receiver
33 .recv()
34 .unwrap_or(Err(StreamError::AbruptTermination))
35 }
36
37 #[must_use]
38 pub fn try_wait(&self) -> Option<StreamResult<NotUsed>> {
39 self.receiver.try_recv().ok()
40 }
41}
42
43pub fn serve_source_ref_over_quic<T>(
45 stream: QuicBidirectionalStream,
46 source_ref: SourceRef<T>,
47 stream_ref_id: StreamRefId,
48 settings: StreamRefSettings,
49) -> StreamResult<StreamRefQuicHandle>
50where
51 T: StreamRefPayload,
52{
53 let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
54 Ok(drive_stream_ref_endpoint(stream, producer))
55}
56
57pub fn serve_source_over_quic<T, Mat>(
59 stream: QuicBidirectionalStream,
60 source: Source<T, Mat>,
61 stream_ref_id: StreamRefId,
62 settings: StreamRefSettings,
63) -> StreamResult<StreamRefQuicHandle>
64where
65 T: StreamRefPayload,
66 Mat: Send + 'static,
67{
68 let producer = StreamRefProtoProducer::from_source(source, stream_ref_id, settings)?;
69 Ok(drive_stream_ref_endpoint(stream, producer))
70}
71
72pub fn source_ref_over_quic<T>(
74 stream: QuicBidirectionalStream,
75 stream_ref_id: StreamRefId,
76 settings: StreamRefSettings,
77) -> (Source<T, NotUsed>, StreamRefQuicHandle)
78where
79 T: StreamRefPayload,
80{
81 let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
82 let source = consumer.source();
83 let handle = drive_stream_ref_endpoint(stream, consumer);
84 (source, handle)
85}
86
87pub fn serve_sink_ref_over_quic<T>(
101 stream: QuicBidirectionalStream,
102 stream_ref_id: StreamRefId,
103 settings: StreamRefSettings,
104) -> (Source<T, NotUsed>, StreamRefQuicHandle)
105where
106 T: StreamRefPayload,
107{
108 let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
109 let source = consumer.source();
110 let handle = drive_stream_ref_endpoint(stream, consumer);
111 (source, handle)
112}
113
114pub fn sink_ref_over_quic<T>(
130 stream: QuicBidirectionalStream,
131 stream_ref_id: StreamRefId,
132 settings: StreamRefSettings,
133) -> (Sink<T, StreamCompletion<NotUsed>>, StreamRefQuicHandle)
134where
135 T: StreamRefPayload,
136{
137 let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
138 let sink = producer.sink();
139 let handle = drive_stream_ref_endpoint(stream, producer);
140 (sink, handle)
141}
142
143fn drive_stream_ref_endpoint<E>(stream: QuicBidirectionalStream, endpoint: E) -> StreamRefQuicHandle
144where
145 E: StreamRefProtoEndpoint,
146{
147 let (byte_source, byte_sink) = stream.into_parts();
148 let (sender, receiver) = mpsc::channel();
149
150 let outbound_endpoint = endpoint.clone();
151 let outbound_thread = thread::spawn(move || {
152 let result = outbound_frames(outbound_endpoint.clone())
153 .run_with(byte_sink)
154 .and_then(|completion| completion.wait());
155 if let Err(error) = &result {
156 outbound_endpoint.fail_connection(error.clone());
157 }
158 result
159 });
160
161 let inbound_endpoint = endpoint.clone();
162 let inbound_thread = thread::spawn(move || {
163 let result = inbound_frames(byte_source)
164 .run_with(Sink::foreach_result({
165 let inbound_endpoint = inbound_endpoint.clone();
166 move |frame| inbound_endpoint.handle_frame(frame)
167 }))
168 .and_then(|completion| completion.wait());
169 if let Err(error) = &result {
170 inbound_endpoint.fail_connection(error.clone());
171 }
172 result
173 });
174
175 thread::spawn(move || {
176 let outbound = join_carrier_thread(outbound_thread);
177 let inbound = join_carrier_thread(inbound_thread);
178 let result = match (outbound, inbound) {
179 (Err(error), _) => Err(error),
180 (_, Err(error)) => Err(error),
181 (Ok(()), Ok(())) => Ok(NotUsed),
182 };
183 let _ = sender.send(result);
184 });
185
186 StreamRefQuicHandle { receiver }
187}
188
189fn outbound_frames<E>(endpoint: E) -> Source<Vec<u8>, NotUsed>
190where
191 E: StreamRefProtoEndpoint,
192{
193 Source::unfold_resource(
194 move || Ok(endpoint.clone()),
195 |endpoint| match endpoint.next_frame() {
196 Some(Ok(frame)) => Ok(Some(encode_carrier_frame(frame)?)),
197 Some(Err(error)) => Err(error),
198 None => Ok(None),
199 },
200 |_endpoint| Ok(()),
201 )
202}
203
204fn inbound_frames(byte_source: Source<Vec<u8>, NotUsed>) -> Source<StreamRefFrame, NotUsed> {
205 let decoder = Arc::new(Mutex::new(FrameDecoder::default()));
206 byte_source.map_concat_result(move |chunk| {
207 decoder
208 .lock()
209 .expect("stream ref frame decoder poisoned")
210 .push_chunk(chunk)
211 })
212}
213
214fn encode_carrier_frame(frame: StreamRefFrame) -> StreamResult<Vec<u8>> {
215 let payload = frame.encode_to_vec();
216 let len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
217 max: u32::MAX as u64,
218 })?;
219 let mut bytes = Vec::with_capacity(FRAME_LEN_BYTES + payload.len());
220 bytes.extend(len.to_be_bytes());
221 bytes.extend(payload);
222 Ok(bytes)
223}
224
225#[derive(Default)]
226struct FrameDecoder {
227 buffer: VecDeque<u8>,
228}
229
230impl FrameDecoder {
231 fn push_chunk(&mut self, chunk: Vec<u8>) -> StreamResult<Vec<StreamRefFrame>> {
232 self.buffer.extend(chunk);
233 let mut frames = Vec::new();
234 while let Some(len) = self.peek_len()? {
235 if self.buffer.len() < FRAME_LEN_BYTES + len {
236 break;
237 }
238 self.buffer.drain(..FRAME_LEN_BYTES);
239 let payload = self.buffer.drain(..len).collect::<Vec<_>>();
240 frames.push(StreamRefFrame::decode(&payload)?);
241 }
242 Ok(frames)
243 }
244
245 fn peek_len(&self) -> StreamResult<Option<usize>> {
246 if self.buffer.len() < FRAME_LEN_BYTES {
247 return Ok(None);
248 }
249 let mut len = [0_u8; FRAME_LEN_BYTES];
250 for (target, source) in len.iter_mut().zip(self.buffer.iter().take(FRAME_LEN_BYTES)) {
251 *target = *source;
252 }
253 let len = u32::from_be_bytes(len) as usize;
254 if len > MAX_STREAM_REF_FRAME_BYTES {
255 return Err(StreamError::LimitExceeded {
256 max: MAX_STREAM_REF_FRAME_BYTES as u64,
257 });
258 }
259 Ok(Some(len))
260 }
261}
262
263fn join_carrier_thread(handle: thread::JoinHandle<StreamResult<NotUsed>>) -> StreamResult<()> {
264 match handle.join() {
265 Ok(Ok(NotUsed)) => Ok(()),
266 Ok(Err(error)) => Err(error),
267 Err(_) => Err(StreamError::Failed(
268 "StreamRefs QUIC carrier thread panicked".to_owned(),
269 )),
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn carrier_frame_decoder_reassembles_split_frames() {
279 let frame = StreamRefFrame::new(
280 StreamRefId::from_u128(1),
281 datum::StreamRefMessage::CumulativeDemand { seq_nr: 32 },
282 );
283 let bytes = encode_carrier_frame(frame.clone()).unwrap();
284 let split = bytes.len() / 2;
285 let mut decoder = FrameDecoder::default();
286
287 assert!(
288 decoder
289 .push_chunk(bytes[..split].to_vec())
290 .unwrap()
291 .is_empty()
292 );
293 assert_eq!(
294 decoder.push_chunk(bytes[split..].to_vec()).unwrap(),
295 vec![frame]
296 );
297 }
298}