use anyhow::{Context, Result};
use std::fmt;
const DEFAULT_MAX_JUMP_HOSTS: usize = 10;
const ABSOLUTE_MAX_JUMP_HOSTS: usize = 30;
pub fn get_max_jump_hosts() -> usize {
std::env::var("BSSH_MAX_JUMP_HOSTS")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.map(|n| {
if n == 0 {
tracing::warn!(
"BSSH_MAX_JUMP_HOSTS cannot be 0, using default: {}",
DEFAULT_MAX_JUMP_HOSTS
);
DEFAULT_MAX_JUMP_HOSTS
} else if n > ABSOLUTE_MAX_JUMP_HOSTS {
tracing::warn!(
"BSSH_MAX_JUMP_HOSTS={} exceeds absolute maximum {}, capping at {}",
n,
ABSOLUTE_MAX_JUMP_HOSTS,
ABSOLUTE_MAX_JUMP_HOSTS
);
ABSOLUTE_MAX_JUMP_HOSTS
} else {
n
}
})
.unwrap_or(DEFAULT_MAX_JUMP_HOSTS)
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct JumpHost {
pub user: Option<String>,
pub host: String,
pub port: Option<u16>,
}
impl JumpHost {
pub fn new(host: String, user: Option<String>, port: Option<u16>) -> Self {
Self { user, host, port }
}
pub fn effective_user(&self) -> String {
self.user.clone().unwrap_or_else(whoami::username)
}
pub fn effective_port(&self) -> u16 {
self.port.unwrap_or(22)
}
pub fn to_connection_string(&self) -> String {
match (&self.user, &self.port) {
(Some(user), Some(port)) => format!("{}@{}:{}", user, self.host, port),
(Some(user), None) => format!("{}@{}", user, self.host),
(None, Some(port)) => format!("{}:{}", self.host, port),
(None, None) => self.host.clone(),
}
}
}
impl fmt::Display for JumpHost {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_connection_string())
}
}
pub fn parse_jump_hosts(jump_spec: &str) -> Result<Vec<JumpHost>> {
if jump_spec.trim().is_empty() {
return Ok(Vec::new());
}
let mut jump_hosts = Vec::new();
for host_spec in jump_spec.split(',') {
let host_spec = host_spec.trim();
if host_spec.is_empty() {
continue;
}
let jump_host = parse_single_jump_host(host_spec)
.with_context(|| format!("Failed to parse jump host specification: '{host_spec}'"))?;
jump_hosts.push(jump_host);
}
if jump_hosts.is_empty() {
anyhow::bail!(
"No valid jump hosts found in specification: '{}'",
jump_spec
);
}
let max_jump_hosts = get_max_jump_hosts();
if jump_hosts.len() > max_jump_hosts {
anyhow::bail!(
"Too many jump hosts specified: {} (maximum allowed: {}). Reduce the number of jump hosts in your chain or set BSSH_MAX_JUMP_HOSTS environment variable.",
jump_hosts.len(),
max_jump_hosts
);
}
Ok(jump_hosts)
}
fn parse_single_jump_host(host_spec: &str) -> Result<JumpHost> {
if host_spec.is_empty() {
anyhow::bail!("Empty jump host specification");
}
let parts: Vec<&str> = host_spec.splitn(2, '@').collect();
let (user, host_port) = if parts.len() == 2 {
(Some(parts[0].to_string()), parts[1])
} else {
(None, parts[0])
};
let user = if let Some(username) = user {
Some(crate::utils::sanitize_username(&username).with_context(|| {
format!("Invalid username in jump host specification: '{host_spec}'")
})?)
} else {
None
};
let (host, port) = parse_host_port(host_port)
.with_context(|| format!("Invalid host:port specification: '{host_port}'"))?;
let host = crate::utils::sanitize_hostname(&host)
.with_context(|| format!("Invalid hostname in jump host specification: '{host}'"))?;
Ok(JumpHost::new(host, user, port))
}
fn parse_host_port(host_port: &str) -> Result<(String, Option<u16>)> {
if host_port.is_empty() {
anyhow::bail!("Empty host specification");
}
if host_port.starts_with('[') {
if let Some(bracket_end) = host_port.find(']') {
let ipv6_addr = &host_port[1..bracket_end];
if ipv6_addr.is_empty() {
anyhow::bail!("Empty IPv6 address in brackets");
}
let remaining = &host_port[bracket_end + 1..];
if remaining.is_empty() {
return Ok((ipv6_addr.to_string(), None));
} else if let Some(port_str) = remaining.strip_prefix(':') {
if port_str.is_empty() {
anyhow::bail!("Empty port specification after IPv6 address");
}
let port = port_str
.parse::<u16>()
.with_context(|| format!("Invalid port number: '{port_str}'"))?;
if port == 0 {
anyhow::bail!("Port number cannot be zero");
}
return Ok((ipv6_addr.to_string(), Some(port)));
} else {
anyhow::bail!("Invalid characters after IPv6 address: '{}'", remaining);
}
} else {
anyhow::bail!("Unclosed bracket in IPv6 address");
}
}
if let Some(colon_pos) = host_port.rfind(':') {
let host_part = &host_port[..colon_pos];
let port_part = &host_port[colon_pos + 1..];
if host_part.is_empty() {
anyhow::bail!("Empty hostname");
}
if port_part.is_empty() {
anyhow::bail!("Empty port specification");
}
match port_part.parse::<u16>() {
Ok(port) => {
if port == 0 {
anyhow::bail!("Port number cannot be zero");
}
Ok((host_part.to_string(), Some(port)))
}
Err(e) => {
if port_part.chars().all(|c| c.is_ascii_digit()) {
anyhow::bail!("Invalid port number: '{}' ({})", port_part, e);
} else {
Ok((host_port.to_string(), None))
}
}
}
} else {
Ok((host_port.to_string(), None))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_single_jump_host_hostname_only() {
let result = parse_single_jump_host("example.com").unwrap();
assert_eq!(result.host, "example.com");
assert_eq!(result.user, None);
assert_eq!(result.port, None);
}
#[test]
fn test_parse_single_jump_host_with_user() {
let result = parse_single_jump_host("admin@example.com").unwrap();
assert_eq!(result.host, "example.com");
assert_eq!(result.user, Some("admin".to_string()));
assert_eq!(result.port, None);
}
#[test]
fn test_parse_single_jump_host_with_port() {
let result = parse_single_jump_host("example.com:2222").unwrap();
assert_eq!(result.host, "example.com");
assert_eq!(result.user, None);
assert_eq!(result.port, Some(2222));
}
#[test]
fn test_parse_single_jump_host_with_user_and_port() {
let result = parse_single_jump_host("admin@example.com:2222").unwrap();
assert_eq!(result.host, "example.com");
assert_eq!(result.user, Some("admin".to_string()));
assert_eq!(result.port, Some(2222));
}
#[test]
fn test_parse_single_jump_host_ipv6_brackets() {
let result = parse_single_jump_host("[::1]").unwrap();
assert_eq!(result.host, "::1");
assert_eq!(result.user, None);
assert_eq!(result.port, None);
}
#[test]
fn test_parse_single_jump_host_ipv6_with_port() {
let result = parse_single_jump_host("[::1]:2222").unwrap();
assert_eq!(result.host, "::1");
assert_eq!(result.user, None);
assert_eq!(result.port, Some(2222));
}
#[test]
fn test_parse_single_jump_host_ipv6_with_user_and_port() {
let result = parse_single_jump_host("admin@[::1]:2222").unwrap();
assert_eq!(result.host, "::1");
assert_eq!(result.user, Some("admin".to_string()));
assert_eq!(result.port, Some(2222));
}
#[test]
fn test_parse_jump_hosts_multiple() {
let result = parse_jump_hosts("jump1@host1,user@host2:2222,host3").unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result[0].host, "host1");
assert_eq!(result[0].user, Some("jump1".to_string()));
assert_eq!(result[0].port, None);
assert_eq!(result[1].host, "host2");
assert_eq!(result[1].user, Some("user".to_string()));
assert_eq!(result[1].port, Some(2222));
assert_eq!(result[2].host, "host3");
assert_eq!(result[2].user, None);
assert_eq!(result[2].port, None);
}
#[test]
fn test_parse_jump_hosts_whitespace_handling() {
let result = parse_jump_hosts(" host1 , user@host2:2222 , host3 ").unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result[0].host, "host1");
assert_eq!(result[1].host, "host2");
assert_eq!(result[2].host, "host3");
}
#[test]
fn test_parse_jump_hosts_empty_string() {
let result = parse_jump_hosts("").unwrap();
assert_eq!(result.len(), 0);
}
#[test]
fn test_parse_jump_hosts_only_commas() {
let result = parse_jump_hosts(",,");
assert!(result.is_err()); }
#[test]
fn test_parse_single_jump_host_errors() {
assert!(parse_single_jump_host("").is_err());
assert!(parse_single_jump_host("@host").is_err());
assert!(parse_single_jump_host("user@").is_err());
assert!(parse_single_jump_host("host:").is_err());
assert!(parse_single_jump_host("host:0").is_err());
assert!(parse_single_jump_host("host:99999").is_err());
assert!(parse_single_jump_host("[::1").is_err());
assert!(parse_single_jump_host("[]").is_err());
}
#[test]
fn test_jump_host_display() {
let host = JumpHost::new("example.com".to_string(), None, None);
assert_eq!(format!("{host}"), "example.com");
let host = JumpHost::new("example.com".to_string(), Some("user".to_string()), None);
assert_eq!(format!("{host}"), "user@example.com");
let host = JumpHost::new("example.com".to_string(), None, Some(2222));
assert_eq!(format!("{host}"), "example.com:2222");
let host = JumpHost::new(
"example.com".to_string(),
Some("user".to_string()),
Some(2222),
);
assert_eq!(format!("{host}"), "user@example.com:2222");
}
#[test]
fn test_jump_host_effective_values() {
let host = JumpHost::new("example.com".to_string(), None, None);
assert_eq!(host.effective_port(), 22);
assert!(!host.effective_user().is_empty());
let host = JumpHost::new(
"example.com".to_string(),
Some("testuser".to_string()),
Some(2222),
);
assert_eq!(host.effective_port(), 2222);
assert_eq!(host.effective_user(), "testuser");
}
#[test]
fn test_max_jump_hosts_limit_exactly_10() {
let spec = (0..10)
.map(|i| format!("host{i}"))
.collect::<Vec<_>>()
.join(",");
let result = parse_jump_hosts(&spec);
assert!(result.is_ok(), "Should accept exactly 10 jump hosts");
assert_eq!(result.unwrap().len(), 10);
}
#[test]
fn test_max_jump_hosts_limit_11_rejected() {
let spec = (0..11)
.map(|i| format!("host{i}"))
.collect::<Vec<_>>()
.join(",");
let result = parse_jump_hosts(&spec);
assert!(result.is_err(), "Should reject 11 jump hosts");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("Too many jump hosts"),
"Error should mention 'Too many jump hosts', got: {err_msg}"
);
assert!(
err_msg.contains("11"),
"Error should mention the actual count (11), got: {err_msg}"
);
assert!(
err_msg.contains("10"),
"Error should mention the maximum (10), got: {err_msg}"
);
}
#[test]
fn test_max_jump_hosts_limit_excessive() {
let spec = (0..100)
.map(|i| format!("host{i}"))
.collect::<Vec<_>>()
.join(",");
let result = parse_jump_hosts(&spec);
assert!(
result.is_err(),
"Should reject excessive number of jump hosts"
);
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("Too many jump hosts"),
"Error should be about too many hosts, got: {err_msg}"
);
}
#[test]
#[serial_test::serial]
fn test_get_max_jump_hosts_default() {
std::env::remove_var("BSSH_MAX_JUMP_HOSTS");
let max = get_max_jump_hosts();
assert_eq!(max, 10, "Default should be 10");
}
#[test]
#[serial_test::serial]
fn test_get_max_jump_hosts_custom_value() {
unsafe {
std::env::set_var("BSSH_MAX_JUMP_HOSTS", "15");
}
let max = get_max_jump_hosts();
assert_eq!(max, 15, "Should use custom value from environment");
std::env::remove_var("BSSH_MAX_JUMP_HOSTS");
}
#[test]
#[serial_test::serial]
fn test_get_max_jump_hosts_capped_at_absolute_max() {
unsafe {
std::env::set_var("BSSH_MAX_JUMP_HOSTS", "50");
}
let max = get_max_jump_hosts();
assert_eq!(
max, 30,
"Should be capped at absolute maximum of 30 for security"
);
std::env::remove_var("BSSH_MAX_JUMP_HOSTS");
}
#[test]
#[serial_test::serial]
fn test_get_max_jump_hosts_zero_falls_back() {
unsafe {
std::env::set_var("BSSH_MAX_JUMP_HOSTS", "0");
}
let max = get_max_jump_hosts();
assert_eq!(max, 10, "Zero should fall back to default (10)");
std::env::remove_var("BSSH_MAX_JUMP_HOSTS");
}
#[test]
#[serial_test::serial]
fn test_get_max_jump_hosts_invalid_value() {
unsafe {
std::env::set_var("BSSH_MAX_JUMP_HOSTS", "invalid");
}
let max = get_max_jump_hosts();
assert_eq!(max, 10, "Invalid value should fall back to default (10)");
std::env::remove_var("BSSH_MAX_JUMP_HOSTS");
}
#[test]
#[serial_test::serial]
fn test_max_jump_hosts_respects_environment() {
unsafe {
std::env::set_var("BSSH_MAX_JUMP_HOSTS", "15");
}
let spec_15 = (0..15)
.map(|i| format!("host{i}"))
.collect::<Vec<_>>()
.join(",");
let result = parse_jump_hosts(&spec_15);
assert!(
result.is_ok(),
"Should accept 15 hosts when BSSH_MAX_JUMP_HOSTS=15"
);
assert_eq!(result.unwrap().len(), 15);
let spec_16 = (0..16)
.map(|i| format!("host{i}"))
.collect::<Vec<_>>()
.join(",");
let result = parse_jump_hosts(&spec_16);
assert!(
result.is_err(),
"Should reject 16 hosts when BSSH_MAX_JUMP_HOSTS=15"
);
std::env::remove_var("BSSH_MAX_JUMP_HOSTS");
}
}