use super::types::ServerFileConfig;
use anyhow::{Context, Result};
use std::path::{Path, PathBuf};
pub fn load_config(config_path: Option<&Path>) -> Result<ServerFileConfig> {
let mut config = ServerFileConfig::default();
if let Some(path) = config_path {
config = load_config_file(path).context("Failed to load configuration file")?;
tracing::info!(path = %path.display(), "Loaded configuration from file");
} else {
for path in default_config_paths() {
if path.exists() {
config = load_config_file(&path).context("Failed to load configuration file")?;
tracing::info!(path = %path.display(), "Loaded configuration from file");
break;
}
}
}
config = apply_env_overrides(config)?;
validate_config(&config)?;
Ok(config)
}
pub fn generate_config_template() -> String {
let config = ServerFileConfig::default();
let mut yaml = String::new();
yaml.push_str("# bssh-server configuration file\n");
yaml.push_str("#\n");
yaml.push_str(
"# This is a comprehensive configuration template showing all available options.\n",
);
yaml.push_str("# Uncomment and modify options as needed.\n");
yaml.push_str("#\n");
yaml.push_str("# Configuration hierarchy (highest to lowest precedence):\n");
yaml.push_str("# 1. CLI arguments\n");
yaml.push_str("# 2. Environment variables (BSSH_* prefix)\n");
yaml.push_str("# 3. This configuration file\n");
yaml.push_str("# 4. Default values\n\n");
yaml.push_str(&serde_yaml::to_string(&config).unwrap_or_default());
yaml
}
fn load_config_file(path: &Path) -> Result<ServerFileConfig> {
#[cfg(unix)]
check_config_file_permissions(path)?;
let content =
std::fs::read_to_string(path).context(format!("Failed to read {}", path.display()))?;
serde_yaml::from_str(&content).context(format!("Failed to parse {}", path.display()))
}
#[cfg(unix)]
fn check_config_file_permissions(path: &Path) -> Result<()> {
use std::os::unix::fs::PermissionsExt;
let metadata = std::fs::metadata(path)
.context(format!("Failed to read metadata for {}", path.display()))?;
let permissions = metadata.permissions();
let mode = permissions.mode();
if mode & 0o077 != 0 {
tracing::warn!(
path = %path.display(),
mode = format!("{:o}", mode & 0o777),
"Configuration file is readable by group or others. \
Consider using 'chmod 600 {}' to restrict access.",
path.display()
);
}
Ok(())
}
fn default_config_paths() -> Vec<PathBuf> {
let mut paths = Vec::new();
paths.push(PathBuf::from("./bssh-server.yaml"));
paths.push(PathBuf::from("/etc/bssh/server.yaml"));
if let Some(config_dir) = dirs::config_dir() {
paths.push(config_dir.join("bssh/server.yaml"));
}
paths
}
fn apply_env_overrides(mut config: ServerFileConfig) -> Result<ServerFileConfig> {
if let Ok(port_str) = std::env::var("BSSH_PORT") {
config.server.port = port_str
.parse()
.context(format!("Invalid BSSH_PORT value: {port_str}"))?;
tracing::debug!(port = config.server.port, "Applied BSSH_PORT override");
}
if let Ok(addr) = std::env::var("BSSH_BIND_ADDRESS") {
config.server.bind_address = addr.clone();
tracing::debug!(address = %addr, "Applied BSSH_BIND_ADDRESS override");
}
if let Ok(keys) = std::env::var("BSSH_HOST_KEY") {
config.server.host_keys = keys.split(',').map(|s| PathBuf::from(s.trim())).collect();
tracing::debug!(
key_count = config.server.host_keys.len(),
"Applied BSSH_HOST_KEY override"
);
}
if let Ok(max_str) = std::env::var("BSSH_MAX_CONNECTIONS") {
config.server.max_connections = max_str
.parse()
.context(format!("Invalid BSSH_MAX_CONNECTIONS value: {max_str}"))?;
tracing::debug!(
max = config.server.max_connections,
"Applied BSSH_MAX_CONNECTIONS override"
);
}
if let Ok(interval_str) = std::env::var("BSSH_KEEPALIVE_INTERVAL") {
config.server.keepalive_interval = interval_str.parse().context(format!(
"Invalid BSSH_KEEPALIVE_INTERVAL value: {interval_str}"
))?;
tracing::debug!(
interval = config.server.keepalive_interval,
"Applied BSSH_KEEPALIVE_INTERVAL override"
);
}
if let Ok(methods_str) = std::env::var("BSSH_AUTH_METHODS") {
use super::types::AuthMethod;
let mut methods = Vec::new();
for method in methods_str.split(',') {
let method = method.trim().to_lowercase();
match method.as_str() {
"publickey" => methods.push(AuthMethod::PublicKey),
"password" => methods.push(AuthMethod::Password),
_ => {
anyhow::bail!("Unknown auth method in BSSH_AUTH_METHODS: {}", method);
}
}
}
config.auth.methods = methods;
tracing::debug!(
methods = ?config.auth.methods,
"Applied BSSH_AUTH_METHODS override"
);
}
if let Ok(dir) = std::env::var("BSSH_AUTHORIZED_KEYS_DIR") {
config.auth.publickey.authorized_keys_dir = Some(PathBuf::from(dir.clone()));
config.auth.publickey.authorized_keys_pattern = None;
tracing::debug!(dir = %dir, "Applied BSSH_AUTHORIZED_KEYS_DIR override");
}
if let Ok(pattern) = std::env::var("BSSH_AUTHORIZED_KEYS_PATTERN") {
config.auth.publickey.authorized_keys_pattern = Some(pattern.clone());
config.auth.publickey.authorized_keys_dir = None;
tracing::debug!(
pattern = %pattern,
"Applied BSSH_AUTHORIZED_KEYS_PATTERN override"
);
}
if let Ok(shell) = std::env::var("BSSH_SHELL") {
config.shell.default = PathBuf::from(shell.clone());
tracing::debug!(shell = %shell, "Applied BSSH_SHELL override");
}
if let Ok(timeout_str) = std::env::var("BSSH_COMMAND_TIMEOUT") {
config.shell.command_timeout = timeout_str
.parse()
.context(format!("Invalid BSSH_COMMAND_TIMEOUT value: {timeout_str}"))?;
tracing::debug!(
timeout = config.shell.command_timeout,
"Applied BSSH_COMMAND_TIMEOUT override"
);
}
Ok(config)
}
fn validate_config(config: &ServerFileConfig) -> Result<()> {
if config.server.host_keys.is_empty() {
anyhow::bail!(
"At least one host key must be configured (server.host_keys or BSSH_HOST_KEY)"
);
}
for key_path in &config.server.host_keys {
if !key_path.exists() {
anyhow::bail!("Host key file not found: {}", key_path.display());
}
}
if config.auth.methods.is_empty() {
anyhow::bail!("At least one authentication method must be enabled (auth.methods)");
}
config
.server
.bind_address
.parse::<std::net::IpAddr>()
.context(format!(
"Invalid bind_address: {}",
config.server.bind_address
))?;
if let Some(ref pattern) = config.auth.publickey.authorized_keys_pattern {
if pattern.contains("..") {
anyhow::bail!(
"authorized_keys_pattern contains '..' which could lead to path traversal: {}",
pattern
);
}
if pattern.contains("{user}") {
let without_placeholder = pattern.replace("{user}", "");
if !without_placeholder.starts_with('/') && !without_placeholder.starts_with("./") {
anyhow::bail!(
"authorized_keys_pattern must use absolute paths: {}",
pattern
);
}
} else if !pattern.starts_with('/') {
anyhow::bail!(
"authorized_keys_pattern must use absolute paths: {}",
pattern
);
}
}
if !config.shell.default.exists() {
anyhow::bail!(
"Default shell does not exist: {}",
config.shell.default.display()
);
}
for cidr in &config.security.allowed_ips {
cidr.parse::<ipnetwork::IpNetwork>()
.context(format!("Invalid CIDR notation in allowed_ips: {cidr}"))?;
}
for cidr in &config.security.blocked_ips {
cidr.parse::<ipnetwork::IpNetwork>()
.context(format!("Invalid CIDR notation in blocked_ips: {cidr}"))?;
}
if config.server.port == 0 {
anyhow::bail!("Server port cannot be 0");
}
if config.server.max_connections == 0 {
anyhow::bail!("max_connections must be greater than 0");
}
tracing::info!("Configuration validation passed");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_generate_config_template() {
let template = generate_config_template();
assert!(template.contains("bssh-server configuration"));
assert!(template.contains("server:"));
assert!(template.contains("auth:"));
assert!(template.contains("shell:"));
let parsed: Result<ServerFileConfig, _> = serde_yaml::from_str(&template);
assert!(parsed.is_ok());
}
#[test]
fn test_load_config_from_file() {
let yaml_content = r#"
server:
port: 2223
bind_address: "127.0.0.1"
host_keys:
- /tmp/test_key
auth:
methods:
- publickey
"#;
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml_content.as_bytes()).unwrap();
temp_file.flush().unwrap();
let config = load_config_file(temp_file.path()).unwrap();
assert_eq!(config.server.port, 2223);
assert_eq!(config.server.bind_address, "127.0.0.1");
assert_eq!(config.server.host_keys.len(), 1);
}
#[test]
#[serial_test::serial]
fn test_env_override_port() {
std::env::remove_var("BSSH_PORT");
std::env::set_var("BSSH_PORT", "3333");
let config = apply_env_overrides(ServerFileConfig::default()).unwrap();
assert_eq!(config.server.port, 3333);
std::env::remove_var("BSSH_PORT");
}
#[test]
#[serial_test::serial]
fn test_env_override_bind_address() {
std::env::remove_var("BSSH_PORT");
std::env::set_var("BSSH_BIND_ADDRESS", "192.168.1.1");
let config = apply_env_overrides(ServerFileConfig::default()).unwrap();
assert_eq!(config.server.bind_address, "192.168.1.1");
std::env::remove_var("BSSH_BIND_ADDRESS");
}
#[test]
#[serial_test::serial]
fn test_env_override_host_keys() {
std::env::remove_var("BSSH_PORT");
std::env::set_var("BSSH_HOST_KEY", "/key1,/key2,/key3");
let config = apply_env_overrides(ServerFileConfig::default()).unwrap();
assert_eq!(config.server.host_keys.len(), 3);
assert_eq!(config.server.host_keys[0], PathBuf::from("/key1"));
std::env::remove_var("BSSH_HOST_KEY");
}
#[test]
#[serial_test::serial]
fn test_env_override_auth_methods() {
std::env::remove_var("BSSH_PORT");
std::env::set_var("BSSH_AUTH_METHODS", "publickey,password");
let config = apply_env_overrides(ServerFileConfig::default()).unwrap();
assert_eq!(config.auth.methods.len(), 2);
std::env::remove_var("BSSH_AUTH_METHODS");
}
#[test]
#[serial_test::serial]
fn test_env_override_invalid_port() {
std::env::remove_var("BSSH_PORT");
std::env::set_var("BSSH_PORT", "invalid");
let result = apply_env_overrides(ServerFileConfig::default());
assert!(result.is_err());
std::env::remove_var("BSSH_PORT");
}
#[test]
fn test_validate_config_no_host_keys() {
let mut config = ServerFileConfig::default();
config.server.host_keys.clear();
let result = validate_config(&config);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("At least one host key"));
}
#[test]
fn test_validate_config_no_auth_methods() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"fake host key").unwrap();
temp_file.flush().unwrap();
let mut config = ServerFileConfig::default();
config.server.host_keys.push(temp_file.path().to_path_buf());
config.auth.methods.clear();
let result = validate_config(&config);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("authentication method"));
}
#[test]
fn test_validate_config_invalid_cidr() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"fake host key").unwrap();
temp_file.flush().unwrap();
let mut config = ServerFileConfig::default();
config.server.host_keys.push(temp_file.path().to_path_buf());
config.security.allowed_ips.push("invalid-cidr".to_string());
let result = validate_config(&config);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("CIDR"));
}
#[test]
fn test_validate_config_valid_cidr() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"fake host key").unwrap();
temp_file.flush().unwrap();
let mut config = ServerFileConfig::default();
config.server.host_keys.push(temp_file.path().to_path_buf());
config
.security
.allowed_ips
.push("192.168.1.0/24".to_string());
config.security.blocked_ips.push("10.0.0.0/8".to_string());
let result = validate_config(&config);
assert!(result.is_ok());
}
#[test]
fn test_validate_config_zero_port() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"fake host key").unwrap();
temp_file.flush().unwrap();
let mut config = ServerFileConfig::default();
config.server.host_keys.push(temp_file.path().to_path_buf());
config.server.port = 0;
let result = validate_config(&config);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("port cannot be 0"));
}
#[test]
fn test_validate_config_zero_max_connections() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"fake host key").unwrap();
temp_file.flush().unwrap();
let mut config = ServerFileConfig::default();
config.server.host_keys.push(temp_file.path().to_path_buf());
config.server.max_connections = 0;
let result = validate_config(&config);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("max_connections must be greater than 0"));
}
#[test]
fn test_validate_config_invalid_bind_address() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"fake host key").unwrap();
temp_file.flush().unwrap();
let mut config = ServerFileConfig::default();
config.server.host_keys.push(temp_file.path().to_path_buf());
config.server.bind_address = "not-an-ip-address".to_string();
let result = validate_config(&config);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Invalid bind_address"));
}
#[test]
fn test_validate_config_valid_bind_address() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"fake host key").unwrap();
temp_file.flush().unwrap();
let mut config = ServerFileConfig::default();
config.server.host_keys.push(temp_file.path().to_path_buf());
config.server.bind_address = "127.0.0.1".to_string();
assert!(validate_config(&config).is_ok());
config.server.bind_address = "::1".to_string();
assert!(validate_config(&config).is_ok());
config.server.bind_address = "0.0.0.0".to_string();
assert!(validate_config(&config).is_ok());
}
#[test]
fn test_validate_config_authorized_keys_pattern_path_traversal() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"fake host key").unwrap();
temp_file.flush().unwrap();
let mut config = ServerFileConfig::default();
config.server.host_keys.push(temp_file.path().to_path_buf());
config.auth.publickey.authorized_keys_pattern = Some("/home/../etc/passwd".to_string());
let result = validate_config(&config);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("path traversal"));
}
#[test]
fn test_validate_config_authorized_keys_pattern_relative_path() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"fake host key").unwrap();
temp_file.flush().unwrap();
let mut config = ServerFileConfig::default();
config.server.host_keys.push(temp_file.path().to_path_buf());
config.auth.publickey.authorized_keys_pattern = Some("relative/path".to_string());
let result = validate_config(&config);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("absolute paths"));
}
#[test]
fn test_validate_config_authorized_keys_pattern_valid() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"fake host key").unwrap();
temp_file.flush().unwrap();
let mut config = ServerFileConfig::default();
config.server.host_keys.push(temp_file.path().to_path_buf());
config.auth.publickey.authorized_keys_pattern =
Some("/home/{user}/.ssh/authorized_keys".to_string());
assert!(validate_config(&config).is_ok());
config.auth.publickey.authorized_keys_pattern =
Some("/etc/bssh/authorized_keys".to_string());
assert!(validate_config(&config).is_ok());
}
#[test]
fn test_validate_config_shell_not_exists() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"fake host key").unwrap();
temp_file.flush().unwrap();
let mut config = ServerFileConfig::default();
config.server.host_keys.push(temp_file.path().to_path_buf());
config.shell.default = PathBuf::from("/nonexistent/shell");
let result = validate_config(&config);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Default shell does not exist"));
}
#[cfg(unix)]
#[test]
fn test_config_file_permissions_warning() {
use std::os::unix::fs::PermissionsExt;
let yaml_content = r#"
server:
port: 2222
host_keys:
- /tmp/test_key
"#;
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml_content.as_bytes()).unwrap();
temp_file.flush().unwrap();
let mut permissions = temp_file.as_file().metadata().unwrap().permissions();
permissions.set_mode(0o644);
std::fs::set_permissions(temp_file.path(), permissions).unwrap();
let result = check_config_file_permissions(temp_file.path());
assert!(result.is_ok());
let mut permissions = temp_file.as_file().metadata().unwrap().permissions();
permissions.set_mode(0o600);
std::fs::set_permissions(temp_file.path(), permissions).unwrap();
let result = check_config_file_permissions(temp_file.path());
assert!(result.is_ok());
}
}