ombrac_client/endpoint/
http.rs1use std::sync::Arc;
2use std::{io, net::SocketAddr};
3
4use bytes::Bytes;
5use http_body_util::{BodyExt, combinators::BoxBody};
6use hyper::{Method, Request, Response};
7use hyper_util::rt::TokioIo;
8use ombrac::prelude::{Address, Client, Secret};
9use ombrac_macros::{error, info};
10use ombrac_transport::Initiator;
11use tokio::net::TcpListener;
12
13type ClientBuilder = hyper::client::conn::http1::Builder;
14type ServerBuilder = hyper::server::conn::http1::Builder;
15
16pub struct Server;
17
18impl Server {
19 pub async fn run<I>(
20 listener: TcpListener,
21 secret: Secret,
22 ombrac_client: Arc<Client<I>>,
23 shutdown_signal: impl Future<Output = ()>,
24 ) -> io::Result<()>
25 where
26 I: Initiator,
27 {
28 let ombrac = Arc::clone(&ombrac_client);
29
30 tokio::pin!(shutdown_signal);
31
32 loop {
33 tokio::select! {
34 biased;
35 _ = &mut shutdown_signal => return Ok(()),
36
37 result = listener.accept() => {
38 let (stream, addr) = match result {
39 Ok(res) => res,
40 Err(_err) => {
41 error!("Failed to accept connection: {}", _err);
42 continue;
43 }
44 };
45
46 let ombrac = ombrac.clone();
47 tokio::spawn(async move {
48 let io = TokioIo::new(stream);
49 if let Err(_error) = ServerBuilder::new()
50 .preserve_header_case(true)
51 .title_case_headers(true)
52 .serve_connection(
53 io,
54 hyper::service::service_fn(|req| async {
55 Self::tunnel(req, ombrac.clone(), secret, addr).await
56 }),
57 )
58 .with_upgrades()
59 .await
60 {
61 error!("Failed to serve connection: {}", _error);
62 }
63 });
64 }
65 }
66 }
67 }
68
69 async fn tunnel<I>(
70 req: Request<hyper::body::Incoming>,
71 conn: Arc<Client<I>>,
72 secret: Secret,
73 _from_addr: SocketAddr,
74 ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error>
75 where
76 I: Initiator,
77 {
78 use ombrac::io::util::copy_bidirectional;
79
80 let host = match req.uri().host() {
81 Some(addr) => addr,
82 None => {
83 error!("Connect host is not socket addr: {:?}", req.uri());
84 let mut resp = Response::default();
85 *resp.status_mut() = http::StatusCode::BAD_REQUEST;
86
87 return Ok(resp);
88 }
89 };
90
91 let port = req.uri().port_u16().unwrap_or(80);
92
93 let target_addr = match Address::try_from(format!("{host}:{port}")) {
94 Ok(addr) => addr,
95 Err(_error) => {
96 error!("{_error}");
97 let mut resp = Response::default();
98 *resp.status_mut() = http::StatusCode::BAD_REQUEST;
99
100 return Ok(resp);
101 }
102 };
103
104 let mut outbound = match conn.connect(target_addr.clone(), secret).await {
105 Ok(conn) => conn,
106 Err(_error) => {
107 let mut resp = Response::default();
108 *resp.status_mut() = http::StatusCode::BAD_REQUEST;
109
110 return Ok(resp);
111 }
112 };
113
114 if Method::CONNECT == req.method() {
115 tokio::spawn(async move {
116 match hyper::upgrade::on(req).await {
117 Ok(upgraded) => {
118 let mut stream = TokioIo::new(upgraded);
119
120 match copy_bidirectional(&mut stream, &mut outbound).await {
121 Ok(_copy) => {
122 info!(
123 "{} Connect {}, Send: {}, Recv: {}",
124 _from_addr, target_addr, _copy.0, _copy.1
125 );
126 }
127
128 Err(_error) => {
129 error!("{_error}")
130 }
131 }
132 }
133 Err(_error) => {
134 error!("Upgrade error: {}", _error);
135 }
136 }
137 });
138 } else {
139 let io = TokioIo::new(outbound);
140
141 let (mut sender, conn) = ClientBuilder::new()
142 .preserve_header_case(true)
143 .title_case_headers(true)
144 .handshake(io)
145 .await?;
146
147 tokio::spawn(async move {
148 info!("{_from_addr } Connect {target_addr}");
149 if let Err(err) = conn.await {
150 error!("Connection failed: {:?}", err);
151 }
152 });
153
154 let resp = sender.send_request(req).await?;
155
156 return Ok(resp.map(|b| b.boxed()));
157 }
158
159 Ok(Response::default())
160 }
161}