use crate::service_error::ServiceError;
use crate::sessions::connection::SshConnection;
use crate::sessions::terminal::TerminalEmulator;
use crate::sessions::types::*;
use chrono::{DateTime, Utc};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, warn};
const MAX_BUFFER_SIZE: usize = 1024 * 1024;
pub struct OutputBuffer {
data: Vec<u8>,
truncated: bool,
}
impl OutputBuffer {
fn new() -> Self {
Self {
data: Vec::new(),
truncated: false,
}
}
fn append(&mut self, new_data: &[u8]) {
self.data.extend_from_slice(new_data);
if self.data.len() > MAX_BUFFER_SIZE {
let excess = self.data.len() - MAX_BUFFER_SIZE;
self.data.drain(0..excess);
self.truncated = true;
}
}
fn take(&mut self) -> (Vec<u8>, bool) {
let truncated = self.truncated;
self.truncated = false;
(std::mem::take(&mut self.data), truncated)
}
fn peek(&self) -> (&[u8], bool) {
(&self.data, self.truncated)
}
fn len(&self) -> usize {
self.data.len()
}
}
pub struct ShellSession {
pub id: String,
pub target_id: String,
pub name: Option<String>,
pub client_id: Option<String>,
pub tmux_session: String,
pub cols: u16,
pub rows: u16,
pub created_at: DateTime<Utc>,
pub last_activity: Arc<RwLock<DateTime<Utc>>>,
pub state: Arc<RwLock<SessionState>>,
buffer: Arc<RwLock<OutputBuffer>>,
terminal: Arc<RwLock<TerminalEmulator>>,
connection: Arc<RwLock<Option<SshConnection>>>,
shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
}
impl ShellSession {
pub fn new(
id: String,
target_id: String,
name: Option<String>,
client_id: Option<String>,
tmux_session: String,
cols: u16,
rows: u16,
connection: SshConnection,
) -> Self {
let now = Utc::now();
Self {
id,
target_id,
name,
client_id,
tmux_session,
cols,
rows,
created_at: now,
last_activity: Arc::new(RwLock::new(now)),
state: Arc::new(RwLock::new(SessionState::Active)),
buffer: Arc::new(RwLock::new(OutputBuffer::new())),
terminal: Arc::new(RwLock::new(TerminalEmulator::new(cols, rows))),
connection: Arc::new(RwLock::new(Some(connection))),
shutdown_tx: None,
}
}
pub fn start_output_reader(&mut self) {
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
self.shutdown_tx = Some(shutdown_tx);
let connection = Arc::clone(&self.connection);
let buffer = Arc::clone(&self.buffer);
let terminal = Arc::clone(&self.terminal);
let state = Arc::clone(&self.state);
let last_activity = Arc::clone(&self.last_activity);
let session_id = self.id.clone();
tokio::spawn(async move {
debug!("Starting output reader for session {}", session_id);
loop {
if shutdown_rx.try_recv().is_ok() {
debug!("Output reader shutting down for session {}", session_id);
break;
}
let read_result = {
let mut conn_guard = connection.write().await;
if let Some(conn) = conn_guard.as_mut() {
if !conn.is_alive() {
warn!("Connection died for session {}", session_id);
*state.write().await = SessionState::Disconnected;
*conn_guard = None;
None
} else {
conn.try_recv()
}
} else {
None
}
};
if let Some(data) = read_result {
if !data.is_empty() {
*last_activity.write().await = Utc::now();
buffer.write().await.append(&data);
terminal.read().await.process(&data).await;
}
} else {
let conn_guard = connection.read().await;
if conn_guard.is_none() {
debug!("Connection removed for session {}", session_id);
break;
}
}
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
debug!("Output reader finished for session {}", session_id);
});
}
pub async fn info(&self) -> ShellSessionInfo {
ShellSessionInfo {
id: self.id.clone(),
target_id: self.target_id.clone(),
name: self.name.clone(),
client_id: self.client_id.clone(),
state: *self.state.read().await,
created_at: self.created_at.to_rfc3339(),
last_activity: self.last_activity.read().await.to_rfc3339(),
size: (self.cols, self.rows),
}
}
pub async fn write(&self, data: &[u8]) -> Result<usize, ServiceError> {
let mut conn_guard = self.connection.write().await;
if let Some(conn) = conn_guard.as_mut() {
conn.send(data).await?;
*self.last_activity.write().await = Utc::now();
Ok(data.len())
} else {
Err(ServiceError::Internal(
"Session is disconnected".to_string(),
))
}
}
pub async fn read(&self, consume: bool) -> (String, usize, bool) {
let mut buffer = self.buffer.write().await;
if consume {
let (data, truncated) = buffer.take();
let text = String::from_utf8_lossy(&data).to_string();
(text, 0, truncated)
} else {
let (data, truncated) = buffer.peek();
let text = String::from_utf8_lossy(data).to_string();
(text, data.len(), truncated)
}
}
pub async fn screen_state(&self) -> ScreenState {
self.terminal.read().await.screen_state().await
}
pub async fn resize(&mut self, cols: u16, rows: u16) -> Result<(), ServiceError> {
self.terminal.write().await.resize(cols, rows).await;
let mut conn_guard = self.connection.write().await;
if let Some(conn) = conn_guard.as_mut() {
conn.resize(cols, rows)?;
}
self.cols = cols;
self.rows = rows;
Ok(())
}
pub async fn wait_for_output(&self, timeout_ms: u64, min_bytes: usize) -> bool {
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_millis(timeout_ms);
while start.elapsed() < timeout {
if self.buffer.read().await.len() >= min_bytes {
return true;
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
self.buffer.read().await.len() >= min_bytes
}
pub async fn get_screen_text(&self) -> String {
let screen_state = self.terminal.read().await.screen_state().await;
screen_state.lines.join("\n")
}
pub async fn wait_for_pattern(&self, pattern: &str, timeout_ms: u64) -> bool {
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_millis(timeout_ms);
while start.elapsed() < timeout {
let screen_text = self.get_screen_text().await;
if screen_text.contains(pattern) {
return true;
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
self.get_screen_text().await.contains(pattern)
}
pub async fn wait_for_stable(&self, stable_ms: u64, timeout_ms: u64) -> bool {
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_millis(timeout_ms);
let stable_duration = std::time::Duration::from_millis(stable_ms);
let mut last_screen = self.get_screen_text().await;
let mut last_change = std::time::Instant::now();
while start.elapsed() < timeout {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let current_screen = self.get_screen_text().await;
if current_screen != last_screen {
last_screen = current_screen;
last_change = std::time::Instant::now();
} else if last_change.elapsed() >= stable_duration {
return true;
}
}
false
}
pub async fn close(mut self, _force: bool) -> Result<bool, ServiceError> {
if let Some(shutdown_tx) = self.shutdown_tx.take() {
let _ = shutdown_tx.send(());
}
let connection = self.connection.write().await.take();
if let Some(conn) = connection {
conn.close().await?;
}
*self.state.write().await = SessionState::Closed;
Ok(true)
}
}
pub struct ShellSessionRegistry {
sessions: Arc<RwLock<HashMap<String, Arc<RwLock<ShellSession>>>>>,
}
impl ShellSessionRegistry {
pub fn new() -> Self {
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn add(&self, session: ShellSession) {
let id = session.id.clone();
let session = Arc::new(RwLock::new(session));
self.sessions.write().await.insert(id, session);
}
pub async fn get(&self, id: &str) -> Option<Arc<RwLock<ShellSession>>> {
self.sessions.read().await.get(id).cloned()
}
pub async fn remove(&self, id: &str) -> Option<Arc<RwLock<ShellSession>>> {
self.sessions.write().await.remove(id)
}
pub async fn list(
&self,
target_id: Option<&str>,
client_id: Option<&str>,
include_disconnected: bool,
) -> Vec<ShellSessionInfo> {
let sessions = self.sessions.read().await;
let mut result = Vec::new();
for session_lock in sessions.values() {
let session = session_lock.read().await;
let state = *session.state.read().await;
if !include_disconnected && state == SessionState::Disconnected {
continue;
}
if let Some(tid) = target_id {
if session.target_id != tid {
continue;
}
}
if let Some(cid) = client_id {
match &session.client_id {
Some(session_cid) if session_cid == cid => {}
_ => continue,
}
}
result.push(session.info().await);
}
result
}
pub async fn active_count(&self) -> usize {
let sessions = self.sessions.read().await;
let mut count = 0;
for session_lock in sessions.values() {
let session = session_lock.read().await;
if *session.state.read().await == SessionState::Active {
count += 1;
}
}
count
}
}
impl Default for ShellSessionRegistry {
fn default() -> Self {
Self::new()
}
}