1use std::{
4 fmt::{Debug, Formatter},
5 sync::Arc,
6};
7
8use http::{Request, Response};
9use hyper::{
10 body::{Body, Incoming},
11 client::conn::http2::SendRequest,
12};
13use hyper_util::rt::{TokioExecutor, TokioTimer, tokio::WithHyperIo};
14use tokio::{
15 io::{AsyncRead, AsyncWrite},
16 sync::Mutex,
17 task::JoinSet,
18};
19
20use crate::{Client, Error};
21
22#[derive(Clone)]
24pub struct Http2<B> {
25 inner: Arc<Inner<B>>,
26}
27
28impl<B> Debug for Http2<B> {
29 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("Http2").finish_non_exhaustive()
31 }
32}
33
34struct Inner<B> {
35 client: Mutex<SendRequest<B>>,
36 _runner: JoinSet<()>,
37}
38
39impl<B> Client<B> for Http2<B>
40where
41 B: Body + Send + 'static,
42 B::Data: Send,
43 B::Error: Send + Sync + 'static,
44{
45 async fn send(&self, req: Request<B>) -> Result<Response<Incoming>, Error> {
46 let mut client = self.inner.client.lock().await;
47
48 client
49 .send_request(req)
50 .await
51 .inspect_err(|e| {
52 tracing::error!(error = %e, "sending request");
53 })
54 .map_err(Error::from)
55 }
56}
57
58pub async fn connect<B>(
60 io: impl AsyncRead + AsyncWrite + Unpin + Send + 'static,
61) -> Result<Http2<B>, Error>
62where
63 B: Body + Send + Unpin + 'static,
64 B::Data: Send,
65 B::Error: core::error::Error + Send + Sync + 'static,
66{
67 let (client, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
84 .timer(TokioTimer::new())
85 .keep_alive_interval(core::time::Duration::from_secs(60))
86 .keep_alive_timeout(core::time::Duration::from_secs(20))
87 .keep_alive_while_idle(true)
88 .handshake(WithHyperIo::new(io))
89 .await
90 .inspect_err(|e| {
91 tracing::error!(error = %e, "http2 handshake");
92 })
93 .map_err(Error::from)?;
94
95 let mut tasks = JoinSet::new();
96
97 tasks.spawn(async move {
98 if let Err(e) = conn.await {
99 tracing::error!(?e, "error in http/2 connection; closing connection");
100 }
101 });
102
103 Ok(Http2 {
104 inner: Arc::new(Inner {
105 client: Mutex::new(client),
106 _runner: tasks,
107 }),
108 })
109}
110
111pub async fn connect_tcp<B>(url: &url::Url) -> Result<Http2<B>, Error>
113where
114 B: Body + Send + Unpin + 'static,
115 B::Data: Send,
116 B::Error: core::error::Error + Send + Sync + 'static,
117{
118 let conn = crate::dial_tcp(url).await?;
119 connect(conn).await
120}
121
122pub async fn connect_tls<B>(url: &url::Url) -> Result<Http2<B>, Error>
124where
125 B: Body + Send + Unpin + 'static,
126 B::Data: Send,
127 B::Error: core::error::Error + Send + Sync + 'static,
128{
129 let conn = crate::dial_tls(url, [b"h2".to_vec()]).await?;
130 connect(conn).await
131}
132
133#[cfg(test)]
134mod tests {
135 use bytes::Bytes;
136 use http_body_util::Empty;
137 use tracing_test::traced_test;
138
139 use super::*;
140 use crate::ClientExt;
141
142 #[tokio::test]
143 #[traced_test]
144 async fn http2_over_tls_over_tcp() {
145 if !ts_test_util::run_net_tests() {
146 return;
147 }
148
149 let url: url::Url = "https://controlplane.tailscale.com/key".parse().unwrap();
150 let client = connect_tls::<Empty<Bytes>>(&url).await.unwrap();
151
152 let resp = client.get(&url, []).await.unwrap();
153 tracing::info!("{:?}", resp);
154 }
155}