1use crate::async_carrier::{self, AsyncCommandSender, DemandBatcher};
10use datum::{Flow, Keep, NotUsed, Sink, Source, StreamCompletion, StreamError, StreamResult};
11pub use quinn::{self, crypto, rustls};
12use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
13use std::sync::{Arc, Mutex, mpsc as std_mpsc};
14use tokio::net::ToSocketAddrs;
15use tokio::runtime::Handle;
16use tokio::sync::{mpsc, watch};
17use tokio::task::JoinHandle;
18
19pub const DEFAULT_CHUNK_SIZE: usize = 8192;
21
22const DEFAULT_RECEIVE_BUFFER: usize = 64;
23
24pub type QuicByteSource = Source<Vec<u8>, NotUsed>;
30
31pub type QuicByteSink = Sink<Vec<u8>, StreamCompletion<NotUsed>>;
36
37enum DemandResponse<T> {
38 Item(T),
39 Complete,
40 Error(StreamError),
41}
42
43struct ReadResource {
44 receiver: std_mpsc::Receiver<DemandResponse<Vec<u8>>>,
45 carrier: QuicCarrier,
46 demand: DemandBatcher,
47 pending: Option<DemandResponse<Vec<u8>>>,
48}
49
50impl Drop for ReadResource {
51 fn drop(&mut self) {
52 self.carrier.close_read();
53 }
54}
55
56enum QuicCarrierCommand {
57 Demand(usize),
58 SendOne(Vec<u8>),
59 SendBatch(Vec<Vec<u8>>),
60 CloseRead,
61 CloseWrite {
62 ack: std_mpsc::Sender<StreamResult<()>>,
63 },
64}
65
66#[derive(Clone)]
67struct QuicCarrier {
68 inner: Arc<QuicCarrierInner>,
69}
70
71struct QuicCarrierInner {
72 commands: AsyncCommandSender<QuicCarrierCommand>,
73 send_errors: Mutex<std_mpsc::Receiver<StreamError>>,
74 task: Mutex<Option<JoinHandle<()>>>,
75}
76
77impl Drop for QuicCarrierInner {
78 fn drop(&mut self) {
79 if let Some(task) = self.task.lock().expect("QUIC carrier task poisoned").take() {
80 task.abort();
81 }
82 }
83}
84
85impl QuicCarrier {
86 fn close_read(&self) {
87 let _ = self.inner.commands.try_send(QuicCarrierCommand::CloseRead);
88 }
89
90 fn request_demand(&self, demand: usize) -> StreamResult<()> {
91 self.inner
92 .commands
93 .send_or_blocking(QuicCarrierCommand::Demand(demand))
94 }
95
96 fn send_items(&self, items: Vec<Vec<u8>>) -> StreamResult<()> {
97 self.check_send_error()?;
98 self.inner
99 .commands
100 .send_or_blocking(QuicCarrierCommand::SendBatch(items))
101 .map_err(|error| StreamError::Failed(format!("QUIC send batch failed: {error:?}")))
102 }
103
104 fn send_one(&self, item: Vec<u8>) -> StreamResult<()> {
105 self.check_send_error()?;
106 self.inner
107 .commands
108 .send_or_blocking(QuicCarrierCommand::SendOne(item))
109 .map_err(|error| StreamError::Failed(format!("QUIC send failed: {error:?}")))
110 }
111
112 fn close_write(&self) -> StreamResult<()> {
113 self.check_send_error()?;
114 let (ack_sender, ack_receiver) = std_mpsc::channel();
115 if self
116 .inner
117 .commands
118 .send_or_blocking(QuicCarrierCommand::CloseWrite { ack: ack_sender })
119 .is_err()
120 {
121 return Ok(());
122 }
123 match ack_receiver.recv() {
124 Ok(result) => result,
125 Err(_) => Err(abrupt_termination()),
126 }?;
127 self.check_send_error()
128 }
129
130 fn check_send_error(&self) -> StreamResult<()> {
131 match self
132 .inner
133 .send_errors
134 .lock()
135 .expect("QUIC carrier send error receiver poisoned")
136 .try_recv()
137 {
138 Ok(error) => Err(error),
139 Err(std_mpsc::TryRecvError::Empty) | Err(std_mpsc::TryRecvError::Disconnected) => {
140 Ok(())
141 }
142 }
143 }
144}
145
146struct SendResource {
147 carrier: QuicCarrier,
148 pending: Vec<Vec<u8>>,
149 batch_size: usize,
150}
151
152#[derive(Clone, Copy)]
153struct QuicReadConfig {
154 chunk_size: usize,
155 emit_available: bool,
156}
157
158struct BindResource {
159 demands: mpsc::Sender<std_mpsc::Sender<DemandResponse<QuicIncomingConnection>>>,
160 cancel: watch::Sender<bool>,
161 task: JoinHandle<()>,
162}
163
164impl Drop for BindResource {
165 fn drop(&mut self) {
166 let _ = self.cancel.send(true);
167 self.task.abort();
168 }
169}
170
171struct AcceptBiResource {
172 demands: mpsc::Sender<std_mpsc::Sender<DemandResponse<QuicBidirectionalStream>>>,
173 cancel: watch::Sender<bool>,
174 task: JoinHandle<()>,
175}
176
177impl Drop for AcceptBiResource {
178 fn drop(&mut self) {
179 let _ = self.cancel.send(true);
180 self.task.abort();
181 }
182}
183
184fn quic_error(error: impl std::fmt::Display) -> StreamError {
185 StreamError::Failed(error.to_string())
186}
187
188fn io_error(error: std::io::Error) -> StreamError {
189 StreamError::Failed(error.to_string())
190}
191
192fn abrupt_termination() -> StreamError {
193 StreamError::AbruptTermination
194}
195
196fn close_code() -> quinn::VarInt {
197 quinn::VarInt::from_u32(0)
198}
199
200#[derive(Debug, Clone, Copy, PartialEq, Eq)]
202pub struct QuicBinding {
203 pub local_addr: SocketAddr,
204}
205
206impl QuicBinding {
207 #[must_use]
209 pub fn local_addr(&self) -> SocketAddr {
210 self.local_addr
211 }
212}
213
214#[derive(Debug, Clone, Copy, PartialEq, Eq)]
216pub struct QuicStream {
217 pub id: quinn::StreamId,
218}
219
220impl QuicStream {
221 #[must_use]
223 pub fn id(&self) -> quinn::StreamId {
224 self.id
225 }
226}
227
228#[derive(Debug, Clone)]
230pub struct QuicConnection {
231 endpoint: quinn::Endpoint,
232 connection: quinn::Connection,
233 handle: Handle,
234 local_addr: SocketAddr,
235 remote_addr: SocketAddr,
236 chunk_size: usize,
237}
238
239impl QuicConnection {
240 #[must_use]
242 pub fn local_addr(&self) -> SocketAddr {
243 self.local_addr
244 }
245
246 #[must_use]
248 pub fn remote_addr(&self) -> SocketAddr {
249 self.remote_addr
250 }
251
252 #[must_use]
254 pub fn chunk_size(&self) -> usize {
255 self.chunk_size
256 }
257
258 #[must_use]
260 pub fn quinn_connection(&self) -> &quinn::Connection {
261 &self.connection
262 }
263
264 #[must_use]
266 pub fn quinn_endpoint(&self) -> &quinn::Endpoint {
267 &self.endpoint
268 }
269
270 #[must_use]
277 pub fn open_bi(
278 &self,
279 chunk_size: usize,
280 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<QuicStream>> {
281 assert!(chunk_size > 0, "chunk size must be greater than zero");
282 let connection = self.connection.clone();
283 let handle = self.handle.clone();
284 Flow::future_flow(move || {
285 let connection = connection.clone();
286 let handle = handle.clone();
287 async move {
288 let (send, recv) = connection.open_bi().await.map_err(quic_error)?;
289 Ok(quic_bi_stream_from_halves(send, recv, handle, chunk_size, false).into_flow())
290 }
291 })
292 }
293
294 #[must_use]
296 pub fn open_bi_default(&self) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<QuicStream>> {
297 self.open_bi(self.chunk_size)
298 }
299
300 #[must_use]
306 pub fn open_bi_stream(
307 &self,
308 chunk_size: usize,
309 ) -> Source<QuicBidirectionalStream, StreamCompletion<QuicStream>> {
310 assert!(chunk_size > 0, "chunk size must be greater than zero");
311 let connection = self.connection.clone();
312 let handle = self.handle.clone();
313 Source::lazy_future_source(move || {
314 let connection = connection.clone();
315 let handle = handle.clone();
316 async move {
317 let (send, recv) = connection.open_bi().await.map_err(quic_error)?;
318 let stream = quic_bi_stream_from_halves(send, recv, handle, chunk_size, false);
319 let metadata = stream.stream();
320 let stream = Arc::new(Mutex::new(Some(stream)));
321 Ok(Source::unfold_resource(
322 {
323 let stream = Arc::clone(&stream);
324 move || {
325 stream
326 .lock()
327 .expect("single-use QUIC bidi stream poisoned")
328 .take()
329 .map(Some)
330 .ok_or_else(|| {
331 StreamError::Failed(
332 "QUIC bidi stream already materialized".into(),
333 )
334 })
335 }
336 },
337 |stream| Ok(stream.take()),
338 |_stream| Ok(()),
339 )
340 .map_materialized_value(move |_| metadata))
341 }
342 })
343 }
344
345 #[must_use]
347 pub fn open_bi_stream_default(
348 &self,
349 ) -> Source<QuicBidirectionalStream, StreamCompletion<QuicStream>> {
350 self.open_bi_stream(self.chunk_size)
351 }
352
353 #[must_use]
360 pub fn open_bi_stream_available(
361 &self,
362 chunk_size: usize,
363 ) -> Source<QuicBidirectionalStream, StreamCompletion<QuicStream>> {
364 assert!(chunk_size > 0, "chunk size must be greater than zero");
365 let connection = self.connection.clone();
366 let handle = self.handle.clone();
367 Source::lazy_future_source(move || {
368 let connection = connection.clone();
369 let handle = handle.clone();
370 async move {
371 let (send, recv) = connection.open_bi().await.map_err(quic_error)?;
372 let stream = quic_bi_stream_from_halves(send, recv, handle, chunk_size, true);
373 let metadata = stream.stream();
374 let stream = Arc::new(Mutex::new(Some(stream)));
375 Ok(Source::unfold_resource(
376 {
377 let stream = Arc::clone(&stream);
378 move || {
379 stream
380 .lock()
381 .expect("single-use QUIC bidi stream poisoned")
382 .take()
383 .map(Some)
384 .ok_or_else(|| {
385 StreamError::Failed(
386 "QUIC bidi stream already materialized".into(),
387 )
388 })
389 }
390 },
391 |stream| Ok(stream.take()),
392 |_stream| Ok(()),
393 )
394 .map_materialized_value(move |_| metadata))
395 }
396 })
397 }
398
399 #[must_use]
405 pub fn accept_bi(&self, chunk_size: usize) -> Source<QuicBidirectionalStream, QuicConnection> {
406 assert!(chunk_size > 0, "chunk size must be greater than zero");
407 let connection = self.clone();
408 Source::unfold_resource(
409 {
410 let connection = connection.clone();
411 move || {
412 let handle = connection.handle.clone();
413 let (demand_sender, demand_receiver) = mpsc::channel(1);
414 let (cancel_sender, cancel_receiver) = watch::channel(false);
415 let task = handle.spawn(run_accept_bi_task(
416 connection.connection.clone(),
417 chunk_size,
418 false,
419 handle.clone(),
420 demand_receiver,
421 cancel_receiver,
422 ));
423 Ok(AcceptBiResource {
424 demands: demand_sender,
425 cancel: cancel_sender,
426 task,
427 })
428 }
429 },
430 receive_demand_response,
431 close_accept_bi_resource,
432 )
433 .map_materialized_value(move |_| connection.clone())
434 }
435
436 #[must_use]
439 pub fn accept_bi_default(&self) -> Source<QuicBidirectionalStream, QuicConnection> {
440 self.accept_bi(self.chunk_size)
441 }
442
443 #[must_use]
448 pub fn accept_bi_available(
449 &self,
450 chunk_size: usize,
451 ) -> Source<QuicBidirectionalStream, QuicConnection> {
452 assert!(chunk_size > 0, "chunk size must be greater than zero");
453 let connection = self.clone();
454 Source::unfold_resource(
455 {
456 let connection = connection.clone();
457 move || {
458 let handle = connection.handle.clone();
459 let (demand_sender, demand_receiver) = mpsc::channel(1);
460 let (cancel_sender, cancel_receiver) = watch::channel(false);
461 let task = handle.spawn(run_accept_bi_task(
462 connection.connection.clone(),
463 chunk_size,
464 true,
465 handle.clone(),
466 demand_receiver,
467 cancel_receiver,
468 ));
469 Ok(AcceptBiResource {
470 demands: demand_sender,
471 cancel: cancel_sender,
472 task,
473 })
474 }
475 },
476 receive_demand_response,
477 close_accept_bi_resource,
478 )
479 .map_materialized_value(move |_| connection.clone())
480 }
481
482 pub fn close(&self, reason: &[u8]) {
484 self.connection.close(close_code(), reason);
485 }
486}
487
488#[derive(Debug, Clone)]
490pub struct QuicIncomingConnection {
491 connection: QuicConnection,
492}
493
494impl QuicIncomingConnection {
495 #[must_use]
497 pub fn local_addr(&self) -> SocketAddr {
498 self.connection.local_addr()
499 }
500
501 #[must_use]
503 pub fn remote_addr(&self) -> SocketAddr {
504 self.connection.remote_addr()
505 }
506
507 #[must_use]
509 pub fn connection(&self) -> QuicConnection {
510 self.connection.clone()
511 }
512
513 #[must_use]
515 pub fn into_connection(self) -> QuicConnection {
516 self.connection
517 }
518
519 #[must_use]
521 pub fn open_bi(
522 &self,
523 chunk_size: usize,
524 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<QuicStream>> {
525 self.connection.open_bi(chunk_size)
526 }
527
528 #[must_use]
530 pub fn open_bi_default(&self) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<QuicStream>> {
531 self.connection.open_bi_default()
532 }
533
534 #[must_use]
536 pub fn open_bi_stream(
537 &self,
538 chunk_size: usize,
539 ) -> Source<QuicBidirectionalStream, StreamCompletion<QuicStream>> {
540 self.connection.open_bi_stream(chunk_size)
541 }
542
543 #[must_use]
545 pub fn open_bi_stream_default(
546 &self,
547 ) -> Source<QuicBidirectionalStream, StreamCompletion<QuicStream>> {
548 self.connection.open_bi_stream_default()
549 }
550
551 #[must_use]
553 pub fn open_bi_stream_available(
554 &self,
555 chunk_size: usize,
556 ) -> Source<QuicBidirectionalStream, StreamCompletion<QuicStream>> {
557 self.connection.open_bi_stream_available(chunk_size)
558 }
559
560 #[must_use]
562 pub fn accept_bi(&self, chunk_size: usize) -> Source<QuicBidirectionalStream, QuicConnection> {
563 self.connection.accept_bi(chunk_size)
564 }
565
566 #[must_use]
568 pub fn accept_bi_default(&self) -> Source<QuicBidirectionalStream, QuicConnection> {
569 self.connection.accept_bi_default()
570 }
571
572 #[must_use]
574 pub fn accept_bi_available(
575 &self,
576 chunk_size: usize,
577 ) -> Source<QuicBidirectionalStream, QuicConnection> {
578 self.connection.accept_bi_available(chunk_size)
579 }
580}
581
582pub struct QuicBidirectionalStream {
584 stream: QuicStream,
585 send: quinn::SendStream,
586 recv: quinn::RecvStream,
587 handle: Handle,
588 chunk_size: usize,
589 emit_available: bool,
590}
591
592impl QuicBidirectionalStream {
593 #[must_use]
595 pub fn stream(&self) -> QuicStream {
596 self.stream
597 }
598
599 #[must_use]
601 pub fn into_parts(self) -> (QuicByteSource, QuicByteSink) {
602 let Self {
603 send,
604 recv,
605 handle,
606 chunk_size,
607 emit_available,
608 ..
609 } = self;
610 single_use_quic_halves(send, recv, handle, chunk_size, emit_available)
611 }
612
613 #[must_use]
615 pub fn into_flow(self) -> Flow<Vec<u8>, Vec<u8>, QuicStream> {
616 let stream = self.stream;
617 let (source, sink) = self.into_parts();
618 Flow::from_sink_and_source(sink, source).map_materialized_value(move |_| stream)
619 }
620
621 pub(crate) fn into_stream_ref_parts(
622 self,
623 ) -> (quinn::RecvStream, quinn::SendStream, Handle, usize, bool) {
624 (
625 self.recv,
626 self.send,
627 self.handle,
628 self.chunk_size,
629 self.emit_available,
630 )
631 }
632}
633
634pub struct TokioQuic;
636
637pub type Quic = TokioQuic;
639
640impl TokioQuic {
641 #[must_use]
647 pub fn bind<A>(
648 addr: A,
649 server_config: quinn::ServerConfig,
650 chunk_size: usize,
651 ) -> Source<QuicIncomingConnection, StreamCompletion<QuicBinding>>
652 where
653 A: ToSocketAddrs + Clone + Send + Sync + 'static,
654 {
655 assert!(chunk_size > 0, "chunk size must be greater than zero");
656 Source::lazy_future_source(move || {
657 let addr = addr.clone();
658 let server_config = server_config.clone();
659 async move {
660 let handle = Handle::current();
661 let addr = resolve_addr(addr).await?;
662 let endpoint = quinn::Endpoint::server(server_config, addr).map_err(io_error)?;
663 let local_addr = endpoint.local_addr().map_err(io_error)?;
664 Ok(quic_bind_source(endpoint, local_addr, handle, chunk_size))
665 }
666 })
667 }
668
669 #[must_use]
671 pub fn bind_default<A>(
672 addr: A,
673 server_config: quinn::ServerConfig,
674 ) -> Source<QuicIncomingConnection, StreamCompletion<QuicBinding>>
675 where
676 A: ToSocketAddrs + Clone + Send + Sync + 'static,
677 {
678 Self::bind(addr, server_config, DEFAULT_CHUNK_SIZE)
679 }
680
681 #[must_use]
687 pub fn connect<A>(
688 addr: A,
689 server_name: impl Into<String>,
690 client_config: quinn::ClientConfig,
691 chunk_size: usize,
692 ) -> Source<QuicConnection, StreamCompletion<QuicConnection>>
693 where
694 A: ToSocketAddrs + Clone + Send + Sync + 'static,
695 {
696 assert!(chunk_size > 0, "chunk size must be greater than zero");
697 let server_name = server_name.into();
698 Source::lazy_future_source(move || {
699 let addr = addr.clone();
700 let server_name = server_name.clone();
701 let client_config = client_config.clone();
702 async move {
703 let remote_addr = resolve_addr(addr).await?;
704 let local_addr = client_bind_addr(remote_addr);
705 let mut endpoint = quinn::Endpoint::client(local_addr).map_err(io_error)?;
706 endpoint.set_default_client_config(client_config);
707 let connecting = endpoint
708 .connect(remote_addr, &server_name)
709 .map_err(quic_error)?;
710 let connection = connecting.await.map_err(quic_error)?;
711 let endpoint_local_addr = endpoint.local_addr().map_err(io_error)?;
712 let connection = QuicConnection {
713 local_addr: connection_local_addr(
714 &connection,
715 endpoint_local_addr,
716 remote_addr.ip(),
717 ),
718 remote_addr: connection.remote_address(),
719 endpoint,
720 connection,
721 handle: Handle::current(),
722 chunk_size,
723 };
724 let materialized = connection.clone();
725 Ok(
726 Source::single(connection)
727 .map_materialized_value(move |_| materialized.clone()),
728 )
729 }
730 })
731 }
732
733 #[must_use]
735 pub fn connect_default<A>(
736 addr: A,
737 server_name: impl Into<String>,
738 client_config: quinn::ClientConfig,
739 ) -> Source<QuicConnection, StreamCompletion<QuicConnection>>
740 where
741 A: ToSocketAddrs + Clone + Send + Sync + 'static,
742 {
743 Self::connect(addr, server_name, client_config, DEFAULT_CHUNK_SIZE)
744 }
745}
746
747async fn resolve_addr<A>(addr: A) -> StreamResult<SocketAddr>
748where
749 A: ToSocketAddrs,
750{
751 let mut addrs = tokio::net::lookup_host(addr).await.map_err(io_error)?;
752 addrs
753 .next()
754 .ok_or_else(|| StreamError::Failed("address resolved to no socket addresses".into()))
755}
756
757fn client_bind_addr(remote_addr: SocketAddr) -> SocketAddr {
758 if remote_addr.is_ipv6() {
759 SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
760 } else {
761 SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
762 }
763}
764
765fn connection_local_addr(
766 connection: &quinn::Connection,
767 endpoint_addr: SocketAddr,
768 fallback_ip: IpAddr,
769) -> SocketAddr {
770 connection
771 .local_ip()
772 .map(|ip| SocketAddr::new(ip, endpoint_addr.port()))
773 .or_else(|| {
774 endpoint_addr
775 .ip()
776 .is_unspecified()
777 .then(|| SocketAddr::new(fallback_ip, endpoint_addr.port()))
778 })
779 .unwrap_or(endpoint_addr)
780}
781
782fn quic_bi_stream_from_halves(
783 send: quinn::SendStream,
784 recv: quinn::RecvStream,
785 handle: Handle,
786 chunk_size: usize,
787 emit_available: bool,
788) -> QuicBidirectionalStream {
789 let stream = QuicStream { id: send.id() };
790 QuicBidirectionalStream {
791 stream,
792 send,
793 recv,
794 handle,
795 chunk_size,
796 emit_available,
797 }
798}
799
800fn single_use_quic_halves(
801 send: quinn::SendStream,
802 recv: quinn::RecvStream,
803 handle: Handle,
804 chunk_size: usize,
805 emit_available: bool,
806) -> (QuicByteSource, QuicByteSink) {
807 let (carrier, receiver) = start_quic_carrier(
808 send,
809 recv,
810 handle,
811 chunk_size,
812 emit_available,
813 DEFAULT_RECEIVE_BUFFER,
814 );
815 let source =
816 single_use_quic_read_source_from_carrier(carrier.clone(), receiver, DEFAULT_RECEIVE_BUFFER);
817 let sink = single_use_quic_write_sink_from_carrier(carrier, 1);
818 (source, sink)
819}
820
821fn single_use_quic_read_source_from_carrier(
822 carrier: QuicCarrier,
823 receiver: std_mpsc::Receiver<DemandResponse<Vec<u8>>>,
824 receive_buffer: usize,
825) -> QuicByteSource {
826 let receiver = Arc::new(Mutex::new(Some(receiver)));
827 Source::unfold_resource(
828 {
829 let receiver = Arc::clone(&receiver);
830 move || {
831 let receiver = receiver
832 .lock()
833 .expect("single-use QUIC receiver poisoned")
834 .take()
835 .ok_or_else(|| {
836 StreamError::Failed("QUIC source already materialized".into())
837 })?;
838 let demand = DemandBatcher::new(receive_buffer);
839 let pending = match carrier.request_demand(demand.initial()) {
840 Ok(()) => None,
841 Err(error) => match receiver.try_recv() {
842 Ok(response) => Some(response),
843 Err(std_mpsc::TryRecvError::Empty) => return Err(error),
844 Err(std_mpsc::TryRecvError::Disconnected) => {
845 return Err(abrupt_termination());
846 }
847 },
848 };
849 Ok(ReadResource {
850 receiver,
851 carrier: carrier.clone(),
852 demand,
853 pending,
854 })
855 }
856 },
857 read_next_quic_chunk,
858 close_read_resource,
859 )
860}
861
862fn read_next_quic_chunk(resource: &mut ReadResource) -> StreamResult<Option<Vec<u8>>> {
863 let response = match resource.pending.take() {
864 Some(response) => response,
865 None => resource.receiver.recv().map_err(|_| abrupt_termination())?,
866 };
867 match response {
868 DemandResponse::Item(chunk) => {
869 if let Some(demand) = resource.demand.record_consumed() {
870 let _ = resource.carrier.request_demand(demand);
871 }
872 Ok(Some(chunk))
873 }
874 DemandResponse::Complete => Ok(None),
875 DemandResponse::Error(error) => Err(error),
876 }
877}
878
879fn close_read_resource(resource: ReadResource) -> StreamResult<()> {
880 resource.carrier.close_read();
881 Ok(())
882}
883
884fn start_quic_carrier(
885 send: quinn::SendStream,
886 recv: quinn::RecvStream,
887 handle: Handle,
888 chunk_size: usize,
889 emit_available: bool,
890 receive_buffer: usize,
891) -> (QuicCarrier, std_mpsc::Receiver<DemandResponse<Vec<u8>>>) {
892 let command_capacity = async_carrier::DEFAULT_COMMAND_BUFFER.max(receive_buffer);
893 let (commands, command_receiver) = async_carrier::command_channel(command_capacity, "QUIC");
894 let (send_error_sender, send_error_receiver) = std_mpsc::channel();
895 let (receive_sender, receive_receiver) =
896 std_mpsc::sync_channel(receive_buffer.saturating_add(1));
897 let command_keepalive = commands.clone();
898 let read_config = QuicReadConfig {
899 chunk_size,
900 emit_available,
901 };
902 let task = handle.spawn(run_quic_carrier_task(
903 send,
904 recv,
905 read_config,
906 receive_sender,
907 send_error_sender,
908 command_keepalive,
909 command_receiver,
910 ));
911 (
912 QuicCarrier {
913 inner: Arc::new(QuicCarrierInner {
914 commands,
915 send_errors: Mutex::new(send_error_receiver),
916 task: Mutex::new(Some(task)),
917 }),
918 },
919 receive_receiver,
920 )
921}
922
923async fn run_quic_carrier_task(
924 mut send: quinn::SendStream,
925 mut recv: quinn::RecvStream,
926 read_config: QuicReadConfig,
927 receive_sender: std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
928 send_error_sender: std_mpsc::Sender<StreamError>,
929 _command_keepalive: AsyncCommandSender<QuicCarrierCommand>,
930 mut commands: mpsc::Receiver<QuicCarrierCommand>,
931) {
932 let mut buffer = vec![0_u8; read_config.chunk_size];
933 let mut pending_tail = Vec::with_capacity(read_config.chunk_size);
934 let mut requested = 0_usize;
935 let mut read_open = true;
936 let mut write_open = true;
937
938 loop {
939 if !read_open && !write_open {
940 return;
941 }
942
943 if read_open && requested > 0 {
944 tokio::select! {
945 biased;
946 command = commands.recv() => {
947 let Some(command) = command else {
948 return;
949 };
950 if !handle_quic_carrier_command(
951 &mut send,
952 command,
953 &send_error_sender,
954 &mut read_open,
955 &mut write_open,
956 &mut requested,
957 ).await {
958 return;
959 }
960 }
961 read = recv.read(&mut buffer) => {
962 match read {
963 Ok(Some(read)) => {
964 match queue_quic_read_chunks(
965 &receive_sender,
966 &send_error_sender,
967 read_config.chunk_size,
968 &mut pending_tail,
969 &buffer[..read],
970 read_config.emit_available,
971 ) {
972 QuicReadQueueResult::Queued(queued) => {
973 requested = requested.saturating_sub(queued);
974 }
975 QuicReadQueueResult::Closed => {
976 read_open = false;
977 }
978 QuicReadQueueResult::Failed => {
979 return;
980 }
981 }
982 }
983 Ok(None) => {
984 if !pending_tail.is_empty() {
985 match try_send_quic_read_response(
986 &receive_sender,
987 DemandResponse::Item(std::mem::take(&mut pending_tail)),
988 ) {
989 QuicQueueOutcome::Queued => {
990 requested = requested.saturating_sub(1);
991 }
992 QuicQueueOutcome::Closed => {
993 read_open = false;
994 continue;
995 }
996 QuicQueueOutcome::Full => {
997 report_quic_read_error(
998 &receive_sender,
999 &send_error_sender,
1000 quic_receive_buffer_overflow(),
1001 );
1002 return;
1003 }
1004 }
1005 }
1006 match try_send_quic_read_response(
1007 &receive_sender,
1008 DemandResponse::Complete,
1009 ) {
1010 QuicQueueOutcome::Queued | QuicQueueOutcome::Closed => {
1011 read_open = false;
1012 }
1013 QuicQueueOutcome::Full => {
1014 report_quic_read_error(
1015 &receive_sender,
1016 &send_error_sender,
1017 quic_receive_buffer_overflow(),
1018 );
1019 return;
1020 }
1021 }
1022 }
1023 Err(error) => {
1024 report_quic_read_error(
1025 &receive_sender,
1026 &send_error_sender,
1027 quic_error(error),
1028 );
1029 return;
1030 }
1031 }
1032 }
1033 }
1034 } else {
1035 let Some(command) = commands.recv().await else {
1036 return;
1037 };
1038 if !handle_quic_carrier_command(
1039 &mut send,
1040 command,
1041 &send_error_sender,
1042 &mut read_open,
1043 &mut write_open,
1044 &mut requested,
1045 )
1046 .await
1047 {
1048 return;
1049 }
1050 }
1051 }
1052}
1053
1054async fn handle_quic_carrier_command(
1055 send: &mut quinn::SendStream,
1056 command: QuicCarrierCommand,
1057 send_error_sender: &std_mpsc::Sender<StreamError>,
1058 read_open: &mut bool,
1059 write_open: &mut bool,
1060 requested: &mut usize,
1061) -> bool {
1062 match command {
1063 QuicCarrierCommand::Demand(demand) => {
1064 *requested = requested.saturating_add(demand);
1065 true
1066 }
1067 QuicCarrierCommand::SendOne(chunk) => {
1068 if !*write_open {
1069 report_quic_write_error(
1070 send_error_sender,
1071 StreamError::Failed("QUIC write side is closed".to_owned()),
1072 );
1073 return *read_open;
1074 }
1075 if write_one_quic_chunk(send, send_error_sender, &chunk).await {
1076 true
1077 } else {
1078 *write_open = false;
1079 *read_open
1080 }
1081 }
1082 QuicCarrierCommand::SendBatch(chunks) => {
1083 if !*write_open {
1084 report_quic_write_error(
1085 send_error_sender,
1086 StreamError::Failed("QUIC write side is closed".to_owned()),
1087 );
1088 return *read_open;
1089 }
1090 for chunk in &chunks {
1091 if let Err(error) = send.write_all(chunk).await.map_err(quic_error) {
1092 report_quic_write_error(send_error_sender, error);
1093 *write_open = false;
1094 return *read_open;
1095 }
1096 }
1097 true
1098 }
1099 QuicCarrierCommand::CloseRead => {
1100 *read_open = false;
1101 true
1102 }
1103 QuicCarrierCommand::CloseWrite { ack } => {
1104 *write_open = false;
1105 let result = close_quic_writer(send).await;
1106 match result {
1107 Ok(()) => {
1108 let _ = ack.send(Ok(()));
1109 true
1110 }
1111 Err(error) => {
1112 report_quic_write_error(send_error_sender, error.clone());
1113 let _ = ack.send(Err(error));
1114 *read_open
1115 }
1116 }
1117 }
1118 }
1119}
1120
1121async fn write_one_quic_chunk(
1122 send: &mut quinn::SendStream,
1123 send_error_sender: &std_mpsc::Sender<StreamError>,
1124 chunk: &[u8],
1125) -> bool {
1126 if let Err(error) = send.write_all(chunk).await.map_err(quic_error) {
1127 report_quic_write_error(send_error_sender, error);
1128 return false;
1129 }
1130 true
1131}
1132
1133async fn close_quic_writer(send: &mut quinn::SendStream) -> StreamResult<()> {
1134 send.write_all(&[]).await.map_err(quic_error)?;
1135 send.finish().map_err(quic_error)
1136}
1137
1138enum QuicReadQueueResult {
1139 Queued(usize),
1140 Closed,
1141 Failed,
1142}
1143
1144enum QuicQueueOutcome {
1145 Queued,
1146 Full,
1147 Closed,
1148}
1149
1150fn queue_quic_read_chunks(
1151 sender: &std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
1152 send_error_sender: &std_mpsc::Sender<StreamError>,
1153 chunk_size: usize,
1154 pending_tail: &mut Vec<u8>,
1155 read_buffer: &[u8],
1156 emit_available: bool,
1157) -> QuicReadQueueResult {
1158 let mut offset = 0;
1159 let mut queued = 0_usize;
1160 if !pending_tail.is_empty() {
1161 let needed = chunk_size - pending_tail.len();
1162 let take = needed.min(read_buffer.len());
1163 pending_tail.extend_from_slice(&read_buffer[..take]);
1164 offset += take;
1165 if pending_tail.len() == chunk_size {
1166 match try_send_quic_read_response(
1167 sender,
1168 DemandResponse::Item(std::mem::take(pending_tail)),
1169 ) {
1170 QuicQueueOutcome::Queued => queued += 1,
1171 QuicQueueOutcome::Closed => return QuicReadQueueResult::Closed,
1172 QuicQueueOutcome::Full => {
1173 report_quic_read_error(
1174 sender,
1175 send_error_sender,
1176 quic_receive_buffer_overflow(),
1177 );
1178 return QuicReadQueueResult::Failed;
1179 }
1180 }
1181 }
1182 }
1183
1184 while offset + chunk_size <= read_buffer.len() {
1185 let next = offset + chunk_size;
1186 match try_send_quic_read_response(
1187 sender,
1188 DemandResponse::Item(read_buffer[offset..next].to_vec()),
1189 ) {
1190 QuicQueueOutcome::Queued => queued += 1,
1191 QuicQueueOutcome::Closed => return QuicReadQueueResult::Closed,
1192 QuicQueueOutcome::Full => {
1193 report_quic_read_error(sender, send_error_sender, quic_receive_buffer_overflow());
1194 return QuicReadQueueResult::Failed;
1195 }
1196 }
1197 offset = next;
1198 }
1199
1200 if offset < read_buffer.len() {
1201 pending_tail.extend_from_slice(&read_buffer[offset..]);
1202 }
1203 if emit_available && !pending_tail.is_empty() {
1204 match try_send_quic_read_response(
1205 sender,
1206 DemandResponse::Item(std::mem::take(pending_tail)),
1207 ) {
1208 QuicQueueOutcome::Queued => queued += 1,
1209 QuicQueueOutcome::Closed => return QuicReadQueueResult::Closed,
1210 QuicQueueOutcome::Full => {
1211 report_quic_read_error(sender, send_error_sender, quic_receive_buffer_overflow());
1212 return QuicReadQueueResult::Failed;
1213 }
1214 }
1215 }
1216 QuicReadQueueResult::Queued(queued)
1217}
1218
1219fn try_send_quic_read_response(
1220 sender: &std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
1221 item: DemandResponse<Vec<u8>>,
1222) -> QuicQueueOutcome {
1223 match sender.try_send(item) {
1224 Ok(()) => QuicQueueOutcome::Queued,
1225 Err(std_mpsc::TrySendError::Full(_)) => QuicQueueOutcome::Full,
1226 Err(std_mpsc::TrySendError::Disconnected(_)) => QuicQueueOutcome::Closed,
1227 }
1228}
1229
1230fn report_quic_read_error(
1231 receive_sender: &std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
1232 send_error_sender: &std_mpsc::Sender<StreamError>,
1233 error: StreamError,
1234) {
1235 let _ = send_error_sender.send(error.clone());
1236 let _ = receive_sender.try_send(DemandResponse::Error(error));
1237}
1238
1239fn report_quic_write_error(send_error_sender: &std_mpsc::Sender<StreamError>, error: StreamError) {
1240 let _ = send_error_sender.send(error);
1241}
1242
1243fn quic_receive_buffer_overflow() -> StreamError {
1244 StreamError::Failed("QUIC receive buffer filled without downstream demand".to_owned())
1245}
1246
1247fn single_use_quic_write_sink_from_carrier(
1248 carrier: QuicCarrier,
1249 batch_size: usize,
1250) -> QuicByteSink {
1251 let carrier = Arc::new(Mutex::new(Some(carrier)));
1252 Flow::<Vec<u8>, Vec<u8>>::identity()
1253 .map_with_resource(
1254 {
1255 let carrier = Arc::clone(&carrier);
1256 move || {
1257 let carrier = carrier
1258 .lock()
1259 .expect("single-use QUIC carrier poisoned")
1260 .take()
1261 .ok_or_else(|| {
1262 StreamError::Failed("QUIC sink already materialized".into())
1263 })?;
1264 Ok(SendResource {
1265 carrier,
1266 pending: Vec::with_capacity(batch_size),
1267 batch_size,
1268 })
1269 }
1270 },
1271 |resource, chunk| {
1272 send_quic_chunk(resource, chunk)?;
1273 Ok(NotUsed)
1274 },
1275 close_quic_send_resource,
1276 )
1277 .to_mat(Sink::ignore(), Keep::right)
1278}
1279
1280fn close_quic_send_resource(mut resource: SendResource) -> StreamResult<Option<NotUsed>> {
1281 flush_quic_send_resource(&mut resource)?;
1282 resource.carrier.close_write()?;
1283 Ok(None)
1284}
1285
1286fn send_quic_chunk(resource: &mut SendResource, chunk: Vec<u8>) -> StreamResult<()> {
1287 if resource.batch_size <= 1 {
1288 return resource.carrier.send_one(chunk);
1289 }
1290 resource.pending.push(chunk);
1291 if resource.pending.len() >= resource.batch_size {
1292 flush_quic_send_resource(resource)?;
1293 }
1294 Ok(())
1295}
1296
1297fn flush_quic_send_resource(resource: &mut SendResource) -> StreamResult<()> {
1298 if resource.pending.is_empty() {
1299 return resource.carrier.check_send_error();
1300 }
1301 let pending = std::mem::take(&mut resource.pending);
1302 resource.carrier.send_items(pending)
1303}
1304
1305fn quic_bind_source(
1306 endpoint: quinn::Endpoint,
1307 local_addr: SocketAddr,
1308 handle: Handle,
1309 chunk_size: usize,
1310) -> Source<QuicIncomingConnection, QuicBinding> {
1311 let endpoint = Arc::new(Mutex::new(Some(endpoint)));
1312 Source::unfold_resource(
1313 {
1314 let endpoint = Arc::clone(&endpoint);
1315 let handle = handle.clone();
1316 move || {
1317 let endpoint = endpoint
1318 .lock()
1319 .expect("single-use QUIC endpoint poisoned")
1320 .take()
1321 .ok_or_else(|| {
1322 StreamError::Failed("QUIC endpoint already materialized".into())
1323 })?;
1324 let (demand_sender, demand_receiver) = mpsc::channel(1);
1325 let (cancel_sender, cancel_receiver) = watch::channel(false);
1326 let task = handle.spawn(run_quic_bind_task(
1327 endpoint,
1328 local_addr,
1329 chunk_size,
1330 handle.clone(),
1331 demand_receiver,
1332 cancel_receiver,
1333 ));
1334 Ok(BindResource {
1335 demands: demand_sender,
1336 cancel: cancel_sender,
1337 task,
1338 })
1339 }
1340 },
1341 receive_demand_response,
1342 close_bind_resource,
1343 )
1344 .map_materialized_value(move |_| QuicBinding { local_addr })
1345}
1346
1347fn receive_demand_response<T>(resource: &mut impl DemandResource<T>) -> StreamResult<Option<T>>
1348where
1349 T: Send + 'static,
1350{
1351 let (reply_sender, reply_receiver) = std_mpsc::channel();
1352 resource
1353 .demands()
1354 .blocking_send(reply_sender)
1355 .map_err(|_| abrupt_termination())?;
1356 match reply_receiver.recv() {
1357 Ok(DemandResponse::Item(item)) => Ok(Some(item)),
1358 Ok(DemandResponse::Complete) => Ok(None),
1359 Ok(DemandResponse::Error(error)) => Err(error),
1360 Err(_) => Err(abrupt_termination()),
1361 }
1362}
1363
1364trait DemandResource<T>
1365where
1366 T: Send + 'static,
1367{
1368 fn demands(&self) -> &mpsc::Sender<std_mpsc::Sender<DemandResponse<T>>>;
1369}
1370
1371impl DemandResource<QuicIncomingConnection> for BindResource {
1372 fn demands(&self) -> &mpsc::Sender<std_mpsc::Sender<DemandResponse<QuicIncomingConnection>>> {
1373 &self.demands
1374 }
1375}
1376
1377impl DemandResource<QuicBidirectionalStream> for AcceptBiResource {
1378 fn demands(&self) -> &mpsc::Sender<std_mpsc::Sender<DemandResponse<QuicBidirectionalStream>>> {
1379 &self.demands
1380 }
1381}
1382
1383fn close_bind_resource(resource: BindResource) -> StreamResult<()> {
1384 let _ = resource.cancel.send(true);
1385 resource.task.abort();
1386 Ok(())
1387}
1388
1389fn close_accept_bi_resource(resource: AcceptBiResource) -> StreamResult<()> {
1390 let _ = resource.cancel.send(true);
1391 resource.task.abort();
1392 Ok(())
1393}
1394
1395async fn run_quic_bind_task(
1396 endpoint: quinn::Endpoint,
1397 local_addr: SocketAddr,
1398 chunk_size: usize,
1399 handle: Handle,
1400 mut demands: mpsc::Receiver<std_mpsc::Sender<DemandResponse<QuicIncomingConnection>>>,
1401 mut cancel: watch::Receiver<bool>,
1402) {
1403 loop {
1404 let reply = tokio::select! {
1405 demand = demands.recv() => match demand {
1406 Some(reply) => reply,
1407 None => return,
1408 },
1409 changed = cancel.changed() => {
1410 let _ = changed;
1411 return;
1412 }
1413 };
1414
1415 let incoming = tokio::select! {
1416 incoming = endpoint.accept() => incoming,
1417 changed = cancel.changed() => {
1418 let _ = changed;
1419 return;
1420 }
1421 };
1422
1423 let Some(incoming) = incoming else {
1424 let _ = reply.send(DemandResponse::Complete);
1425 return;
1426 };
1427
1428 let connected = tokio::select! {
1429 connected = incoming => connected,
1430 changed = cancel.changed() => {
1431 let _ = changed;
1432 return;
1433 }
1434 };
1435
1436 match connected {
1437 Ok(connection) => {
1438 let incoming = QuicIncomingConnection {
1439 connection: QuicConnection {
1440 endpoint: endpoint.clone(),
1441 local_addr: connection_local_addr(&connection, local_addr, local_addr.ip()),
1442 remote_addr: connection.remote_address(),
1443 connection,
1444 handle: handle.clone(),
1445 chunk_size,
1446 },
1447 };
1448 if reply.send(DemandResponse::Item(incoming)).is_err() {
1449 return;
1450 }
1451 }
1452 Err(error) => {
1453 let _ = reply.send(DemandResponse::Error(quic_error(error)));
1454 return;
1455 }
1456 }
1457 }
1458}
1459
1460async fn run_accept_bi_task(
1461 connection: quinn::Connection,
1462 chunk_size: usize,
1463 emit_available: bool,
1464 handle: Handle,
1465 mut demands: mpsc::Receiver<std_mpsc::Sender<DemandResponse<QuicBidirectionalStream>>>,
1466 mut cancel: watch::Receiver<bool>,
1467) {
1468 loop {
1469 let reply = tokio::select! {
1470 demand = demands.recv() => match demand {
1471 Some(reply) => reply,
1472 None => return,
1473 },
1474 changed = cancel.changed() => {
1475 let _ = changed;
1476 return;
1477 }
1478 };
1479
1480 let accepted = tokio::select! {
1481 accepted = connection.accept_bi() => accepted,
1482 changed = cancel.changed() => {
1483 let _ = changed;
1484 return;
1485 }
1486 };
1487
1488 match accepted {
1489 Ok((send, recv)) => {
1490 let stream = quic_bi_stream_from_halves(
1491 send,
1492 recv,
1493 handle.clone(),
1494 chunk_size,
1495 emit_available,
1496 );
1497 if reply.send(DemandResponse::Item(stream)).is_err() {
1498 return;
1499 }
1500 }
1501 Err(error) => {
1502 let _ = reply.send(DemandResponse::Error(quic_error(error)));
1503 return;
1504 }
1505 }
1506 }
1507}