localtunnel_client/
lib.rs

1use std::sync::Arc;
2use tokio::sync::Semaphore;
3
4use anyhow::Result;
5use serde::{Deserialize, Serialize};
6use socket2::{SockRef, TcpKeepalive};
7use tokio::io;
8use tokio::net::TcpStream;
9pub use tokio::sync::broadcast;
10use tokio::time::{sleep, Duration};
11
12pub const PROXY_SERVER: &str = "https://your-domain.com";
13pub const LOCAL_HOST: &str = "127.0.0.1";
14
15// See https://tldp.org/HOWTO/html_single/TCP-Keepalive-HOWTO to understand how keepalive work.
16const TCP_KEEPALIVE_TIME: Duration = Duration::from_secs(30);
17const TCP_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(10);
18#[cfg(not(target_os = "windows"))]
19const TCP_KEEPALIVE_RETRIES: u32 = 5;
20
21#[derive(Debug, Serialize, Deserialize)]
22struct ProxyResponse {
23    id: String,
24    port: u16,
25    max_conn_count: u8,
26    url: String,
27}
28
29/// The server detail for client to connect
30#[derive(Clone, Debug)]
31pub struct TunnelServerInfo {
32    pub host: String,
33    pub port: u16,
34    pub max_conn_count: u8,
35    pub url: String,
36}
37
38pub struct ClientConfig {
39    pub server: Option<String>,
40    pub subdomain: Option<String>,
41    pub local_host: Option<String>,
42    pub local_port: u16,
43    pub shutdown_signal: broadcast::Sender<()>,
44    pub max_conn: u8,
45    pub credential: Option<String>,
46}
47
48/// Open tunnels directly between server and localhost
49pub async fn open_tunnel(config: ClientConfig) -> Result<String> {
50    let ClientConfig {
51        server,
52        subdomain,
53        local_host,
54        local_port,
55        shutdown_signal,
56        max_conn,
57        credential,
58    } = config;
59    let tunnel_info = get_tunnel_endpoint(server, subdomain, credential).await?;
60
61    // TODO check the connect is failed and restart the proxy.
62    tunnel_to_endpoint(
63        tunnel_info.clone(),
64        local_host,
65        local_port,
66        shutdown_signal,
67        max_conn,
68    )
69    .await;
70
71    Ok(tunnel_info.url)
72}
73
74async fn get_tunnel_endpoint(
75    server: Option<String>,
76    subdomain: Option<String>,
77    credential: Option<String>,
78) -> Result<TunnelServerInfo> {
79    let server = server.as_deref().unwrap_or(PROXY_SERVER);
80    let assigned_domain = subdomain.as_deref().unwrap_or("?new");
81    let mut uri = format!("{}/{}", server, assigned_domain);
82    if let Some(credential) = credential {
83        uri = format!("{}?credential={}", uri, credential);
84    }
85    log::info!("Request for assign domain: {}", uri);
86
87    let resp = reqwest::get(uri).await?.json::<ProxyResponse>().await?;
88    log::info!("Response from server: {:#?}", resp);
89
90    let parts = resp.url.split("//").collect::<Vec<&str>>();
91    let mut host = parts[1].split(':').collect::<Vec<&str>>()[0];
92    host = match host.split_once('.') {
93        Some((_, base)) => base,
94        None => host,
95    };
96
97    let tunnel_info = TunnelServerInfo {
98        host: host.to_string(),
99        port: resp.port,
100        max_conn_count: resp.max_conn_count,
101        url: resp.url,
102    };
103
104    Ok(tunnel_info)
105}
106
107async fn tunnel_to_endpoint(
108    server: TunnelServerInfo,
109    local_host: Option<String>,
110    local_port: u16,
111    shutdown_signal: broadcast::Sender<()>,
112    max_conn: u8,
113) {
114    log::info!("Tunnel server info: {:?}", server);
115    let server_host = server.host;
116    let server_port = server.port;
117    let local_host = local_host.unwrap_or(LOCAL_HOST.to_string());
118
119    let count = std::cmp::min(server.max_conn_count, max_conn);
120    log::info!("Max connection count: {}", count);
121    let limit_connection = Arc::new(Semaphore::new(count.into()));
122
123    let mut shutdown_receiver = shutdown_signal.subscribe();
124
125    tokio::spawn(async move {
126        loop {
127            tokio::select! {
128                res = limit_connection.clone().acquire_owned() => {
129                    let permit = match res {
130                        Ok(permit) => permit,
131                        Err(err) => {
132                            log::error!("Acquire limit connection failed: {:?}", err);
133                            return;
134                        },
135                    };
136                    let server_host = server_host.clone();
137                    let local_host = local_host.clone();
138
139                    let mut shutdown_receiver = shutdown_signal.subscribe();
140
141                    tokio::spawn(async move {
142                        log::info!("Create a new proxy connection.");
143                        tokio::select! {
144                            res = handle_connection(server_host, server_port, local_host, local_port) => {
145                                match res {
146                                    Ok(_) => log::info!("Connection result: {:?}", res),
147                                    Err(err) => {
148                                        log::error!("Failed to connect to proxy or local server: {:?}", err);
149                                        sleep(Duration::from_secs(10)).await;
150                                    }
151                                }
152                            }
153                            _ = shutdown_receiver.recv() => {
154                                log::info!("Shutting down the connection immediately");
155                            }
156                        }
157
158                        drop(permit);
159                    });
160                }
161                _ = shutdown_receiver.recv() => {
162                    log::info!("Shuttign down the loop immediately");
163                    return;
164                }
165            };
166        }
167    });
168}
169
170async fn handle_connection(
171    remote_host: String,
172    remote_port: u16,
173    local_host: String,
174    local_port: u16,
175) -> Result<()> {
176    log::debug!("Connect to remote: {}, {}", remote_host, remote_port);
177    let mut remote_stream = TcpStream::connect(format!("{}:{}", remote_host, remote_port)).await?;
178    log::debug!("Connect to local: {}, {}", local_host, local_port);
179    let mut local_stream = TcpStream::connect(format!("{}:{}", local_host, local_port)).await?;
180
181    // configure keepalive on remote socket to early detect network issues and attempt to re-establish the connection.
182    let ka = TcpKeepalive::new()
183        .with_time(TCP_KEEPALIVE_TIME)
184        .with_interval(TCP_KEEPALIVE_INTERVAL);
185    #[cfg(not(target_os = "windows"))]
186    let ka = ka.with_retries(TCP_KEEPALIVE_RETRIES);
187    let sf = SockRef::from(&remote_stream);
188    sf.set_tcp_keepalive(&ka)?;
189
190    io::copy_bidirectional(&mut remote_stream, &mut local_stream).await?;
191    Ok(())
192}