use crate::protocol::SessionType;
use std::collections::HashMap;
pub trait SsmDocument: Send + Sync {
fn document_name(&self) -> &'static str;
fn session_type(&self) -> SessionType;
fn parameters(&self) -> HashMap<String, Vec<String>>;
}
#[derive(Debug, Clone)]
pub struct PortForwardingSession {
pub remote_port: u16,
pub local_port: Option<u16>,
}
impl PortForwardingSession {
pub const DOCUMENT_NAME: &'static str = "AWS-StartPortForwardingSession";
pub fn new(remote_port: u16) -> Self {
Self {
remote_port,
local_port: None,
}
}
pub fn builder() -> PortForwardingSessionBuilder {
PortForwardingSessionBuilder::default()
}
}
impl SsmDocument for PortForwardingSession {
fn document_name(&self) -> &'static str {
Self::DOCUMENT_NAME
}
fn session_type(&self) -> SessionType {
SessionType::Port
}
fn parameters(&self) -> HashMap<String, Vec<String>> {
let mut params = HashMap::new();
params.insert("portNumber".to_string(), vec![self.remote_port.to_string()]);
if let Some(local) = self.local_port {
params.insert("localPortNumber".to_string(), vec![local.to_string()]);
}
params
}
}
#[derive(Debug, Default)]
pub struct PortForwardingSessionBuilder {
remote_port: Option<u16>,
local_port: Option<u16>,
}
impl PortForwardingSessionBuilder {
pub fn remote_port(mut self, port: u16) -> Self {
self.remote_port = Some(port);
self
}
pub fn local_port(mut self, port: u16) -> Self {
self.local_port = Some(port);
self
}
pub fn build(self) -> PortForwardingSession {
self.try_build().expect("remote_port is required")
}
pub fn try_build(self) -> Option<PortForwardingSession> {
Some(PortForwardingSession {
remote_port: self.remote_port?,
local_port: self.local_port,
})
}
}
#[derive(Debug, Clone)]
pub struct PortForwardingToRemoteHost {
pub host: String,
pub remote_port: u16,
pub local_port: Option<u16>,
}
impl PortForwardingToRemoteHost {
pub const DOCUMENT_NAME: &'static str = "AWS-StartPortForwardingSessionToRemoteHost";
pub fn new(host: impl Into<String>, remote_port: u16) -> Self {
Self {
host: host.into(),
remote_port,
local_port: None,
}
}
pub fn with_local_port(mut self, port: u16) -> Self {
self.local_port = Some(port);
self
}
}
impl SsmDocument for PortForwardingToRemoteHost {
fn document_name(&self) -> &'static str {
Self::DOCUMENT_NAME
}
fn session_type(&self) -> SessionType {
SessionType::Port
}
fn parameters(&self) -> HashMap<String, Vec<String>> {
let mut params = HashMap::new();
params.insert("host".to_string(), vec![self.host.clone()]);
params.insert("portNumber".to_string(), vec![self.remote_port.to_string()]);
if let Some(local) = self.local_port {
params.insert("localPortNumber".to_string(), vec![local.to_string()]);
}
params
}
}
#[derive(Debug, Clone, Default)]
pub struct ShellSession;
impl ShellSession {
pub fn new() -> Self {
Self
}
}
impl SsmDocument for ShellSession {
fn document_name(&self) -> &'static str {
""
}
fn session_type(&self) -> SessionType {
SessionType::StandardStream
}
fn parameters(&self) -> HashMap<String, Vec<String>> {
HashMap::new()
}
}
#[derive(Debug, Clone)]
pub struct InteractiveCommand {
pub commands: Vec<String>,
}
impl InteractiveCommand {
pub const DOCUMENT_NAME: &'static str = "AWS-StartInteractiveCommand";
pub fn new(command: impl Into<String>) -> Self {
Self {
commands: vec![command.into()],
}
}
pub fn with_commands(commands: Vec<String>) -> Self {
Self { commands }
}
}
impl SsmDocument for InteractiveCommand {
fn document_name(&self) -> &'static str {
Self::DOCUMENT_NAME
}
fn session_type(&self) -> SessionType {
SessionType::InteractiveCommands
}
fn parameters(&self) -> HashMap<String, Vec<String>> {
let mut params = HashMap::new();
params.insert("command".to_string(), self.commands.clone());
params
}
}
#[derive(Debug, Clone)]
pub struct SshSession {
pub port: u16,
}
impl SshSession {
pub const DOCUMENT_NAME: &'static str = "AWS-StartSSHSession";
pub fn new() -> Self {
Self { port: 22 }
}
pub fn with_port(port: u16) -> Self {
Self { port }
}
}
impl Default for SshSession {
fn default() -> Self {
Self::new()
}
}
impl SsmDocument for SshSession {
fn document_name(&self) -> &'static str {
Self::DOCUMENT_NAME
}
fn session_type(&self) -> SessionType {
SessionType::Port
}
fn parameters(&self) -> HashMap<String, Vec<String>> {
let mut params = HashMap::new();
params.insert("portNumber".to_string(), vec![self.port.to_string()]);
params
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_port_forwarding_simple() {
let doc = PortForwardingSession::new(3306);
assert_eq!(doc.document_name(), "AWS-StartPortForwardingSession");
assert_eq!(doc.session_type(), SessionType::Port);
let params = doc.parameters();
assert_eq!(params.get("portNumber"), Some(&vec!["3306".to_string()]));
assert!(!params.contains_key("localPortNumber"));
}
#[test]
fn test_port_forwarding_builder() {
let doc = PortForwardingSession::builder()
.remote_port(3306)
.local_port(13306)
.build();
let params = doc.parameters();
assert_eq!(params.get("portNumber"), Some(&vec!["3306".to_string()]));
assert_eq!(
params.get("localPortNumber"),
Some(&vec!["13306".to_string()])
);
}
#[test]
fn test_port_forwarding_to_remote_host() {
let doc =
PortForwardingToRemoteHost::new("mydb.rds.amazonaws.com", 5432).with_local_port(15432);
assert_eq!(
doc.document_name(),
"AWS-StartPortForwardingSessionToRemoteHost"
);
let params = doc.parameters();
assert_eq!(
params.get("host"),
Some(&vec!["mydb.rds.amazonaws.com".to_string()])
);
assert_eq!(params.get("portNumber"), Some(&vec!["5432".to_string()]));
assert_eq!(
params.get("localPortNumber"),
Some(&vec!["15432".to_string()])
);
}
#[test]
fn test_shell_session() {
let doc = ShellSession::new();
assert_eq!(doc.session_type(), SessionType::StandardStream);
assert!(doc.parameters().is_empty());
}
#[test]
fn test_interactive_command() {
let doc = InteractiveCommand::new("top");
assert_eq!(doc.document_name(), "AWS-StartInteractiveCommand");
assert_eq!(doc.session_type(), SessionType::InteractiveCommands);
let params = doc.parameters();
assert_eq!(params.get("command"), Some(&vec!["top".to_string()]));
}
#[test]
fn test_ssh_session() {
let doc = SshSession::new();
assert_eq!(doc.document_name(), "AWS-StartSSHSession");
assert_eq!(doc.session_type(), SessionType::Port);
let params = doc.parameters();
assert_eq!(params.get("portNumber"), Some(&vec!["22".to_string()]));
let doc = SshSession::with_port(2222);
let params = doc.parameters();
assert_eq!(params.get("portNumber"), Some(&vec!["2222".to_string()]));
}
}