Skip to main content

msquic_h3/
lib.rs

1use std::{ffi::c_void, pin::Pin, sync::Arc};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use futures::{
5    StreamExt,
6    channel::{mpsc, oneshot},
7    ready,
8};
9use h3::quic::{
10    BidiStream, ConnectionErrorIncoming, OpenStreams, RecvStream, SendStream, StreamErrorIncoming,
11};
12use msquic::{
13    Configuration, ConnectionEvent, ConnectionRef, ConnectionShutdownFlags, ReceiveFlags,
14    Registration, SendFlags, Status, StatusCode, StreamEvent, StreamOpenFlags, StreamRef,
15    StreamShutdownFlags, StreamStartFlags,
16};
17
18mod buffer;
19pub use buffer::*;
20mod listener;
21pub use listener::Listener;
22
23/// re-export msquic type
24pub mod msquic {
25    pub use ::msquic::*;
26}
27
28#[derive(Debug)]
29pub struct Connection {
30    conn: Arc<msquic::Connection>,
31    ctx: ConnCtxReceiver,
32    opener: StreamOpener,
33}
34
35/// from callback send to fount end.
36#[derive(Debug)]
37struct ConnCtxSender {
38    connected: Option<oneshot::Sender<()>>,
39    shutdown: Option<oneshot::Sender<()>>,
40    bidi: Option<mpsc::UnboundedSender<Option<crate::H3Stream>>>,
41    uni: Option<mpsc::UnboundedSender<Option<crate::H3Stream>>>,
42}
43
44/// front end receive.
45#[derive(Debug)]
46struct ConnCtxReceiver {
47    connected: Option<oneshot::Receiver<()>>,
48    shutdown: Option<oneshot::Receiver<()>>,
49    bidi: mpsc::UnboundedReceiver<Option<crate::H3Stream>>,
50    uni: mpsc::UnboundedReceiver<Option<crate::H3Stream>>,
51}
52
53fn conn_ctx_channel() -> (ConnCtxSender, ConnCtxReceiver) {
54    let (conn_tx, conn_rx) = oneshot::channel();
55    let (shutdown_tx, shutdown_rx) = oneshot::channel();
56    let (bidi_tx, bidi_rx) = mpsc::unbounded();
57    let (uni_tx, uni_rx) = mpsc::unbounded();
58    (
59        ConnCtxSender {
60            connected: Some(conn_tx),
61            shutdown: Some(shutdown_tx),
62            bidi: Some(bidi_tx),
63            uni: Some(uni_tx),
64        },
65        ConnCtxReceiver {
66            connected: Some(conn_rx),
67            shutdown: Some(shutdown_rx),
68            bidi: bidi_rx,
69            uni: uni_rx,
70        },
71    )
72}
73
74#[cfg_attr(
75    feature = "tracing",
76    tracing::instrument(skip(ctx), level = "trace", ret, err)
77)]
78fn connection_callback(ctx: &mut ConnCtxSender, ev: msquic::ConnectionEvent) -> Result<(), Status> {
79    match ev {
80        ConnectionEvent::Connected { .. } => {
81            ctx.connected.take().unwrap().send(()).unwrap();
82        }
83        ConnectionEvent::PeerStreamStarted { stream, flags } => {
84            // TODO: need to set callback
85            let s = unsafe { msquic::Stream::from_raw(stream.as_raw()) };
86            if flags.contains(StreamOpenFlags::UNIDIRECTIONAL) {
87                if let Some(uni) = ctx.uni.as_ref() {
88                    uni.unbounded_send(Some(crate::H3Stream::attach(s)))
89                        .expect("cannot send");
90                }
91            } else if let Some(bidi) = ctx.bidi.as_ref() {
92                bidi.unbounded_send(Some(crate::H3Stream::attach(s)))
93                    .expect("cannot send");
94            }
95        }
96        ConnectionEvent::ShutdownComplete { .. } => {
97            // clear all channels.
98            ctx.connected.take();
99            ctx.uni.take();
100            ctx.bidi.take();
101            if let Some(shutdown) = ctx.shutdown.take() {
102                let _ = shutdown.send(());
103            }
104        }
105        _ => {}
106    }
107    Ok(())
108}
109
110impl Connection {
111    /// Connects to the server
112    /// Note: Registration must be kept outside of connection
113    /// and must wait for all connections to finish before closing,
114    /// else registration close will wait on system lock, and block
115    /// rust runtime.
116    pub async fn connect(
117        reg: &Registration,
118        config: &Configuration,
119        server_name: &str,
120        server_port: u16,
121    ) -> Result<Self, Status> {
122        let (mut ctx, mut crx) = conn_ctx_channel();
123        let handler =
124            move |_: ConnectionRef, ev: ConnectionEvent| connection_callback(&mut ctx, ev);
125        let conn = msquic::Connection::open(reg, handler)?;
126        conn.start(config, server_name, server_port)?;
127        // wait for connection.
128        crx.connected
129            .take()
130            .unwrap()
131            .await
132            .map_err(|_| Status::new(StatusCode::QUIC_STATUS_ABORTED))?;
133
134        let conn = Arc::new(conn);
135
136        let opener = StreamOpener::new(conn.clone());
137
138        Ok(Self {
139            conn,
140            ctx: crx,
141            opener,
142        })
143    }
144
145    /// attach to an accepted connection
146    pub(crate) fn attach(inner: msquic::Connection) -> Self {
147        let (mut ctx, crx) = conn_ctx_channel();
148        let handler =
149            move |_: ConnectionRef, ev: ConnectionEvent| connection_callback(&mut ctx, ev);
150        inner.set_callback_handler(handler);
151        let conn = Arc::new(inner);
152
153        let opener = StreamOpener::new(conn.clone());
154
155        Self {
156            conn,
157            ctx: crx,
158            opener,
159        }
160    }
161
162    /// Can only be called once after construction.
163    pub fn get_shutdown_waiter(&mut self) -> ConnectionShutdownWaiter {
164        ConnectionShutdownWaiter {
165            rx: self.ctx.shutdown.take().unwrap(),
166        }
167    }
168}
169
170/// wait for connection to be fully shutdown.
171pub struct ConnectionShutdownWaiter {
172    rx: oneshot::Receiver<()>,
173}
174impl ConnectionShutdownWaiter {
175    /// wait for connection to be fully shutdown.
176    pub async fn wait(self) {
177        self.rx
178            .await
179            .expect("failed to wait for connection shutdown");
180    }
181}
182
183/// responsible for open streams on a connection.
184#[derive(Debug)]
185pub struct StreamOpener {
186    conn: Arc<msquic::Connection>,
187    bidi_temp: Option<H3Stream>,
188    uni_temp: Option<H3Stream>,
189}
190
191impl Clone for StreamOpener {
192    fn clone(&self) -> Self {
193        Self {
194            conn: self.conn.clone(),
195            bidi_temp: None,
196            uni_temp: None,
197        }
198    }
199}
200
201/// Server accept streams
202impl<B: Buf> h3::quic::Connection<B> for Connection {
203    type RecvStream = H3RecvStream;
204
205    type OpenStreams = StreamOpener;
206
207    #[cfg_attr(
208        feature = "tracing",
209        tracing::instrument(skip_all, level = "trace", ret)
210    )]
211    fn poll_accept_recv(
212        &mut self,
213        cx: &mut std::task::Context<'_>,
214    ) -> std::task::Poll<Result<Self::RecvStream, ConnectionErrorIncoming>> {
215        let s = ready!(self.ctx.uni.poll_next_unpin(cx)).unwrap_or(None);
216        // wrap for h3 type. Drop the send stream part
217        std::task::Poll::Ready(
218            s.map(|s| s.recv)
219                .ok_or(ConnectionErrorIncoming::ApplicationClose { error_code: 0 }),
220        )
221    }
222
223    #[cfg_attr(
224        feature = "tracing",
225        tracing::instrument(skip_all, level = "trace", ret)
226    )]
227    fn poll_accept_bidi(
228        &mut self,
229        cx: &mut std::task::Context<'_>,
230    ) -> std::task::Poll<Result<Self::BidiStream, ConnectionErrorIncoming>> {
231        let s = ready!(self.ctx.bidi.poll_next_unpin(cx)).unwrap_or(None);
232        // wrap for h3 type
233        std::task::Poll::Ready(s.ok_or(ConnectionErrorIncoming::ApplicationClose { error_code: 0 }))
234    }
235
236    #[cfg_attr(
237        feature = "tracing",
238        tracing::instrument(skip_all, level = "trace", ret)
239    )]
240    fn opener(&self) -> Self::OpenStreams {
241        StreamOpener::new(self.conn.clone())
242    }
243}
244
245/// Create new streams from connection.
246impl<B: Buf> OpenStreams<B> for StreamOpener {
247    type BidiStream = H3Stream;
248
249    type SendStream = H3SendStream;
250
251    #[cfg_attr(
252        feature = "tracing",
253        tracing::instrument(skip_all, level = "trace", ret)
254    )]
255    fn poll_open_bidi(
256        &mut self,
257        cx: &mut std::task::Context<'_>,
258    ) -> std::task::Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
259        Self::poll_open_inner(&self.conn, false, &mut self.bidi_temp, cx)
260    }
261
262    #[cfg_attr(
263        feature = "tracing",
264        tracing::instrument(skip_all, level = "trace", ret)
265    )]
266    fn poll_open_send(
267        &mut self,
268        cx: &mut std::task::Context<'_>,
269    ) -> std::task::Poll<Result<Self::SendStream, StreamErrorIncoming>> {
270        let res = ready!(Self::poll_open_inner(
271            &self.conn,
272            true,
273            &mut self.uni_temp,
274            cx
275        ));
276        // get the send part.
277        std::task::Poll::Ready(res.map(|s| s.send))
278    }
279
280    #[cfg_attr(
281        feature = "tracing",
282        tracing::instrument(skip_all, level = "trace", ret)
283    )]
284    fn close(&mut self, code: h3::error::Code, _reason: &[u8]) {
285        self.conn
286            .shutdown(ConnectionShutdownFlags::NONE, code.value());
287    }
288}
289
290impl StreamOpener {
291    fn new(conn: Arc<msquic::Connection>) -> Self {
292        Self {
293            conn,
294            bidi_temp: None,
295            uni_temp: None,
296        }
297    }
298
299    /// open a stream and poll it in the holder.
300    fn poll_open_inner(
301        conn: &Arc<msquic::Connection>,
302        uni: bool,
303        stream_holder: &mut Option<H3Stream>,
304        cx: &mut std::task::Context<'_>,
305    ) -> std::task::Poll<Result<H3Stream, StreamErrorIncoming>> {
306        if stream_holder.is_none() {
307            // create new stream
308            let s = match H3Stream::open_and_start(conn, uni) {
309                Ok(s) => s,
310                Err(e) => {
311                    return std::task::Poll::Ready(Err(StreamErrorIncoming::Unknown(e.into())));
312                }
313            };
314            *stream_holder = Some(s);
315        }
316
317        // poll stream start.
318        let res = {
319            let s = stream_holder.as_mut().unwrap();
320            let rx = s.send.sctx.start.as_mut().unwrap();
321            let p = Pin::new(rx);
322            ready!(std::future::Future::poll(p, cx))
323        };
324        // current stream is either ready or error. So ready to be returned or dropped.
325        let s = stream_holder.take().unwrap();
326        let res = res
327            .expect("cannot receive")
328            .map(|_| s)
329            .map_err(|e| StreamErrorIncoming::Unknown(e.into()));
330        std::task::Poll::Ready(res)
331    }
332}
333
334/// bypass for StreamOpener
335impl<B: Buf> OpenStreams<B> for Connection {
336    type BidiStream = H3Stream;
337
338    type SendStream = H3SendStream;
339
340    fn poll_open_bidi(
341        &mut self,
342        cx: &mut std::task::Context<'_>,
343    ) -> std::task::Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
344        OpenStreams::<B>::poll_open_bidi(&mut self.opener, cx)
345    }
346
347    fn poll_open_send(
348        &mut self,
349        cx: &mut std::task::Context<'_>,
350    ) -> std::task::Poll<Result<Self::SendStream, StreamErrorIncoming>> {
351        OpenStreams::<B>::poll_open_send(&mut self.opener, cx)
352    }
353
354    fn close(&mut self, code: h3::error::Code, reason: &[u8]) {
355        OpenStreams::<B>::close(&mut self.opener, code, reason)
356    }
357}
358
359/// Msquic Stream.
360#[derive(Debug)]
361pub struct H3Stream {
362    send: H3SendStream,
363    recv: H3RecvStream,
364}
365#[derive(Debug)]
366pub struct H3SendStream {
367    stream: Arc<msquic::Stream>,
368    sctx: SendStreamReceiveCtx,
369}
370#[derive(Debug)]
371pub struct H3RecvStream {
372    stream: Arc<msquic::Stream>,
373    rctx: RecvStreamReceiveCtx,
374}
375
376struct BufPtr(*const c_void);
377unsafe impl Send for BufPtr {}
378unsafe impl Sync for BufPtr {}
379
380struct StreamSendCtx {
381    start: Option<oneshot::Sender<Result<(), Status>>>,
382    // cancelled, client_context
383    send: Option<mpsc::UnboundedSender<(bool, BufPtr)>>,
384    shutdown: Option<oneshot::Sender<()>>,
385    receive: Option<mpsc::UnboundedSender<Bytes>>,
386}
387
388/// ctx for receiving data on frontend.
389#[derive(Debug)]
390struct RecvStreamReceiveCtx {
391    receive: mpsc::UnboundedReceiver<Bytes>,
392}
393
394/// ctx for sending data on frontend.
395#[derive(Debug)]
396struct SendStreamReceiveCtx {
397    start: Option<oneshot::Receiver<Result<(), Status>>>,
398    // cancelled, client_context
399    send: mpsc::UnboundedReceiver<(bool, BufPtr)>,
400    send_inprogress: bool,
401    shutdown: oneshot::Receiver<()>,
402}
403
404fn stream_ctx_channel() -> (StreamSendCtx, SendStreamReceiveCtx, RecvStreamReceiveCtx) {
405    let (start_tx, start_rx) = oneshot::channel::<Result<(), Status>>();
406    let (send_tx, send_rx) = mpsc::unbounded();
407    let (shutdown_tx, shutdown_rx) = oneshot::channel();
408    let (receive_tx, receive_rx) = mpsc::unbounded();
409    (
410        StreamSendCtx {
411            start: Some(start_tx),
412            send: Some(send_tx),
413            shutdown: Some(shutdown_tx),
414            receive: Some(receive_tx),
415        },
416        SendStreamReceiveCtx {
417            start: Some(start_rx),
418            send: send_rx,
419            send_inprogress: false,
420            shutdown: shutdown_rx,
421        },
422        RecvStreamReceiveCtx {
423            receive: receive_rx,
424        },
425    )
426}
427
428#[cfg_attr(
429    feature = "tracing",
430    tracing::instrument(skip(ctx), level = "trace", ret)
431)]
432fn stream_callback(ctx: &mut StreamSendCtx, ev: StreamEvent) -> Result<(), Status> {
433    match ev {
434        StreamEvent::StartComplete { status, .. } => {
435            let tx = ctx.start.take().unwrap();
436            if status.is_ok() {
437                tx.send(Ok(())).expect("cannot send");
438            } else {
439                tx.send(Err(status)).expect("cannot send")
440            }
441        }
442        StreamEvent::SendComplete {
443            cancelled,
444            client_context,
445        } => {
446            if let Some(send) = ctx.send.as_ref() {
447                send.unbounded_send((cancelled, BufPtr(client_context)))
448                    .expect("cannot send");
449            } else {
450                debug_assert!(false, "mem leak");
451            }
452        }
453        StreamEvent::Receive { buffers, flags, .. } => {
454            if let Some(receive) = ctx.receive.as_ref() {
455                let mut b = BytesMut::new();
456                for br in buffers {
457                    // skip empty buffs.
458                    if !br.as_bytes().is_empty() {
459                        b.put_slice(br.as_bytes());
460                    }
461                }
462                let b = b.freeze();
463                if !b.is_empty() {
464                    receive.unbounded_send(b).expect("cannot send");
465                } else {
466                    // zero buff can happen. so drop the receiver.
467                    ctx.receive.take();
468                }
469            }
470            if flags.contains(ReceiveFlags::FIN) {
471                // close
472                ctx.receive.take();
473            }
474        }
475        StreamEvent::SendShutdownComplete { graceful: _ } => {
476            // Peer acknowledged shutdown.
477            if let Some(shutdown) = ctx.shutdown.take() {
478                shutdown.send(()).expect("cannot send");
479            }
480        }
481        StreamEvent::ShutdownComplete { .. } => {
482            // close all channels
483            ctx.receive.take();
484            ctx.send.take();
485            ctx.shutdown.take();
486            ctx.start.take();
487        }
488        _ => {}
489    }
490    Ok(())
491}
492
493impl H3Stream {
494    /// attach to accepted stream
495    pub(crate) fn attach(stream: msquic::Stream) -> Self {
496        let (mut ctx, rtx, rrtx) = stream_ctx_channel();
497        let handler = move |_: StreamRef, ev: StreamEvent| stream_callback(&mut ctx, ev);
498
499        stream.set_callback_handler(handler);
500        let s = Arc::new(stream);
501        Self {
502            send: H3SendStream {
503                stream: s.clone(),
504                sctx: rtx,
505            },
506            recv: H3RecvStream {
507                stream: s,
508                rctx: rrtx,
509            },
510        }
511    }
512
513    #[cfg_attr(
514        feature = "tracing",
515        tracing::instrument(skip_all, level = "trace", err, ret)
516    )]
517    fn open_and_start(conn: &msquic::Connection, uni: bool) -> Result<Self, Status> {
518        let (mut ctx, rtx, rrtx) = stream_ctx_channel();
519        let handler = move |_: StreamRef, ev: StreamEvent| stream_callback(&mut ctx, ev);
520
521        let flag = match uni {
522            true => StreamOpenFlags::UNIDIRECTIONAL,
523            false => StreamOpenFlags::NONE,
524        };
525
526        let s = msquic::Stream::open(conn, flag, handler)?;
527        s.start(StreamStartFlags::NONE)?;
528        let s = Arc::new(s);
529        Ok(Self {
530            send: H3SendStream {
531                stream: s.clone(),
532                sctx: rtx,
533            },
534            recv: H3RecvStream {
535                stream: s,
536                rctx: rrtx,
537            },
538        })
539    }
540}
541
542impl<B: Buf> SendStream<B> for H3SendStream {
543    // Seems like poll_ready is called after send_data is called.
544    // To ensure data is sent.
545    #[cfg_attr(
546        feature = "tracing",
547        tracing::instrument(skip_all, level = "trace", ret)
548    )]
549    fn poll_ready(
550        &mut self,
551        cx: &mut std::task::Context<'_>,
552    ) -> std::task::Poll<Result<(), StreamErrorIncoming>> {
553        if !self.sctx.send_inprogress {
554            // no send is current so ready to get more.
555            return std::task::Poll::Ready(Ok(()));
556        }
557        match ready!(self.sctx.send.poll_next_unpin(cx)) {
558            Some((cancelled, ptr)) => {
559                self.sctx.send_inprogress = false;
560                // reattach buff
561                let _: H3Buff<h3::quic::WriteBuf<B>> =
562                    unsafe { H3Buff::from_raw(ptr.0 as *mut c_void) };
563                match cancelled {
564                    true => std::task::Poll::Ready(Err(StreamErrorIncoming::Unknown(
565                        Status::from(StatusCode::QUIC_STATUS_ABORTED).into(),
566                    ))),
567                    false => std::task::Poll::Ready(Ok(())),
568                }
569            }
570            // closed.
571            None => std::task::Poll::Ready(Err(StreamErrorIncoming::Unknown(
572                Status::from(StatusCode::QUIC_STATUS_ABORTED).into(),
573            ))),
574        }
575    }
576
577    #[cfg_attr(
578        feature = "tracing",
579        tracing::instrument(skip_all, level = "trace", ret, err)
580    )]
581    fn send_data<T: Into<h3::quic::WriteBuf<B>>>(
582        &mut self,
583        data: T,
584    ) -> Result<(), StreamErrorIncoming> {
585        if self.sctx.send_inprogress {
586            panic!("send while send is in progress.");
587        }
588        let data: h3::quic::WriteBuf<B> = data.into();
589        let buff = H3Buff::new(data);
590        let (buff_ref, ptr) = unsafe { buff.into_raw() };
591        unsafe { self.stream.send(buff_ref, SendFlags::NONE, ptr) }
592            .inspect_err(|_| {
593                // reattach buff
594                let _: H3Buff<h3::quic::WriteBuf<B>> = unsafe { H3Buff::from_raw(ptr) };
595            })
596            .map_err(|e| StreamErrorIncoming::Unknown(e.into()))?;
597        self.sctx.send_inprogress = true;
598        Ok(())
599    }
600
601    // Send FIN signal to peer.
602    #[cfg_attr(
603        feature = "tracing",
604        tracing::instrument(skip_all, level = "trace", ret)
605    )]
606    fn poll_finish(
607        &mut self,
608        cx: &mut std::task::Context<'_>,
609    ) -> std::task::Poll<Result<(), StreamErrorIncoming>> {
610        // Graceful sends a Fin to peer.
611        if let Err(e) = self.stream.shutdown(StreamShutdownFlags::GRACEFUL, 0) {
612            return std::task::Poll::Ready(Err(StreamErrorIncoming::Unknown(e.into())));
613        }
614        // poll the ctx
615        let rx = &mut self.sctx.shutdown;
616        let p = Pin::new(rx);
617        // if backend is closed return error.
618        let res = ready!(std::future::Future::poll(p, cx)).map_err(|_| {
619            StreamErrorIncoming::Unknown(Status::from(StatusCode::QUIC_STATUS_ABORTED).into())
620        });
621        std::task::Poll::Ready(res)
622    }
623
624    #[cfg_attr(
625        feature = "tracing",
626        tracing::instrument(skip_all, level = "trace", ret)
627    )]
628    fn reset(&mut self, _reset_code: u64) {
629        panic!("reset not supported")
630    }
631
632    fn send_id(&self) -> h3::quic::StreamId {
633        get_id(&self.stream)
634    }
635}
636
637fn get_id(s: &msquic::Stream) -> h3::quic::StreamId {
638    let raw_id = unsafe {
639        msquic::Api::get_param_auto::<u64>(s.as_raw(), msquic::ffi::QUIC_PARAM_STREAM_ID)
640    }
641    .unwrap();
642    raw_id.try_into().expect("cannot parse id")
643}
644
645impl RecvStream for H3RecvStream {
646    type Buf = Bytes;
647
648    #[cfg_attr(
649        feature = "tracing",
650        tracing::instrument(skip_all, level = "trace", ret)
651    )]
652    fn poll_data(
653        &mut self,
654        cx: &mut std::task::Context<'_>,
655    ) -> std::task::Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
656        let res = ready!(self.rctx.receive.poll_next_unpin(cx));
657        std::task::Poll::Ready(Ok(res))
658    }
659
660    /// Stop accepting data. Discard unread data, notify peer to not send.
661    #[cfg_attr(
662        feature = "tracing",
663        tracing::instrument(skip_all, level = "trace", ret)
664    )]
665    fn stop_sending(&mut self, error_code: u64) {
666        // Close the send path.
667        let _ = self
668            .stream
669            .shutdown(StreamShutdownFlags::ABORT_RECEIVE, error_code);
670    }
671
672    fn recv_id(&self) -> h3::quic::StreamId {
673        get_id(&self.stream)
674    }
675}
676
677// bidi stream
678
679impl<B: Buf> SendStream<B> for H3Stream {
680    fn poll_ready(
681        &mut self,
682        cx: &mut std::task::Context<'_>,
683    ) -> std::task::Poll<Result<(), StreamErrorIncoming>> {
684        SendStream::<B>::poll_ready(&mut self.send, cx)
685    }
686
687    fn send_data<T: Into<h3::quic::WriteBuf<B>>>(
688        &mut self,
689        data: T,
690    ) -> Result<(), StreamErrorIncoming> {
691        SendStream::<B>::send_data(&mut self.send, data)
692    }
693
694    fn poll_finish(
695        &mut self,
696        cx: &mut std::task::Context<'_>,
697    ) -> std::task::Poll<Result<(), StreamErrorIncoming>> {
698        SendStream::<B>::poll_finish(&mut self.send, cx)
699    }
700
701    fn reset(&mut self, reset_code: u64) {
702        SendStream::<B>::reset(&mut self.send, reset_code);
703    }
704
705    fn send_id(&self) -> h3::quic::StreamId {
706        SendStream::<B>::send_id(&self.send)
707    }
708}
709
710impl RecvStream for H3Stream {
711    type Buf = Bytes;
712
713    fn poll_data(
714        &mut self,
715        cx: &mut std::task::Context<'_>,
716    ) -> std::task::Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
717        RecvStream::poll_data(&mut self.recv, cx)
718    }
719
720    fn stop_sending(&mut self, error_code: u64) {
721        RecvStream::stop_sending(&mut self.recv, error_code)
722    }
723
724    fn recv_id(&self) -> h3::quic::StreamId {
725        RecvStream::recv_id(&self.recv)
726    }
727}
728
729impl<B: Buf> BidiStream<B> for H3Stream {
730    type SendStream = H3SendStream;
731
732    type RecvStream = H3RecvStream;
733
734    #[cfg_attr(
735        feature = "tracing",
736        tracing::instrument(skip_all, level = "trace", ret)
737    )]
738    fn split(self) -> (Self::SendStream, Self::RecvStream) {
739        (self.send, self.recv)
740    }
741}
742
743#[cfg(test)]
744mod test {
745    use std::sync::Arc;
746
747    use bytes::Buf;
748    use h3::error::ConnectionError;
749    use http::Uri;
750    use msquic::{
751        BufferRef, Configuration, CredentialConfig, CredentialFlags, Registration,
752        RegistrationConfig, Settings,
753    };
754
755    use crate::Connection;
756
757    pub mod util {
758        use msquic::Credential;
759        // used for debugging
760        pub const DEVEL_TRACE_LEVEL: tracing::Level = tracing::Level::TRACE;
761
762        pub fn try_setup_tracing() {
763            let _ = tracing_subscriber::fmt()
764                .with_max_level(DEVEL_TRACE_LEVEL)
765                .try_init();
766        }
767
768        /// Use pwsh to get the test cert hash
769        #[cfg(target_os = "windows")]
770        pub fn get_test_cred() -> Credential {
771            use msquic::CertificateHash;
772
773            let output = std::process::Command::new("pwsh.exe")
774                .args(["-Command", "Get-ChildItem Cert:\\CurrentUser\\My | Where-Object -Property FriendlyName -EQ -Value MsQuic-Test | Select-Object -ExpandProperty Thumbprint -First 1"]).
775                output().expect("Failed to execute command");
776            assert!(output.status.success());
777            let mut s = String::from_utf8(output.stdout).unwrap();
778            if s.ends_with('\n') {
779                s.pop();
780                if s.ends_with('\r') {
781                    s.pop();
782                }
783            };
784            Credential::CertificateHash(CertificateHash::from_str(&s).unwrap())
785        }
786
787        /// Generate a test cert if not present using openssl cli.
788        #[cfg(not(target_os = "windows"))]
789        pub fn get_test_cred() -> Credential {
790            use msquic::CertificateFile;
791
792            let cert_dir = std::env::temp_dir().join("msquic_h3_test_rs");
793            let key = "key.pem";
794            let cert = "cert.pem";
795            let key_path = cert_dir.join(key);
796            let cert_path = cert_dir.join(cert);
797            if !key_path.exists() || !cert_path.exists() {
798                // remove the dir
799                let _ = std::fs::remove_dir_all(&cert_dir);
800                std::fs::create_dir_all(&cert_dir).expect("cannot create cert dir");
801                // generate test cert using openssl cli
802                let output = std::process::Command::new("openssl")
803                    .args([
804                        "req",
805                        "-x509",
806                        "-newkey",
807                        "rsa:4096",
808                        "-keyout",
809                        "key.pem",
810                        "-out",
811                        "cert.pem",
812                        "-sha256",
813                        "-days",
814                        "3650",
815                        "-nodes",
816                        "-subj",
817                        "/CN=localhost",
818                    ])
819                    .current_dir(cert_dir)
820                    .stderr(std::process::Stdio::inherit())
821                    .stdout(std::process::Stdio::inherit())
822                    .output()
823                    .expect("cannot generate cert");
824                if !output.status.success() {
825                    panic!("generate cert failed");
826                }
827            }
828            Credential::CertificateFile(CertificateFile::new(
829                key_path.display().to_string(),
830                cert_path.display().to_string(),
831            ))
832        }
833    }
834
835    pub(crate) async fn send_get_request(uri: Uri) {
836        let app_name = String::from("testapp");
837        let config = RegistrationConfig::new().set_app_name(app_name);
838        let reg = Arc::new(Registration::new(&config).unwrap());
839
840        let alpn = BufferRef::from("h3");
841        // create an client
842        // open client
843        let client_settings = Settings::new().set_IdleTimeoutMs(2000);
844        let client_config = Configuration::open(&reg, &[alpn], Some(&client_settings)).unwrap();
845        {
846            let cred_config = CredentialConfig::new_client()
847                .set_credential_flags(CredentialFlags::NO_CERTIFICATE_VALIDATION);
848            client_config.load_credential(&cred_config).unwrap();
849        }
850
851        tracing::info!("client conn open and start");
852        let conn = Connection::connect(
853            &reg,
854            &client_config,
855            uri.host().unwrap(),
856            uri.port_u16().unwrap(),
857        )
858        .await
859        .unwrap();
860
861        tracing::info!("client create h3 client");
862        let (mut driver, mut send_request) = h3::client::new(conn).await.unwrap();
863
864        tracing::info!("client start driver");
865        let drive = async move {
866            Err::<(), ConnectionError>(futures::future::poll_fn(|cx| driver.poll_close(cx)).await)
867        };
868
869        // tokio::time::sleep(std::time::Duration::from_millis(3)).await;
870        // In the following block, we want to take ownership of `send_request`:
871        // the connection will be closed only when all `SendRequest`s instances
872        // are dropped.
873        //
874        //             So we "move" it.
875        //                  vvvv
876        let request = async move {
877            tracing::info!("sending request ...");
878
879            let req = http::Request::builder().uri(uri).body(())?;
880
881            // sending request results in a bidirectional stream,
882            // which is also used for receiving response
883            let mut stream = send_request.send_request(req).await?;
884
885            // finish on the sending side
886            stream.finish().await?;
887
888            tracing::info!("receiving response ...");
889
890            let resp = stream.recv_response().await?;
891
892            tracing::info!("response: {:?} {}", resp.version(), resp.status());
893            tracing::info!("headers: {:#?}", resp.headers());
894
895            // `recv_data()` must be called after `recv_response()` for
896            // receiving potential response body
897            let mut data = vec![];
898            while let Some(mut chunk) = stream.recv_data().await? {
899                // let mut out = tokio::io::stdout();
900                // tokio::io::AsyncWriteExt::write_all_buf(&mut out, &mut chunk).await?;
901                // tokio::io::AsyncWriteExt::flush(&mut out).await?;
902                let mut dst = vec![0; chunk.remaining()];
903                chunk.copy_to_slice(&mut dst[..]);
904                data.extend_from_slice(&dst);
905            }
906            let body = String::from_utf8_lossy(&data);
907            tracing::info!("client got body: {}", body);
908            // tokio::time::sleep(std::time::Duration::from_millis(5)).await;
909            Ok::<_, Box<dyn std::error::Error>>(())
910        };
911
912        let (req_res, drive_res) = tokio::join!(request, drive);
913        if let Err(e) = req_res {
914            tracing::error!("req_err {e:?}");
915        }
916        if let Err(e) = drive_res {
917            tracing::error!("drive_res {e:?}");
918        }
919        tracing::info!("client ended success");
920    }
921
922    #[test]
923    fn client_test_apache() {
924        util::try_setup_tracing();
925        // This does not work (cloudflare servers):
926        // let uri = http::Uri::from_static("https://quic.tech:8443/");
927        // let uri = http::Uri::from_static("https://cloudflare-quic.com:443/");
928
929        // These works
930        let uri = http::Uri::from_static("https://h2o.examp1e.net:443");
931        // let uri = http::Uri::from_static("https://docs.trafficserver.apache.org:443/");
932        // use tokio
933        tokio::runtime::Builder::new_current_thread()
934            .enable_time()
935            .build()
936            .unwrap()
937            .block_on(send_get_request(uri));
938    }
939}