ombrac_client/endpoint/
http.rs

1use std::sync::Arc;
2use std::{io, net::SocketAddr};
3
4use bytes::Bytes;
5use http_body_util::{BodyExt, combinators::BoxBody};
6use hyper::{Method, Request, Response};
7use hyper_util::rt::TokioIo;
8use ombrac::prelude::{Address, Client, Secret};
9use ombrac_macros::{error, info};
10use ombrac_transport::Initiator;
11use tokio::net::TcpListener;
12
13type ClientBuilder = hyper::client::conn::http1::Builder;
14type ServerBuilder = hyper::server::conn::http1::Builder;
15
16pub struct Server;
17
18impl Server {
19    pub async fn run<I>(
20        listener: TcpListener,
21        secret: Secret,
22        ombrac_client: Arc<Client<I>>,
23        shutdown_signal: impl Future<Output = ()>,
24    ) -> io::Result<()>
25    where
26        I: Initiator,
27    {
28        let ombrac = Arc::clone(&ombrac_client);
29
30        tokio::pin!(shutdown_signal);
31
32        loop {
33            tokio::select! {
34                biased;
35                _ = &mut shutdown_signal => return Ok(()),
36
37                result = listener.accept() => {
38                    let (stream, addr) = match result {
39                        Ok(res) => res,
40                        Err(_err) => {
41                            error!("Failed to accept connection: {}", _err);
42                            continue;
43                        }
44                    };
45
46                    let ombrac = ombrac.clone();
47                    tokio::spawn(async move {
48                        let io = TokioIo::new(stream);
49                        if let Err(_error) = ServerBuilder::new()
50                            .preserve_header_case(true)
51                            .title_case_headers(true)
52                            .serve_connection(
53                                io,
54                                hyper::service::service_fn(|req| async {
55                                    Self::tunnel(req, ombrac.clone(), secret, addr).await
56                                }),
57                            )
58                            .with_upgrades()
59                            .await
60                        {
61                            error!("Failed to serve connection: {}", _error);
62                        }
63                    });
64                }
65            }
66        }
67    }
68
69    async fn tunnel<I>(
70        req: Request<hyper::body::Incoming>,
71        conn: Arc<Client<I>>,
72        secret: Secret,
73        _from_addr: SocketAddr,
74    ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error>
75    where
76        I: Initiator,
77    {
78        use ombrac::io::util::copy_bidirectional;
79
80        let host = match req.uri().host() {
81            Some(addr) => addr,
82            None => {
83                error!("Connect host is not socket addr: {:?}", req.uri());
84                let mut resp = Response::default();
85                *resp.status_mut() = http::StatusCode::BAD_REQUEST;
86
87                return Ok(resp);
88            }
89        };
90
91        let port = req.uri().port_u16().unwrap_or(80);
92
93        let target_addr = match Address::try_from(format!("{host}:{port}")) {
94            Ok(addr) => addr,
95            Err(_error) => {
96                error!("{_error}");
97                let mut resp = Response::default();
98                *resp.status_mut() = http::StatusCode::BAD_REQUEST;
99
100                return Ok(resp);
101            }
102        };
103
104        let mut outbound = match conn.connect(target_addr.clone(), secret).await {
105            Ok(conn) => conn,
106            Err(_error) => {
107                let mut resp = Response::default();
108                *resp.status_mut() = http::StatusCode::BAD_REQUEST;
109
110                return Ok(resp);
111            }
112        };
113
114        if Method::CONNECT == req.method() {
115            tokio::spawn(async move {
116                match hyper::upgrade::on(req).await {
117                    Ok(upgraded) => {
118                        let mut stream = TokioIo::new(upgraded);
119
120                        match copy_bidirectional(&mut stream, &mut outbound).await {
121                            Ok(_copy) => {
122                                info!(
123                                    "{} Connect {}, Send: {}, Recv: {}",
124                                    _from_addr, target_addr, _copy.0, _copy.1
125                                );
126                            }
127
128                            Err(_error) => {
129                                error!("{_error}")
130                            }
131                        }
132                    }
133                    Err(_error) => {
134                        error!("Upgrade error: {}", _error);
135                    }
136                }
137            });
138        } else {
139            let io = TokioIo::new(outbound);
140
141            let (mut sender, conn) = ClientBuilder::new()
142                .preserve_header_case(true)
143                .title_case_headers(true)
144                .handshake(io)
145                .await?;
146
147            tokio::spawn(async move {
148                info!("{_from_addr } Connect {target_addr}");
149                if let Err(err) = conn.await {
150                    error!("Connection failed: {:?}", err);
151                }
152            });
153
154            let resp = sender.send_request(req).await?;
155
156            return Ok(resp.map(|b| b.boxed()));
157        }
158
159        Ok(Response::default())
160    }
161}