use anyhow::{Context, Result};
use std::path::PathBuf;
use super::super::path::expand_path_internal;
#[cfg(not(test))]
use super::validation::is_path_allowed;
use super::validation::{validate_glob_pattern, validate_include_path};
use crate::ssh::ssh_config::include::IncludeContext;
pub fn parse_include_line(line: &str) -> Option<Vec<&str>> {
let line = line.trim();
if !line.to_lowercase().starts_with("include") {
return None;
}
let patterns_part = if let Some(pos) = line.find('=') {
line[pos + 1..].trim()
} else {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() < 2 || parts[0].to_lowercase() != "include" {
return None;
}
line[parts[0].len()..].trim()
};
if patterns_part.is_empty() {
return None;
}
let patterns: Vec<&str> = patterns_part.split_whitespace().collect();
if patterns.is_empty() {
None
} else {
Some(patterns)
}
}
pub async fn resolve_include_pattern(
pattern: &str,
context: &IncludeContext,
) -> Result<Vec<PathBuf>> {
validate_glob_pattern(pattern)?;
let expanded = expand_path_internal(pattern)?;
let search_path = if expanded.is_relative() {
context.base_dir.join(&expanded)
} else {
expanded
};
let pattern_str = search_path
.to_str()
.ok_or_else(|| anyhow::anyhow!("Invalid UTF-8 in path: {:?}", search_path))?;
validate_glob_pattern(pattern_str)?;
const MAX_GLOB_RESULTS: usize = 100;
let mut files = Vec::new();
let glob_options = glob::MatchOptions {
case_sensitive: true,
require_literal_separator: true, require_literal_leading_dot: true, };
for entry in glob::glob_with(pattern_str, glob_options)
.with_context(|| format!("Invalid glob pattern: {}", pattern_str))?
{
if files.len() >= MAX_GLOB_RESULTS {
anyhow::bail!(
"Glob pattern '{}' matched too many files (>{MAX_GLOB_RESULTS}). \
Please use a more specific pattern.",
pattern
);
}
match entry {
Ok(path) => {
#[cfg(not(test))]
{
let canonical = match path.canonicalize() {
Ok(c) => c,
Err(_) if !path.exists() => continue, Err(e) => {
tracing::debug!("Failed to canonicalize {}: {}", path.display(), e);
continue;
}
};
if !is_path_allowed(&canonical) {
tracing::warn!(
"Glob result {} escapes allowed directories, skipping",
path.display()
);
continue;
}
}
match std::fs::symlink_metadata(&path) {
Ok(metadata) => {
if metadata.is_file() && !metadata.is_symlink() {
if validate_include_path(&path).is_ok() {
files.push(path);
}
}
}
Err(e) => {
tracing::debug!("Failed to get metadata for {}: {}", path.display(), e);
}
}
}
Err(e) => {
tracing::warn!("Error processing glob pattern '{}': {}", pattern_str, e);
}
}
}
files.sort();
if files.is_empty() && !pattern.contains('*') && !pattern.contains('?') {
tracing::debug!(
"Include pattern '{}' matched no files (this may be intentional)",
pattern
);
}
Ok(files)
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[test]
fn test_parse_include_line() {
assert_eq!(
parse_include_line("Include ~/.ssh/config.d/*"),
Some(vec!["~/.ssh/config.d/*"])
);
assert_eq!(
parse_include_line("Include=~/.ssh/config.d/*"),
Some(vec!["~/.ssh/config.d/*"])
);
assert_eq!(
parse_include_line("Include /etc/ssh/config.d/* ~/.ssh/extra/*"),
Some(vec!["/etc/ssh/config.d/*", "~/.ssh/extra/*"])
);
assert_eq!(
parse_include_line("include ~/.ssh/config.d/*"),
Some(vec!["~/.ssh/config.d/*"])
);
assert_eq!(parse_include_line("Host example.com"), None);
assert_eq!(parse_include_line("User testuser"), None);
}
#[tokio::test]
async fn test_resolve_include_pattern_glob() {
use crate::ssh::ssh_config::include::IncludeContext;
let temp_dir = TempDir::new().unwrap();
let config_dir = temp_dir.path().join("config.d");
fs::create_dir(&config_dir).unwrap();
fs::write(config_dir.join("01-first.conf"), "Host first\n").unwrap();
fs::write(config_dir.join("02-second.conf"), "Host second\n").unwrap();
fs::write(config_dir.join("03-third.conf"), "Host third\n").unwrap();
let main_config = temp_dir.path().join("config");
fs::write(&main_config, "").unwrap();
let context = IncludeContext::new(&main_config);
let pattern = format!("{}/*.conf", config_dir.display());
let files = resolve_include_pattern(&pattern, &context).await.unwrap();
assert_eq!(files.len(), 3);
assert!(files[0]
.file_name()
.unwrap()
.to_str()
.unwrap()
.contains("01-first"));
assert!(files[1]
.file_name()
.unwrap()
.to_str()
.unwrap()
.contains("02-second"));
assert!(files[2]
.file_name()
.unwrap()
.to_str()
.unwrap()
.contains("03-third"));
}
#[tokio::test]
async fn test_include_with_tilde_expansion() {
let patterns = parse_include_line("Include ~/.ssh/config.d/*.conf");
assert!(patterns.is_some());
let patterns = patterns.unwrap();
assert_eq!(patterns.len(), 1);
assert!(patterns[0].starts_with("~/"));
}
}