msquic_h3/
lib.rs

1use std::{ffi::c_void, pin::Pin, sync::Arc};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use futures::{
5    channel::{mpsc, oneshot},
6    ready, StreamExt,
7};
8use h3::quic::{BidiStream, OpenStreams, RecvStream, SendStream};
9use msquic::{
10    Configuration, ConnectionEvent, ConnectionRef, ConnectionShutdownFlags, ReceiveFlags,
11    Registration, SendFlags, Status, StatusCode, StreamEvent, StreamOpenFlags, StreamRef,
12    StreamShutdownFlags, StreamStartFlags,
13};
14
15mod buffer;
16pub use buffer::*;
17mod listener;
18pub use listener::Listener;
19
20/// re-export msquic type
21pub mod msquic {
22    pub use ::msquic::*;
23}
24
25#[derive(Debug)]
26pub struct H3Error {
27    status: Status,
28    error_code: Option<u64>,
29}
30
31impl H3Error {
32    pub fn new(status: Status, ec: Option<u64>) -> Self {
33        Self {
34            status,
35            error_code: ec,
36        }
37    }
38}
39
40impl h3::quic::Error for H3Error {
41    fn is_timeout(&self) -> bool {
42        self.status
43            .try_as_status_code()
44            .unwrap_or(StatusCode::QUIC_STATUS_SUCCESS)
45            == StatusCode::QUIC_STATUS_CONNECTION_TIMEOUT
46    }
47
48    fn err_code(&self) -> Option<u64> {
49        self.error_code
50    }
51}
52
53impl std::error::Error for H3Error {}
54
55impl std::fmt::Display for H3Error {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        write!(f, "{:?}", self)
58    }
59}
60
61#[derive(Debug)]
62pub struct Connection {
63    conn: Arc<msquic::Connection>,
64    ctx: ConnCtxReceiver,
65    opener: StreamOpener,
66}
67
68/// from callback send to fount end.
69#[derive(Debug)]
70struct ConnCtxSender {
71    connected: Option<oneshot::Sender<()>>,
72    bidi: Option<mpsc::UnboundedSender<Option<crate::H3Stream>>>,
73    uni: Option<mpsc::UnboundedSender<Option<crate::H3Stream>>>,
74}
75
76/// front end receive.
77#[derive(Debug)]
78struct ConnCtxReceiver {
79    connected: Option<oneshot::Receiver<()>>,
80    bidi: mpsc::UnboundedReceiver<Option<crate::H3Stream>>,
81    uni: mpsc::UnboundedReceiver<Option<crate::H3Stream>>,
82}
83
84fn conn_ctx_channel() -> (ConnCtxSender, ConnCtxReceiver) {
85    let (conn_tx, conn_rx) = oneshot::channel();
86    let (bidi_tx, bidi_rx) = mpsc::unbounded();
87    let (uni_tx, uni_rx) = mpsc::unbounded();
88    (
89        ConnCtxSender {
90            connected: Some(conn_tx),
91            bidi: Some(bidi_tx),
92            uni: Some(uni_tx),
93        },
94        ConnCtxReceiver {
95            connected: Some(conn_rx),
96            bidi: bidi_rx,
97            uni: uni_rx,
98        },
99    )
100}
101
102#[cfg_attr(
103    feature = "tracing",
104    tracing::instrument(skip(ctx), level = "trace", ret, err)
105)]
106fn connection_callback(ctx: &mut ConnCtxSender, ev: msquic::ConnectionEvent) -> Result<(), Status> {
107    match ev {
108        ConnectionEvent::Connected { .. } => {
109            ctx.connected.take().unwrap().send(()).unwrap();
110        }
111        ConnectionEvent::PeerStreamStarted { stream, flags } => {
112            // TODO: need to set callback
113            let s = unsafe { msquic::Stream::from_raw(stream.as_raw()) };
114            if flags.contains(StreamOpenFlags::UNIDIRECTIONAL) {
115                if let Some(uni) = ctx.uni.as_ref() {
116                    uni.unbounded_send(Some(crate::H3Stream::attach(s)))
117                        .expect("cannot send");
118                }
119            } else if let Some(bidi) = ctx.bidi.as_ref() {
120                bidi.unbounded_send(Some(crate::H3Stream::attach(s)))
121                    .expect("cannot send");
122            }
123        }
124        ConnectionEvent::ShutdownComplete { .. } => {
125            // clear all channels.
126            ctx.connected.take();
127            ctx.uni.take();
128            ctx.bidi.take();
129        }
130        _ => {}
131    }
132    Ok(())
133}
134
135impl Connection {
136    /// Connects to the server
137    pub async fn connect(
138        reg: &Registration,
139        config: &Configuration,
140        server_name: &str,
141        server_port: u16,
142    ) -> Result<Self, Status> {
143        let (mut ctx, mut crx) = conn_ctx_channel();
144        let handler =
145            move |_: ConnectionRef, ev: ConnectionEvent| connection_callback(&mut ctx, ev);
146        let conn = msquic::Connection::open(reg, handler)?;
147        conn.start(config, server_name, server_port)?;
148        // wait for connection.
149        crx.connected
150            .take()
151            .unwrap()
152            .await
153            .map_err(|_| Status::new(StatusCode::QUIC_STATUS_ABORTED))?;
154
155        let conn = Arc::new(conn);
156
157        let opener = StreamOpener::new(conn.clone());
158
159        Ok(Self {
160            conn,
161            ctx: crx,
162            opener,
163        })
164    }
165
166    /// attach to an accepted connection
167    pub(crate) fn attach(inner: msquic::Connection) -> Self {
168        let (mut ctx, crx) = conn_ctx_channel();
169        let handler =
170            move |_: ConnectionRef, ev: ConnectionEvent| connection_callback(&mut ctx, ev);
171        inner.set_callback_handler(handler);
172        let conn = Arc::new(inner);
173
174        let opener = StreamOpener::new(conn.clone());
175
176        Self {
177            conn,
178            ctx: crx,
179            opener,
180        }
181    }
182}
183
184/// responsible for open streams on a connection.
185#[derive(Debug)]
186pub struct StreamOpener {
187    conn: Arc<msquic::Connection>,
188    bidi_temp: Option<H3Stream>,
189    uni_temp: Option<H3Stream>,
190}
191
192impl Clone for StreamOpener {
193    fn clone(&self) -> Self {
194        Self {
195            conn: self.conn.clone(),
196            bidi_temp: None,
197            uni_temp: None,
198        }
199    }
200}
201
202/// Server accept streams
203impl<B: Buf> h3::quic::Connection<B> for Connection {
204    type RecvStream = H3RecvStream;
205
206    type OpenStreams = StreamOpener;
207
208    type AcceptError = H3Error;
209
210    #[cfg_attr(
211        feature = "tracing",
212        tracing::instrument(skip_all, level = "trace", ret)
213    )]
214    fn poll_accept_recv(
215        &mut self,
216        cx: &mut std::task::Context<'_>,
217    ) -> std::task::Poll<Result<Option<Self::RecvStream>, Self::AcceptError>> {
218        let s = ready!(self.ctx.uni.poll_next_unpin(cx)).unwrap_or(None);
219        // wrap for h3 type. Drop the send stream part
220        std::task::Poll::Ready(Ok(s.map(|s| s.recv)))
221    }
222
223    #[cfg_attr(
224        feature = "tracing",
225        tracing::instrument(skip_all, level = "trace", ret)
226    )]
227    fn poll_accept_bidi(
228        &mut self,
229        cx: &mut std::task::Context<'_>,
230    ) -> std::task::Poll<Result<Option<Self::BidiStream>, Self::AcceptError>> {
231        let s = ready!(self.ctx.bidi.poll_next_unpin(cx)).unwrap_or(None);
232        // wrap for h3 type
233        std::task::Poll::Ready(Ok(s))
234    }
235
236    #[cfg_attr(
237        feature = "tracing",
238        tracing::instrument(skip_all, level = "trace", ret)
239    )]
240    fn opener(&self) -> Self::OpenStreams {
241        StreamOpener::new(self.conn.clone())
242    }
243}
244
245/// Create new streams from connection.
246impl<B: Buf> OpenStreams<B> for StreamOpener {
247    type BidiStream = H3Stream;
248
249    type SendStream = H3SendStream;
250
251    type OpenError = H3Error;
252
253    #[cfg_attr(
254        feature = "tracing",
255        tracing::instrument(skip_all, level = "trace", ret)
256    )]
257    fn poll_open_bidi(
258        &mut self,
259        cx: &mut std::task::Context<'_>,
260    ) -> std::task::Poll<Result<Self::BidiStream, Self::OpenError>> {
261        Self::poll_open_inner(&self.conn, false, &mut self.bidi_temp, cx)
262    }
263
264    #[cfg_attr(
265        feature = "tracing",
266        tracing::instrument(skip_all, level = "trace", ret)
267    )]
268    fn poll_open_send(
269        &mut self,
270        cx: &mut std::task::Context<'_>,
271    ) -> std::task::Poll<Result<Self::SendStream, Self::OpenError>> {
272        let res = ready!(Self::poll_open_inner(
273            &self.conn,
274            true,
275            &mut self.uni_temp,
276            cx
277        ));
278        // get the send part.
279        std::task::Poll::Ready(res.map(|s| s.send))
280    }
281
282    #[cfg_attr(
283        feature = "tracing",
284        tracing::instrument(skip_all, level = "trace", ret)
285    )]
286    fn close(&mut self, code: h3::error::Code, _reason: &[u8]) {
287        self.conn
288            .shutdown(ConnectionShutdownFlags::NONE, code.value());
289    }
290}
291
292impl StreamOpener {
293    fn new(conn: Arc<msquic::Connection>) -> Self {
294        Self {
295            conn,
296            bidi_temp: None,
297            uni_temp: None,
298        }
299    }
300
301    /// open a stream and poll it in the holder.
302    fn poll_open_inner(
303        conn: &Arc<msquic::Connection>,
304        uni: bool,
305        stream_holder: &mut Option<H3Stream>,
306        cx: &mut std::task::Context<'_>,
307    ) -> std::task::Poll<Result<H3Stream, H3Error>> {
308        if stream_holder.is_none() {
309            // create new stream
310            let s = match H3Stream::open_and_start(conn, uni) {
311                Ok(s) => s,
312                Err(e) => return std::task::Poll::Ready(Err(H3Error::new(e, None))),
313            };
314            *stream_holder = Some(s);
315        }
316
317        // poll stream start.
318        let res = {
319            let s = stream_holder.as_mut().unwrap();
320            let rx = s.send.sctx.start.as_mut().unwrap();
321            let p = Pin::new(rx);
322            ready!(std::future::Future::poll(p, cx))
323        };
324        // current stream is either ready or error. So ready to be returned or dropped.
325        let s = stream_holder.take().unwrap();
326        let res = res
327            .expect("cannot receive")
328            .map(|_| s)
329            .map_err(|e| H3Error::new(e, None));
330        std::task::Poll::Ready(res)
331    }
332}
333
334/// bypass for StreamOpener
335impl<B: Buf> OpenStreams<B> for Connection {
336    type BidiStream = H3Stream;
337
338    type SendStream = H3SendStream;
339
340    type OpenError = H3Error;
341
342    fn poll_open_bidi(
343        &mut self,
344        cx: &mut std::task::Context<'_>,
345    ) -> std::task::Poll<Result<Self::BidiStream, Self::OpenError>> {
346        OpenStreams::<B>::poll_open_bidi(&mut self.opener, cx)
347    }
348
349    fn poll_open_send(
350        &mut self,
351        cx: &mut std::task::Context<'_>,
352    ) -> std::task::Poll<Result<Self::SendStream, Self::OpenError>> {
353        OpenStreams::<B>::poll_open_send(&mut self.opener, cx)
354    }
355
356    fn close(&mut self, code: h3::error::Code, reason: &[u8]) {
357        OpenStreams::<B>::close(&mut self.opener, code, reason)
358    }
359}
360
361/// Msquic Stream.
362#[derive(Debug)]
363pub struct H3Stream {
364    send: H3SendStream,
365    recv: H3RecvStream,
366}
367#[derive(Debug)]
368pub struct H3SendStream {
369    stream: Arc<msquic::Stream>,
370    sctx: SendStreamReceiveCtx,
371}
372#[derive(Debug)]
373pub struct H3RecvStream {
374    stream: Arc<msquic::Stream>,
375    rctx: RecvStreamReceiveCtx,
376}
377
378struct BufPtr(*const c_void);
379unsafe impl Send for BufPtr {}
380unsafe impl Sync for BufPtr {}
381
382struct StreamSendCtx {
383    start: Option<oneshot::Sender<Result<(), Status>>>,
384    // cancelled, client_context
385    send: Option<mpsc::UnboundedSender<(bool, BufPtr)>>,
386    shutdown: Option<oneshot::Sender<()>>,
387    receive: Option<mpsc::UnboundedSender<Bytes>>,
388}
389
390/// ctx for receiving data on frontend.
391#[derive(Debug)]
392struct RecvStreamReceiveCtx {
393    receive: mpsc::UnboundedReceiver<Bytes>,
394}
395
396/// ctx for sending data on frontend.
397#[derive(Debug)]
398struct SendStreamReceiveCtx {
399    start: Option<oneshot::Receiver<Result<(), Status>>>,
400    // cancelled, client_context
401    send: mpsc::UnboundedReceiver<(bool, BufPtr)>,
402    send_inprogress: bool,
403    shutdown: oneshot::Receiver<()>,
404}
405
406fn stream_ctx_channel() -> (StreamSendCtx, SendStreamReceiveCtx, RecvStreamReceiveCtx) {
407    let (start_tx, start_rx) = oneshot::channel::<Result<(), Status>>();
408    let (send_tx, send_rx) = mpsc::unbounded();
409    let (shutdown_tx, shutdown_rx) = oneshot::channel();
410    let (receive_tx, receive_rx) = mpsc::unbounded();
411    (
412        StreamSendCtx {
413            start: Some(start_tx),
414            send: Some(send_tx),
415            shutdown: Some(shutdown_tx),
416            receive: Some(receive_tx),
417        },
418        SendStreamReceiveCtx {
419            start: Some(start_rx),
420            send: send_rx,
421            send_inprogress: false,
422            shutdown: shutdown_rx,
423        },
424        RecvStreamReceiveCtx {
425            receive: receive_rx,
426        },
427    )
428}
429
430#[cfg_attr(
431    feature = "tracing",
432    tracing::instrument(skip(ctx), level = "trace", ret)
433)]
434fn stream_callback(ctx: &mut StreamSendCtx, ev: StreamEvent) -> Result<(), Status> {
435    match ev {
436        StreamEvent::StartComplete { status, .. } => {
437            let tx = ctx.start.take().unwrap();
438            if status.is_ok() {
439                tx.send(Ok(())).expect("cannot send");
440            } else {
441                tx.send(Err(status)).expect("cannot send")
442            }
443        }
444        StreamEvent::SendComplete {
445            cancelled,
446            client_context,
447        } => {
448            if let Some(send) = ctx.send.as_ref() {
449                send.unbounded_send((cancelled, BufPtr(client_context)))
450                    .expect("cannot send");
451            } else {
452                debug_assert!(false, "mem leak");
453            }
454        }
455        StreamEvent::Receive { buffers, flags, .. } => {
456            if let Some(receive) = ctx.receive.as_ref() {
457                let mut b = BytesMut::new();
458                for br in buffers {
459                    // skip empty buffs.
460                    if !br.as_bytes().is_empty() {
461                        b.put_slice(br.as_bytes());
462                    }
463                }
464                let b = b.freeze();
465                if !b.is_empty() {
466                    receive.unbounded_send(b).expect("cannot send");
467                } else {
468                    // zero buff can happen. so drop the receiver.
469                    ctx.receive.take();
470                }
471            }
472            if flags.contains(ReceiveFlags::FIN) {
473                // close
474                ctx.receive.take();
475            }
476        }
477        StreamEvent::SendShutdownComplete { graceful: _ } => {
478            // Peer acknowledged shutdown.
479            if let Some(shutdown) = ctx.shutdown.take() {
480                shutdown.send(()).expect("cannot send");
481            }
482        }
483        StreamEvent::ShutdownComplete { .. } => {
484            // close all channels
485            ctx.receive.take();
486            ctx.send.take();
487            ctx.shutdown.take();
488            ctx.start.take();
489        }
490        _ => {}
491    }
492    Ok(())
493}
494
495impl H3Stream {
496    /// attach to accepted stream
497    pub(crate) fn attach(stream: msquic::Stream) -> Self {
498        let (mut ctx, rtx, rrtx) = stream_ctx_channel();
499        let handler = move |_: StreamRef, ev: StreamEvent| stream_callback(&mut ctx, ev);
500
501        stream.set_callback_handler(handler);
502        let s = Arc::new(stream);
503        Self {
504            send: H3SendStream {
505                stream: s.clone(),
506                sctx: rtx,
507            },
508            recv: H3RecvStream {
509                stream: s,
510                rctx: rrtx,
511            },
512        }
513    }
514
515    #[cfg_attr(
516        feature = "tracing",
517        tracing::instrument(skip_all, level = "trace", err, ret)
518    )]
519    fn open_and_start(conn: &msquic::Connection, uni: bool) -> Result<Self, Status> {
520        let (mut ctx, rtx, rrtx) = stream_ctx_channel();
521        let handler = move |_: StreamRef, ev: StreamEvent| stream_callback(&mut ctx, ev);
522
523        let flag = match uni {
524            true => StreamOpenFlags::UNIDIRECTIONAL,
525            false => StreamOpenFlags::NONE,
526        };
527
528        let s = msquic::Stream::open(conn, flag, handler)?;
529        s.start(StreamStartFlags::NONE)?;
530        let s = Arc::new(s);
531        Ok(Self {
532            send: H3SendStream {
533                stream: s.clone(),
534                sctx: rtx,
535            },
536            recv: H3RecvStream {
537                stream: s,
538                rctx: rrtx,
539            },
540        })
541    }
542}
543
544impl<B: Buf> SendStream<B> for H3SendStream {
545    type Error = H3Error;
546
547    // Seems like poll_ready is called after send_data is called.
548    // To ensure data is sent.
549    #[cfg_attr(
550        feature = "tracing",
551        tracing::instrument(skip_all, level = "trace", ret)
552    )]
553    fn poll_ready(
554        &mut self,
555        cx: &mut std::task::Context<'_>,
556    ) -> std::task::Poll<Result<(), Self::Error>> {
557        if !self.sctx.send_inprogress {
558            // no send is current so ready to get more.
559            return std::task::Poll::Ready(Ok(()));
560        }
561        match ready!(self.sctx.send.poll_next_unpin(cx)) {
562            Some((cancelled, ptr)) => {
563                self.sctx.send_inprogress = false;
564                // reattach buff
565                let _: H3Buff<h3::quic::WriteBuf<B>> =
566                    unsafe { H3Buff::from_raw(ptr.0 as *mut c_void) };
567                match cancelled {
568                    true => std::task::Poll::Ready(Err(H3Error::new(
569                        Status::from(StatusCode::QUIC_STATUS_ABORTED),
570                        None,
571                    ))),
572                    false => std::task::Poll::Ready(Ok(())),
573                }
574            }
575            // closed.
576            None => std::task::Poll::Ready(Err(H3Error::new(
577                Status::from(StatusCode::QUIC_STATUS_ABORTED),
578                None,
579            ))),
580        }
581    }
582
583    #[cfg_attr(
584        feature = "tracing",
585        tracing::instrument(skip_all, level = "trace", ret, err)
586    )]
587    fn send_data<T: Into<h3::quic::WriteBuf<B>>>(&mut self, data: T) -> Result<(), Self::Error> {
588        if self.sctx.send_inprogress {
589            panic!("send while send is in progress.");
590        }
591        let data: h3::quic::WriteBuf<B> = data.into();
592        let buff = H3Buff::new(data);
593        let (buff_ref, ptr) = unsafe { buff.into_raw() };
594        unsafe { self.stream.send(buff_ref, SendFlags::NONE, ptr) }
595            .inspect_err(|_| {
596                // reattach buff
597                let _: H3Buff<h3::quic::WriteBuf<B>> = unsafe { H3Buff::from_raw(ptr) };
598            })
599            .map_err(|e| H3Error::new(e, None))?;
600        self.sctx.send_inprogress = true;
601        Ok(())
602    }
603
604    // Send FIN signal to peer.
605    #[cfg_attr(
606        feature = "tracing",
607        tracing::instrument(skip_all, level = "trace", ret)
608    )]
609    fn poll_finish(
610        &mut self,
611        cx: &mut std::task::Context<'_>,
612    ) -> std::task::Poll<Result<(), Self::Error>> {
613        // Graceful sends a Fin to peer.
614        if let Err(e) = self.stream.shutdown(StreamShutdownFlags::GRACEFUL, 0) {
615            return std::task::Poll::Ready(Err(H3Error::new(e, None)));
616        }
617        // poll the ctx
618        let rx = &mut self.sctx.shutdown;
619        let p = Pin::new(rx);
620        // if backend is closed return error.
621        let res = ready!(std::future::Future::poll(p, cx))
622            .map_err(|_| H3Error::new(Status::from(StatusCode::QUIC_STATUS_ABORTED), None));
623        std::task::Poll::Ready(res)
624    }
625
626    #[cfg_attr(
627        feature = "tracing",
628        tracing::instrument(skip_all, level = "trace", ret)
629    )]
630    fn reset(&mut self, _reset_code: u64) {
631        panic!("reset not supported")
632    }
633
634    fn send_id(&self) -> h3::quic::StreamId {
635        get_id(&self.stream)
636    }
637}
638
639fn get_id(s: &msquic::Stream) -> h3::quic::StreamId {
640    let raw_id = unsafe {
641        msquic::Api::get_param_auto::<u64>(s.as_raw(), msquic::ffi::QUIC_PARAM_STREAM_ID)
642    }
643    .unwrap();
644    raw_id.try_into().expect("cannot parse id")
645}
646
647impl RecvStream for H3RecvStream {
648    type Buf = Bytes;
649
650    type Error = H3Error;
651
652    #[cfg_attr(
653        feature = "tracing",
654        tracing::instrument(skip_all, level = "trace", ret)
655    )]
656    fn poll_data(
657        &mut self,
658        cx: &mut std::task::Context<'_>,
659    ) -> std::task::Poll<Result<Option<Self::Buf>, Self::Error>> {
660        let res = ready!(self.rctx.receive.poll_next_unpin(cx));
661        std::task::Poll::Ready(Ok(res))
662    }
663
664    /// Stop accepting data. Discard unread data, notify peer to not send.
665    #[cfg_attr(
666        feature = "tracing",
667        tracing::instrument(skip_all, level = "trace", ret)
668    )]
669    fn stop_sending(&mut self, error_code: u64) {
670        // Close the send path.
671        let _ = self
672            .stream
673            .shutdown(StreamShutdownFlags::ABORT_RECEIVE, error_code);
674    }
675
676    fn recv_id(&self) -> h3::quic::StreamId {
677        get_id(&self.stream)
678    }
679}
680
681// bidi stream
682
683impl<B: Buf> SendStream<B> for H3Stream {
684    type Error = H3Error;
685
686    fn poll_ready(
687        &mut self,
688        cx: &mut std::task::Context<'_>,
689    ) -> std::task::Poll<Result<(), Self::Error>> {
690        SendStream::<B>::poll_ready(&mut self.send, cx)
691    }
692
693    fn send_data<T: Into<h3::quic::WriteBuf<B>>>(&mut self, data: T) -> Result<(), Self::Error> {
694        SendStream::<B>::send_data(&mut self.send, data)
695    }
696
697    fn poll_finish(
698        &mut self,
699        cx: &mut std::task::Context<'_>,
700    ) -> std::task::Poll<Result<(), Self::Error>> {
701        SendStream::<B>::poll_finish(&mut self.send, cx)
702    }
703
704    fn reset(&mut self, reset_code: u64) {
705        SendStream::<B>::reset(&mut self.send, reset_code);
706    }
707
708    fn send_id(&self) -> h3::quic::StreamId {
709        SendStream::<B>::send_id(&self.send)
710    }
711}
712
713impl RecvStream for H3Stream {
714    type Buf = Bytes;
715
716    type Error = H3Error;
717
718    fn poll_data(
719        &mut self,
720        cx: &mut std::task::Context<'_>,
721    ) -> std::task::Poll<Result<Option<Self::Buf>, Self::Error>> {
722        RecvStream::poll_data(&mut self.recv, cx)
723    }
724
725    fn stop_sending(&mut self, error_code: u64) {
726        RecvStream::stop_sending(&mut self.recv, error_code)
727    }
728
729    fn recv_id(&self) -> h3::quic::StreamId {
730        RecvStream::recv_id(&self.recv)
731    }
732}
733
734impl<B: Buf> BidiStream<B> for H3Stream {
735    type SendStream = H3SendStream;
736
737    type RecvStream = H3RecvStream;
738
739    #[cfg_attr(
740        feature = "tracing",
741        tracing::instrument(skip_all, level = "trace", ret)
742    )]
743    fn split(self) -> (Self::SendStream, Self::RecvStream) {
744        (self.send, self.recv)
745    }
746}
747
748#[cfg(test)]
749mod test {
750    use bytes::Buf;
751    use http::Uri;
752    use msquic::{
753        BufferRef, Configuration, CredentialConfig, CredentialFlags, Registration,
754        RegistrationConfig, Settings,
755    };
756
757    use crate::Connection;
758
759    pub mod util {
760        use msquic::Credential;
761        // used for debugging
762        pub const DEVEL_TRACE_LEVEL: tracing::Level = tracing::Level::TRACE;
763
764        pub fn try_setup_tracing() {
765            let _ = tracing_subscriber::fmt()
766                .with_max_level(DEVEL_TRACE_LEVEL)
767                .try_init();
768        }
769
770        /// Use pwsh to get the test cert hash
771        #[cfg(target_os = "windows")]
772        pub fn get_test_cred() -> Credential {
773            use msquic::CertificateHash;
774
775            let output = std::process::Command::new("pwsh.exe")
776                .args(["-Command", "Get-ChildItem Cert:\\CurrentUser\\My | Where-Object -Property FriendlyName -EQ -Value MsQuic-Test | Select-Object -ExpandProperty Thumbprint -First 1"]).
777                output().expect("Failed to execute command");
778            assert!(output.status.success());
779            let mut s = String::from_utf8(output.stdout).unwrap();
780            if s.ends_with('\n') {
781                s.pop();
782                if s.ends_with('\r') {
783                    s.pop();
784                }
785            };
786            Credential::CertificateHash(CertificateHash::from_str(&s).unwrap())
787        }
788
789        /// Generate a test cert if not present using openssl cli.
790        #[cfg(not(target_os = "windows"))]
791        pub fn get_test_cred() -> Credential {
792            use msquic::CertificateFile;
793
794            let cert_dir = std::env::temp_dir().join("msquic_h3_test_rs");
795            let key = "key.pem";
796            let cert = "cert.pem";
797            let key_path = cert_dir.join(key);
798            let cert_path = cert_dir.join(cert);
799            if !key_path.exists() || !cert_path.exists() {
800                // remove the dir
801                let _ = std::fs::remove_dir_all(&cert_dir);
802                std::fs::create_dir_all(&cert_dir).expect("cannot create cert dir");
803                // generate test cert using openssl cli
804                let output = std::process::Command::new("openssl")
805                    .args([
806                        "req",
807                        "-x509",
808                        "-newkey",
809                        "rsa:4096",
810                        "-keyout",
811                        "key.pem",
812                        "-out",
813                        "cert.pem",
814                        "-sha256",
815                        "-days",
816                        "3650",
817                        "-nodes",
818                        "-subj",
819                        "/CN=localhost",
820                    ])
821                    .current_dir(cert_dir)
822                    .stderr(std::process::Stdio::inherit())
823                    .stdout(std::process::Stdio::inherit())
824                    .output()
825                    .expect("cannot generate cert");
826                if !output.status.success() {
827                    panic!("generate cert failed");
828                }
829            }
830            Credential::CertificateFile(CertificateFile::new(
831                key_path.display().to_string(),
832                cert_path.display().to_string(),
833            ))
834        }
835    }
836
837    pub(crate) async fn send_get_request(uri: Uri) {
838        let app_name = String::from("testapp");
839        let config = RegistrationConfig::new().set_app_name(app_name);
840        let reg = Registration::new(&config).unwrap();
841
842        let alpn = BufferRef::from("h3");
843        // create an client
844        // open client
845        let client_settings = Settings::new().set_IdleTimeoutMs(2000);
846        let client_config = Configuration::open(&reg, &[alpn], Some(&client_settings)).unwrap();
847        {
848            let cred_config = CredentialConfig::new_client()
849                .set_credential_flags(CredentialFlags::NO_CERTIFICATE_VALIDATION);
850            client_config.load_credential(&cred_config).unwrap();
851        }
852
853        tracing::info!("client conn open and start");
854        let conn = Connection::connect(
855            &reg,
856            &client_config,
857            uri.host().unwrap(),
858            uri.port_u16().unwrap(),
859        )
860        .await
861        .unwrap();
862
863        tracing::info!("client create h3 client");
864        let (mut driver, mut send_request) = h3::client::new(conn).await.unwrap();
865
866        tracing::info!("client start driver");
867        let drive = async move {
868            std::future::poll_fn(|cx| driver.poll_close(cx)).await?;
869            Ok::<(), Box<dyn std::error::Error>>(())
870        };
871
872        // tokio::time::sleep(std::time::Duration::from_millis(3)).await;
873        // In the following block, we want to take ownership of `send_request`:
874        // the connection will be closed only when all `SendRequest`s instances
875        // are dropped.
876        //
877        //             So we "move" it.
878        //                  vvvv
879        let request = async move {
880            tracing::info!("sending request ...");
881
882            let req = http::Request::builder().uri(uri).body(())?;
883
884            // sending request results in a bidirectional stream,
885            // which is also used for receiving response
886            let mut stream = send_request.send_request(req).await?;
887
888            // finish on the sending side
889            stream.finish().await?;
890
891            tracing::info!("receiving response ...");
892
893            let resp = stream.recv_response().await?;
894
895            tracing::info!("response: {:?} {}", resp.version(), resp.status());
896            tracing::info!("headers: {:#?}", resp.headers());
897
898            // `recv_data()` must be called after `recv_response()` for
899            // receiving potential response body
900            let mut data = vec![];
901            while let Some(mut chunk) = stream.recv_data().await? {
902                // let mut out = tokio::io::stdout();
903                // tokio::io::AsyncWriteExt::write_all_buf(&mut out, &mut chunk).await?;
904                // tokio::io::AsyncWriteExt::flush(&mut out).await?;
905                let mut dst = vec![0; chunk.remaining()];
906                chunk.copy_to_slice(&mut dst[..]);
907                data.extend_from_slice(&dst);
908            }
909            let body = String::from_utf8_lossy(&data);
910            tracing::info!("client got body: {}", body);
911            // tokio::time::sleep(std::time::Duration::from_millis(5)).await;
912            Ok::<_, Box<dyn std::error::Error>>(())
913        };
914
915        let (req_res, drive_res) = tokio::join!(request, drive);
916        if let Err(e) = req_res {
917            tracing::error!("req_err {e:?}");
918        }
919        if let Err(e) = drive_res {
920            tracing::error!("drive_res {e:?}");
921        }
922        tracing::info!("client ended success");
923    }
924
925    #[test]
926    fn client_test_apache() {
927        util::try_setup_tracing();
928        // This does not work (cloudflare servers):
929        // let uri = http::Uri::from_static("https://quic.tech:8443/");
930        // let uri = http::Uri::from_static("https://cloudflare-quic.com:443/");
931
932        // These works
933        let uri = http::Uri::from_static("https://h2o.examp1e.net:443");
934        // let uri = http::Uri::from_static("https://docs.trafficserver.apache.org:443/");
935        // use tokio
936        tokio::runtime::Builder::new_current_thread()
937            .enable_time()
938            .build()
939            .unwrap()
940            .block_on(send_get_request(uri));
941    }
942}