1use crate::protocol::{Frame, FrameCodec};
2use bytes::Bytes;
3use futures_util::{SinkExt, StreamExt};
4use nix::pty::openpty;
5use std::io;
6use std::os::fd::{AsRawFd, OwnedFd};
7use std::process::Stdio;
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::sync::{Arc, OnceLock};
10use tokio::io::unix::AsyncFd;
11use tokio::net::UnixStream;
12use tokio::process::Command;
13use tokio::sync::mpsc;
14use tokio_util::codec::Framed;
15use tracing::{debug, info};
16
17pub struct SessionMetadata {
18 pub pty_path: String,
19 pub shell_pid: u32,
20 pub created_at: u64,
21 pub attached: AtomicBool,
22 pub last_heartbeat: AtomicU64,
23}
24
25struct ManagedChild {
28 child: tokio::process::Child,
29 pgid: nix::unistd::Pid,
30}
31
32impl ManagedChild {
33 fn new(child: tokio::process::Child) -> Self {
34 let pid = child.id().expect("child should have pid") as i32;
35 Self { child, pgid: nix::unistd::Pid::from_raw(pid) }
36 }
37}
38
39impl Drop for ManagedChild {
40 fn drop(&mut self) {
41 let _ = nix::sys::signal::killpg(self.pgid, nix::sys::signal::Signal::SIGHUP);
42 let _ = self.child.try_wait();
43 }
44}
45
46enum RelayExit {
48 ClientGone,
50 ShellExited(i32),
52}
53
54pub async fn run(
55 mut client_rx: mpsc::UnboundedReceiver<Framed<UnixStream, FrameCodec>>,
56 metadata_slot: Arc<OnceLock<SessionMetadata>>,
57) -> anyhow::Result<()> {
58 let pty = openpty(None, None)?;
60 let master: OwnedFd = pty.master;
61 let slave: OwnedFd = pty.slave;
62
63 let pty_path =
65 nix::unistd::ttyname(&slave).map(|p| p.display().to_string()).unwrap_or_default();
66
67 let slave_fd = slave.as_raw_fd();
69 let stdin_fd = crate::security::checked_dup(slave_fd)?;
70 let stdout_fd = crate::security::checked_dup(slave_fd)?;
71 let stderr_fd = crate::security::checked_dup(slave_fd)?;
72 let raw_stdin = stdin_fd.as_raw_fd();
73 drop(slave);
74
75 let flags = nix::fcntl::fcntl(&master, nix::fcntl::FcntlArg::F_GETFL)?;
77 let mut oflags = nix::fcntl::OFlag::from_bits_truncate(flags);
78 oflags |= nix::fcntl::OFlag::O_NONBLOCK;
79 nix::fcntl::fcntl(&master, nix::fcntl::FcntlArg::F_SETFL(oflags))?;
80
81 let async_master = AsyncFd::new(master)?;
82 let mut buf = vec![0u8; 4096];
83
84 let mut framed = match client_rx.recv().await {
86 Some(framed) => {
87 info!("first client connected via channel");
88 framed
89 }
90 None => {
91 info!("client channel closed before first client");
92 return Ok(());
93 }
94 };
95
96 let env_vars =
98 match tokio::time::timeout(std::time::Duration::from_millis(100), framed.next()).await {
99 Ok(Some(Ok(Frame::Env { vars }))) => {
100 debug!(count = vars.len(), "received env vars from client");
101 vars
102 }
103 _ => Vec::new(),
104 };
105
106 let shell = std::env::var("SHELL").unwrap_or_else(|_| "/bin/sh".to_string());
108 let home = std::env::var("HOME").ok();
109 let mut cmd = Command::new(&shell);
110 cmd.arg("-l");
111 if let Some(ref dir) = home {
112 cmd.current_dir(dir);
113 }
114 for (k, v) in &env_vars {
115 cmd.env(k, v);
116 }
117 let mut managed = ManagedChild::new(unsafe {
118 cmd.pre_exec(move || {
119 nix::unistd::setsid().map_err(io::Error::other)?;
120 libc::ioctl(raw_stdin, libc::TIOCSCTTY as libc::c_ulong, 0);
121 Ok(())
122 })
123 .stdin(Stdio::from(stdin_fd))
124 .stdout(Stdio::from(stdout_fd))
125 .stderr(Stdio::from(stderr_fd))
126 .spawn()?
127 });
128
129 let shell_pid = managed.child.id().unwrap_or(0);
130 let created_at = std::time::SystemTime::now()
131 .duration_since(std::time::UNIX_EPOCH)
132 .unwrap_or_default()
133 .as_secs();
134
135 let _ = metadata_slot.set(SessionMetadata {
136 pty_path,
137 shell_pid,
138 created_at,
139 attached: AtomicBool::new(false),
140 last_heartbeat: AtomicU64::new(0),
141 });
142
143 metadata_slot.get().unwrap().attached.store(true, Ordering::Relaxed);
145
146 let mut first_client = true;
149 loop {
150 if !first_client {
151 framed = tokio::select! {
152 client = client_rx.recv() => {
153 match client {
154 Some(f) => {
155 info!("client connected via channel");
156 f
157 }
158 None => {
159 info!("client channel closed");
160 break;
161 }
162 }
163 }
164 status = managed.child.wait() => {
165 let code = status?.code().unwrap_or(1);
166 info!(code, "shell exited while awaiting client");
167 break;
168 }
169 };
170
171 if let Some(meta) = metadata_slot.get() {
172 meta.attached.store(true, Ordering::Relaxed);
173 }
174 }
175 first_client = false;
176
177 let exit = loop {
179 tokio::select! {
180 frame = framed.next() => {
181 match frame {
182 Some(Ok(Frame::Data(data))) => {
183 debug!(len = data.len(), "socket -> pty");
184 let mut guard = async_master.writable().await?;
185 match guard.try_io(|inner| {
186 nix::unistd::write(inner, &data).map_err(io::Error::from)
187 }) {
188 Ok(Ok(_)) => {}
189 Ok(Err(e)) => return Err(e.into()),
190 Err(_would_block) => continue,
191 }
192 }
193 Some(Ok(Frame::Resize { cols, rows })) => {
194 let (cols, rows) = crate::security::clamp_winsize(cols, rows);
195 debug!(cols, rows, "resize pty");
196 let ws = libc::winsize {
197 ws_row: rows,
198 ws_col: cols,
199 ws_xpixel: 0,
200 ws_ypixel: 0,
201 };
202 unsafe {
203 libc::ioctl(
204 async_master.as_raw_fd(),
205 libc::TIOCSWINSZ,
206 &ws as *const _,
207 );
208 }
209 if let Ok(pgid) = nix::unistd::tcgetpgrp(&async_master) {
210 let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGWINCH);
211 }
212 }
213 Some(Ok(Frame::Ping)) => {
214 if let Some(meta) = metadata_slot.get() {
215 let now = std::time::SystemTime::now()
216 .duration_since(std::time::UNIX_EPOCH)
217 .unwrap_or_default()
218 .as_secs();
219 meta.last_heartbeat.store(now, Ordering::Relaxed);
220 }
221 let _ = framed.send(Frame::Pong).await;
222 }
223 Some(Ok(Frame::Exit { .. })) | None => {
225 break RelayExit::ClientGone;
226 }
227 Some(Ok(_)) => {}
229 Some(Err(e)) => return Err(e.into()),
230 }
231 }
232
233 ready = async_master.readable() => {
234 let mut guard = ready?;
235 match guard.try_io(|inner| {
236 nix::unistd::read(inner, &mut buf).map_err(io::Error::from)
237 }) {
238 Ok(Ok(0)) => {
239 debug!("pty EOF");
240 break RelayExit::ShellExited(0);
241 }
242 Ok(Ok(n)) => {
243 debug!(len = n, "pty -> socket");
244 framed.send(Frame::Data(Bytes::copy_from_slice(&buf[..n]))).await?;
245 }
246 Ok(Err(e)) => {
247 if e.raw_os_error() == Some(libc::EIO) {
248 debug!("pty EIO (shell exited)");
249 break RelayExit::ShellExited(0);
250 }
251 return Err(e.into());
252 }
253 Err(_would_block) => continue,
254 }
255 }
256
257 new_client = client_rx.recv() => {
259 if let Some(new_framed) = new_client {
260 info!("new client via channel, detaching old client");
261 let _ = framed.send(Frame::Detached).await;
262 framed = new_framed;
263 }
264 }
265
266 status = managed.child.wait() => {
267 let code = status?.code().unwrap_or(1);
268 info!(code, "shell exited");
269 break RelayExit::ShellExited(code);
270 }
271 }
272 };
273
274 match exit {
275 RelayExit::ClientGone => {
276 if let Some(meta) = metadata_slot.get() {
277 meta.attached.store(false, Ordering::Relaxed);
278 }
279 info!("client disconnected, waiting for reconnect");
280 continue;
281 }
282 RelayExit::ShellExited(mut code) => {
283 if let Ok(Some(status)) = managed.child.try_wait() {
286 code = status.code().unwrap_or(code);
287 }
288 let _ = framed.send(Frame::Exit { code }).await;
289 info!(code, "session ended");
290 break;
291 }
292 }
293 }
294
295 Ok(())
296}