use anyhow::Result;
use std::collections::HashMap;
use super::pattern::matches_pattern;
mod exec;
#[allow(unused_imports)]
pub use exec::{execute_match_command, expand_variables, validate_exec_command};
#[derive(Debug, Clone, PartialEq)]
pub enum MatchCondition {
Host(Vec<String>),
User(Vec<String>),
LocalUser(Vec<String>),
Exec(String),
All,
}
#[derive(Debug, Clone)]
pub struct MatchBlock {
pub conditions: Vec<MatchCondition>,
pub config: super::types::SshHostConfig,
#[allow(dead_code)]
pub line_number: usize,
}
impl MatchBlock {
pub fn new(line_number: usize) -> Self {
Self {
conditions: Vec::new(),
config: super::types::SshHostConfig::default(),
line_number,
}
}
pub fn matches(&self, context: &MatchContext) -> Result<bool> {
for condition in &self.conditions {
if !condition.matches(context)? {
return Ok(false);
}
}
Ok(true)
}
}
#[derive(Debug, Clone)]
pub struct MatchContext {
pub hostname: String,
pub remote_user: Option<String>,
pub local_user: String,
pub variables: HashMap<String, String>,
}
impl MatchContext {
pub fn new(hostname: String, remote_user: Option<String>) -> Result<Self> {
let local_user = whoami::username().unwrap_or_else(|_| "user".to_string());
let mut variables = HashMap::new();
variables.insert("h".to_string(), hostname.clone());
variables.insert("host".to_string(), hostname.clone());
variables.insert("l".to_string(), local_user.clone());
variables.insert("localuser".to_string(), local_user.clone());
if let Some(ref user) = remote_user {
variables.insert("u".to_string(), user.clone());
variables.insert("user".to_string(), user.clone());
}
Ok(Self {
hostname,
remote_user,
local_user,
variables,
})
}
}
impl MatchCondition {
pub fn parse_match_line(line: &str, line_number: usize) -> Result<Vec<MatchCondition>> {
let line = line.trim();
let conditions_str = if line.to_lowercase().starts_with("match ") {
&line[6..]
} else if let Some(pos) = line.find('=') {
if line[..pos].trim().to_lowercase() == "match" {
line[pos + 1..].trim()
} else {
anyhow::bail!("Invalid Match directive at line {line_number}");
}
} else {
anyhow::bail!("Invalid Match directive at line {line_number}");
};
if conditions_str.is_empty() {
anyhow::bail!("Match directive requires conditions at line {line_number}");
}
let mut conditions = Vec::new();
let mut parts = conditions_str.split_whitespace();
while let Some(keyword) = parts.next() {
let keyword_lower = keyword.to_lowercase();
match keyword_lower.as_str() {
"host" => {
let patterns = collect_patterns(&mut parts)?;
if patterns.is_empty() {
anyhow::bail!("Match host requires patterns at line {line_number}");
}
conditions.push(MatchCondition::Host(patterns));
}
"user" => {
let patterns = collect_patterns(&mut parts)?;
if patterns.is_empty() {
anyhow::bail!("Match user requires patterns at line {line_number}");
}
conditions.push(MatchCondition::User(patterns));
}
"localuser" => {
let patterns = collect_patterns(&mut parts)?;
if patterns.is_empty() {
anyhow::bail!("Match localuser requires patterns at line {line_number}");
}
conditions.push(MatchCondition::LocalUser(patterns));
}
"exec" => {
let remaining: Vec<&str> = parts.collect();
if remaining.is_empty() {
anyhow::bail!("Match exec requires a command at line {line_number}");
}
let exec_part = conditions_str
[conditions_str.to_lowercase().find("exec").unwrap() + 4..]
.trim();
let command = if exec_part.starts_with('"') && exec_part.ends_with('"') {
exec_part[1..exec_part.len() - 1].to_string()
} else {
remaining.join(" ")
};
conditions.push(MatchCondition::Exec(command));
break; }
"all" => {
conditions.push(MatchCondition::All);
}
_ => {
anyhow::bail!("Unknown Match condition '{keyword}' at line {line_number}");
}
}
}
if conditions.is_empty() {
anyhow::bail!("Match directive requires at least one condition at line {line_number}");
}
Ok(conditions)
}
pub fn matches(&self, context: &MatchContext) -> Result<bool> {
match self {
MatchCondition::Host(patterns) => {
for pattern in patterns {
if matches_pattern(&context.hostname, pattern) {
return Ok(true);
}
}
Ok(false)
}
MatchCondition::User(patterns) => {
if let Some(ref user) = context.remote_user {
for pattern in patterns {
if matches_pattern(user, pattern) {
return Ok(true);
}
}
}
Ok(false)
}
MatchCondition::LocalUser(patterns) => {
for pattern in patterns {
if matches_pattern(&context.local_user, pattern) {
return Ok(true);
}
}
Ok(false)
}
MatchCondition::Exec(command) => {
execute_match_command(command, context)
}
MatchCondition::All => {
Ok(true)
}
}
}
}
fn collect_patterns(parts: &mut std::str::SplitWhitespace) -> Result<Vec<String>> {
let mut patterns = Vec::new();
let remaining: Vec<&str> = parts.clone().collect();
for part in remaining {
let lower = part.to_lowercase();
if matches!(
lower.as_str(),
"host" | "user" | "localuser" | "exec" | "all"
) {
break;
}
patterns.push(part.to_string());
parts.next();
}
Ok(patterns)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_match_conditions() {
let conditions = MatchCondition::parse_match_line("Match host *.example.com", 1).unwrap();
assert_eq!(conditions.len(), 1);
match &conditions[0] {
MatchCondition::Host(patterns) => assert_eq!(patterns, &["*.example.com"]),
_ => panic!("Expected Host condition"),
}
let conditions =
MatchCondition::parse_match_line("Match host *.example.com user admin", 1).unwrap();
assert_eq!(conditions.len(), 2);
let conditions = MatchCondition::parse_match_line("Match all", 1).unwrap();
assert_eq!(conditions.len(), 1);
assert_eq!(conditions[0], MatchCondition::All);
let conditions =
MatchCondition::parse_match_line("Match exec \"test -f /tmp/vpn\"", 1).unwrap();
assert_eq!(conditions.len(), 1);
match &conditions[0] {
MatchCondition::Exec(cmd) => assert_eq!(cmd, "test -f /tmp/vpn"),
_ => panic!("Expected Exec condition"),
}
}
#[test]
fn test_match_host_condition() {
let context =
MatchContext::new("web1.example.com".to_string(), Some("testuser".to_string()))
.unwrap();
let condition = MatchCondition::Host(vec!["*.example.com".to_string()]);
assert!(condition.matches(&context).unwrap());
let condition = MatchCondition::Host(vec!["*.test.com".to_string()]);
assert!(!condition.matches(&context).unwrap());
}
#[test]
fn test_match_user_condition() {
let context =
MatchContext::new("example.com".to_string(), Some("admin".to_string())).unwrap();
let condition = MatchCondition::User(vec!["admin".to_string()]);
assert!(condition.matches(&context).unwrap());
let condition = MatchCondition::User(vec!["root".to_string()]);
assert!(!condition.matches(&context).unwrap());
let context_no_user = MatchContext::new("example.com".to_string(), None).unwrap();
let condition = MatchCondition::User(vec!["admin".to_string()]);
assert!(!condition.matches(&context_no_user).unwrap());
}
#[test]
fn test_match_localuser_condition() {
let context = MatchContext::new("example.com".to_string(), None).unwrap();
let local_user = whoami::username().unwrap();
let condition = MatchCondition::LocalUser(vec![local_user.clone()]);
assert!(condition.matches(&context).unwrap());
let condition = MatchCondition::LocalUser(vec!["nonexistentuser12345".to_string()]);
assert!(!condition.matches(&context).unwrap());
}
#[test]
fn test_match_all_condition() {
let context = MatchContext::new("example.com".to_string(), None).unwrap();
let condition = MatchCondition::All;
assert!(condition.matches(&context).unwrap());
}
#[test]
fn test_match_block() {
let mut block = MatchBlock::new(10);
block
.conditions
.push(MatchCondition::Host(vec!["*.example.com".to_string()]));
block
.conditions
.push(MatchCondition::User(vec!["admin".to_string()]));
let context =
MatchContext::new("web.example.com".to_string(), Some("admin".to_string())).unwrap();
assert!(block.matches(&context).unwrap());
let context =
MatchContext::new("web.example.com".to_string(), Some("guest".to_string())).unwrap();
assert!(!block.matches(&context).unwrap());
let context =
MatchContext::new("web.test.com".to_string(), Some("admin".to_string())).unwrap();
assert!(!block.matches(&context).unwrap());
}
#[test]
fn test_match_host_with_negation() {
let context_internal =
MatchContext::new("web.internal.com".to_string(), Some("testuser".to_string()))
.unwrap();
let context_external = MatchContext::new("web.example.com".to_string(), None).unwrap();
let condition = MatchCondition::Host(vec!["!*.internal.com".to_string()]);
assert!(!condition.matches(&context_internal).unwrap());
assert!(condition.matches(&context_external).unwrap());
let condition = MatchCondition::Host(vec!["!db*.example.com".to_string()]);
let context_db = MatchContext::new("db1.example.com".to_string(), None).unwrap();
let context_web = MatchContext::new("web.example.com".to_string(), None).unwrap();
assert!(!condition.matches(&context_db).unwrap());
assert!(condition.matches(&context_web).unwrap());
let condition = MatchCondition::Host(vec!["!production.example.com".to_string()]);
let context_prod = MatchContext::new("production.example.com".to_string(), None).unwrap();
let context_staging = MatchContext::new("staging.example.com".to_string(), None).unwrap();
assert!(!condition.matches(&context_prod).unwrap());
assert!(condition.matches(&context_staging).unwrap());
}
#[test]
fn test_match_user_multiple_patterns() {
let context =
MatchContext::new("example.com".to_string(), Some("admin".to_string())).unwrap();
let condition = MatchCondition::User(vec!["admin".to_string(), "root".to_string()]);
assert!(condition.matches(&context).unwrap());
let condition = MatchCondition::User(vec!["root".to_string(), "operator".to_string()]);
assert!(!condition.matches(&context).unwrap());
}
#[test]
fn test_match_localuser_with_wildcards() {
let context = MatchContext::new("example.com".to_string(), None).unwrap();
let local_user = whoami::username().unwrap();
if local_user.len() > 2 {
let pattern = format!("{}*", &local_user[..2]);
let condition = MatchCondition::LocalUser(vec![pattern]);
assert!(condition.matches(&context).unwrap());
}
let condition = MatchCondition::LocalUser(vec!["!nonexistent*".to_string()]);
assert!(condition.matches(&context).unwrap());
}
#[test]
fn test_parse_match_complex_conditions() {
let conditions = MatchCondition::parse_match_line(
"Match host *.example.com,!db*.example.com user admin,root",
1,
)
.unwrap();
assert_eq!(conditions.len(), 2);
let conditions =
MatchCondition::parse_match_line("Match exec \"test -f /tmp/%h.lock\"", 1).unwrap();
assert_eq!(conditions.len(), 1);
match &conditions[0] {
MatchCondition::Exec(cmd) => assert!(cmd.contains("%h")),
_ => panic!("Expected Exec condition"),
}
}
#[test]
fn test_match_block_all_conditions() {
let mut block = MatchBlock::new(10);
block.conditions.push(MatchCondition::All);
let context1 = MatchContext::new("anything.com".to_string(), None).unwrap();
let context2 =
MatchContext::new("example.com".to_string(), Some("admin".to_string())).unwrap();
assert!(block.matches(&context1).unwrap());
assert!(block.matches(&context2).unwrap());
let mut block2 = MatchBlock::new(10);
block2.conditions.push(MatchCondition::All);
block2
.conditions
.push(MatchCondition::Host(vec!["*.example.com".to_string()]));
let context_match = MatchContext::new("web.example.com".to_string(), None).unwrap();
let context_nomatch = MatchContext::new("web.other.com".to_string(), None).unwrap();
assert!(block2.matches(&context_match).unwrap());
assert!(!block2.matches(&context_nomatch).unwrap());
}
}