use std::os::fd::{AsRawFd, FromRawFd};
use std::process::Stdio;
use std::sync::Arc;
use anyhow::{Context, Result};
use russh::server::{Handle, Msg};
use russh::{ChannelId, ChannelStream};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::process::Child;
use tokio::sync::{mpsc, RwLock};
use super::pty::{PtyConfig, PtyMaster};
use crate::shared::auth_types::UserInfo;
const IO_BUFFER_SIZE: usize = 8192;
pub struct ShellSession {
channel_id: ChannelId,
pty: Arc<RwLock<PtyMaster>>,
child: Option<Child>,
}
impl ShellSession {
pub fn new(channel_id: ChannelId, config: PtyConfig) -> Result<Self> {
let pty = PtyMaster::open(config).context("Failed to create PTY")?;
Ok(Self {
channel_id,
pty: Arc::new(RwLock::new(pty)),
child: None,
})
}
async fn spawn_shell(&self, user_info: &UserInfo) -> Result<Child> {
let pty = self.pty.read().await;
let slave_path = pty.slave_path().clone();
let term = pty.config().term.clone();
drop(pty);
let shell = user_info.shell.clone();
let home_dir = user_info.home_dir.clone();
let username = user_info.username.clone();
if !shell.exists() {
anyhow::bail!("Shell does not exist: {}", shell.display());
}
let slave_file = std::fs::OpenOptions::new()
.read(true)
.write(true)
.open(&slave_path)
.context("Failed to open slave PTY")?;
let slave_fd = slave_file.as_raw_fd();
let stdin_fd = unsafe { nix::libc::dup(slave_fd) };
let stdout_fd = unsafe { nix::libc::dup(slave_fd) };
let stderr_fd = unsafe { nix::libc::dup(slave_fd) };
if stdin_fd < 0 || stdout_fd < 0 || stderr_fd < 0 {
unsafe {
if stdin_fd >= 0 {
nix::libc::close(stdin_fd);
}
if stdout_fd >= 0 {
nix::libc::close(stdout_fd);
}
if stderr_fd >= 0 {
nix::libc::close(stderr_fd);
}
}
anyhow::bail!("Failed to duplicate slave PTY file descriptor");
}
drop(slave_file);
let mut cmd = tokio::process::Command::new(&shell);
cmd.arg("-l");
cmd.env_clear();
cmd.env("HOME", &home_dir);
cmd.env("USER", &username);
cmd.env("LOGNAME", &username);
cmd.env("SHELL", &shell);
cmd.env("TERM", &term);
cmd.env("PATH", "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin");
cmd.current_dir(&home_dir);
unsafe {
cmd.stdin(Stdio::from_raw_fd(stdin_fd));
cmd.stdout(Stdio::from_raw_fd(stdout_fd));
cmd.stderr(Stdio::from_raw_fd(stderr_fd));
}
cmd.kill_on_drop(true);
unsafe {
cmd.pre_exec(|| {
nix::unistd::setsid().map_err(|e| std::io::Error::other(e.to_string()))?;
if nix::libc::ioctl(0, nix::libc::TIOCSCTTY as _, 0) < 0 {
return Err(std::io::Error::last_os_error());
}
Ok(())
});
}
let child = cmd.spawn().context("Failed to spawn shell process")?;
tracing::info!(
shell = %shell.display(),
home = %home_dir.display(),
user = %username,
"Shell process spawned"
);
Ok(child)
}
pub fn take_child(&mut self) -> Option<Child> {
self.child.take()
}
pub fn pty(&self) -> &Arc<RwLock<PtyMaster>> {
&self.pty
}
pub fn channel_id(&self) -> ChannelId {
self.channel_id
}
pub async fn spawn_shell_process(&mut self, user_info: &UserInfo) -> Result<()> {
let child = self.spawn_shell(user_info).await?;
self.child = Some(child);
Ok(())
}
pub async fn resize(&self, cols: u32, rows: u32) -> Result<()> {
let mut pty = self.pty.write().await;
pty.resize(cols, rows)
}
}
pub async fn run_shell_io_loop(
channel_id: ChannelId,
pty: Arc<RwLock<PtyMaster>>,
mut child: Option<Child>,
mut channel_stream: ChannelStream<Msg>,
) -> i32 {
let mut pty_buf = vec![0u8; IO_BUFFER_SIZE];
let mut ssh_buf = vec![0u8; IO_BUFFER_SIZE];
tracing::debug!(channel = ?channel_id, "Starting shell I/O loop (ChannelStream)");
let mut iteration = 0u64;
loop {
iteration += 1;
tracing::debug!(channel = ?channel_id, iter = iteration, "I/O loop iteration start");
if let Some(ref mut c) = child {
match c.try_wait() {
Ok(Some(status)) => {
tracing::debug!(
channel = ?channel_id,
exit_code = ?status.code(),
"Shell process exited"
);
drain_pty_output_to_stream(channel_id, &pty, &mut channel_stream, &mut pty_buf)
.await;
return status.code().unwrap_or(1);
}
Ok(None) => {
}
Err(e) => {
tracing::warn!(
channel = ?channel_id,
error = %e,
"Error checking child process status"
);
}
}
}
tracing::debug!(channel = ?channel_id, iter = iteration, "About to enter select! (PTY read vs SSH read)");
tokio::select! {
read_result = async {
let pty_guard = pty.read().await;
pty_guard.read(&mut pty_buf).await
} => {
tracing::debug!(channel = ?channel_id, iter = iteration, result = ?read_result.as_ref().map(|n| *n), "PTY read branch triggered");
match read_result {
Ok(0) => {
tracing::debug!(channel = ?channel_id, "PTY EOF");
return wait_for_child(&mut child).await;
}
Ok(n) => {
tracing::debug!(channel = ?channel_id, bytes = n, "Read from PTY, writing to SSH");
if let Err(e) = channel_stream.write_all(&pty_buf[..n]).await {
tracing::debug!(
channel = ?channel_id,
error = %e,
"Failed to write to channel stream"
);
return wait_for_child(&mut child).await;
}
if let Err(e) = channel_stream.flush().await {
tracing::debug!(
channel = ?channel_id,
error = %e,
"Failed to flush channel stream"
);
}
}
Err(e) => {
if e.kind() == std::io::ErrorKind::WouldBlock {
continue;
}
tracing::debug!(
channel = ?channel_id,
error = %e,
"PTY read error"
);
return wait_for_child(&mut child).await;
}
}
}
read_result = channel_stream.read(&mut ssh_buf) => {
tracing::debug!(channel = ?channel_id, iter = iteration, result = ?read_result.as_ref().map(|n| *n), "SSH read branch triggered");
match read_result {
Ok(0) => {
tracing::debug!(channel = ?channel_id, "SSH channel stream EOF");
drain_pty_output_to_stream(channel_id, &pty, &mut channel_stream, &mut pty_buf)
.await;
if let Some(ref mut c) = child {
let _ = c.kill().await;
}
return wait_for_child(&mut child).await;
}
Ok(n) => {
tracing::debug!(channel = ?channel_id, bytes = n, "Read from SSH, writing to PTY");
let pty_guard = pty.read().await;
if let Err(e) = pty_guard.write_all(&ssh_buf[..n]).await {
tracing::debug!(
channel = ?channel_id,
error = %e,
"PTY write error"
);
}
}
Err(e) => {
tracing::debug!(
channel = ?channel_id,
error = %e,
"SSH channel stream read error"
);
if let Some(ref mut c) = child {
let _ = c.kill().await;
}
return wait_for_child(&mut child).await;
}
}
}
}
}
}
async fn drain_pty_output_to_stream(
channel_id: ChannelId,
pty: &Arc<RwLock<PtyMaster>>,
channel_stream: &mut ChannelStream<Msg>,
buf: &mut [u8],
) {
tracing::debug!(channel = ?channel_id, "Starting PTY drain");
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let mut consecutive_timeouts = 0;
for _ in 0..100 {
let pty_guard = pty.read().await;
match tokio::time::timeout(std::time::Duration::from_millis(100), pty_guard.read(buf)).await
{
Ok(Ok(0)) => break,
Ok(Ok(n)) => {
consecutive_timeouts = 0;
drop(pty_guard);
if channel_stream.write_all(&buf[..n]).await.is_err() {
break;
}
let _ = channel_stream.flush().await;
}
Ok(Err(_)) => break,
Err(_) => {
consecutive_timeouts += 1;
if consecutive_timeouts >= 3 {
break;
}
}
}
}
tracing::trace!(channel = ?channel_id, "Drained PTY output");
}
async fn wait_for_child(child: &mut Option<Child>) -> i32 {
if let Some(ref mut c) = child {
match c.wait().await {
Ok(status) => status.code().unwrap_or(1),
Err(e) => {
tracing::warn!(error = %e, "Error waiting for shell process");
1
}
}
} else {
1
}
}
pub async fn run_shell_io_loop_with_handle(
channel_id: ChannelId,
pty: Arc<RwLock<PtyMaster>>,
mut child: Option<Child>,
handle: Handle,
mut data_rx: mpsc::Receiver<Vec<u8>>,
) -> i32 {
tracing::debug!(channel = ?channel_id, "Starting shell I/O loop (Handle-based, spawned output task)");
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
let pty_clone = Arc::clone(&pty);
let handle_clone = handle.clone();
let output_task = tokio::spawn(async move {
let mut buf = vec![0u8; IO_BUFFER_SIZE];
loop {
tokio::select! {
biased;
_ = shutdown_rx.recv() => {
tracing::trace!(channel = ?channel_id, "Output task received shutdown signal");
break;
}
read_result = async {
let pty_guard = pty_clone.read().await;
tokio::time::timeout(
std::time::Duration::from_millis(50),
pty_guard.read(&mut buf)
).await
} => {
match read_result {
Err(_elapsed) => {
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
continue;
}
Ok(Ok(0)) => {
tracing::trace!(channel = ?channel_id, "PTY EOF in output task");
break;
}
Ok(Ok(n)) => {
tracing::trace!(channel = ?channel_id, bytes = n, "Read from PTY, calling handle.data()");
let data = bytes::Bytes::copy_from_slice(&buf[..n]);
match handle_clone.data(channel_id, data).await {
Ok(_) => {
tracing::trace!(channel = ?channel_id, "handle.data() returned successfully");
tokio::task::yield_now().await;
}
Err(e) => {
tracing::debug!(
channel = ?channel_id,
error = ?e,
"Output task: failed to send data"
);
break;
}
}
}
Ok(Err(e)) => {
if e.kind() != std::io::ErrorKind::WouldBlock {
tracing::debug!(
channel = ?channel_id,
error = %e,
"Output task: PTY read error"
);
break;
}
}
}
}
}
}
});
let exit_code = loop {
if let Some(ref mut c) = child {
match c.try_wait() {
Ok(Some(status)) => {
tracing::debug!(
channel = ?channel_id,
exit_code = ?status.code(),
"Shell process exited"
);
break status.code().unwrap_or(1);
}
Ok(None) => {
}
Err(e) => {
tracing::warn!(
channel = ?channel_id,
error = %e,
"Error checking child process status"
);
}
}
}
tokio::select! {
Some(data) = data_rx.recv() => {
tracing::debug!(
channel = ?channel_id,
bytes = data.len(),
"Received data from SSH via mpsc, writing to PTY"
);
let pty_guard = pty.read().await;
if let Err(e) = pty_guard.write_all(&data).await {
tracing::debug!(
channel = ?channel_id,
error = %e,
"Failed to write to PTY"
);
} else {
tracing::debug!(
channel = ?channel_id,
bytes = data.len(),
"Successfully wrote data to PTY"
);
}
}
_ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
}
}
};
let _ = shutdown_tx.send(()).await;
match tokio::time::timeout(std::time::Duration::from_secs(1), output_task).await {
Ok(Ok(())) => tracing::debug!(channel = ?channel_id, "Output task completed"),
Ok(Err(e)) => tracing::warn!(channel = ?channel_id, error = %e, "Output task panicked"),
Err(_) => tracing::warn!(channel = ?channel_id, "Output task timed out"),
}
exit_code
}
impl Drop for ShellSession {
fn drop(&mut self) {
if let Some(ref mut child) = self.child {
let _ = child.start_kill();
}
tracing::debug!(channel = ?self.channel_id, "Shell session dropped");
}
}
impl std::fmt::Debug for ShellSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShellSession")
.field("channel_id", &self.channel_id)
.field("has_child", &self.child.is_some())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_io_buffer_size() {
const _: () = {
assert!(IO_BUFFER_SIZE >= 4096);
assert!(IO_BUFFER_SIZE <= 65536);
};
}
#[test]
fn test_io_buffer_size_value() {
assert_eq!(IO_BUFFER_SIZE, 8192);
}
#[test]
fn test_shell_session_debug() {
let config = PtyConfig::default();
let debug_str = format!("{:?}", config);
assert!(debug_str.contains("term"));
assert!(debug_str.contains("col_width"));
assert!(debug_str.contains("row_height"));
}
#[test]
fn test_pty_config_default_values() {
let config = PtyConfig::default();
assert_eq!(config.term, "xterm-256color");
assert_eq!(config.col_width, 80);
assert_eq!(config.row_height, 24);
}
#[test]
fn test_pty_config_custom_values() {
use super::super::pty::PtyConfig as PtyMasterConfig;
let config = PtyMasterConfig::new("vt100".to_string(), 120, 40, 800, 600);
assert_eq!(config.term, "vt100");
assert_eq!(config.col_width, 120);
assert_eq!(config.row_height, 40);
assert_eq!(config.pix_width, 800);
assert_eq!(config.pix_height, 600);
}
#[tokio::test]
async fn test_shell_path_validation() {
let nonexistent_path = PathBuf::from("/nonexistent/shell/path");
assert!(!nonexistent_path.exists());
let common_shells = ["/bin/sh", "/bin/bash", "/usr/bin/bash"];
let has_valid_shell = common_shells.iter().any(|s| PathBuf::from(s).exists());
assert!(has_valid_shell, "No common shell found on system");
}
}