use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use crate::tools::strip_ansi;
use crate::{Result, RuntimeError};
use super::config::ShellConfig;
use super::pty::PtyHandle;
use super::readiness::{ReadinessDetector, ReadinessResult, ReadinessStrategy};
pub struct SessionManager {
sessions: Mutex<HashMap<String, ShellSession>>,
config: ShellConfig,
next_id: AtomicU32,
}
struct ShellSession {
pty: PtyHandle,
detector: ReadinessDetector,
created_at: Instant,
last_active: Instant,
idle_timeout: Duration,
status: SessionStatus,
}
#[derive(Debug, Clone, PartialEq)]
pub enum SessionStatus {
Active,
Exited(Option<i32>),
Closed,
}
pub struct SessionOpts {
pub command: Option<String>,
pub working_directory: Option<String>,
pub env: HashMap<String, String>,
pub rows: Option<u16>,
pub cols: Option<u16>,
pub readiness_timeout_ms: Option<u64>,
pub idle_timeout: Option<u64>,
}
#[derive(Debug)]
pub struct SendResult {
pub output: String,
pub status: String,
}
pub struct ShellSessionInfo {
pub id: String,
pub status: SessionStatus,
pub created_at: Instant,
pub last_active: Instant,
}
fn normalize_output(raw: &str) -> String {
strip_ansi(raw).replace("\r\n", "\n").replace('\r', "")
}
fn process_input_escapes(input: &str) -> String {
let mut result = String::with_capacity(input.len());
let mut chars = input.chars().peekable();
while let Some(ch) = chars.next() {
if ch == '\\' {
match chars.peek() {
Some('n') => { chars.next(); result.push('\n'); }
Some('r') => { chars.next(); result.push('\r'); }
Some('t') => { chars.next(); result.push('\t'); }
Some('\\') => { chars.next(); result.push('\\'); }
Some('a') => { chars.next(); result.push('\x07'); } Some('b') => { chars.next(); result.push('\x08'); } Some('0') => { chars.next(); result.push('\0'); } Some('e') => {
chars.next();
tracing::warn!("blocked \\e escape sequence (raw ESC) in shell input");
}
Some('x') => {
chars.next(); let mut hex = String::new();
for _ in 0..2 {
if let Some(&c) = chars.peek() {
if c.is_ascii_hexdigit() {
hex.push(c);
chars.next();
} else {
break;
}
}
}
if let Ok(byte) = u8::from_str_radix(&hex, 16) {
if byte == 0x1b {
tracing::warn!("blocked \\x1b escape sequence (raw ESC) in shell input");
} else if byte >= 0x80 {
tracing::warn!("blocked \\x{hex:} high byte (>= 0x80) in shell input");
} else {
result.push(byte as char);
}
} else {
result.push('\\');
result.push('x');
result.push_str(&hex);
}
}
_ => {
result.push(ch);
}
}
} else {
result.push(ch);
}
}
result
}
fn status_string(status: &SessionStatus) -> String {
match status {
SessionStatus::Active => "active".into(),
SessionStatus::Exited(Some(code)) => format!("exited({code})"),
SessionStatus::Exited(None) => "exited".into(),
SessionStatus::Closed => "closed".into(),
}
}
async fn wait_for_output(
pty: &mut PtyHandle,
detector: &ReadinessDetector,
timeout_override: Option<u64>,
tx_delta: Option<&tokio::sync::mpsc::UnboundedSender<String>>,
max_output: usize,
) -> (String, String) {
let override_detector;
let effective_detector = if let Some(ms) = timeout_override {
override_detector = ReadinessDetector::new(
ReadinessStrategy::Hybrid,
&[], ms,
ms.saturating_mul(10).max(10_000), );
&override_detector
} else {
detector
};
let mut output = String::new();
let start = Instant::now();
let mut last_output_time = Instant::now();
let poll_interval = Duration::from_millis(50);
loop {
let bytes = pty.try_read_output(poll_interval).await;
if !bytes.is_empty() {
let text = String::from_utf8_lossy(&bytes);
output.push_str(&text);
last_output_time = Instant::now();
if let Some(tx) = tx_delta {
let _ = tx.send(normalize_output(&text));
}
}
if output.len() > max_output {
let mut trunc = max_output;
while trunc > 0 && !output.is_char_boundary(trunc) {
trunc -= 1;
}
output.truncate(trunc);
return (normalize_output(&output), "active".into());
}
if !pty.is_alive() {
tokio::time::sleep(Duration::from_millis(50)).await;
let remaining = pty.try_read_output(Duration::from_millis(100)).await;
if !remaining.is_empty() {
let remaining_text = String::from_utf8_lossy(&remaining);
output.push_str(&remaining_text);
if let Some(tx) = tx_delta {
let _ = tx.send(normalize_output(&remaining_text));
}
}
return (normalize_output(&output), status_string(&SessionStatus::Exited(None)));
}
let silence_elapsed = last_output_time.elapsed();
let total_elapsed = start.elapsed();
match effective_detector.check(&output, silence_elapsed, total_elapsed) {
ReadinessResult::Ready => return (normalize_output(&output), "active".into()),
ReadinessResult::SilenceTimeout => return (normalize_output(&output), "active".into()),
ReadinessResult::MaxTimeout => return (normalize_output(&output), "timeout".into()),
ReadinessResult::Waiting => continue,
}
}
}
impl SessionManager {
pub fn new(config: ShellConfig) -> Arc<Self> {
Arc::new(Self {
sessions: Mutex::new(HashMap::new()),
config,
next_id: AtomicU32::new(0),
})
}
pub async fn create_session(
&self,
opts: SessionOpts,
tx_delta: Option<&tokio::sync::mpsc::UnboundedSender<String>>,
) -> Result<(String, String, String)> {
{
let sessions = self.sessions.lock().map_err(|e| {
RuntimeError::Tool(format!("session lock poisoned: {e}"))
})?;
if sessions.len() >= self.config.max_sessions {
return Err(RuntimeError::Tool(format!(
"maximum session limit reached ({})",
self.config.max_sessions
)));
}
}
let seq = self.next_id.fetch_add(1, Ordering::SeqCst) + 1;
let id = format!("shell_{:02}", seq);
let command = opts.command.unwrap_or_else(|| {
std::env::var("SHELL").unwrap_or_else(|_| "bash".into())
});
let rows = opts.rows.unwrap_or(self.config.default_rows);
let cols = opts.cols.unwrap_or(self.config.default_cols);
let idle_timeout = opts
.idle_timeout
.map(Duration::from_secs)
.unwrap_or(self.config.idle_timeout);
let mut pty = PtyHandle::spawn(
&command,
opts.working_directory.as_deref(),
opts.env,
rows,
cols,
)?;
let silence_ms = opts
.readiness_timeout_ms
.unwrap_or(self.config.readiness_timeout_ms);
let detector = ReadinessDetector::new(
super::readiness::ReadinessStrategy::Hybrid,
&self.config.prompt_patterns,
silence_ms,
self.config.max_readiness_timeout_ms,
);
tokio::time::sleep(Duration::from_millis(200)).await;
let (initial_output, status_str) =
wait_for_output(&mut pty, &detector, opts.readiness_timeout_ms, tx_delta, 30000).await;
let now = Instant::now();
let status = if status_str.starts_with("exited") {
SessionStatus::Exited(None)
} else {
SessionStatus::Active
};
let session = ShellSession {
pty,
detector,
created_at: now,
last_active: now,
idle_timeout,
status,
};
{
let mut sessions = self.sessions.lock().map_err(|e| {
RuntimeError::Tool(format!("session lock poisoned: {e}"))
})?;
sessions.insert(id.clone(), session);
}
Ok((id, initial_output, status_str))
}
pub async fn send_input(
&self,
id: &str,
input: &str,
timeout_ms: Option<u64>,
tx_delta: Option<&tokio::sync::mpsc::UnboundedSender<String>>,
) -> Result<SendResult> {
let mut session = {
let mut sessions = self.sessions.lock().map_err(|e| {
RuntimeError::Tool(format!("session lock poisoned: {e}"))
})?;
sessions.remove(id).ok_or_else(|| {
RuntimeError::Tool(format!(
"session {id} not found — it may have been closed, reaped, or is currently in use by another call"
))
})?
};
if session.status != SessionStatus::Active {
let s_str = status_string(&session.status);
let mut sessions = self.sessions.lock().map_err(|e| {
RuntimeError::Tool(format!("session lock poisoned: {e}"))
})?;
sessions.insert(id.to_string(), session);
return Err(RuntimeError::Tool(format!(
"session {id} is not active (status: {s_str})"
)));
}
let processed = process_input_escapes(input);
session.pty.write(processed.as_bytes())?;
let (output, status_str) =
wait_for_output(&mut session.pty, &session.detector, timeout_ms, tx_delta, 30000).await;
session.last_active = Instant::now();
if !session.pty.is_alive() {
session.status = SessionStatus::Exited(None);
}
let result = SendResult {
output,
status: status_str,
};
{
let mut sessions = self.sessions.lock().map_err(|e| {
RuntimeError::Tool(format!("session lock poisoned: {e}"))
})?;
sessions.insert(id.to_string(), session);
}
Ok(result)
}
pub async fn close_session(&self, id: &str) -> Result<String> {
let mut session = {
let mut sessions = self.sessions.lock().map_err(|e| {
RuntimeError::Tool(format!("session lock poisoned: {e}"))
})?;
match sessions.remove(id) {
Some(s) => s,
None => return Ok(String::new()),
}
};
let remaining = session
.pty
.try_read_output(Duration::from_millis(100))
.await;
let final_output = if remaining.is_empty() {
String::new()
} else {
strip_ansi(&String::from_utf8_lossy(&remaining))
};
drop(session);
Ok(final_output)
}
pub fn reap_idle(&self) -> Vec<String> {
let mut sessions = match self.sessions.lock() {
Ok(s) => s,
Err(e) => {
tracing::error!("session lock poisoned: {e}");
return Vec::new();
}
};
let grace_period = Duration::from_secs(5);
let ids_to_reap: Vec<String> = sessions
.iter()
.filter(|(_, s)| {
let elapsed = s.last_active.elapsed();
elapsed > s.idle_timeout && elapsed > grace_period
})
.map(|(id, _)| id.clone())
.collect();
for id in &ids_to_reap {
sessions.remove(id);
}
ids_to_reap
}
pub fn shutdown_all(&self) {
match self.sessions.lock() {
Ok(mut sessions) => {
sessions.drain();
}
Err(e) => {
tracing::error!("session lock poisoned: {e}");
}
}
}
pub fn active_count(&self) -> usize {
match self.sessions.lock() {
Ok(s) => s.len(),
Err(e) => {
tracing::error!("session lock poisoned: {e}");
0
}
}
}
pub fn list_sessions(&self) -> Vec<ShellSessionInfo> {
match self.sessions.lock() {
Ok(sessions) => {
sessions
.iter()
.map(|(id, s)| ShellSessionInfo {
id: id.clone(),
status: s.status.clone(),
created_at: s.created_at,
last_active: s.last_active,
})
.collect()
}
Err(e) => {
tracing::error!("session lock poisoned: {e}");
Vec::new()
}
}
}
}
impl Drop for SessionManager {
fn drop(&mut self) {
self.shutdown_all();
}
}
pub fn start_reaper(
manager: Arc<SessionManager>,
cancel: tokio_util::sync::CancellationToken,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let interval = Duration::from_secs(30);
loop {
tokio::select! {
_ = cancel.cancelled() => break,
_ = tokio::time::sleep(interval) => {
let reaped = manager.reap_idle();
for id in &reaped {
tracing::info!(session_id = %id, "reaped idle shell session");
}
}
}
}
})
}
#[cfg(test)]
mod tests {
use super::*;
fn default_manager() -> Arc<SessionManager> {
SessionManager::new(ShellConfig::default())
}
fn opts_for(command: &str) -> SessionOpts {
SessionOpts {
command: Some(command.to_string()),
working_directory: None,
env: HashMap::new(),
rows: None,
cols: None,
readiness_timeout_ms: None,
idle_timeout: None,
}
}
#[tokio::test]
async fn test_create_session_echo_hello() {
let mgr = default_manager();
let (id, output, _status) = mgr
.create_session(opts_for("echo hello"), None)
.await
.expect("failed to create session");
assert!(id.starts_with("shell_"));
assert!(
output.contains("hello"),
"expected 'hello' in output, got: {output:?}"
);
}
#[tokio::test]
async fn test_send_input_echo() {
let mgr = default_manager();
let (id, _initial, _status) = mgr
.create_session(opts_for("bash"), None)
.await
.expect("failed to create session");
let result = mgr
.send_input(&id, "echo test\n", None, None)
.await
.expect("failed to send input");
assert!(
result.output.contains("test"),
"expected 'test' in output, got: {:?}",
result.output
);
let _ = mgr.close_session(&id).await;
}
#[tokio::test]
async fn test_close_session_idempotent() {
let mgr = default_manager();
let (id, _, _status) = mgr
.create_session(opts_for("bash"), None)
.await
.expect("failed to create session");
let result1 = mgr.close_session(&id).await;
assert!(result1.is_ok(), "first close should succeed");
let result2 = mgr.close_session(&id).await;
assert!(result2.is_ok(), "second close should also succeed (idempotent)");
assert_eq!(result2.unwrap(), "", "second close returns empty string");
}
#[tokio::test]
async fn test_max_sessions_limit() {
let mut config = ShellConfig::default();
config.max_sessions = 2;
let mgr = SessionManager::new(config);
let (id1, _, _s) = mgr
.create_session(opts_for("bash"), None)
.await
.expect("session 1");
let (id2, _, _s) = mgr
.create_session(opts_for("bash"), None)
.await
.expect("session 2");
let result = mgr.create_session(opts_for("bash"), None).await;
assert!(result.is_err(), "third session should fail");
let err_msg = format!("{}", result.unwrap_err());
assert!(
err_msg.contains("maximum session limit"),
"error should mention limit, got: {err_msg}"
);
let _ = mgr.close_session(&id1).await;
let _ = mgr.close_session(&id2).await;
}
#[tokio::test]
async fn test_session_not_found() {
let mgr = default_manager();
let result = mgr.send_input("shell_99", "hello\n", None, None).await;
assert!(result.is_err(), "send to non-existent session should fail");
let err_msg = format!("{}", result.unwrap_err());
assert!(
err_msg.contains("not found"),
"error should mention 'not found', got: {err_msg}"
);
}
#[test]
fn test_normalize_output_crlf() {
assert_eq!(normalize_output("hello\r\nworld\r\n"), "hello\nworld\n");
}
#[test]
fn test_normalize_output_lone_cr() {
assert_eq!(normalize_output("abc\rdef"), "abcdef");
}
#[test]
fn test_escape_newline() {
assert_eq!(process_input_escapes(r"hello\n"), "hello\n");
}
#[test]
fn test_escape_tab() {
assert_eq!(process_input_escapes(r"a\tb"), "a\tb");
}
#[test]
fn test_escape_ctrl_c() {
assert_eq!(process_input_escapes(r"\x03"), "\x03");
}
#[test]
fn test_escape_ctrl_d() {
assert_eq!(process_input_escapes(r"\x04"), "\x04");
}
#[test]
fn test_escape_literal_backslash() {
assert_eq!(process_input_escapes(r"a\\b"), "a\\b");
}
#[test]
fn test_escape_real_newline_passthrough() {
assert_eq!(process_input_escapes("hello\n"), "hello\n");
}
#[test]
fn test_escape_mixed() {
assert_eq!(process_input_escapes(r"ls -la\n"), "ls -la\n");
assert_eq!(process_input_escapes(r"124\n"), "124\n");
}
#[test]
fn test_escape_unknown_sequence() {
assert_eq!(process_input_escapes(r"\q"), "\\q");
}
#[test]
fn test_escape_hex_partial() {
assert_eq!(process_input_escapes(r"\xZZ"), "\\xZZ");
}
#[test]
fn test_escape_bell() {
assert_eq!(process_input_escapes(r"\a"), "\x07");
}
#[test]
fn test_escape_backspace() {
assert_eq!(process_input_escapes(r"\b"), "\x08");
}
#[test]
fn test_escape_null() {
assert_eq!(process_input_escapes(r"\0"), "\0");
}
#[test]
fn test_escape_esc_blocked() {
assert_eq!(process_input_escapes(r"\e"), "");
}
#[test]
fn test_escape_hex_1b_blocked() {
assert_eq!(process_input_escapes(r"\x1b"), "");
}
#[test]
fn test_escape_hex_high_byte_blocked() {
assert_eq!(process_input_escapes(r"\x80"), "");
assert_eq!(process_input_escapes(r"\xff"), "");
}
#[test]
fn test_escape_hex_del_allowed() {
assert_eq!(process_input_escapes(r"\x7f"), "\x7f");
}
}