use anyhow::{Context, Result};
use std::ffi::OsString;
use std::io::Write;
use std::path::Path;
use std::process::Stdio;
use std::time::Instant;
use tokio::io::AsyncWriteExt;
use tracing::{debug, error, info};
use crate::metrics;
const MAX_TSIG_KEY_NAME_LEN: usize = 253;
const ALLOWED_TSIG_ALGORITHMS: &[&str] = &[
"hmac-md5",
"hmac-sha1",
"hmac-sha224",
"hmac-sha256",
"hmac-sha384",
"hmac-sha512",
];
#[derive(Clone)]
pub struct NsupdateExecutor {
tsig_key_name: Option<String>,
tsig_algorithm: Option<String>,
tsig_secret: Option<String>,
server: String,
port: u16,
use_tcp: bool,
}
impl NsupdateExecutor {
pub fn new(
server: String,
port: u16,
tsig_key_name: Option<String>,
tsig_algorithm: Option<String>,
tsig_secret: Option<String>,
) -> Result<Self> {
info!(
"Creating nsupdate executor for {}:{} with TSIG: {}",
server,
port,
tsig_key_name.is_some()
);
let use_tcp = std::env::var("NSUPDATE_TCP")
.map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes"))
.unwrap_or(false);
Ok(Self {
tsig_key_name,
tsig_algorithm,
tsig_secret,
server,
port,
use_tcp,
})
}
pub(crate) fn create_tsig_key_file(&self) -> Result<Option<tempfile::NamedTempFile>> {
let (Some(key_name), Some(algorithm), Some(secret)) = (
self.tsig_key_name.as_deref(),
self.tsig_algorithm.as_deref(),
self.tsig_secret.as_deref(),
) else {
return Ok(None);
};
let content = build_tsig_key_file_content(key_name, algorithm, secret)?;
let mut keyfile =
tempfile::NamedTempFile::new().context("Failed to create temporary TSIG key file")?;
keyfile
.write_all(content.as_bytes())
.context("Failed to write TSIG key file")?;
keyfile.flush().context("Failed to flush TSIG key file")?;
debug!("Using TSIG authentication with key: {}", key_name);
Ok(Some(keyfile))
}
async fn execute(&self, commands: &str) -> Result<String> {
let start = Instant::now();
debug!("Executing nsupdate commands:\n{}", commands);
let keyfile = self.create_tsig_key_file()?;
let mut cmd = tokio::process::Command::new("nsupdate");
cmd.args(build_nsupdate_args(
self.use_tcp,
keyfile.as_ref().map(tempfile::NamedTempFile::path),
));
cmd.env_clear().envs(minimal_child_env());
cmd.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
let mut child = cmd.spawn().context("Failed to spawn nsupdate process")?;
if let Some(mut stdin) = child.stdin.take() {
stdin
.write_all(commands.as_bytes())
.await
.context("Failed to write to nsupdate stdin")?;
stdin.flush().await.context("Failed to flush stdin")?;
}
let output = child
.wait_with_output()
.await
.context("Failed to wait for nsupdate")?;
drop(keyfile);
let duration = start.elapsed().as_secs_f64();
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let error_msg = parse_nsupdate_error(&stderr);
error!("nsupdate failed: {}", error_msg);
metrics::record_nsupdate_command("update", false, duration);
return Err(anyhow::anyhow!("nsupdate failed: {}", error_msg));
}
metrics::record_nsupdate_command("update", true, duration);
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
debug!("nsupdate completed successfully in {:.3}s", duration);
Ok(stdout)
}
pub async fn add_record(
&self,
zone: &str,
name: &str,
ttl: u32,
record_type: &str,
value: &str,
) -> Result<String> {
info!(
"Adding {} record: {} -> {} (TTL: {})",
record_type, name, value, ttl
);
reject_injection_chars("zone", zone)?;
reject_injection_chars("name", name)?;
reject_injection_chars("value", value)?;
let commands = format!(
"server {} {}\nzone {}\nupdate add {} {} IN {} {}\nsend\n",
self.server, self.port, zone, name, ttl, record_type, value
);
self.execute(&commands).await
}
pub async fn remove_record(
&self,
zone: &str,
name: &str,
record_type: &str,
value: &str,
) -> Result<String> {
info!(
"Removing {} record: {} {} {}",
record_type,
name,
if value.is_empty() { "(all)" } else { value },
""
);
reject_injection_chars("zone", zone)?;
reject_injection_chars("name", name)?;
reject_injection_chars("value", value)?;
let delete_cmd = if value.is_empty() {
format!("update delete {} {}", name, record_type)
} else {
format!("update delete {} {} {}", name, record_type, value)
};
let commands = format!(
"server {} {}\nzone {}\n{}\nsend\n",
self.server, self.port, zone, delete_cmd
);
self.execute(&commands).await
}
pub async fn update_record(
&self,
zone: &str,
name: &str,
ttl: u32,
record_type: &str,
old_value: &str,
new_value: &str,
) -> Result<String> {
info!(
"Updating {} record: {} from {} to {} (TTL: {})",
record_type, name, old_value, new_value, ttl
);
reject_injection_chars("zone", zone)?;
reject_injection_chars("name", name)?;
reject_injection_chars("old_value", old_value)?;
reject_injection_chars("new_value", new_value)?;
let commands = format!(
"server {} {}\nzone {}\nupdate delete {} {} {}\nupdate add {} {} IN {} {}\nsend\n",
self.server,
self.port,
zone,
name,
record_type,
old_value,
name,
ttl,
record_type,
new_value
);
self.execute(&commands).await
}
}
const CHILD_ENV_ALLOWLIST: &[&str] = &["PATH"];
pub(crate) fn minimal_child_env() -> Vec<(OsString, OsString)> {
CHILD_ENV_ALLOWLIST
.iter()
.filter_map(|name| std::env::var_os(name).map(|value| (OsString::from(name), value)))
.collect()
}
pub(crate) fn build_nsupdate_args(use_tcp: bool, keyfile: Option<&Path>) -> Vec<OsString> {
let mut args: Vec<OsString> = Vec::new();
if use_tcp {
args.push("-v".into());
}
if let Some(path) = keyfile {
args.push("-k".into());
args.push(path.as_os_str().to_owned());
}
args
}
pub(crate) fn build_tsig_key_file_content(
key_name: &str,
algorithm: &str,
secret: &str,
) -> Result<String> {
if key_name.is_empty() || key_name.len() > MAX_TSIG_KEY_NAME_LEN {
return Err(anyhow::anyhow!(
"TSIG key name must be 1-{} characters",
MAX_TSIG_KEY_NAME_LEN
));
}
if let Some(bad) = key_name
.chars()
.find(|c| !(c.is_ascii_alphanumeric() || matches!(c, '.' | '-' | '_')))
{
return Err(anyhow::anyhow!(
"TSIG key name contains invalid character: {:?}",
bad
));
}
let mut normalized_algorithm = algorithm.to_ascii_lowercase();
if !normalized_algorithm.starts_with("hmac-") {
normalized_algorithm = format!("hmac-{}", normalized_algorithm);
}
if !ALLOWED_TSIG_ALGORITHMS.contains(&normalized_algorithm.as_str()) {
return Err(anyhow::anyhow!(
"Unsupported TSIG algorithm: {} (allowed: {})",
algorithm,
ALLOWED_TSIG_ALGORITHMS.join(", ")
));
}
if secret.is_empty() {
return Err(anyhow::anyhow!("TSIG secret cannot be empty"));
}
if let Some(bad) = secret
.chars()
.find(|c| !(c.is_ascii_alphanumeric() || matches!(c, '+' | '/' | '=')))
{
return Err(anyhow::anyhow!(
"TSIG secret contains non-base64 character: {:?}",
bad
));
}
Ok(format!(
"key \"{key_name}\" {{\n algorithm {normalized_algorithm};\n secret \"{secret}\";\n}};\n"
))
}
pub(crate) fn reject_injection_chars(field: &str, value: &str) -> Result<()> {
if let Some(bad) = value.chars().find(|c| c.is_control()) {
return Err(anyhow::anyhow!(
"{} contains illegal control character {:?}",
field,
bad
));
}
Ok(())
}
fn parse_nsupdate_error(stderr: &str) -> String {
if stderr.contains("REFUSED") {
"Zone refused the update (check allow-update configuration)".to_string()
} else if stderr.contains("NOTAUTH") {
"Not authorized (check TSIG key configuration)".to_string()
} else if stderr.contains("SERVFAIL") {
"Server failure (check BIND9 logs)".to_string()
} else if stderr.contains("NOTZONE") {
"Zone not found on server".to_string()
} else if stderr.contains("FORMERR") {
"Format error (check record syntax)".to_string()
} else if stderr.contains("NXDOMAIN") {
"Domain name does not exist".to_string()
} else {
stderr.trim().to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_nsupdate_error_refused() {
let stderr = "update failed: REFUSED\n";
assert_eq!(
parse_nsupdate_error(stderr),
"Zone refused the update (check allow-update configuration)"
);
}
#[test]
fn test_parse_nsupdate_error_notauth() {
let stderr = "update failed: NOTAUTH\n";
assert_eq!(
parse_nsupdate_error(stderr),
"Not authorized (check TSIG key configuration)"
);
}
#[test]
fn test_parse_nsupdate_error_servfail() {
let stderr = "update failed: SERVFAIL\n";
assert_eq!(
parse_nsupdate_error(stderr),
"Server failure (check BIND9 logs)"
);
}
#[test]
fn test_parse_nsupdate_error_notzone() {
let stderr = "update failed: NOTZONE\n";
assert_eq!(parse_nsupdate_error(stderr), "Zone not found on server");
}
#[test]
fn test_parse_nsupdate_error_unknown() {
let stderr = "some other error\n";
assert_eq!(parse_nsupdate_error(stderr), "some other error");
}
#[test]
fn test_new_executor_with_tsig() {
let executor = NsupdateExecutor::new(
"127.0.0.1".to_string(),
53,
Some("test-key".to_string()),
Some("HMAC-SHA256".to_string()),
Some("dGVzdA==".to_string()),
);
assert!(executor.is_ok());
}
#[test]
fn test_new_executor_without_tsig() {
let executor = NsupdateExecutor::new("127.0.0.1".to_string(), 53, None, None, None);
assert!(executor.is_ok());
}
}