1use std::{
4 fmt::{Debug, Formatter},
5 sync::Arc,
6};
7
8use http::{Request, Response};
9use hyper::{
10 body::{Body, Incoming},
11 client::conn::http2::SendRequest,
12};
13use hyper_util::rt::{TokioExecutor, tokio::WithHyperIo};
14use tokio::{
15 io::{AsyncRead, AsyncWrite},
16 sync::Mutex,
17 task::JoinSet,
18};
19
20use crate::{Client, Error};
21
22#[derive(Clone)]
24pub struct Http2<B> {
25 inner: Arc<Inner<B>>,
26}
27
28impl<B> Debug for Http2<B> {
29 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("Http2").finish_non_exhaustive()
31 }
32}
33
34struct Inner<B> {
35 client: Mutex<SendRequest<B>>,
36 _runner: JoinSet<()>,
37}
38
39impl<B> Client<B> for Http2<B>
40where
41 B: Body + Send + 'static,
42 B::Data: Send,
43 B::Error: Send + Sync + 'static,
44{
45 async fn send(&self, req: Request<B>) -> Result<Response<Incoming>, Error> {
46 let mut client = self.inner.client.lock().await;
47
48 client
49 .send_request(req)
50 .await
51 .inspect_err(|e| {
52 tracing::error!(error = %e, "sending request");
53 })
54 .map_err(Error::from)
55 }
56}
57
58pub async fn connect<B>(
60 io: impl AsyncRead + AsyncWrite + Unpin + Send + 'static,
61) -> Result<Http2<B>, Error>
62where
63 B: Body + Send + Unpin + 'static,
64 B::Data: Send,
65 B::Error: core::error::Error + Send + Sync + 'static,
66{
67 let (client, conn) =
68 hyper::client::conn::http2::handshake(TokioExecutor::new(), WithHyperIo::new(io))
69 .await
70 .inspect_err(|e| {
71 tracing::error!(error = %e, "http2 handshake");
72 })
73 .map_err(Error::from)?;
74
75 let mut tasks = JoinSet::new();
76
77 tasks.spawn(async move {
78 if let Err(e) = conn.await {
79 tracing::error!(?e, "error in http/2 connection; closing connection");
80 }
81 });
82
83 Ok(Http2 {
84 inner: Arc::new(Inner {
85 client: Mutex::new(client),
86 _runner: tasks,
87 }),
88 })
89}
90
91pub async fn connect_tcp<B>(url: &url::Url) -> Result<Http2<B>, Error>
93where
94 B: Body + Send + Unpin + 'static,
95 B::Data: Send,
96 B::Error: core::error::Error + Send + Sync + 'static,
97{
98 let conn = crate::dial_tcp(url).await?;
99 connect(conn).await
100}
101
102pub async fn connect_tls<B>(url: &url::Url) -> Result<Http2<B>, Error>
104where
105 B: Body + Send + Unpin + 'static,
106 B::Data: Send,
107 B::Error: core::error::Error + Send + Sync + 'static,
108{
109 let conn = crate::dial_tls(url, [b"h2".to_vec()]).await?;
110 connect(conn).await
111}
112
113#[cfg(test)]
114mod tests {
115 use bytes::Bytes;
116 use http_body_util::Empty;
117 use tracing_test::traced_test;
118
119 use super::*;
120 use crate::ClientExt;
121
122 #[tokio::test]
123 #[traced_test]
124 async fn http2_over_tls_over_tcp() {
125 if !ts_test_util::run_net_tests() {
126 return;
127 }
128
129 let url: url::Url = "https://controlplane.tailscale.com/key".parse().unwrap();
130 let client = connect_tls::<Empty<Bytes>>(&url).await.unwrap();
131
132 let resp = client.get(&url, []).await.unwrap();
133 tracing::info!("{:?}", resp);
134 }
135}