1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use axum::Router;
6use axum::body::Body;
7use axum::extract::RawQuery;
8use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
9use axum::extract::{Path, State};
10use axum::http::{HeaderMap, Method, StatusCode};
11use axum::response::{IntoResponse, Response};
12use axum::routing::{any, get};
13use futures_util::{SinkExt, StreamExt};
14use outpunch::protocol::IncomingRequest;
15use outpunch::server::OutpunchServer;
16use tokio::sync::mpsc;
17
18pub fn router(server: OutpunchServer) -> Router {
20 let state = Arc::new(server);
21
22 Router::new()
23 .route("/ws", get(ws_handler))
24 .route("/tunnel/{service}/{*path}", any(tunnel_handler))
25 .route("/tunnel/{service}", any(tunnel_handler_no_path))
26 .with_state(state)
27}
28
29async fn tunnel_handler(
30 State(server): State<Arc<OutpunchServer>>,
31 Path((service, path)): Path<(String, String)>,
32 method: Method,
33 RawQuery(raw_query): RawQuery,
34 headers: HeaderMap,
35 body: Body,
36) -> Response {
37 handle_tunnel(server, service, path, method, raw_query, headers, body).await
38}
39
40async fn tunnel_handler_no_path(
41 State(server): State<Arc<OutpunchServer>>,
42 Path(service): Path<String>,
43 method: Method,
44 RawQuery(raw_query): RawQuery,
45 headers: HeaderMap,
46 body: Body,
47) -> Response {
48 handle_tunnel(
49 server,
50 service,
51 String::new(),
52 method,
53 raw_query,
54 headers,
55 body,
56 )
57 .await
58}
59
60async fn handle_tunnel(
61 server: Arc<OutpunchServer>,
62 service: String,
63 path: String,
64 method: Method,
65 raw_query: Option<String>,
66 headers: HeaderMap,
67 body: Body,
68) -> Response {
69 let body_bytes = match axum::body::to_bytes(body, server.max_body_size()).await {
70 Ok(b) => b,
71 Err(_) => {
72 return (StatusCode::BAD_REQUEST, "request body too large").into_response();
73 }
74 };
75
76 let body_str = if body_bytes.is_empty() {
77 None
78 } else {
79 Some(String::from_utf8_lossy(&body_bytes).into_owned())
80 };
81
82 let query = parse_query(raw_query.as_deref());
83
84 let incoming = IncomingRequest {
85 service,
86 method: method.to_string(),
87 path,
88 query,
89 headers: extract_headers(&headers),
90 body: body_str,
91 };
92
93 let resp = server.handle_request(incoming).await;
94 tunnel_response_to_axum(resp)
95}
96
97async fn ws_handler(State(server): State<Arc<OutpunchServer>>, ws: WebSocketUpgrade) -> Response {
98 ws.on_upgrade(move |socket| handle_ws(server, socket))
99}
100
101async fn handle_ws(server: Arc<OutpunchServer>, socket: WebSocket) {
103 let (mut ws_sink, mut ws_stream) = socket.split();
104
105 let connection = server.create_connection();
106
107 let (msg_tx, mut msg_rx) = mpsc::channel::<String>(64);
109 connection.on_message(move |msg| {
110 let _ = msg_tx.try_send(msg);
111 });
112
113 let write_handle = tokio::spawn(async move {
114 while let Some(msg) = msg_rx.recv().await {
115 if ws_sink.send(WsMessage::text(msg)).await.is_err() {
116 break;
117 }
118 }
119 });
120
121 let conn_for_read = connection.clone();
123 let read_handle = tokio::spawn(async move {
124 while let Some(Ok(msg)) = ws_stream.next().await {
125 match msg {
126 WsMessage::Text(text) => {
127 conn_for_read.push_message(text.to_string()).await;
128 }
129 WsMessage::Close(_) => break,
130 _ => {}
131 }
132 }
133 conn_for_read.close();
134 });
135
136 connection.run().await;
138
139 let _ = tokio::time::timeout(Duration::from_millis(100), write_handle).await;
141 read_handle.abort();
142}
143
144fn parse_query(raw: Option<&str>) -> HashMap<String, String> {
145 let Some(qs) = raw else {
146 return HashMap::new();
147 };
148
149 qs.split('&')
150 .filter_map(|pair| {
151 let (k, v) = pair.split_once('=')?;
152 Some((k.to_string(), v.to_string()))
153 })
154 .collect()
155}
156
157fn extract_headers(headers: &HeaderMap) -> HashMap<String, String> {
158 let skip = ["host", "connection", "upgrade", "transfer-encoding"];
159
160 headers
161 .iter()
162 .filter(|(name, _)| !skip.contains(&name.as_str()))
163 .filter_map(|(name, value)| {
164 value
165 .to_str()
166 .ok()
167 .map(|v| (name.to_string(), v.to_string()))
168 })
169 .collect()
170}
171
172#[cfg(test)]
173#[path = "lib_tests.rs"]
174mod tests;
175
176fn tunnel_response_to_axum(resp: outpunch::protocol::TunnelResponse) -> Response {
177 use base64::Engine;
178 use base64::engine::general_purpose::STANDARD as BASE64;
179
180 let status = StatusCode::from_u16(resp.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
181
182 let body_bytes = match (resp.body, resp.body_encoding.as_deref()) {
183 (Some(encoded), Some("base64")) => BASE64
184 .decode(&encoded)
185 .unwrap_or_else(|_| encoded.into_bytes()),
186 (Some(plain), _) => plain.into_bytes(),
187 (None, _) => Vec::new(),
188 };
189
190 let mut builder = Response::builder().status(status);
191
192 for (key, value) in &resp.headers {
193 builder = builder.header(key.as_str(), value.as_str());
194 }
195
196 builder.body(Body::from(body_bytes)).unwrap_or_else(|_| {
197 Response::builder()
198 .status(StatusCode::INTERNAL_SERVER_ERROR)
199 .body(Body::from("internal error"))
200 .unwrap()
201 })
202}