1use std::{ffi::c_void, pin::Pin, sync::Arc};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use futures::{
5 channel::{mpsc, oneshot},
6 ready, StreamExt,
7};
8use h3::quic::{BidiStream, OpenStreams, RecvStream, SendStream};
9use msquic::{
10 Configuration, ConnectionEvent, ConnectionRef, ConnectionShutdownFlags, ReceiveFlags,
11 Registration, SendFlags, Status, StatusCode, StreamEvent, StreamOpenFlags, StreamRef,
12 StreamShutdownFlags, StreamStartFlags,
13};
14
15mod buffer;
16pub use buffer::*;
17mod listener;
18pub use listener::Listener;
19
20pub mod msquic {
22 pub use ::msquic::*;
23}
24
25#[derive(Debug)]
26pub struct H3Error {
27 status: Status,
28 error_code: Option<u64>,
29}
30
31impl H3Error {
32 pub fn new(status: Status, ec: Option<u64>) -> Self {
33 Self {
34 status,
35 error_code: ec,
36 }
37 }
38}
39
40impl h3::quic::Error for H3Error {
41 fn is_timeout(&self) -> bool {
42 self.status
43 .try_as_status_code()
44 .unwrap_or(StatusCode::QUIC_STATUS_SUCCESS)
45 == StatusCode::QUIC_STATUS_CONNECTION_TIMEOUT
46 }
47
48 fn err_code(&self) -> Option<u64> {
49 self.error_code
50 }
51}
52
53impl std::error::Error for H3Error {}
54
55impl std::fmt::Display for H3Error {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 write!(f, "{:?}", self)
58 }
59}
60
61#[derive(Debug)]
62pub struct Connection {
63 conn: Arc<msquic::Connection>,
64 ctx: ConnCtxReceiver,
65 opener: StreamOpener,
66}
67
68#[derive(Debug)]
70struct ConnCtxSender {
71 connected: Option<oneshot::Sender<()>>,
72 bidi: Option<mpsc::UnboundedSender<Option<crate::H3Stream>>>,
73 uni: Option<mpsc::UnboundedSender<Option<crate::H3Stream>>>,
74}
75
76#[derive(Debug)]
78struct ConnCtxReceiver {
79 connected: Option<oneshot::Receiver<()>>,
80 bidi: mpsc::UnboundedReceiver<Option<crate::H3Stream>>,
81 uni: mpsc::UnboundedReceiver<Option<crate::H3Stream>>,
82}
83
84fn conn_ctx_channel() -> (ConnCtxSender, ConnCtxReceiver) {
85 let (conn_tx, conn_rx) = oneshot::channel();
86 let (bidi_tx, bidi_rx) = mpsc::unbounded();
87 let (uni_tx, uni_rx) = mpsc::unbounded();
88 (
89 ConnCtxSender {
90 connected: Some(conn_tx),
91 bidi: Some(bidi_tx),
92 uni: Some(uni_tx),
93 },
94 ConnCtxReceiver {
95 connected: Some(conn_rx),
96 bidi: bidi_rx,
97 uni: uni_rx,
98 },
99 )
100}
101
102#[cfg_attr(
103 feature = "tracing",
104 tracing::instrument(skip(ctx), level = "trace", ret, err)
105)]
106fn connection_callback(ctx: &mut ConnCtxSender, ev: msquic::ConnectionEvent) -> Result<(), Status> {
107 match ev {
108 ConnectionEvent::Connected { .. } => {
109 ctx.connected.take().unwrap().send(()).unwrap();
110 }
111 ConnectionEvent::PeerStreamStarted { stream, flags } => {
112 let s = unsafe { msquic::Stream::from_raw(stream.as_raw()) };
114 if flags.contains(StreamOpenFlags::UNIDIRECTIONAL) {
115 if let Some(uni) = ctx.uni.as_ref() {
116 uni.unbounded_send(Some(crate::H3Stream::attach(s)))
117 .expect("cannot send");
118 }
119 } else if let Some(bidi) = ctx.bidi.as_ref() {
120 bidi.unbounded_send(Some(crate::H3Stream::attach(s)))
121 .expect("cannot send");
122 }
123 }
124 ConnectionEvent::ShutdownComplete { .. } => {
125 ctx.connected.take();
127 ctx.uni.take();
128 ctx.bidi.take();
129 }
130 _ => {}
131 }
132 Ok(())
133}
134
135impl Connection {
136 pub async fn connect(
138 reg: &Registration,
139 config: &Configuration,
140 server_name: &str,
141 server_port: u16,
142 ) -> Result<Self, Status> {
143 let (mut ctx, mut crx) = conn_ctx_channel();
144 let handler =
145 move |_: ConnectionRef, ev: ConnectionEvent| connection_callback(&mut ctx, ev);
146 let conn = msquic::Connection::open(reg, handler)?;
147 conn.start(config, server_name, server_port)?;
148 crx.connected
150 .take()
151 .unwrap()
152 .await
153 .map_err(|_| Status::new(StatusCode::QUIC_STATUS_ABORTED))?;
154
155 let conn = Arc::new(conn);
156
157 let opener = StreamOpener::new(conn.clone());
158
159 Ok(Self {
160 conn,
161 ctx: crx,
162 opener,
163 })
164 }
165
166 pub(crate) fn attach(inner: msquic::Connection) -> Self {
168 let (mut ctx, crx) = conn_ctx_channel();
169 let handler =
170 move |_: ConnectionRef, ev: ConnectionEvent| connection_callback(&mut ctx, ev);
171 inner.set_callback_handler(handler);
172 let conn = Arc::new(inner);
173
174 let opener = StreamOpener::new(conn.clone());
175
176 Self {
177 conn,
178 ctx: crx,
179 opener,
180 }
181 }
182}
183
184#[derive(Debug)]
186pub struct StreamOpener {
187 conn: Arc<msquic::Connection>,
188 bidi_temp: Option<H3Stream>,
189 uni_temp: Option<H3Stream>,
190}
191
192impl Clone for StreamOpener {
193 fn clone(&self) -> Self {
194 Self {
195 conn: self.conn.clone(),
196 bidi_temp: None,
197 uni_temp: None,
198 }
199 }
200}
201
202impl<B: Buf> h3::quic::Connection<B> for Connection {
204 type RecvStream = H3RecvStream;
205
206 type OpenStreams = StreamOpener;
207
208 type AcceptError = H3Error;
209
210 #[cfg_attr(
211 feature = "tracing",
212 tracing::instrument(skip_all, level = "trace", ret)
213 )]
214 fn poll_accept_recv(
215 &mut self,
216 cx: &mut std::task::Context<'_>,
217 ) -> std::task::Poll<Result<Option<Self::RecvStream>, Self::AcceptError>> {
218 let s = ready!(self.ctx.uni.poll_next_unpin(cx)).unwrap_or(None);
219 std::task::Poll::Ready(Ok(s.map(|s| s.recv)))
221 }
222
223 #[cfg_attr(
224 feature = "tracing",
225 tracing::instrument(skip_all, level = "trace", ret)
226 )]
227 fn poll_accept_bidi(
228 &mut self,
229 cx: &mut std::task::Context<'_>,
230 ) -> std::task::Poll<Result<Option<Self::BidiStream>, Self::AcceptError>> {
231 let s = ready!(self.ctx.bidi.poll_next_unpin(cx)).unwrap_or(None);
232 std::task::Poll::Ready(Ok(s))
234 }
235
236 #[cfg_attr(
237 feature = "tracing",
238 tracing::instrument(skip_all, level = "trace", ret)
239 )]
240 fn opener(&self) -> Self::OpenStreams {
241 StreamOpener::new(self.conn.clone())
242 }
243}
244
245impl<B: Buf> OpenStreams<B> for StreamOpener {
247 type BidiStream = H3Stream;
248
249 type SendStream = H3SendStream;
250
251 type OpenError = H3Error;
252
253 #[cfg_attr(
254 feature = "tracing",
255 tracing::instrument(skip_all, level = "trace", ret)
256 )]
257 fn poll_open_bidi(
258 &mut self,
259 cx: &mut std::task::Context<'_>,
260 ) -> std::task::Poll<Result<Self::BidiStream, Self::OpenError>> {
261 Self::poll_open_inner(&self.conn, false, &mut self.bidi_temp, cx)
262 }
263
264 #[cfg_attr(
265 feature = "tracing",
266 tracing::instrument(skip_all, level = "trace", ret)
267 )]
268 fn poll_open_send(
269 &mut self,
270 cx: &mut std::task::Context<'_>,
271 ) -> std::task::Poll<Result<Self::SendStream, Self::OpenError>> {
272 let res = ready!(Self::poll_open_inner(
273 &self.conn,
274 true,
275 &mut self.uni_temp,
276 cx
277 ));
278 std::task::Poll::Ready(res.map(|s| s.send))
280 }
281
282 #[cfg_attr(
283 feature = "tracing",
284 tracing::instrument(skip_all, level = "trace", ret)
285 )]
286 fn close(&mut self, code: h3::error::Code, _reason: &[u8]) {
287 self.conn
288 .shutdown(ConnectionShutdownFlags::NONE, code.value());
289 }
290}
291
292impl StreamOpener {
293 fn new(conn: Arc<msquic::Connection>) -> Self {
294 Self {
295 conn,
296 bidi_temp: None,
297 uni_temp: None,
298 }
299 }
300
301 fn poll_open_inner(
303 conn: &Arc<msquic::Connection>,
304 uni: bool,
305 stream_holder: &mut Option<H3Stream>,
306 cx: &mut std::task::Context<'_>,
307 ) -> std::task::Poll<Result<H3Stream, H3Error>> {
308 if stream_holder.is_none() {
309 let s = match H3Stream::open_and_start(conn, uni) {
311 Ok(s) => s,
312 Err(e) => return std::task::Poll::Ready(Err(H3Error::new(e, None))),
313 };
314 *stream_holder = Some(s);
315 }
316
317 let res = {
319 let s = stream_holder.as_mut().unwrap();
320 let rx = s.send.sctx.start.as_mut().unwrap();
321 let p = Pin::new(rx);
322 ready!(std::future::Future::poll(p, cx))
323 };
324 let s = stream_holder.take().unwrap();
326 let res = res
327 .expect("cannot receive")
328 .map(|_| s)
329 .map_err(|e| H3Error::new(e, None));
330 std::task::Poll::Ready(res)
331 }
332}
333
334impl<B: Buf> OpenStreams<B> for Connection {
336 type BidiStream = H3Stream;
337
338 type SendStream = H3SendStream;
339
340 type OpenError = H3Error;
341
342 fn poll_open_bidi(
343 &mut self,
344 cx: &mut std::task::Context<'_>,
345 ) -> std::task::Poll<Result<Self::BidiStream, Self::OpenError>> {
346 OpenStreams::<B>::poll_open_bidi(&mut self.opener, cx)
347 }
348
349 fn poll_open_send(
350 &mut self,
351 cx: &mut std::task::Context<'_>,
352 ) -> std::task::Poll<Result<Self::SendStream, Self::OpenError>> {
353 OpenStreams::<B>::poll_open_send(&mut self.opener, cx)
354 }
355
356 fn close(&mut self, code: h3::error::Code, reason: &[u8]) {
357 OpenStreams::<B>::close(&mut self.opener, code, reason)
358 }
359}
360
361#[derive(Debug)]
363pub struct H3Stream {
364 send: H3SendStream,
365 recv: H3RecvStream,
366}
367#[derive(Debug)]
368pub struct H3SendStream {
369 stream: Arc<msquic::Stream>,
370 sctx: SendStreamReceiveCtx,
371}
372#[derive(Debug)]
373pub struct H3RecvStream {
374 stream: Arc<msquic::Stream>,
375 rctx: RecvStreamReceiveCtx,
376}
377
378struct BufPtr(*const c_void);
379unsafe impl Send for BufPtr {}
380unsafe impl Sync for BufPtr {}
381
382struct StreamSendCtx {
383 start: Option<oneshot::Sender<Result<(), Status>>>,
384 send: Option<mpsc::UnboundedSender<(bool, BufPtr)>>,
386 shutdown: Option<oneshot::Sender<()>>,
387 receive: Option<mpsc::UnboundedSender<Bytes>>,
388}
389
390#[derive(Debug)]
392struct RecvStreamReceiveCtx {
393 receive: mpsc::UnboundedReceiver<Bytes>,
394}
395
396#[derive(Debug)]
398struct SendStreamReceiveCtx {
399 start: Option<oneshot::Receiver<Result<(), Status>>>,
400 send: mpsc::UnboundedReceiver<(bool, BufPtr)>,
402 send_inprogress: bool,
403 shutdown: oneshot::Receiver<()>,
404}
405
406fn stream_ctx_channel() -> (StreamSendCtx, SendStreamReceiveCtx, RecvStreamReceiveCtx) {
407 let (start_tx, start_rx) = oneshot::channel::<Result<(), Status>>();
408 let (send_tx, send_rx) = mpsc::unbounded();
409 let (shutdown_tx, shutdown_rx) = oneshot::channel();
410 let (receive_tx, receive_rx) = mpsc::unbounded();
411 (
412 StreamSendCtx {
413 start: Some(start_tx),
414 send: Some(send_tx),
415 shutdown: Some(shutdown_tx),
416 receive: Some(receive_tx),
417 },
418 SendStreamReceiveCtx {
419 start: Some(start_rx),
420 send: send_rx,
421 send_inprogress: false,
422 shutdown: shutdown_rx,
423 },
424 RecvStreamReceiveCtx {
425 receive: receive_rx,
426 },
427 )
428}
429
430#[cfg_attr(
431 feature = "tracing",
432 tracing::instrument(skip(ctx), level = "trace", ret)
433)]
434fn stream_callback(ctx: &mut StreamSendCtx, ev: StreamEvent) -> Result<(), Status> {
435 match ev {
436 StreamEvent::StartComplete { status, .. } => {
437 let tx = ctx.start.take().unwrap();
438 if status.is_ok() {
439 tx.send(Ok(())).expect("cannot send");
440 } else {
441 tx.send(Err(status)).expect("cannot send")
442 }
443 }
444 StreamEvent::SendComplete {
445 cancelled,
446 client_context,
447 } => {
448 if let Some(send) = ctx.send.as_ref() {
449 send.unbounded_send((cancelled, BufPtr(client_context)))
450 .expect("cannot send");
451 } else {
452 debug_assert!(false, "mem leak");
453 }
454 }
455 StreamEvent::Receive { buffers, flags, .. } => {
456 if let Some(receive) = ctx.receive.as_ref() {
457 let mut b = BytesMut::new();
458 for br in buffers {
459 if !br.as_bytes().is_empty() {
461 b.put_slice(br.as_bytes());
462 }
463 }
464 let b = b.freeze();
465 if !b.is_empty() {
466 receive.unbounded_send(b).expect("cannot send");
467 } else {
468 ctx.receive.take();
470 }
471 }
472 if flags.contains(ReceiveFlags::FIN) {
473 ctx.receive.take();
475 }
476 }
477 StreamEvent::SendShutdownComplete { graceful: _ } => {
478 if let Some(shutdown) = ctx.shutdown.take() {
480 shutdown.send(()).expect("cannot send");
481 }
482 }
483 StreamEvent::ShutdownComplete { .. } => {
484 ctx.receive.take();
486 ctx.send.take();
487 ctx.shutdown.take();
488 ctx.start.take();
489 }
490 _ => {}
491 }
492 Ok(())
493}
494
495impl H3Stream {
496 pub(crate) fn attach(stream: msquic::Stream) -> Self {
498 let (mut ctx, rtx, rrtx) = stream_ctx_channel();
499 let handler = move |_: StreamRef, ev: StreamEvent| stream_callback(&mut ctx, ev);
500
501 stream.set_callback_handler(handler);
502 let s = Arc::new(stream);
503 Self {
504 send: H3SendStream {
505 stream: s.clone(),
506 sctx: rtx,
507 },
508 recv: H3RecvStream {
509 stream: s,
510 rctx: rrtx,
511 },
512 }
513 }
514
515 #[cfg_attr(
516 feature = "tracing",
517 tracing::instrument(skip_all, level = "trace", err, ret)
518 )]
519 fn open_and_start(conn: &msquic::Connection, uni: bool) -> Result<Self, Status> {
520 let (mut ctx, rtx, rrtx) = stream_ctx_channel();
521 let handler = move |_: StreamRef, ev: StreamEvent| stream_callback(&mut ctx, ev);
522
523 let flag = match uni {
524 true => StreamOpenFlags::UNIDIRECTIONAL,
525 false => StreamOpenFlags::NONE,
526 };
527
528 let s = msquic::Stream::open(conn, flag, handler)?;
529 s.start(StreamStartFlags::NONE)?;
530 let s = Arc::new(s);
531 Ok(Self {
532 send: H3SendStream {
533 stream: s.clone(),
534 sctx: rtx,
535 },
536 recv: H3RecvStream {
537 stream: s,
538 rctx: rrtx,
539 },
540 })
541 }
542}
543
544impl<B: Buf> SendStream<B> for H3SendStream {
545 type Error = H3Error;
546
547 #[cfg_attr(
550 feature = "tracing",
551 tracing::instrument(skip_all, level = "trace", ret)
552 )]
553 fn poll_ready(
554 &mut self,
555 cx: &mut std::task::Context<'_>,
556 ) -> std::task::Poll<Result<(), Self::Error>> {
557 if !self.sctx.send_inprogress {
558 return std::task::Poll::Ready(Ok(()));
560 }
561 match ready!(self.sctx.send.poll_next_unpin(cx)) {
562 Some((cancelled, ptr)) => {
563 self.sctx.send_inprogress = false;
564 let _: H3Buff<h3::quic::WriteBuf<B>> =
566 unsafe { H3Buff::from_raw(ptr.0 as *mut c_void) };
567 match cancelled {
568 true => std::task::Poll::Ready(Err(H3Error::new(
569 Status::from(StatusCode::QUIC_STATUS_ABORTED),
570 None,
571 ))),
572 false => std::task::Poll::Ready(Ok(())),
573 }
574 }
575 None => std::task::Poll::Ready(Err(H3Error::new(
577 Status::from(StatusCode::QUIC_STATUS_ABORTED),
578 None,
579 ))),
580 }
581 }
582
583 #[cfg_attr(
584 feature = "tracing",
585 tracing::instrument(skip_all, level = "trace", ret, err)
586 )]
587 fn send_data<T: Into<h3::quic::WriteBuf<B>>>(&mut self, data: T) -> Result<(), Self::Error> {
588 if self.sctx.send_inprogress {
589 panic!("send while send is in progress.");
590 }
591 let data: h3::quic::WriteBuf<B> = data.into();
592 let buff = H3Buff::new(data);
593 let (buff_ref, ptr) = unsafe { buff.into_raw() };
594 unsafe { self.stream.send(buff_ref, SendFlags::NONE, ptr) }
595 .inspect_err(|_| {
596 let _: H3Buff<h3::quic::WriteBuf<B>> = unsafe { H3Buff::from_raw(ptr) };
598 })
599 .map_err(|e| H3Error::new(e, None))?;
600 self.sctx.send_inprogress = true;
601 Ok(())
602 }
603
604 #[cfg_attr(
606 feature = "tracing",
607 tracing::instrument(skip_all, level = "trace", ret)
608 )]
609 fn poll_finish(
610 &mut self,
611 cx: &mut std::task::Context<'_>,
612 ) -> std::task::Poll<Result<(), Self::Error>> {
613 if let Err(e) = self.stream.shutdown(StreamShutdownFlags::GRACEFUL, 0) {
615 return std::task::Poll::Ready(Err(H3Error::new(e, None)));
616 }
617 let rx = &mut self.sctx.shutdown;
619 let p = Pin::new(rx);
620 let res = ready!(std::future::Future::poll(p, cx))
622 .map_err(|_| H3Error::new(Status::from(StatusCode::QUIC_STATUS_ABORTED), None));
623 std::task::Poll::Ready(res)
624 }
625
626 #[cfg_attr(
627 feature = "tracing",
628 tracing::instrument(skip_all, level = "trace", ret)
629 )]
630 fn reset(&mut self, _reset_code: u64) {
631 panic!("reset not supported")
632 }
633
634 fn send_id(&self) -> h3::quic::StreamId {
635 get_id(&self.stream)
636 }
637}
638
639fn get_id(s: &msquic::Stream) -> h3::quic::StreamId {
640 let raw_id = unsafe {
641 msquic::Api::get_param_auto::<u64>(s.as_raw(), msquic::ffi::QUIC_PARAM_STREAM_ID)
642 }
643 .unwrap();
644 raw_id.try_into().expect("cannot parse id")
645}
646
647impl RecvStream for H3RecvStream {
648 type Buf = Bytes;
649
650 type Error = H3Error;
651
652 #[cfg_attr(
653 feature = "tracing",
654 tracing::instrument(skip_all, level = "trace", ret)
655 )]
656 fn poll_data(
657 &mut self,
658 cx: &mut std::task::Context<'_>,
659 ) -> std::task::Poll<Result<Option<Self::Buf>, Self::Error>> {
660 let res = ready!(self.rctx.receive.poll_next_unpin(cx));
661 std::task::Poll::Ready(Ok(res))
662 }
663
664 #[cfg_attr(
666 feature = "tracing",
667 tracing::instrument(skip_all, level = "trace", ret)
668 )]
669 fn stop_sending(&mut self, error_code: u64) {
670 let _ = self
672 .stream
673 .shutdown(StreamShutdownFlags::ABORT_RECEIVE, error_code);
674 }
675
676 fn recv_id(&self) -> h3::quic::StreamId {
677 get_id(&self.stream)
678 }
679}
680
681impl<B: Buf> SendStream<B> for H3Stream {
684 type Error = H3Error;
685
686 fn poll_ready(
687 &mut self,
688 cx: &mut std::task::Context<'_>,
689 ) -> std::task::Poll<Result<(), Self::Error>> {
690 SendStream::<B>::poll_ready(&mut self.send, cx)
691 }
692
693 fn send_data<T: Into<h3::quic::WriteBuf<B>>>(&mut self, data: T) -> Result<(), Self::Error> {
694 SendStream::<B>::send_data(&mut self.send, data)
695 }
696
697 fn poll_finish(
698 &mut self,
699 cx: &mut std::task::Context<'_>,
700 ) -> std::task::Poll<Result<(), Self::Error>> {
701 SendStream::<B>::poll_finish(&mut self.send, cx)
702 }
703
704 fn reset(&mut self, reset_code: u64) {
705 SendStream::<B>::reset(&mut self.send, reset_code);
706 }
707
708 fn send_id(&self) -> h3::quic::StreamId {
709 SendStream::<B>::send_id(&self.send)
710 }
711}
712
713impl RecvStream for H3Stream {
714 type Buf = Bytes;
715
716 type Error = H3Error;
717
718 fn poll_data(
719 &mut self,
720 cx: &mut std::task::Context<'_>,
721 ) -> std::task::Poll<Result<Option<Self::Buf>, Self::Error>> {
722 RecvStream::poll_data(&mut self.recv, cx)
723 }
724
725 fn stop_sending(&mut self, error_code: u64) {
726 RecvStream::stop_sending(&mut self.recv, error_code)
727 }
728
729 fn recv_id(&self) -> h3::quic::StreamId {
730 RecvStream::recv_id(&self.recv)
731 }
732}
733
734impl<B: Buf> BidiStream<B> for H3Stream {
735 type SendStream = H3SendStream;
736
737 type RecvStream = H3RecvStream;
738
739 #[cfg_attr(
740 feature = "tracing",
741 tracing::instrument(skip_all, level = "trace", ret)
742 )]
743 fn split(self) -> (Self::SendStream, Self::RecvStream) {
744 (self.send, self.recv)
745 }
746}
747
748#[cfg(test)]
749mod test {
750 use bytes::Buf;
751 use http::Uri;
752 use msquic::{
753 BufferRef, Configuration, CredentialConfig, CredentialFlags, Registration,
754 RegistrationConfig, Settings,
755 };
756
757 use crate::Connection;
758
759 pub mod util {
760 use msquic::Credential;
761 pub const DEVEL_TRACE_LEVEL: tracing::Level = tracing::Level::TRACE;
763
764 pub fn try_setup_tracing() {
765 let _ = tracing_subscriber::fmt()
766 .with_max_level(DEVEL_TRACE_LEVEL)
767 .try_init();
768 }
769
770 #[cfg(target_os = "windows")]
772 pub fn get_test_cred() -> Credential {
773 use msquic::CertificateHash;
774
775 let output = std::process::Command::new("pwsh.exe")
776 .args(["-Command", "Get-ChildItem Cert:\\CurrentUser\\My | Where-Object -Property FriendlyName -EQ -Value MsQuic-Test | Select-Object -ExpandProperty Thumbprint -First 1"]).
777 output().expect("Failed to execute command");
778 assert!(output.status.success());
779 let mut s = String::from_utf8(output.stdout).unwrap();
780 if s.ends_with('\n') {
781 s.pop();
782 if s.ends_with('\r') {
783 s.pop();
784 }
785 };
786 Credential::CertificateHash(CertificateHash::from_str(&s).unwrap())
787 }
788
789 #[cfg(not(target_os = "windows"))]
791 pub fn get_test_cred() -> Credential {
792 use msquic::CertificateFile;
793
794 let cert_dir = std::env::temp_dir().join("msquic_h3_test_rs");
795 let key = "key.pem";
796 let cert = "cert.pem";
797 let key_path = cert_dir.join(key);
798 let cert_path = cert_dir.join(cert);
799 if !key_path.exists() || !cert_path.exists() {
800 let _ = std::fs::remove_dir_all(&cert_dir);
802 std::fs::create_dir_all(&cert_dir).expect("cannot create cert dir");
803 let output = std::process::Command::new("openssl")
805 .args([
806 "req",
807 "-x509",
808 "-newkey",
809 "rsa:4096",
810 "-keyout",
811 "key.pem",
812 "-out",
813 "cert.pem",
814 "-sha256",
815 "-days",
816 "3650",
817 "-nodes",
818 "-subj",
819 "/CN=localhost",
820 ])
821 .current_dir(cert_dir)
822 .stderr(std::process::Stdio::inherit())
823 .stdout(std::process::Stdio::inherit())
824 .output()
825 .expect("cannot generate cert");
826 if !output.status.success() {
827 panic!("generate cert failed");
828 }
829 }
830 Credential::CertificateFile(CertificateFile::new(
831 key_path.display().to_string(),
832 cert_path.display().to_string(),
833 ))
834 }
835 }
836
837 pub(crate) async fn send_get_request(uri: Uri) {
838 let app_name = String::from("testapp");
839 let config = RegistrationConfig::new().set_app_name(app_name);
840 let reg = Registration::new(&config).unwrap();
841
842 let alpn = BufferRef::from("h3");
843 let client_settings = Settings::new().set_IdleTimeoutMs(2000);
846 let client_config = Configuration::open(®, &[alpn], Some(&client_settings)).unwrap();
847 {
848 let cred_config = CredentialConfig::new_client()
849 .set_credential_flags(CredentialFlags::NO_CERTIFICATE_VALIDATION);
850 client_config.load_credential(&cred_config).unwrap();
851 }
852
853 tracing::info!("client conn open and start");
854 let conn = Connection::connect(
855 ®,
856 &client_config,
857 uri.host().unwrap(),
858 uri.port_u16().unwrap(),
859 )
860 .await
861 .unwrap();
862
863 tracing::info!("client create h3 client");
864 let (mut driver, mut send_request) = h3::client::new(conn).await.unwrap();
865
866 tracing::info!("client start driver");
867 let drive = async move {
868 std::future::poll_fn(|cx| driver.poll_close(cx)).await?;
869 Ok::<(), Box<dyn std::error::Error>>(())
870 };
871
872 let request = async move {
880 tracing::info!("sending request ...");
881
882 let req = http::Request::builder().uri(uri).body(())?;
883
884 let mut stream = send_request.send_request(req).await?;
887
888 stream.finish().await?;
890
891 tracing::info!("receiving response ...");
892
893 let resp = stream.recv_response().await?;
894
895 tracing::info!("response: {:?} {}", resp.version(), resp.status());
896 tracing::info!("headers: {:#?}", resp.headers());
897
898 let mut data = vec![];
901 while let Some(mut chunk) = stream.recv_data().await? {
902 let mut dst = vec![0; chunk.remaining()];
906 chunk.copy_to_slice(&mut dst[..]);
907 data.extend_from_slice(&dst);
908 }
909 let body = String::from_utf8_lossy(&data);
910 tracing::info!("client got body: {}", body);
911 Ok::<_, Box<dyn std::error::Error>>(())
913 };
914
915 let (req_res, drive_res) = tokio::join!(request, drive);
916 if let Err(e) = req_res {
917 tracing::error!("req_err {e:?}");
918 }
919 if let Err(e) = drive_res {
920 tracing::error!("drive_res {e:?}");
921 }
922 tracing::info!("client ended success");
923 }
924
925 #[test]
926 fn client_test_apache() {
927 util::try_setup_tracing();
928 let uri = http::Uri::from_static("https://h2o.examp1e.net:443");
934 tokio::runtime::Builder::new_current_thread()
937 .enable_time()
938 .build()
939 .unwrap()
940 .block_on(send_get_request(uri));
941 }
942}