Skip to main content

aster_server/tunnel/
mod.rs

1pub mod lapstone;
2
3#[cfg(test)]
4mod lapstone_test;
5
6use crate::configuration::Settings;
7use aster::config::{paths::Paths, Config};
8use fs2::FileExt as _;
9use serde::{Deserialize, Serialize};
10use std::fs::{File, OpenOptions};
11use std::io::Write;
12use std::sync::Arc;
13use tokio::sync::{mpsc, RwLock};
14use utoipa::ToSchema;
15
16fn get_server_port() -> anyhow::Result<u16> {
17    let settings = Settings::new()?;
18    Ok(settings.port)
19}
20
21fn get_lock_path() -> std::path::PathBuf {
22    Paths::config_dir().join("tunnel.lock")
23}
24
25fn try_acquire_tunnel_lock() -> anyhow::Result<File> {
26    let lock_path = get_lock_path();
27
28    if let Some(parent) = lock_path.parent() {
29        std::fs::create_dir_all(parent)?;
30    }
31
32    let mut file = OpenOptions::new()
33        .write(true)
34        .create(true)
35        .truncate(true)
36        .open(&lock_path)?;
37
38    file.try_lock_exclusive()
39        .map_err(|_| anyhow::anyhow!("Another aster instance is already running the tunnel"))?;
40
41    writeln!(file, "{}", std::process::id())?;
42    file.sync_all()?;
43
44    Ok(file)
45}
46
47fn is_tunnel_locked_by_another() -> bool {
48    let lock_path = get_lock_path();
49
50    let file = match OpenOptions::new()
51        .write(true)
52        .create(true)
53        .truncate(false)
54        .open(&lock_path)
55    {
56        Ok(f) => f,
57        Err(_) => return false,
58    };
59
60    if file.try_lock_exclusive().is_err() {
61        return true;
62    }
63
64    // Lock released when file is dropped
65    false
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default, ToSchema)]
69#[serde(rename_all = "lowercase")]
70pub enum TunnelState {
71    #[default]
72    Idle,
73    Starting,
74    Running,
75    Error,
76    Disabled,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
80pub struct TunnelInfo {
81    pub state: TunnelState,
82    pub url: String,
83    pub hostname: String,
84    pub secret: String,
85}
86
87pub struct TunnelManager {
88    state: Arc<RwLock<TunnelState>>,
89    info: Arc<RwLock<Option<TunnelInfo>>>,
90    lapstone_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
91    restart_tx: Arc<RwLock<Option<mpsc::Sender<()>>>>,
92    watchdog_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
93    lock_file: Arc<std::sync::Mutex<Option<File>>>,
94}
95
96impl Default for TunnelManager {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102impl TunnelManager {
103    pub fn new() -> Self {
104        TunnelManager {
105            state: Arc::new(RwLock::new(TunnelState::Idle)),
106            info: Arc::new(RwLock::new(None)),
107            lapstone_handle: Arc::new(RwLock::new(None)),
108            restart_tx: Arc::new(RwLock::new(None)),
109            watchdog_handle: Arc::new(RwLock::new(None)),
110            lock_file: Arc::new(std::sync::Mutex::new(None)),
111        }
112    }
113
114    fn get_auto_start() -> bool {
115        Config::global()
116            .get_param("tunnel_auto_start")
117            .unwrap_or(false)
118    }
119
120    fn get_secret() -> Option<String> {
121        Config::global().get_secret("tunnel_secret").ok()
122    }
123
124    fn get_agent_id() -> Option<String> {
125        Config::global().get_secret("tunnel_agent_id").ok()
126    }
127
128    pub async fn check_auto_start(&self) {
129        let auto_start = Self::get_auto_start();
130        let state = self.state.read().await.clone();
131
132        if auto_start && state == TunnelState::Idle {
133            if is_tunnel_locked_by_another() {
134                tracing::info!(
135                    "Tunnel already running on another aster instance, skipping auto-start"
136                );
137                return;
138            }
139
140            tracing::info!("Auto-starting tunnel");
141            match self.start().await {
142                Ok(info) => {
143                    tracing::info!("Tunnel auto-started successfully: {:?}", info.url);
144                }
145                Err(e) => {
146                    tracing::info!("Tunnel auto-start skipped: {}", e);
147                }
148            }
149        }
150    }
151
152    fn is_tunnel_disabled() -> bool {
153        if let Ok(val) = std::env::var("ASTER_TUNNEL") {
154            let val = val.to_lowercase();
155            val == "no" || val == "none"
156        } else {
157            false
158        }
159    }
160
161    pub async fn get_info(&self) -> TunnelInfo {
162        if Self::is_tunnel_disabled() {
163            return TunnelInfo {
164                state: TunnelState::Disabled,
165                url: String::new(),
166                hostname: String::new(),
167                secret: String::new(),
168            };
169        }
170
171        let state = self.state.read().await.clone();
172        let info = self.info.read().await.clone();
173
174        match info {
175            Some(mut tunnel_info) => {
176                tunnel_info.state = state;
177                tunnel_info
178            }
179            None => {
180                let effective_state = if state == TunnelState::Idle && is_tunnel_locked_by_another()
181                {
182                    TunnelState::Running
183                } else {
184                    state
185                };
186                TunnelInfo {
187                    state: effective_state,
188                    url: String::new(),
189                    hostname: String::new(),
190                    secret: String::new(),
191                }
192            }
193        }
194    }
195
196    pub fn set_auto_start(auto_start: bool) -> anyhow::Result<()> {
197        Config::global()
198            .set_param("tunnel_auto_start", auto_start)
199            .map_err(|e| anyhow::anyhow!("Failed to save tunnel config: {}", e))
200    }
201
202    pub fn set_secret(secret: &str) -> anyhow::Result<()> {
203        Config::global()
204            .set_secret("tunnel_secret", &secret.to_string())
205            .map_err(|e| anyhow::anyhow!("Failed to save tunnel secret: {}", e))
206    }
207
208    pub fn set_agent_id(agent_id: &str) -> anyhow::Result<()> {
209        Config::global()
210            .set_secret("tunnel_agent_id", &agent_id.to_string())
211            .map_err(|e| anyhow::anyhow!("Failed to save tunnel agent_id: {}", e))
212    }
213
214    async fn start_tunnel_internal(&self) -> anyhow::Result<(TunnelInfo, mpsc::Receiver<()>)> {
215        let server_port = get_server_port()?;
216        let tunnel_secret = Self::get_secret().unwrap_or_else(generate_secret);
217        let server_secret =
218            std::env::var("ASTER_SERVER__SECRET_KEY").unwrap_or_else(|_| "test".to_string());
219        let agent_id = Self::get_agent_id().unwrap_or_else(generate_agent_id);
220
221        Self::set_secret(&tunnel_secret)?;
222        Self::set_agent_id(&agent_id)?;
223
224        let (restart_tx, restart_rx) = mpsc::channel::<()>(1);
225        *self.restart_tx.write().await = Some(restart_tx.clone());
226
227        let result = lapstone::start(
228            server_port,
229            tunnel_secret,
230            server_secret,
231            agent_id,
232            self.lapstone_handle.clone(),
233            restart_tx,
234        )
235        .await;
236
237        match result {
238            Ok(info) => Ok((info, restart_rx)),
239            Err(e) => Err(e),
240        }
241    }
242
243    pub async fn start(&self) -> anyhow::Result<TunnelInfo> {
244        if Self::is_tunnel_disabled() {
245            anyhow::bail!("Tunnel is disabled via ASTER_TUNNEL environment variable");
246        }
247
248        let mut state = self.state.write().await;
249        if *state != TunnelState::Idle {
250            anyhow::bail!("Tunnel is already running or starting");
251        }
252
253        let lock = try_acquire_tunnel_lock()?;
254        *self.lock_file.lock().unwrap() = Some(lock);
255
256        *state = TunnelState::Starting;
257        drop(state);
258
259        match self.start_tunnel_internal().await {
260            Ok((info, mut restart_rx)) => {
261                *self.state.write().await = TunnelState::Running;
262                *self.info.write().await = Some(info.clone());
263                let _ = Self::set_auto_start(true);
264
265                let state = self.state.clone();
266                let lapstone_handle = self.lapstone_handle.clone();
267                let watchdog_handle_arc = self.watchdog_handle.clone();
268                let manager = Arc::new(self.clone_for_watchdog());
269
270                let watchdog = tokio::spawn(async move {
271                    while restart_rx.recv().await.is_some() {
272                        let auto_start = Self::get_auto_start();
273                        if !auto_start {
274                            tracing::info!("Tunnel connection lost but auto_start is disabled");
275                            break;
276                        }
277
278                        tracing::warn!("Tunnel connection lost, initiating restart...");
279                        lapstone::stop(lapstone_handle.clone()).await;
280                        *state.write().await = TunnelState::Idle;
281                        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
282                        *state.write().await = TunnelState::Starting;
283
284                        match manager.start_tunnel_internal().await {
285                            Ok((_, new_restart_rx)) => {
286                                *state.write().await = TunnelState::Running;
287                                tracing::info!("Tunnel restarted successfully");
288                                restart_rx = new_restart_rx;
289                            }
290                            Err(e) => {
291                                tracing::error!("Failed to restart tunnel: {}", e);
292                                *state.write().await = TunnelState::Error;
293                                break;
294                            }
295                        }
296                    }
297                });
298
299                *watchdog_handle_arc.write().await = Some(watchdog);
300
301                Ok(info)
302            }
303            Err(e) => {
304                self.release_lock();
305                *self.state.write().await = TunnelState::Error;
306                Err(e)
307            }
308        }
309    }
310
311    fn clone_for_watchdog(&self) -> Self {
312        TunnelManager {
313            state: self.state.clone(),
314            info: self.info.clone(),
315            lapstone_handle: self.lapstone_handle.clone(),
316            restart_tx: self.restart_tx.clone(),
317            watchdog_handle: self.watchdog_handle.clone(),
318            lock_file: self.lock_file.clone(),
319        }
320    }
321
322    fn release_lock(&self) {
323        if let Ok(mut guard) = self.lock_file.lock() {
324            // Dropping the file releases the lock
325            guard.take();
326        }
327    }
328
329    pub async fn stop(&self, clear_auto_start: bool) {
330        if let Some(handle) = self.watchdog_handle.write().await.take() {
331            handle.abort();
332        }
333
334        *self.restart_tx.write().await = None;
335
336        lapstone::stop(self.lapstone_handle.clone()).await;
337
338        self.release_lock();
339
340        *self.state.write().await = TunnelState::Idle;
341        *self.info.write().await = None;
342
343        if clear_auto_start {
344            let _ = Self::set_auto_start(false);
345        }
346    }
347}
348
349fn generate_secret() -> String {
350    let bytes: [u8; 32] = rand::random();
351    hex::encode(bytes)
352}
353
354pub(super) fn generate_agent_id() -> String {
355    let bytes: [u8; 32] = rand::random();
356    hex::encode(bytes)
357}