use std::path::PathBuf;
use super::error::{Error, Result};
fn sysctl_path(key: &str) -> Result<PathBuf> {
validate_key(key)?;
let relative = key.replace('.', "/");
Ok(PathBuf::from("/proc/sys").join(relative))
}
fn validate_key(key: &str) -> Result<()> {
if key.is_empty() {
return Err(Error::InvalidMessage("sysctl key cannot be empty".into()));
}
if key.contains("..") || key.starts_with('/') || key.contains('\0') {
return Err(Error::InvalidMessage(format!(
"invalid sysctl key: {}",
key
)));
}
Ok(())
}
pub fn get(key: &str) -> Result<String> {
let path = sysctl_path(key)?;
let contents = std::fs::read_to_string(&path).map_err(|e| match e.kind() {
std::io::ErrorKind::NotFound => {
Error::InvalidMessage(format!("sysctl key not found: {}", key))
}
std::io::ErrorKind::PermissionDenied => Error::Io(e),
_ => Error::Io(e),
})?;
Ok(contents.trim_end().to_string())
}
pub fn set(key: &str, value: &str) -> Result<()> {
let path = sysctl_path(key)?;
std::fs::write(&path, value).map_err(|e| match e.kind() {
std::io::ErrorKind::NotFound => {
Error::InvalidMessage(format!("sysctl key not found: {}", key))
}
std::io::ErrorKind::PermissionDenied => Error::Io(e),
_ => Error::Io(e),
})?;
Ok(())
}
pub fn set_many(entries: &[(&str, &str)]) -> Result<()> {
for &(key, value) in entries {
set(key, value)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sysctl_path_conversion() {
assert_eq!(
sysctl_path("net.ipv4.ip_forward").unwrap(),
PathBuf::from("/proc/sys/net/ipv4/ip_forward")
);
assert_eq!(
sysctl_path("net.ipv6.conf.all.forwarding").unwrap(),
PathBuf::from("/proc/sys/net/ipv6/conf/all/forwarding")
);
}
#[test]
fn test_validate_key_rejects_traversal() {
assert!(validate_key("net..ipv4").is_err());
assert!(validate_key("/etc/passwd").is_err());
assert!(validate_key("").is_err());
assert!(validate_key("net.ipv4\0.ip_forward").is_err());
}
#[test]
fn test_validate_key_accepts_valid() {
assert!(validate_key("net.ipv4.ip_forward").is_ok());
assert!(validate_key("net.ipv6.conf.all.forwarding").is_ok());
assert!(validate_key("kernel.hostname").is_ok());
}
}