use bytes::Bytes;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use tokio::task::JoinHandle;
use tracing::{debug, info, warn};
use crate::binary_protocol::PayloadType;
use crate::errors::{Error, Result, SessionError};
use crate::protocol::SessionType;
use crate::terminal::TerminalSize;
use crate::{
channels::ChannelMultiplexer,
connection::{ConnectionManager, ManagerCommand},
OutputStream,
};
pub type SessionId = String;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SessionState {
Initializing,
Connected,
Disconnecting,
Terminated,
}
impl SessionState {
pub fn can_send(&self) -> bool {
matches!(self, SessionState::Connected)
}
pub fn is_active(&self) -> bool {
matches!(self, SessionState::Initializing | SessionState::Connected)
}
}
#[derive(Debug, Clone)]
pub struct SessionConfig {
pub target: String,
pub region: Option<String>,
pub session_type: SessionType,
pub document_name: Option<String>,
pub parameters: std::collections::HashMap<String, Vec<String>>,
pub reason: Option<String>,
pub connect_timeout: std::time::Duration,
pub idle_timeout: std::time::Duration,
pub max_duration: Option<std::time::Duration>,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
target: String::new(),
region: None,
session_type: SessionType::StandardStream,
document_name: None,
parameters: std::collections::HashMap::new(),
reason: None,
connect_timeout: std::time::Duration::from_secs(30),
idle_timeout: std::time::Duration::from_secs(20 * 60), max_duration: None,
}
}
}
pub struct Session {
session_id: SessionId,
config: SessionConfig,
state: Arc<RwLock<SessionState>>,
command_tx: mpsc::UnboundedSender<ManagerCommand>,
channels: Arc<ChannelMultiplexer>,
manager_task: Option<JoinHandle<Result<()>>>,
protocol_can_send: Arc<std::sync::atomic::AtomicBool>,
}
impl Session {
pub(crate) async fn new(
session_id: SessionId,
config: SessionConfig,
stream_url: String,
token_value: String,
) -> Result<Self> {
info!(session_id = %session_id, "Creating new session");
let (command_tx, command_rx) = mpsc::unbounded_channel();
let manager =
ConnectionManager::connect(session_id.clone(), stream_url, token_value, command_rx)
.await?;
let channels = manager.channels();
let protocol_can_send = manager.can_send();
let manager_task = tokio::spawn(async move { manager.run().await });
let session = Self {
session_id,
config,
state: Arc::new(RwLock::new(SessionState::Initializing)),
command_tx,
channels,
manager_task: Some(manager_task),
protocol_can_send,
};
session.set_state(SessionState::Connected).await;
Ok(session)
}
pub fn id(&self) -> &str {
&self.session_id
}
pub fn config(&self) -> &SessionConfig {
&self.config
}
pub async fn state(&self) -> SessionState {
*self.state.read().await
}
pub fn is_ready(&self) -> bool {
self.protocol_can_send
.load(std::sync::atomic::Ordering::SeqCst)
}
pub async fn wait_for_ready(&self, timeout: std::time::Duration) -> bool {
let deadline = tokio::time::Instant::now() + timeout;
let check_interval = std::time::Duration::from_millis(100);
while tokio::time::Instant::now() < deadline {
if self
.protocol_can_send
.load(std::sync::atomic::Ordering::SeqCst)
{
return true;
}
tokio::time::sleep(check_interval).await;
}
false
}
async fn set_state(&self, new_state: SessionState) {
let mut state = self.state.write().await;
debug!(
session_id = %self.session_id,
old_state = ?*state,
new_state = ?new_state,
"Session state transition"
);
*state = new_state;
}
pub fn output(&self) -> OutputStream {
self.channels.output_stream()
}
pub async fn send(&self, data: Bytes) -> Result<()> {
let state = self.state().await;
if !state.can_send() {
return Err(SessionError::InvalidState {
expected: "Connected".to_string(),
actual: format!("{:?}", state),
}
.into());
}
self.command_tx
.send(ManagerCommand::SendData(data))
.map_err(|_| Error::InvalidState("Session command channel closed".to_string()))?;
Ok(())
}
pub async fn send_size(&self, size: TerminalSize) -> Result<()> {
let state = self.state().await;
if !state.can_send() {
return Err(SessionError::InvalidState {
expected: "Connected".to_string(),
actual: format!("{:?}", state),
}
.into());
}
let data = size.to_json()?;
debug!(cols = size.cols, rows = size.rows, "Sending terminal size");
self.command_tx
.send(ManagerCommand::SendMessage {
data,
payload_type: PayloadType::Size,
})
.map_err(|_| Error::InvalidState("Session command channel closed".to_string()))?;
Ok(())
}
pub async fn terminate(&mut self) -> Result<()> {
info!(session_id = %self.session_id, "Terminating session");
self.set_state(SessionState::Disconnecting).await;
self.command_tx
.send(ManagerCommand::Terminate)
.map_err(|_| Error::InvalidState("Session command channel closed".to_string()))?;
if let Some(task) = self.manager_task.take() {
match task.await {
Ok(Ok(())) => debug!("Manager task completed successfully"),
Ok(Err(e)) => warn!(error = ?e, "Manager task completed with error"),
Err(e) => warn!(error = ?e, "Manager task panicked"),
}
}
self.set_state(SessionState::Terminated).await;
Ok(())
}
pub async fn wait_terminated(&self) {
loop {
let state = self.state().await;
if state == SessionState::Terminated {
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
}
}
impl Drop for Session {
fn drop(&mut self) {
if let Some(task) = self.manager_task.take() {
task.abort();
}
}
}
pub struct SessionManager {
ssm_client: Arc<aws_sdk_ssm::Client>,
}
impl SessionManager {
pub async fn new() -> Result<Self> {
let config = aws_config::load_from_env().await;
let ssm_client = aws_sdk_ssm::Client::new(&config);
Ok(Self {
ssm_client: Arc::new(ssm_client),
})
}
pub fn with_config(config: &aws_config::SdkConfig) -> Self {
let ssm_client = aws_sdk_ssm::Client::new(config);
Self {
ssm_client: Arc::new(ssm_client),
}
}
pub async fn start_session(&self, config: SessionConfig) -> Result<Session> {
info!(
target = %config.target,
session_type = ?config.session_type,
"Starting SSM session"
);
if config.target.is_empty() {
return Err(Error::Config("Target cannot be empty".to_string()));
}
let document_name = config.document_name.clone().or_else(|| {
match config.session_type {
SessionType::StandardStream => None, SessionType::Port => Some("AWS-StartPortForwardingSession".to_string()),
SessionType::InteractiveCommands => Some("AWS-StartInteractiveCommand".to_string()),
}
});
let mut request = self.ssm_client.start_session().target(&config.target);
if let Some(ref doc) = document_name {
request = request.document_name(doc);
}
if !config.parameters.is_empty() {
request = request.set_parameters(Some(config.parameters.clone()));
}
if let Some(ref reason) = config.reason {
request = request.reason(reason);
}
let response = request.send().await.map_err(|e| {
warn!(error = ?e, "Failed to start SSM session");
Error::from(e)
})?;
let session_id = response
.session_id()
.ok_or_else(|| Error::AwsSdk("No session ID in response".to_string()))?
.to_string();
let stream_url = response
.stream_url()
.ok_or_else(|| Error::AwsSdk("No stream URL in response".to_string()))?
.to_string();
let token_value = response
.token_value()
.ok_or_else(|| Error::AwsSdk("No token value in response".to_string()))?
.to_string();
info!(session_id = %session_id, "SSM session started");
let session = Session::new(session_id, config, stream_url, token_value).await?;
Ok(session)
}
pub async fn terminate_session(&self, session_id: &str) -> Result<()> {
info!(session_id = %session_id, "Terminating session via AWS API");
self.ssm_client
.terminate_session()
.session_id(session_id)
.send()
.await
.map_err(|e| {
warn!(error = ?e, "Failed to terminate session");
Error::from(e)
})?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_state_transitions() {
assert!(SessionState::Connected.can_send());
assert!(!SessionState::Terminated.can_send());
assert!(SessionState::Connected.is_active());
assert!(!SessionState::Terminated.is_active());
}
#[test]
fn test_session_config_default() {
let config = SessionConfig::default();
assert_eq!(config.session_type, SessionType::StandardStream);
assert!(config.document_name.is_none());
}
}