Skip to main content

outpunch_axum/
lib.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use axum::Router;
6use axum::body::Body;
7use axum::extract::RawQuery;
8use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
9use axum::extract::{Path, State};
10use axum::http::{HeaderMap, Method, StatusCode};
11use axum::response::{IntoResponse, Response};
12use axum::routing::{any, get};
13use futures_util::{SinkExt, StreamExt};
14use outpunch::protocol::IncomingRequest;
15use outpunch::server::OutpunchServer;
16use tokio::sync::mpsc;
17
18/// Build an axum Router with outpunch tunnel routes.
19pub fn router(server: OutpunchServer) -> Router {
20    let state = Arc::new(server);
21
22    Router::new()
23        .route("/ws", get(ws_handler))
24        .route("/tunnel/{service}/{*path}", any(tunnel_handler))
25        .route("/tunnel/{service}", any(tunnel_handler_no_path))
26        .with_state(state)
27}
28
29async fn tunnel_handler(
30    State(server): State<Arc<OutpunchServer>>,
31    Path((service, path)): Path<(String, String)>,
32    method: Method,
33    RawQuery(raw_query): RawQuery,
34    headers: HeaderMap,
35    body: Body,
36) -> Response {
37    handle_tunnel(server, service, path, method, raw_query, headers, body).await
38}
39
40async fn tunnel_handler_no_path(
41    State(server): State<Arc<OutpunchServer>>,
42    Path(service): Path<String>,
43    method: Method,
44    RawQuery(raw_query): RawQuery,
45    headers: HeaderMap,
46    body: Body,
47) -> Response {
48    handle_tunnel(
49        server,
50        service,
51        String::new(),
52        method,
53        raw_query,
54        headers,
55        body,
56    )
57    .await
58}
59
60async fn handle_tunnel(
61    server: Arc<OutpunchServer>,
62    service: String,
63    path: String,
64    method: Method,
65    raw_query: Option<String>,
66    headers: HeaderMap,
67    body: Body,
68) -> Response {
69    let body_bytes = match axum::body::to_bytes(body, server.max_body_size()).await {
70        Ok(b) => b,
71        Err(_) => {
72            return (StatusCode::BAD_REQUEST, "request body too large").into_response();
73        }
74    };
75
76    let body_str = if body_bytes.is_empty() {
77        None
78    } else {
79        Some(String::from_utf8_lossy(&body_bytes).into_owned())
80    };
81
82    let query = parse_query(raw_query.as_deref());
83
84    let incoming = IncomingRequest {
85        service,
86        method: method.to_string(),
87        path,
88        query,
89        headers: extract_headers(&headers),
90        body: body_str,
91    };
92
93    let resp = server.handle_request(incoming).await;
94    tunnel_response_to_axum(resp)
95}
96
97async fn ws_handler(State(server): State<Arc<OutpunchServer>>, ws: WebSocketUpgrade) -> Response {
98    ws.on_upgrade(move |socket| handle_ws(server, socket))
99}
100
101/// Bridge a WebSocket to the core's Connection interface.
102async fn handle_ws(server: Arc<OutpunchServer>, socket: WebSocket) {
103    let (mut ws_sink, mut ws_stream) = socket.split();
104
105    let connection = server.create_connection();
106
107    // Outgoing: core → WS sink (via channel bridged from on_message callback)
108    let (msg_tx, mut msg_rx) = mpsc::channel::<String>(64);
109    connection.on_message(move |msg| {
110        let _ = msg_tx.try_send(msg);
111    });
112
113    let write_handle = tokio::spawn(async move {
114        while let Some(msg) = msg_rx.recv().await {
115            if ws_sink.send(WsMessage::text(msg)).await.is_err() {
116                break;
117            }
118        }
119    });
120
121    // Incoming: WS stream → connection.push_message
122    let conn_for_read = connection.clone();
123    let read_handle = tokio::spawn(async move {
124        while let Some(Ok(msg)) = ws_stream.next().await {
125            match msg {
126                WsMessage::Text(text) => {
127                    conn_for_read.push_message(text.to_string()).await;
128                }
129                WsMessage::Close(_) => break,
130                _ => {}
131            }
132        }
133        conn_for_read.close();
134    });
135
136    // Core handles the connection lifecycle.
137    connection.run().await;
138
139    // Give the write task time to flush remaining messages before closing
140    let _ = tokio::time::timeout(Duration::from_millis(100), write_handle).await;
141    read_handle.abort();
142}
143
144fn parse_query(raw: Option<&str>) -> HashMap<String, String> {
145    let Some(qs) = raw else {
146        return HashMap::new();
147    };
148
149    qs.split('&')
150        .filter_map(|pair| {
151            let (k, v) = pair.split_once('=')?;
152            Some((k.to_string(), v.to_string()))
153        })
154        .collect()
155}
156
157fn extract_headers(headers: &HeaderMap) -> HashMap<String, String> {
158    let skip = ["host", "connection", "upgrade", "transfer-encoding"];
159
160    headers
161        .iter()
162        .filter(|(name, _)| !skip.contains(&name.as_str()))
163        .filter_map(|(name, value)| {
164            value
165                .to_str()
166                .ok()
167                .map(|v| (name.to_string(), v.to_string()))
168        })
169        .collect()
170}
171
172#[cfg(test)]
173#[path = "lib_tests.rs"]
174mod tests;
175
176fn tunnel_response_to_axum(resp: outpunch::protocol::TunnelResponse) -> Response {
177    use base64::Engine;
178    use base64::engine::general_purpose::STANDARD as BASE64;
179
180    let status = StatusCode::from_u16(resp.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
181
182    let body_bytes = match (resp.body, resp.body_encoding.as_deref()) {
183        (Some(encoded), Some("base64")) => BASE64
184            .decode(&encoded)
185            .unwrap_or_else(|_| encoded.into_bytes()),
186        (Some(plain), _) => plain.into_bytes(),
187        (None, _) => Vec::new(),
188    };
189
190    let mut builder = Response::builder().status(status);
191
192    for (key, value) in &resp.headers {
193        builder = builder.header(key.as_str(), value.as_str());
194    }
195
196    builder.body(Body::from(body_bytes)).unwrap_or_else(|_| {
197        Response::builder()
198            .status(StatusCode::INTERNAL_SERVER_ERROR)
199            .body(Body::from("internal error"))
200            .unwrap()
201    })
202}