use anyhow::{Context, Result};
use bssh::{cli::Cli, config::Config, hostlist, node::Node, ssh::SshConfig};
use glob::Pattern;
pub fn parse_node_with_ssh_config(node_str: &str, ssh_config: &SshConfig) -> Result<Node> {
if node_str.is_empty() {
anyhow::bail!("Node string cannot be empty");
}
if node_str.contains(';')
|| node_str.contains('&')
|| node_str.contains('|')
|| node_str.contains('`')
|| node_str.contains('$')
|| node_str.contains('\n')
{
anyhow::bail!("Node string contains invalid characters");
}
let (user_part, host_part) = if let Some(at_pos) = node_str.find('@') {
let user = &node_str[..at_pos];
let rest = &node_str[at_pos + 1..];
(Some(user), rest)
} else {
(None, node_str)
};
let (raw_host, cli_port) = if let Some(colon_pos) = host_part.rfind(':') {
let host = &host_part[..colon_pos];
let port_str = &host_part[colon_pos + 1..];
let port = port_str.parse::<u16>().context("Invalid port number")?;
(host, Some(port))
} else {
(host_part, None)
};
let validated_host = bssh::security::validate_hostname(raw_host)
.with_context(|| format!("Invalid hostname in node: {raw_host}"))?;
if let Some(user) = user_part {
bssh::security::validate_username(user)
.with_context(|| format!("Invalid username in node: {user}"))?;
}
let effective_hostname = ssh_config.get_effective_hostname(&validated_host);
let effective_user = if let Some(user) = user_part {
user.to_string()
} else if let Some(ssh_user) = ssh_config.get_effective_user(raw_host, None) {
ssh_user
} else {
std::env::var("USER")
.or_else(|_| std::env::var("USERNAME"))
.or_else(|_| std::env::var("LOGNAME"))
.unwrap_or_else(|_| {
whoami::username().unwrap_or_else(|_| "user".to_string())
})
};
let effective_port = ssh_config.get_effective_port(raw_host, cli_port);
Ok(Node::new(
effective_hostname,
effective_port,
effective_user,
))
}
pub async fn resolve_nodes(
cli: &Cli,
config: &Config,
ssh_config: &SshConfig,
) -> Result<(Vec<Node>, Option<String>)> {
let mut nodes = Vec::new();
let mut cluster_name = None;
if cli.is_ssh_mode() {
let (user, host, port) = cli
.parse_destination()
.ok_or_else(|| anyhow::anyhow!("Invalid destination format"))?;
let effective_hostname = ssh_config.get_effective_hostname(&host);
let effective_user = if let Some(u) = user {
u
} else if let Some(cli_user) = cli.get_effective_user() {
cli_user
} else if let Some(ssh_user) = ssh_config.get_effective_user(&host, None) {
ssh_user
} else if let Ok(env_user) = std::env::var("USER") {
env_user
} else {
"root".to_string()
};
let effective_port =
ssh_config.get_effective_port(&host, port.or_else(|| cli.get_effective_port()));
let node = Node::new(effective_hostname, effective_port, effective_user);
nodes.push(node);
} else if let Some(hosts) = &cli.hosts {
for host_str in hosts {
let expanded_hosts = hostlist::expander::expand_host_specs(host_str)
.with_context(|| format!("Failed to expand host expression: {host_str}"))?;
for single_host in expanded_hosts {
let node = parse_node_with_ssh_config(&single_host, ssh_config)?;
nodes.push(node);
}
}
} else if let Some(cli_cluster_name) = &cli.cluster {
nodes = config.resolve_nodes(cli_cluster_name)?;
cluster_name = Some(cli_cluster_name.clone());
} else {
if config.clusters.contains_key("bai_auto") {
nodes = config.resolve_nodes("bai_auto")?;
cluster_name = Some("bai_auto".to_string());
}
}
if let Some(filter) = cli.get_host_filter() {
nodes = filter_nodes_with_hostlist(nodes, filter)?;
if nodes.is_empty() {
anyhow::bail!("No hosts matched the filter pattern: {filter}");
}
}
if let Some(exclude_patterns) = cli.get_exclude_patterns() {
let node_count_before = nodes.len();
nodes = exclude_nodes_with_hostlist(nodes, exclude_patterns)?;
if nodes.is_empty() {
let patterns_str = exclude_patterns.join(", ");
anyhow::bail!(
"All {node_count_before} hosts were excluded by pattern(s): {patterns_str}"
);
}
}
Ok((nodes, cluster_name))
}
fn pattern_matches_node(pattern: &Pattern, node: &Node) -> bool {
pattern.matches(&node.host) || pattern.matches(&node.to_string())
}
pub fn exclude_nodes(nodes: Vec<Node>, patterns: &[String]) -> Result<Vec<Node>> {
if patterns.is_empty() {
return Ok(nodes);
}
let mut compiled_patterns = Vec::with_capacity(patterns.len());
for pattern in patterns {
const MAX_PATTERN_LENGTH: usize = 256;
if pattern.len() > MAX_PATTERN_LENGTH {
anyhow::bail!("Exclusion pattern too long (max {MAX_PATTERN_LENGTH} characters)");
}
if pattern.is_empty() {
anyhow::bail!("Exclusion pattern cannot be empty");
}
let wildcard_count = pattern.chars().filter(|c| *c == '*' || *c == '?').count();
const MAX_WILDCARDS: usize = 10;
if wildcard_count > MAX_WILDCARDS {
anyhow::bail!("Exclusion pattern contains too many wildcards (max {MAX_WILDCARDS})");
}
if pattern.contains("..") || pattern.contains("//") {
anyhow::bail!("Exclusion pattern contains invalid sequences");
}
let valid_chars = pattern.chars().all(|c| {
c.is_ascii_alphanumeric()
|| c == '.'
|| c == '-'
|| c == '_'
|| c == '@'
|| c == ':'
|| c == '*'
|| c == '?'
|| c == '['
|| c == ']'
|| c == '!'
});
if !valid_chars {
anyhow::bail!("Exclusion pattern contains invalid characters for hostname matching");
}
let glob_pattern = Pattern::new(pattern)
.with_context(|| format!("Invalid exclusion pattern: {pattern}"))?;
compiled_patterns.push((pattern.clone(), glob_pattern));
}
let filtered: Vec<Node> = nodes
.into_iter()
.filter(|node| {
!compiled_patterns.iter().any(|(raw_pattern, glob_pattern)| {
if !raw_pattern.contains('*')
&& !raw_pattern.contains('?')
&& !raw_pattern.contains('[')
{
node.host == *raw_pattern
|| node.to_string() == *raw_pattern
|| node.host.contains(raw_pattern.as_str())
} else {
pattern_matches_node(glob_pattern, node)
}
})
})
.collect();
Ok(filtered)
}
pub fn filter_nodes(nodes: Vec<Node>, pattern: &str) -> Result<Vec<Node>> {
const MAX_PATTERN_LENGTH: usize = 256;
if pattern.len() > MAX_PATTERN_LENGTH {
anyhow::bail!("Filter pattern too long (max {MAX_PATTERN_LENGTH} characters)");
}
if pattern.is_empty() {
anyhow::bail!("Filter pattern cannot be empty");
}
let wildcard_count = pattern.chars().filter(|c| *c == '*' || *c == '?').count();
const MAX_WILDCARDS: usize = 10;
if wildcard_count > MAX_WILDCARDS {
anyhow::bail!("Filter pattern contains too many wildcards (max {MAX_WILDCARDS})");
}
if pattern.contains("..") || pattern.contains("//") {
anyhow::bail!("Filter pattern contains invalid sequences");
}
let valid_chars = pattern.chars().all(|c| {
c.is_ascii_alphanumeric()
|| c == '.'
|| c == '-'
|| c == '_'
|| c == '@'
|| c == ':'
|| c == '*'
|| c == '?'
|| c == '['
|| c == ']'
|| c == '!'
});
if !valid_chars {
anyhow::bail!("Filter pattern contains invalid characters for hostname matching");
}
if pattern.contains('*') || pattern.contains('?') || pattern.contains('[') {
let glob_pattern =
Pattern::new(pattern).with_context(|| format!("Invalid filter pattern: {pattern}"))?;
let mut matched_nodes = Vec::with_capacity(nodes.len());
for node in nodes {
let host_matches = glob_pattern.matches(&node.host);
let full_matches = if !host_matches {
glob_pattern.matches(&node.to_string())
} else {
true
};
if host_matches || full_matches {
matched_nodes.push(node);
}
}
Ok(matched_nodes)
} else {
Ok(nodes
.into_iter()
.filter(|node| {
node.host == pattern || node.to_string() == pattern || node.host.contains(pattern)
})
.collect())
}
}
pub fn filter_nodes_with_hostlist(nodes: Vec<Node>, pattern: &str) -> Result<Vec<Node>> {
if pattern.is_empty() {
anyhow::bail!("Filter pattern cannot be empty");
}
if hostlist::is_hostlist_expression(pattern) {
let expanded_patterns = hostlist::expander::expand_host_specs(pattern)
.with_context(|| format!("Failed to expand filter pattern: {pattern}"))?;
let pattern_set: std::collections::HashSet<&str> =
expanded_patterns.iter().map(|s| s.as_str()).collect();
let filtered: Vec<Node> = nodes
.into_iter()
.filter(|node| {
pattern_set.contains(node.host.as_str())
|| pattern_set.contains(node.to_string().as_str())
})
.collect();
Ok(filtered)
} else {
filter_nodes(nodes, pattern)
}
}
pub fn exclude_nodes_with_hostlist(nodes: Vec<Node>, patterns: &[String]) -> Result<Vec<Node>> {
if patterns.is_empty() {
return Ok(nodes);
}
let mut expanded_patterns = Vec::new();
let mut glob_patterns = Vec::new();
for pattern in patterns {
if hostlist::is_hostlist_expression(pattern) {
let expanded = hostlist::expander::expand_host_specs(pattern)
.with_context(|| format!("Failed to expand exclusion pattern: {pattern}"))?;
expanded_patterns.extend(expanded);
} else {
glob_patterns.push(pattern.clone());
}
}
let expanded_set: std::collections::HashSet<&str> =
expanded_patterns.iter().map(|s| s.as_str()).collect();
let mut filtered: Vec<Node> = nodes
.into_iter()
.filter(|node| {
!expanded_set.contains(node.host.as_str())
&& !expanded_set.contains(node.to_string().as_str())
})
.collect();
if !glob_patterns.is_empty() {
filtered = exclude_nodes(filtered, &glob_patterns)?;
}
Ok(filtered)
}
#[cfg(test)]
mod tests {
use super::*;
use bssh::hostlist::is_hostlist_expression;
fn create_test_nodes() -> Vec<Node> {
vec![
Node::new("web1.example.com".to_string(), 22, "admin".to_string()),
Node::new("web2.example.com".to_string(), 22, "admin".to_string()),
Node::new("db1.example.com".to_string(), 22, "admin".to_string()),
Node::new("db2.example.com".to_string(), 22, "admin".to_string()),
Node::new(
"cache-backup.example.com".to_string(),
22,
"admin".to_string(),
),
]
}
#[test]
fn test_exclude_single_host_exact() {
let nodes = create_test_nodes();
let patterns = vec!["web1.example.com".to_string()];
let result = exclude_nodes(nodes, &patterns).unwrap();
assert_eq!(result.len(), 4);
assert!(!result.iter().any(|n| n.host == "web1.example.com"));
}
#[test]
fn test_exclude_multiple_hosts() {
let nodes = create_test_nodes();
let patterns = vec![
"web1.example.com".to_string(),
"db1.example.com".to_string(),
];
let result = exclude_nodes(nodes, &patterns).unwrap();
assert_eq!(result.len(), 3);
assert!(!result.iter().any(|n| n.host == "web1.example.com"));
assert!(!result.iter().any(|n| n.host == "db1.example.com"));
}
#[test]
fn test_exclude_with_wildcard_prefix() {
let nodes = create_test_nodes();
let patterns = vec!["db*".to_string()];
let result = exclude_nodes(nodes, &patterns).unwrap();
assert_eq!(result.len(), 3);
assert!(!result.iter().any(|n| n.host.starts_with("db")));
}
#[test]
fn test_exclude_with_wildcard_suffix() {
let nodes = create_test_nodes();
let patterns = vec!["*-backup*".to_string()];
let result = exclude_nodes(nodes, &patterns).unwrap();
assert_eq!(result.len(), 4);
assert!(!result.iter().any(|n| n.host.contains("-backup")));
}
#[test]
fn test_exclude_with_question_mark_wildcard() {
let nodes = create_test_nodes();
let patterns = vec!["web?.example.com".to_string()];
let result = exclude_nodes(nodes, &patterns).unwrap();
assert_eq!(result.len(), 3);
assert!(!result.iter().any(|n| n.host.starts_with("web")));
}
#[test]
fn test_exclude_multiple_patterns_with_wildcards() {
let nodes = create_test_nodes();
let patterns = vec!["web*".to_string(), "db*".to_string()];
let result = exclude_nodes(nodes, &patterns).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].host, "cache-backup.example.com");
}
#[test]
fn test_exclude_empty_patterns() {
let nodes = create_test_nodes();
let patterns: Vec<String> = vec![];
let result = exclude_nodes(nodes.clone(), &patterns).unwrap();
assert_eq!(result.len(), nodes.len());
}
#[test]
fn test_exclude_no_matches() {
let nodes = create_test_nodes();
let patterns = vec!["nonexistent*".to_string()];
let result = exclude_nodes(nodes.clone(), &patterns).unwrap();
assert_eq!(result.len(), nodes.len());
}
#[test]
fn test_exclude_all_hosts_returns_empty() {
let nodes = create_test_nodes();
let patterns = vec!["*".to_string()];
let result = exclude_nodes(nodes, &patterns).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_exclude_pattern_too_long() {
let nodes = create_test_nodes();
let long_pattern = "a".repeat(300);
let patterns = vec![long_pattern];
let result = exclude_nodes(nodes, &patterns);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("too long"));
}
#[test]
fn test_exclude_empty_pattern() {
let nodes = create_test_nodes();
let patterns = vec!["".to_string()];
let result = exclude_nodes(nodes, &patterns);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("cannot be empty"));
}
#[test]
fn test_exclude_too_many_wildcards() {
let nodes = create_test_nodes();
let patterns = vec!["*a*b*c*d*e*f*g*h*i*j*k*".to_string()];
let result = exclude_nodes(nodes, &patterns);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("too many wildcards"));
}
#[test]
fn test_exclude_invalid_characters() {
let nodes = create_test_nodes();
let patterns = vec!["host;rm -rf /".to_string()];
let result = exclude_nodes(nodes, &patterns);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("invalid characters"));
}
#[test]
fn test_exclude_path_traversal_attempt() {
let nodes = create_test_nodes();
let patterns = vec!["../etc/passwd".to_string()];
let result = exclude_nodes(nodes, &patterns);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("invalid sequences"));
}
#[test]
fn test_exclude_partial_hostname_match() {
let nodes = create_test_nodes();
let patterns = vec!["web".to_string()];
let result = exclude_nodes(nodes, &patterns).unwrap();
assert_eq!(result.len(), 3);
assert!(!result.iter().any(|n| n.host.contains("web")));
}
#[test]
fn test_filter_and_exclude_combined() {
let nodes = create_test_nodes();
let filtered = filter_nodes(nodes, "*.example.com").unwrap();
assert_eq!(filtered.len(), 5);
let patterns = vec!["db*".to_string()];
let result = exclude_nodes(filtered, &patterns).unwrap();
assert_eq!(result.len(), 3);
assert!(!result.iter().any(|n| n.host.starts_with("db")));
}
#[test]
fn test_exclude_with_bracket_pattern() {
let nodes = create_test_nodes();
let patterns = vec!["db[12].example.com".to_string()];
let result = exclude_nodes(nodes, &patterns).unwrap();
assert_eq!(result.len(), 3);
assert!(!result.iter().any(|n| n.host == "db1.example.com"));
assert!(!result.iter().any(|n| n.host == "db2.example.com"));
assert!(result.iter().any(|n| n.host == "web1.example.com"));
}
#[test]
fn test_filter_with_bracket_pattern() {
let nodes = create_test_nodes();
let result = filter_nodes(nodes, "web[12].example.com").unwrap();
assert_eq!(result.len(), 2);
assert!(result.iter().any(|n| n.host == "web1.example.com"));
assert!(result.iter().any(|n| n.host == "web2.example.com"));
}
#[test]
fn test_exclude_with_bracket_negation_pattern() {
let nodes = vec![
Node::new("web1.example.com".to_string(), 22, "admin".to_string()),
Node::new("web2.example.com".to_string(), 22, "admin".to_string()),
Node::new("web3.example.com".to_string(), 22, "admin".to_string()),
Node::new("weba.example.com".to_string(), 22, "admin".to_string()),
];
let patterns = vec!["web[!12].example.com".to_string()];
let result = exclude_nodes(nodes, &patterns).unwrap();
assert_eq!(result.len(), 2);
assert!(result.iter().any(|n| n.host == "web1.example.com"));
assert!(result.iter().any(|n| n.host == "web2.example.com"));
}
#[test]
fn test_is_hostlist_expression_numeric_range() {
assert!(is_hostlist_expression("node[1-5]"));
assert!(is_hostlist_expression("node[01-05]"));
assert!(is_hostlist_expression("node[1,2,3]"));
assert!(is_hostlist_expression("node[1-3,5-7]"));
assert!(is_hostlist_expression("rack[1-2]-node[1-3]"));
}
#[test]
fn test_is_hostlist_expression_glob_pattern() {
assert!(!is_hostlist_expression("web*"));
assert!(!is_hostlist_expression("web[abc]"));
assert!(!is_hostlist_expression("web[a-z]"));
assert!(!is_hostlist_expression("web[!12]"));
assert!(!is_hostlist_expression("simple.host.com"));
}
#[test]
fn test_filter_nodes_with_hostlist_range() {
let nodes = vec![
Node::new("node1".to_string(), 22, "admin".to_string()),
Node::new("node2".to_string(), 22, "admin".to_string()),
Node::new("node3".to_string(), 22, "admin".to_string()),
Node::new("node4".to_string(), 22, "admin".to_string()),
Node::new("node5".to_string(), 22, "admin".to_string()),
];
let result = filter_nodes_with_hostlist(nodes, "node[1-3]").unwrap();
assert_eq!(result.len(), 3);
assert!(result.iter().any(|n| n.host == "node1"));
assert!(result.iter().any(|n| n.host == "node2"));
assert!(result.iter().any(|n| n.host == "node3"));
assert!(!result.iter().any(|n| n.host == "node4"));
assert!(!result.iter().any(|n| n.host == "node5"));
}
#[test]
fn test_filter_nodes_with_hostlist_comma_separated() {
let nodes = vec![
Node::new("node1".to_string(), 22, "admin".to_string()),
Node::new("node2".to_string(), 22, "admin".to_string()),
Node::new("node3".to_string(), 22, "admin".to_string()),
Node::new("node4".to_string(), 22, "admin".to_string()),
Node::new("node5".to_string(), 22, "admin".to_string()),
];
let result = filter_nodes_with_hostlist(nodes, "node[1,3,5]").unwrap();
assert_eq!(result.len(), 3);
assert!(result.iter().any(|n| n.host == "node1"));
assert!(result.iter().any(|n| n.host == "node3"));
assert!(result.iter().any(|n| n.host == "node5"));
}
#[test]
fn test_filter_nodes_with_hostlist_falls_back_to_glob() {
let nodes = create_test_nodes();
let result = filter_nodes_with_hostlist(nodes, "web*").unwrap();
assert_eq!(result.len(), 2);
assert!(result.iter().all(|n| n.host.starts_with("web")));
}
#[test]
fn test_exclude_nodes_with_hostlist_range() {
let nodes = vec![
Node::new("node1".to_string(), 22, "admin".to_string()),
Node::new("node2".to_string(), 22, "admin".to_string()),
Node::new("node3".to_string(), 22, "admin".to_string()),
Node::new("node4".to_string(), 22, "admin".to_string()),
Node::new("node5".to_string(), 22, "admin".to_string()),
];
let patterns = vec!["node[2-4]".to_string()];
let result = exclude_nodes_with_hostlist(nodes, &patterns).unwrap();
assert_eq!(result.len(), 2);
assert!(result.iter().any(|n| n.host == "node1"));
assert!(result.iter().any(|n| n.host == "node5"));
}
#[test]
fn test_exclude_nodes_with_hostlist_mixed_patterns() {
let nodes = vec![
Node::new("node1".to_string(), 22, "admin".to_string()),
Node::new("node2".to_string(), 22, "admin".to_string()),
Node::new("node3".to_string(), 22, "admin".to_string()),
Node::new("web1".to_string(), 22, "admin".to_string()),
Node::new("web2".to_string(), 22, "admin".to_string()),
];
let patterns = vec!["node[1-2]".to_string(), "web*".to_string()];
let result = exclude_nodes_with_hostlist(nodes, &patterns).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].host, "node3");
}
#[test]
fn test_exclude_nodes_with_hostlist_falls_back_to_glob() {
let nodes = create_test_nodes();
let patterns = vec!["db*".to_string()];
let result = exclude_nodes_with_hostlist(nodes, &patterns).unwrap();
assert_eq!(result.len(), 3);
assert!(!result.iter().any(|n| n.host.starts_with("db")));
}
#[test]
fn test_filter_nodes_with_hostlist_zero_padded() {
let nodes = vec![
Node::new("node01".to_string(), 22, "admin".to_string()),
Node::new("node02".to_string(), 22, "admin".to_string()),
Node::new("node03".to_string(), 22, "admin".to_string()),
Node::new("node04".to_string(), 22, "admin".to_string()),
Node::new("node05".to_string(), 22, "admin".to_string()),
];
let result = filter_nodes_with_hostlist(nodes, "node[01-03]").unwrap();
assert_eq!(result.len(), 3);
assert!(result.iter().any(|n| n.host == "node01"));
assert!(result.iter().any(|n| n.host == "node02"));
assert!(result.iter().any(|n| n.host == "node03"));
}
#[test]
fn test_exclude_nodes_with_hostlist_cartesian_product() {
let nodes = vec![
Node::new("rack1-node1".to_string(), 22, "admin".to_string()),
Node::new("rack1-node2".to_string(), 22, "admin".to_string()),
Node::new("rack2-node1".to_string(), 22, "admin".to_string()),
Node::new("rack2-node2".to_string(), 22, "admin".to_string()),
Node::new("rack3-node1".to_string(), 22, "admin".to_string()),
];
let patterns = vec!["rack[1-2]-node[1-2]".to_string()];
let result = exclude_nodes_with_hostlist(nodes, &patterns).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].host, "rack3-node1");
}
}