1use std::{
16 collections::HashMap,
17 env,
18 ffi::OsString,
19 fs, io, net,
20 ops::Add,
21 os,
22 os::unix::{
23 fs::PermissionsExt as _,
24 net::{UnixListener, UnixStream},
25 process::CommandExt as _,
26 },
27 path::{Path, PathBuf},
28 process,
29 sync::{Arc, Mutex},
30 thread, time,
31 time::{Duration, Instant},
32};
33
34use anyhow::{anyhow, Context};
35use nix::unistd;
36use shpool_protocol::{
37 AttachHeader, AttachReplyHeader, AttachStatus, ConnectHeader, DetachReply, DetachRequest,
38 KillReply, KillRequest, ListReply, LogLevel, ResizeReply, Session, SessionMessageDetachReply,
39 SessionMessageReply, SessionMessageRequest, SessionMessageRequestPayload, SessionStatus,
40 SetLogLevelReply, SetLogLevelRequest, VersionHeader,
41};
42use tracing::{debug, error, info, instrument, span, warn, Level};
43
44use crate::{
45 config,
46 config::MotdDisplayMode,
47 consts,
48 daemon::{
49 etc_environment, exit_notify::ExitNotifier, hooks, pager, pager::PagerError, prompt, shell,
50 show_motd, ttl_reaper,
51 },
52 protocol, test_hooks, tty, user,
53};
54
55const DEFAULT_INITIAL_SHELL_PATH: &str = "/usr/bin:/bin:/usr/sbin:/sbin";
56const DEFAULT_OUTPUT_SPOOL_LINES: usize = 500;
57const DEFAULT_PROMPT_PREFIX: &str = "shpool:$SHPOOL_SESSION_NAME ";
58
59const SESSION_MSG_TIMEOUT: time::Duration = time::Duration::from_millis(500);
64
65pub struct Server {
66 config: config::Manager,
67 shells: Arc<Mutex<HashMap<String, Box<shell::Session>>>>,
74 runtime_dir: PathBuf,
75 register_new_reapable_session: crossbeam_channel::Sender<(String, Instant)>,
76 hooks: Box<dyn hooks::Hooks + Send + Sync>,
77 daily_messenger: Arc<show_motd::DailyMessenger>,
78 log_level_handle: tracing_subscriber::reload::Handle<
79 tracing_subscriber::filter::LevelFilter,
80 tracing_subscriber::registry::Registry,
81 >,
82}
83
84impl Server {
85 #[instrument(skip_all)]
86 pub fn new(
87 config: config::Manager,
88 hooks: Box<dyn hooks::Hooks + Send + Sync>,
89 runtime_dir: PathBuf,
90 log_level_handle: tracing_subscriber::reload::Handle<
91 tracing_subscriber::filter::LevelFilter,
92 tracing_subscriber::registry::Registry,
93 >,
94 ) -> anyhow::Result<Arc<Self>> {
95 let shells = Arc::new(Mutex::new(HashMap::new()));
96 let (new_sess_tx, new_sess_rx) = crossbeam_channel::bounded(10);
99 let shells_tab = Arc::clone(&shells);
100 thread::spawn(move || {
101 if let Err(e) = ttl_reaper::run(new_sess_rx, shells_tab) {
102 warn!("ttl reaper exited with error: {:?}", e);
103 }
104 });
105
106 let daily_messenger = Arc::new(show_motd::DailyMessenger::new(config.clone())?);
107 Ok(Arc::new(Server {
108 config,
109 shells,
110 runtime_dir,
111 register_new_reapable_session: new_sess_tx,
112 hooks,
113 daily_messenger,
114 log_level_handle,
115 }))
116 }
117
118 #[instrument(skip_all)]
119 pub fn serve(server: Arc<Self>, listener: UnixListener) -> anyhow::Result<()> {
120 test_hooks::emit("daemon-about-to-listen");
121 let mut conn_counter = 0;
122 for stream in listener.incoming() {
123 info!("socket got a new connection");
124 match stream {
125 Ok(stream) => {
126 conn_counter += 1;
127 let conn_id = conn_counter;
128 let server = Arc::clone(&server);
129 thread::spawn(move || {
130 if let Err(err) = server.handle_conn(stream, conn_id) {
131 error!("handling new connection: {:?}", err)
132 }
133 });
134 }
135 Err(err) => {
136 error!("accepting stream: {:?}", err);
137 }
138 }
139 }
140
141 Ok(())
142 }
143
144 #[instrument(skip_all, fields(cid = conn_id))]
145 fn handle_conn(&self, mut stream: UnixStream, conn_id: usize) -> anyhow::Result<()> {
146 if let Err(e) = stream.set_read_timeout(Some(consts::SOCK_STREAM_TIMEOUT)) {
151 #[cfg(target_os = "macos")]
152 if e.raw_os_error() == Some(libc::EINVAL) {
153 info!("EINVAL setting read timeout, peer already closed (presence probe)");
154 return Ok(());
155 }
156 return Err(e).context("setting read timeout on inbound session");
157 }
158
159 match protocol::encode_to(
162 &VersionHeader {
163 version: match env::var("SHPOOL_TEST__OVERRIDE_VERSION") {
167 Ok(fake_version) => fake_version,
168 Err(_) => String::from(shpool_protocol::VERSION),
169 },
170 },
171 &mut stream,
172 ) {
173 Ok(_) => {}
174 Err(e)
175 if e.root_cause()
176 .downcast_ref::<io::Error>()
177 .map(|ioe| ioe.kind() == io::ErrorKind::BrokenPipe)
178 .unwrap_or(false) =>
179 {
180 info!("broken pipe while writing version, likely just a daemon presence probe");
181 return Ok(());
182 }
183 Err(e) => return Err(e).context("while writing version"),
184 }
185
186 let header = parse_connect_header(&mut stream).context("parsing connect header")?;
187
188 if let Err(err) = check_peer(&stream) {
189 if let ConnectHeader::Attach(_) = header {
190 write_reply(
191 &mut stream,
192 AttachReplyHeader { status: AttachStatus::Forbidden(format!("{err:?}")) },
193 )?;
194 }
195 stream.shutdown(net::Shutdown::Both).context("closing stream")?;
196 return Err(err);
197 };
198
199 stream.set_read_timeout(None).context("unsetting read timout on inbound session")?;
204
205 match header {
206 ConnectHeader::Attach(h) => self.handle_attach(stream, conn_id, h),
207 ConnectHeader::Detach(r) => self.handle_detach(stream, r),
208 ConnectHeader::Kill(r) => self.handle_kill(stream, r),
209 ConnectHeader::List => self.handle_list(stream),
210 ConnectHeader::SessionMessage(header) => self.handle_session_message(stream, header),
211 ConnectHeader::SetLogLevel(r) => self.handle_set_log_level(stream, r),
212 }
213 }
214
215 #[instrument(skip_all)]
216 fn handle_attach(
217 &self,
218 stream: UnixStream,
219 conn_id: usize,
220 header: AttachHeader,
221 ) -> anyhow::Result<()> {
222 let user_info = user::info().context("resolving user info")?;
223 let shell_env = self.build_shell_env(&user_info, &header).context("building shell env")?;
224
225 let (child_exit_notifier, inner_to_stream, pager_ctl_slot, status) =
226 match self.select_shell_desc(stream, conn_id, &header, &user_info, &shell_env) {
227 Ok(t) => t,
228 Err(err)
229 if err
230 .downcast_ref::<ShellSelectionError>()
231 .map(|e| e == &ShellSelectionError::BusyShellSession)
232 .unwrap_or(false) =>
233 {
234 return Ok(());
235 }
236 Err(err) => return Err(err)?,
237 };
238 info!("released lock on shells table");
239
240 self.link_ssh_auth_sock(&header).context("linking SSH_AUTH_SOCK")?;
241 self.populate_session_env_file(&header).context("populating session env file")?;
242
243 if let (Some(child_exit_notifier), Some(inner), Some(pager_ctl_slot)) =
244 (child_exit_notifier, inner_to_stream, pager_ctl_slot)
245 {
246 let mut child_done = false;
247 let mut inner = inner.lock().unwrap();
248 let client_stream = match inner.client_stream.as_mut() {
249 Some(s) => s,
250 None => {
251 return Err(anyhow!("no client stream, should be impossible"));
252 }
253 };
254
255 let reply_status =
256 write_reply(client_stream, AttachReplyHeader { status: status.clone() });
257 if let Err(e) = reply_status {
258 error!("error writing reply status: {:?}", e);
259 }
260
261 let motd_mode = self.config.get().motd.clone().unwrap_or_default();
265 let init_tty_size = if matches!(motd_mode, MotdDisplayMode::Pager { .. }) {
266 match self.daily_messenger.display_in_pager(
267 client_stream,
268 pager_ctl_slot,
269 header.local_tty_size.clone(),
270 &shell_env,
271 ) {
272 Ok(Some(new_size)) => {
273 info!("motd pager finished, reporting new tty size: {:?}", new_size);
274 new_size
275 }
276 Ok(None) => {
277 info!("not time to show the motd in the pager yet");
278 header.local_tty_size.clone()
279 }
280 Err(e) => match e.downcast::<PagerError>() {
281 Ok(PagerError::ClientHangup) => {
282 info!("client hung up while talking to pager, bailing");
283 return Ok(());
284 }
285 Err(e) => {
286 return Err(e).context("showing motd in pager")?;
287 }
288 },
289 }
290 } else {
291 header.local_tty_size.clone()
292 };
293
294 info!("starting bidi stream loop");
295 match inner.bidi_stream(conn_id, init_tty_size, child_exit_notifier) {
296 Ok(done) => {
297 child_done = done;
298 }
299 Err(e) => {
300 error!("error shuffling bytes: {:?}", e);
301 }
302 }
303 info!("bidi stream loop finished child_done={}", child_done);
304
305 if child_done {
306 info!("'{}' exited, removing from session table", header.name);
307 if let Err(err) = self.hooks.on_shell_disconnect(&header.name) {
308 warn!("shell_disconnect hook: {:?}", err);
309 }
310
311 {
312 let _s = span!(Level::INFO, "2_lock(shells)").entered();
313 let mut shells = self.shells.lock().unwrap();
314 shells.remove(&header.name);
315 }
316
317 if let Some(h) = inner.shell_to_client_join_h.take() {
323 h.join()
324 .map_err(|e| anyhow!("joining shell->client after child exit: {:?}", e))?
325 .context("within shell->client thread after child exit")?;
326 }
327 } else {
328 {
330 let _s = span!(Level::INFO, "disconnect_lock(shells)").entered();
331 let shells = self.shells.lock().unwrap();
332 if let Some(session) = shells.get(&header.name) {
333 session.lifecycle_timestamps.lock().unwrap().last_disconnected_at =
334 Some(time::SystemTime::now());
335 }
336 }
337 if let Err(err) = self.hooks.on_client_disconnect(&header.name) {
338 warn!("client_disconnect hook: {:?}", err);
339 }
340 }
341
342 info!("finished attach streaming section");
343 } else {
344 error!("internal error: failed to fetch just inserted session");
345 }
346
347 Ok(())
348 }
349
350 #[allow(clippy::type_complexity)]
351 fn select_shell_desc(
352 &self,
353 mut stream: UnixStream,
354 conn_id: usize,
355 header: &AttachHeader,
356 user_info: &user::Info,
357 shell_env: &[(OsString, OsString)],
358 ) -> anyhow::Result<(
359 Option<Arc<ExitNotifier>>,
360 Option<Arc<Mutex<shell::SessionInner>>>,
361 Option<Arc<Mutex<Option<pager::PagerCtl>>>>,
362 AttachStatus,
363 )> {
364 let warnings = vec![];
365
366 let _s = span!(Level::INFO, "1_lock(shells)").entered();
368 let mut shells = self.shells.lock().unwrap();
369
370 let mut status = AttachStatus::Attached { warnings: warnings.clone() };
371 if let Some(session) = shells.get(&header.name) {
372 info!("found entry for '{}'", header.name);
373 if let Ok(mut inner) = session.inner.try_lock() {
374 let _s =
375 span!(Level::INFO, "aquired_lock(session.inner)", s = header.name).entered();
376 match session.child_exit_notifier.wait(Some(time::Duration::from_millis(0))) {
387 None => {
388 info!("taking over existing session inner");
390 inner.client_stream = Some(stream.try_clone()?);
391 session.lifecycle_timestamps.lock().unwrap().last_connected_at =
392 Some(time::SystemTime::now());
393
394 if inner
395 .shell_to_client_join_h
396 .as_ref()
397 .map(|h| h.is_finished())
398 .unwrap_or(false)
399 {
400 warn!(
401 "child_exited chan unclosed, but shell->client thread has exited, clobbering with new subshell"
402 );
403 status = AttachStatus::Created { warnings };
404 }
405
406 }
408 Some(exit_status) => {
409 info!(
411 "stale inner, (child exited with status {}) clobbering with new subshell",
412 exit_status
413 );
414 status = AttachStatus::Created { warnings };
415 }
416 }
417
418 if inner.shell_to_client_join_h.as_ref().map(|h| h.is_finished()).unwrap_or(false) {
419 info!("shell->client thread finished, joining");
420 if let Some(h) = inner.shell_to_client_join_h.take() {
421 h.join()
422 .map_err(|e| anyhow!("joining shell->client on reattach: {:?}", e))?
423 .context("within shell->client thread on reattach")?;
424 }
425 assert!(matches!(status, AttachStatus::Created { .. }));
426 }
427
428 } else {
430 info!("busy shell session, doing nothing");
431 write_reply(&mut stream, AttachReplyHeader { status: AttachStatus::Busy })?;
433 stream.shutdown(net::Shutdown::Both).context("closing stream")?;
434 if let Err(err) = self.hooks.on_busy(&header.name) {
435 warn!("busy hook: {:?}", err);
436 }
437 return Err(ShellSelectionError::BusyShellSession)?;
438 }
439 } else {
440 info!("no existing '{}' session, creating new one", &header.name);
441 status = AttachStatus::Created { warnings };
442 }
443
444 if matches!(status, AttachStatus::Created { .. }) {
445 info!("creating new subshell");
446 if let Err(err) = self.hooks.on_new_session(&header.name) {
447 warn!("new_session hook: {:?}", err);
448 }
449 let motd = self.config.get().motd.clone().unwrap_or_default();
450 let session = self.spawn_subshell(
451 conn_id,
452 stream,
453 header,
454 user_info,
455 shell_env,
456 matches!(motd, MotdDisplayMode::Dump),
457 )?;
458
459 session.lifecycle_timestamps.lock().unwrap().last_connected_at =
460 Some(time::SystemTime::now());
461 shells.insert(header.name.clone(), Box::new(session));
462 } else if let Err(err) = self.hooks.on_reattach(&header.name) {
464 warn!("reattach hook: {:?}", err);
465 }
466
467 if let Some(session) = shells.get(&header.name) {
471 Ok((
472 Some(Arc::clone(&session.child_exit_notifier)),
473 Some(Arc::clone(&session.inner)),
474 Some(Arc::clone(&session.pager_ctl)),
475 status,
476 ))
477 } else {
478 Ok((None, None, None, status))
479 }
480 }
481
482 #[instrument(skip_all)]
483 fn link_ssh_auth_sock(&self, header: &AttachHeader) -> anyhow::Result<()> {
484 if self.config.get().nosymlink_ssh_auth_sock.unwrap_or(false) {
485 return Ok(());
486 }
487
488 if let Some(ssh_auth_sock) = header.local_env_get("SSH_AUTH_SOCK") {
489 let symlink = self.ssh_auth_sock_symlink(PathBuf::from(&header.name));
490 fs::create_dir_all(symlink.parent().ok_or(anyhow!("no symlink parent dir"))?)
491 .context("could not create directory for SSH_AUTH_SOCK symlink")?;
492
493 let sessions_dir =
494 symlink.parent().and_then(|d| d.parent()).ok_or(anyhow!("no sessions dir"))?;
495 let sessions_meta = fs::metadata(sessions_dir).context("stating sessions dir")?;
496
497 let mut sessions_perm = sessions_meta.permissions();
499 if sessions_perm.mode() != 0o700 {
500 sessions_perm.set_mode(0o700);
501 fs::set_permissions(sessions_dir, sessions_perm)
502 .context("locking down permissions for sessions dir")?;
503 }
504
505 let _ = fs::remove_file(&symlink); os::unix::fs::symlink(ssh_auth_sock, &symlink).context(format!(
507 "could not symlink '{symlink:?}' to point to '{ssh_auth_sock:?}'"
508 ))?;
509 } else {
510 info!("no SSH_AUTH_SOCK in client env, leaving it unlinked");
511 }
512
513 Ok(())
514 }
515
516 #[instrument(skip_all)]
517 fn populate_session_env_file(&self, header: &AttachHeader) -> anyhow::Result<()> {
518 let session_name = PathBuf::from(&header.name);
519 fs::create_dir_all(self.session_dir(session_name.clone()))
520 .context("creating session dir")?;
521
522 let session_env_file = self.session_env_file(session_name);
523 info!("populating {:?}", session_env_file);
524 fs::write(
525 session_env_file,
526 header.local_env.iter().map(|(k, v)| format!("{k}={v}")).collect::<Vec<_>>().join("\n"),
527 )
528 .context("writing session env")?;
529
530 Ok(())
531 }
532
533 #[instrument(skip_all)]
534 fn handle_detach(&self, mut stream: UnixStream, request: DetachRequest) -> anyhow::Result<()> {
535 let mut not_found_sessions = vec![];
536 let mut not_attached_sessions = vec![];
537 {
538 let _s = span!(Level::INFO, "lock(shells)").entered();
539 let shells = self.shells.lock().unwrap();
540 for session in request.sessions.into_iter() {
541 if let Some(s) = shells.get(&session) {
542 let _s = span!(Level::INFO, "lock(shell_to_client_ctl)", s = session).entered();
543 let shell_to_client_ctl = s.shell_to_client_ctl.lock().unwrap();
544 shell_to_client_ctl
545 .client_connection
546 .send(shell::ClientConnectionMsg::Disconnect)
547 .context("sending client detach to shell->client")?;
548 let status = shell_to_client_ctl
549 .client_connection_ack
550 .recv()
551 .context("getting client conn ack")?;
552 info!("detached session({}), status = {:?}", session, status);
553 if let shell::ClientConnectionStatus::DetachNone = status {
554 not_attached_sessions.push(session);
555 } else {
556 s.lifecycle_timestamps.lock().unwrap().last_disconnected_at =
557 Some(time::SystemTime::now());
558 }
559 } else {
560 not_found_sessions.push(session);
561 }
562 }
563 }
564
565 write_reply(&mut stream, DetachReply { not_found_sessions, not_attached_sessions })
566 .context("writing detach reply")?;
567
568 Ok(())
569 }
570
571 #[instrument(skip_all)]
572 fn handle_set_log_level(
573 &self,
574 mut stream: UnixStream,
575 request: SetLogLevelRequest,
576 ) -> anyhow::Result<()> {
577 let level_filter = match request.level {
578 LogLevel::Off => tracing_subscriber::filter::LevelFilter::OFF,
579 LogLevel::Error => tracing_subscriber::filter::LevelFilter::ERROR,
580 LogLevel::Warn => tracing_subscriber::filter::LevelFilter::WARN,
581 LogLevel::Info => tracing_subscriber::filter::LevelFilter::INFO,
582 LogLevel::Debug => tracing_subscriber::filter::LevelFilter::DEBUG,
583 LogLevel::Trace => tracing_subscriber::filter::LevelFilter::TRACE,
584 };
585 if let Err(e) = self.log_level_handle.modify(|filter| *filter = level_filter) {
586 error!("modifying log level: {}", e);
587 }
588
589 write_reply(&mut stream, SetLogLevelReply {}).context("writing set log level reply")?;
590 Ok(())
591 }
592
593 #[instrument(skip_all)]
594 fn handle_kill(&self, mut stream: UnixStream, request: KillRequest) -> anyhow::Result<()> {
595 let mut not_found_sessions = vec![];
596 {
597 let _s = span!(Level::INFO, "lock(shells)").entered();
598 let mut shells = self.shells.lock().unwrap();
599
600 let mut to_remove = Vec::with_capacity(request.sessions.len());
601 for session in request.sessions.into_iter() {
602 if let Some(s) = shells.get(&session) {
603 s.kill().context("killing shell proc")?;
604
605 to_remove.push(session);
608 } else {
609 not_found_sessions.push(session);
610 }
611 }
612
613 for session in to_remove.iter() {
614 shells.remove(session);
615 }
616 if !to_remove.is_empty() {
617 test_hooks::emit("daemon-handle-kill-removed-shells");
618 }
619 }
620
621 write_reply(&mut stream, KillReply { not_found_sessions }).context("writing kill reply")?;
622
623 Ok(())
624 }
625
626 #[instrument(skip_all)]
627 fn handle_list(&self, mut stream: UnixStream) -> anyhow::Result<()> {
628 let _s = span!(Level::INFO, "lock(shells)").entered();
629 let shells = self.shells.lock().unwrap();
630
631 let sessions: anyhow::Result<Vec<Session>> = shells
632 .iter()
633 .map(|(k, v)| {
634 let status = match v.inner.try_lock() {
635 Ok(_) => SessionStatus::Disconnected,
636 Err(_) => SessionStatus::Attached,
637 };
638
639 let timestamps = v.lifecycle_timestamps.lock().unwrap();
640 let last_connected_at_unix_ms = timestamps
641 .last_connected_at
642 .map(|t| t.duration_since(time::UNIX_EPOCH).map(|d| d.as_millis() as i64))
643 .transpose()?;
644
645 let last_disconnected_at_unix_ms = timestamps
646 .last_disconnected_at
647 .map(|t| t.duration_since(time::UNIX_EPOCH).map(|d| d.as_millis() as i64))
648 .transpose()?;
649
650 Ok(Session {
651 name: k.to_string(),
652 started_at_unix_ms: v.started_at.duration_since(time::UNIX_EPOCH)?.as_millis()
653 as i64,
654 last_connected_at_unix_ms,
655 last_disconnected_at_unix_ms,
656 status,
657 })
658 })
659 .collect();
660 let sessions = sessions.context("collecting running session metadata")?;
661
662 write_reply(&mut stream, ListReply { sessions })?;
663
664 Ok(())
665 }
666
667 #[instrument(skip_all, fields(s = &header.session_name))]
668 fn handle_session_message(
669 &self,
670 mut stream: UnixStream,
671 header: SessionMessageRequest,
672 ) -> anyhow::Result<()> {
673 let reply = {
676 let _s = span!(Level::INFO, "lock(shells)").entered();
677 let shells = self.shells.lock().unwrap();
678 if let Some(session) = shells.get(&header.session_name) {
679 match header.payload {
680 SessionMessageRequestPayload::Resize(resize_request) => {
681 let _s = span!(Level::INFO, "lock(pager_ctl)").entered();
682 let pager_ctl = session.pager_ctl.lock().unwrap();
683 if let Some(pager_ctl) = pager_ctl.as_ref() {
684 info!("resizing pager");
685 pager_ctl
686 .tty_size_change
687 .send_timeout(resize_request.tty_size.clone(), SESSION_MSG_TIMEOUT)
688 .context("sending tty size change to pager")?;
689 pager_ctl
690 .tty_size_change_ack
691 .recv_timeout(SESSION_MSG_TIMEOUT)
692 .context("recving tty size change ack from pager")?;
693 } else {
694 let _s =
695 span!(Level::INFO, "resize_lock(shell_to_client_ctl)").entered();
696 let shell_to_client_ctl = session.shell_to_client_ctl.lock().unwrap();
697 shell_to_client_ctl
698 .tty_size_change
699 .send_timeout(resize_request.tty_size, SESSION_MSG_TIMEOUT)
700 .context("sending tty size change to shell->client")?;
701 shell_to_client_ctl
702 .tty_size_change_ack
703 .recv_timeout(SESSION_MSG_TIMEOUT)
704 .context("recving tty size ack")?;
705 }
706
707 SessionMessageReply::Resize(ResizeReply::Ok)
708 }
709 SessionMessageRequestPayload::Detach => {
710 let _s = span!(Level::INFO, "detach_lock(shell_to_client_ctl)").entered();
711 let shell_to_client_ctl = session.shell_to_client_ctl.lock().unwrap();
712 shell_to_client_ctl
713 .client_connection
714 .send_timeout(
715 shell::ClientConnectionMsg::Disconnect,
716 SESSION_MSG_TIMEOUT,
717 )
718 .context("sending client detach to shell->client")?;
719 let status = shell_to_client_ctl
720 .client_connection_ack
721 .recv_timeout(SESSION_MSG_TIMEOUT)
722 .context("getting client conn ack")?;
723 info!("detached session({}), status = {:?}", header.session_name, status);
724 SessionMessageReply::Detach(SessionMessageDetachReply::Ok)
725 }
726 }
727 } else {
728 SessionMessageReply::NotFound
729 }
730 };
731
732 write_reply(&mut stream, reply).context("handle_session_message: writing reply")?;
733
734 Ok(())
735 }
736
737 #[instrument(skip_all)]
741 fn spawn_subshell(
742 &self,
743 conn_id: usize,
744 client_stream: UnixStream,
745 header: &AttachHeader,
746 user_info: &user::Info,
747 shell_env: &[(OsString, OsString)],
748 dump_motd_on_new_session: bool,
749 ) -> anyhow::Result<shell::Session> {
750 let shell = if let Some(s) = &self.config.get().shell {
751 s.clone()
752 } else {
753 user_info.default_shell.clone()
754 };
755 info!("user_info={:?}", user_info);
756
757 let mut cmd = if let Some(cmd_str) = &header.cmd {
762 let cmd_parts = shell_words::split(cmd_str).context("parsing cmd")?;
763 info!("running cmd: {:?}", cmd_parts);
764 if cmd_parts.is_empty() {
765 return Err(anyhow!("no command to run"));
766 }
767 let mut cmd = process::Command::new(&cmd_parts[0]);
768 cmd.args(&cmd_parts[1..]);
769 cmd
770 } else {
771 let mut cmd = process::Command::new(&shell);
772 if self.config.get().norc.unwrap_or(false) {
773 if shell.ends_with("bash") {
774 cmd.arg("--norc").arg("--noprofile");
775 } else if shell.ends_with("zsh") {
776 cmd.arg("--no-rcs");
777 } else if shell.ends_with("fish") {
778 cmd.arg("--no-config");
779 }
780 }
781 cmd
782 };
783
784 let start_dir = match header.dir.as_deref() {
785 None => user_info.home_dir.clone(),
786 Some(path) => String::from(path),
787 };
788 info!("spawning shell in '{}'", start_dir);
789 cmd.current_dir(start_dir)
790 .stdin(process::Stdio::inherit())
791 .stdout(process::Stdio::inherit())
792 .stderr(process::Stdio::inherit())
793 .env_clear();
798
799 let term = shell_env.iter().filter(|(k, _)| k == "TERM").map(|(_, v)| v).next();
800 cmd.envs(shell_env.to_vec());
801 let fallback_terminfo = || match termini::TermInfo::from_name("xterm") {
802 Ok(db) => Ok(db),
803 Err(err) => {
804 warn!("could not get xterm terminfo: {:?}", err);
805 let empty_db = io::Cursor::new(vec![]);
806 termini::TermInfo::parse(empty_db).context("getting terminfo db")
807 }
808 };
809 let term_db = Arc::new(if let Some(term) = &term {
810 match termini::TermInfo::from_name(term.to_string_lossy().as_ref())
811 .context("resolving terminfo")
812 {
813 Ok(ti) => ti,
814 Err(err) => {
815 warn!("could not get terminfo for '{:?}': {:?}", term, err);
816 fallback_terminfo()?
817 }
818 }
819 } else {
820 warn!("no $TERM, using default terminfo");
821 match termini::TermInfo::from_env() {
822 Ok(db) => db,
823 Err(err) => {
824 warn!("could not get terminfo from env: {:?}", err);
825 fallback_terminfo()?
826 }
827 }
828 });
829
830 if header.cmd.is_none() {
831 let shell_basename = Path::new(&shell)
837 .file_name()
838 .ok_or(anyhow!("error building login shell indicator"))?
839 .to_str()
840 .ok_or(anyhow!("error parsing shell name as utf8"))?;
841 cmd.arg0(format!("-{shell_basename}"));
842 };
843
844 let noecho = self.config.get().noecho.unwrap_or(false);
845 info!("about to fork subshell noecho={}", noecho);
846 let mut fork = shpool_pty::fork::Fork::from_ptmx().context("forking pty")?;
847 if let Ok(slave) = fork.is_child() {
848 if noecho {
849 if let Some(fd) = slave.borrow_fd() {
850 tty::disable_echo(fd).context("disabling echo on pty")?;
851 }
852 }
853 for fd in consts::STDERR_FD + 1..(nix::unistd::SysconfVar::OPEN_MAX as i32) {
854 let _ = nix::unistd::close(fd);
855 }
856 let err = cmd.exec();
857 eprintln!("shell exec err: {err:?}");
858 std::process::exit(1);
859 }
860
861 let child_exit_notifier = Arc::new(ExitNotifier::new());
864
865 let waitable_child_pid = fork.child_pid().ok_or(anyhow!("missing child pid"))?;
876 let session_name = header.name.clone();
877 let notifiable_child_exit_notifier = Arc::clone(&child_exit_notifier);
878 let shell_to_client_child_exit_notifier = Arc::clone(&child_exit_notifier);
879 thread::spawn(move || {
880 let _s = span!(Level::INFO, "child_watcher", s = session_name, cid = conn_id).entered();
881
882 let mut err = None;
883 let mut status = 0;
884 let mut unpacked_status = None;
885 loop {
886 unsafe {
888 match libc::waitpid(waitable_child_pid, &mut status, 0) {
889 0 => continue,
890 -1 => {
891 err = Some("waitpid failed");
892 break;
893 }
894 _ => {
895 if libc::WIFEXITED(status) {
896 unpacked_status = Some(libc::WEXITSTATUS(status));
897 }
898 break;
899 }
900 }
901 }
902 }
903 if let Some(status) = unpacked_status {
904 info!("child exited with status {}", status);
905 notifiable_child_exit_notifier.notify_exit(status);
906 } else {
907 if let Some(e) = err {
908 info!("child exited without status, using 1: {:?}", e);
909 } else {
910 info!("child exited without status, using 1");
911 }
912 notifiable_child_exit_notifier.notify_exit(1);
913 }
914 });
915
916 let prompt_prefix_is_blank =
917 self.config.get().prompt_prefix.as_ref().map(|p| p.is_empty()).unwrap_or(false);
918 let supports_sentinels =
919 header.cmd.is_none() && !prompt_prefix_is_blank && !does_not_support_sentinels(&shell);
920 info!("supports_sentianls={}", supports_sentinels);
921
922 if supports_sentinels {
926 info!("injecting prompt prefix");
927 let prompt_prefix = self
928 .config
929 .get()
930 .prompt_prefix
931 .clone()
932 .unwrap_or(String::from(DEFAULT_PROMPT_PREFIX));
933 if let Err(err) = prompt::maybe_inject_prefix(&mut fork, &prompt_prefix, &header.name) {
934 warn!("issue injecting prefix: {:?}", err);
935 }
936 }
937
938 let (client_connection_tx, client_connection_rx) = crossbeam_channel::bounded(0);
939 let (client_connection_ack_tx, client_connection_ack_rx) = crossbeam_channel::bounded(0);
940 let (tty_size_change_tx, tty_size_change_rx) = crossbeam_channel::bounded(0);
941 let (tty_size_change_ack_tx, tty_size_change_ack_rx) = crossbeam_channel::bounded(0);
942
943 let (heartbeat_tx, heartbeat_rx) = crossbeam_channel::bounded(0);
944 let (heartbeat_ack_tx, heartbeat_ack_rx) = crossbeam_channel::bounded(0);
945
946 let shell_to_client_ctl = Arc::new(Mutex::new(shell::ReaderCtl {
947 client_connection: client_connection_tx,
948 client_connection_ack: client_connection_ack_rx,
949 tty_size_change: tty_size_change_tx,
950 tty_size_change_ack: tty_size_change_ack_rx,
951 heartbeat: heartbeat_tx,
952 heartbeat_ack: heartbeat_ack_rx,
953 }));
954
955 let mut session_inner = shell::SessionInner {
956 name: header.name.clone(),
957 shell_to_client_ctl: Arc::clone(&shell_to_client_ctl),
958 pty_master: fork,
959 client_stream: Some(client_stream),
960 config: self.config.clone(),
961 shell_to_client_join_h: None,
962 term_db,
963 daily_messenger: Arc::clone(&self.daily_messenger),
964 needs_initial_motd_dump: dump_motd_on_new_session,
965 supports_sentinels,
966 };
967 let child_pid = session_inner.pty_master.child_pid().ok_or(anyhow!("no child pid"))?;
968 session_inner.shell_to_client_join_h =
969 Some(session_inner.spawn_shell_to_client(shell::ShellToClientArgs {
970 conn_id,
971 tty_size: header.local_tty_size.clone(),
972 scrollback_lines: match (
973 self.config.get().output_spool_lines,
974 &self.config.get().session_restore_mode,
975 ) {
976 (Some(l), _) => l,
977 (None, Some(config::SessionRestoreMode::Lines(l))) => *l as usize,
978 (None, _) => DEFAULT_OUTPUT_SPOOL_LINES,
979 },
980 client_connection: client_connection_rx,
981 client_connection_ack: client_connection_ack_tx,
982 tty_size_change: tty_size_change_rx,
983 tty_size_change_ack: tty_size_change_ack_tx,
984 heartbeat: heartbeat_rx,
985 heartbeat_ack: heartbeat_ack_tx,
986 child_exit_notifier: shell_to_client_child_exit_notifier,
987 })?);
988
989 if let Some(ttl_secs) = header.ttl_secs {
990 info!("registering session with ttl with the reaper");
991 self.register_new_reapable_session
992 .send((header.name.clone(), Instant::now().add(Duration::from_secs(ttl_secs))))
993 .context("sending reapable session registration msg")?;
994 }
995
996 Ok(shell::Session {
997 shell_to_client_ctl,
998 pager_ctl: Arc::new(Mutex::new(None)),
999 child_pid,
1000 child_exit_notifier,
1001 started_at: time::SystemTime::now(),
1002 lifecycle_timestamps: Mutex::new(shell::SessionLifecycleTimestamps::default()),
1003 inner: Arc::new(Mutex::new(session_inner)),
1004 })
1005 }
1006
1007 #[instrument(skip_all)]
1009 fn build_shell_env(
1010 &self,
1011 user_info: &user::Info,
1012 header: &AttachHeader,
1013 ) -> anyhow::Result<Vec<(OsString, OsString)>> {
1014 let s = OsString::from;
1015 let config = self.config.get();
1016 let auth_sock = self.ssh_auth_sock_symlink(PathBuf::from(&header.name));
1017 let mut env = vec![
1018 (s("HOME"), s(&user_info.home_dir)),
1019 (
1020 s("PATH"),
1021 s(config
1022 .initial_path
1023 .as_ref()
1024 .map(|x| x.as_ref())
1025 .unwrap_or(DEFAULT_INITIAL_SHELL_PATH)),
1026 ),
1027 (s("SHPOOL_SESSION_NAME"), s(&header.name)),
1028 (
1029 s("SHPOOL_SESSION_DIR"),
1030 self.session_dir(PathBuf::from(&header.name)).into_os_string(),
1031 ),
1032 (s("SHELL"), s(&user_info.default_shell)),
1033 (s("USER"), s(&user_info.user)),
1034 (
1035 s("SSH_AUTH_SOCK"),
1036 s(auth_sock.to_str().ok_or(anyhow!("failed to convert auth sock symlink"))?),
1037 ),
1038 ];
1039
1040 if let Some(xdg_runtime_dir) = env::var_os("XDG_RUNTIME_DIR") {
1041 env.push((s("XDG_RUNTIME_DIR"), xdg_runtime_dir));
1042 }
1043
1044 let mut term = None;
1050 if let Some(t) = header.local_env_get("TERM") {
1051 term = Some(String::from(t));
1052 }
1053 let filtered_env_pin;
1054 if let Some(extra_env) = config.env.as_ref() {
1055 term = match extra_env.get("TERM") {
1056 None => term,
1057 Some(t) if t.is_empty() => None,
1058 Some(t) => Some(String::from(t)),
1059 };
1060
1061 let extra_env = if term.is_none() {
1068 let mut e = extra_env.clone();
1069 e.remove("TERM");
1070 filtered_env_pin = Some(e);
1071 filtered_env_pin.as_ref().unwrap()
1072 } else {
1073 extra_env
1074 };
1075
1076 if !env.is_empty() {
1077 env.extend(extra_env.iter().map(|(k, v)| (s(k), s(v))));
1078 }
1079 }
1080 info!("injecting TERM into shell {:?}", term);
1081 if let Some(t) = &term {
1082 env.push((s("TERM"), s(t)));
1083 }
1084
1085 for (var, val) in &header.local_env {
1087 if var == "TERM" || var == "SSH_AUTH_SOCK" {
1088 continue;
1089 }
1090 env.push((s(var), s(val)));
1091 }
1092
1093 if !self.config.get().noread_etc_environment.unwrap_or(false) {
1095 match fs::File::open("/etc/environment") {
1096 Ok(f) => {
1097 let pairs = etc_environment::parse_compat(io::BufReader::new(f))?;
1098 for (var, val) in pairs.into_iter() {
1099 env.push((var.into(), val.into()));
1100 }
1101 }
1102 Err(e) => {
1103 warn!("could not open /etc/environment to load env vars: {:?}", e);
1104 }
1105 }
1106 }
1107 debug!("ENV: {env:?}");
1108
1109 Ok(env)
1110 }
1111
1112 fn session_env_file<P: AsRef<Path>>(&self, session_name: P) -> PathBuf {
1115 self.session_dir(session_name).join("forward.env")
1116 }
1117
1118 fn ssh_auth_sock_symlink<P: AsRef<Path>>(&self, session_name: P) -> PathBuf {
1119 self.session_dir(session_name).join("ssh-auth-sock.socket")
1120 }
1121
1122 fn session_dir<P: AsRef<Path>>(&self, session_name: P) -> PathBuf {
1123 self.runtime_dir.join("sessions").join(session_name)
1124 }
1125}
1126
1127fn does_not_support_sentinels(shell: &str) -> bool {
1131 shell.ends_with("nu")
1132}
1133
1134#[instrument(skip_all)]
1135fn parse_connect_header(stream: &mut UnixStream) -> anyhow::Result<ConnectHeader> {
1136 let header: ConnectHeader = protocol::decode_from(stream).context("parsing header")?;
1137 Ok(header)
1138}
1139
1140#[instrument(skip_all)]
1141fn write_reply<H>(stream: &mut UnixStream, header: H) -> anyhow::Result<()>
1142where
1143 H: serde::Serialize,
1144{
1145 stream
1146 .set_write_timeout(Some(consts::SOCK_STREAM_TIMEOUT))
1147 .context("setting write timout on inbound session")?;
1148
1149 let serializeable_stream = stream.try_clone().context("cloning stream handle")?;
1150 protocol::encode_to(&header, serializeable_stream).context("writing reply")?;
1151
1152 stream.set_write_timeout(None).context("unsetting write timout on inbound session")?;
1153
1154 Ok(())
1155}
1156
1157#[cfg(target_os = "linux")]
1161fn check_peer(sock: &UnixStream) -> anyhow::Result<()> {
1162 use nix::sys::socket;
1163
1164 let peer_creds = socket::getsockopt(sock, socket::sockopt::PeerCredentials)
1165 .context("could not get peer creds from socket")?;
1166 let peer_uid = unistd::Uid::from_raw(peer_creds.uid());
1167 let self_uid = unistd::Uid::current();
1168 if peer_uid != self_uid {
1169 return Err(anyhow!("shpool prohibits connections across users"));
1170 }
1171
1172 let peer_pid = unistd::Pid::from_raw(peer_creds.pid());
1173 let self_pid = unistd::Pid::this();
1174 let peer_exe = exe_for_pid(peer_pid).context("could not resolve exe from the pid")?;
1175 let self_exe = exe_for_pid(self_pid).context("could not resolve our own exe")?;
1176 if peer_exe != self_exe {
1177 warn!("attach binary differs from daemon binary");
1178 }
1179
1180 Ok(())
1181}
1182
1183#[cfg(target_os = "macos")]
1184fn check_peer(sock: &UnixStream) -> anyhow::Result<()> {
1185 use std::os::unix::io::AsRawFd;
1186
1187 let mut peer_uid: libc::uid_t = 0;
1188 let mut peer_gid: libc::gid_t = 0;
1189 unsafe {
1191 if libc::getpeereid(sock.as_raw_fd(), &mut peer_uid, &mut peer_gid) != 0 {
1192 return Err(anyhow!(
1193 "could not get peer uid from socket: {}",
1194 io::Error::last_os_error()
1195 ));
1196 }
1197 }
1198 let peer_uid = unistd::Uid::from_raw(peer_uid);
1199 let self_uid = unistd::Uid::current();
1200 if peer_uid != self_uid {
1201 return Err(anyhow!("shpool prohibits connections across users"));
1202 }
1203
1204 let mut peer_pid: libc::pid_t = 0;
1205 let mut len = std::mem::size_of::<libc::pid_t>() as libc::socklen_t;
1206 unsafe {
1208 if libc::getsockopt(
1209 sock.as_raw_fd(),
1210 libc::SOL_LOCAL,
1211 libc::LOCAL_PEERPID,
1212 &mut peer_pid as *mut _ as *mut libc::c_void,
1213 &mut len,
1214 ) != 0
1215 {
1216 return Err(anyhow!(
1217 "could not get peer pid from socket: {}",
1218 io::Error::last_os_error()
1219 ));
1220 }
1221 }
1222
1223 let peer_pid = unistd::Pid::from_raw(peer_pid);
1224 let self_pid = unistd::Pid::this();
1225 let peer_exe = exe_for_pid(peer_pid).context("could not resolve exe from the pid")?;
1226 let self_exe = exe_for_pid(self_pid).context("could not resolve our own exe")?;
1227 if peer_exe != self_exe {
1228 warn!("attach binary differs from daemon binary");
1229 }
1230
1231 Ok(())
1232}
1233
1234#[cfg(target_os = "linux")]
1235fn exe_for_pid(pid: unistd::Pid) -> anyhow::Result<PathBuf> {
1236 let path = std::fs::read_link(format!("/proc/{pid}/exe"))?;
1237 Ok(path)
1238}
1239
1240#[cfg(target_os = "macos")]
1241fn exe_for_pid(pid: unistd::Pid) -> anyhow::Result<PathBuf> {
1242 use libproc::proc_pid::pidpath;
1243 let path = pidpath(pid.as_raw())
1244 .map_err(|e| anyhow!("could not get exe path for pid {}: {:?}", pid, e))?;
1245 Ok(PathBuf::from(path))
1246}
1247
1248#[derive(Debug, Clone, PartialEq, Eq)]
1249pub enum ShellSelectionError {
1250 BusyShellSession,
1251}
1252
1253impl std::fmt::Display for ShellSelectionError {
1254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
1255 write!(f, "{self:?}")?;
1256 Ok(())
1257 }
1258}
1259
1260impl std::error::Error for ShellSelectionError {}