#![cfg(feature = "discovery-dns-update")]
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use base64::Engine;
use hickory_client::client::{AsyncClient, ClientHandle, Signer};
use hickory_client::error::ClientError;
use hickory_client::rr::rdata::{SRV, TXT};
use hickory_client::rr::{Name, RData, Record};
use hickory_client::tcp::TcpClientStream;
use hickory_proto::iocompat::AsyncIoTokioAsStd;
use hickory_proto::rr::dnssec::rdata::tsig::TsigAlgorithm;
use hickory_proto::rr::dnssec::tsig::TSigner;
use tokio::net::TcpStream as TokioTcpStream;
#[derive(Debug, Clone, Copy)]
pub enum TsigAlgo {
HmacSha256,
HmacSha512,
}
impl TsigAlgo {
fn to_proto(self) -> TsigAlgorithm {
match self {
TsigAlgo::HmacSha256 => TsigAlgorithm::HmacSha256,
TsigAlgo::HmacSha512 => TsigAlgorithm::HmacSha512,
}
}
}
#[derive(Debug, Clone)]
pub struct TsigKey {
pub name: String,
pub algorithm: TsigAlgo,
pub secret: Vec<u8>,
}
impl TsigKey {
pub fn from_bind_file(path: impl AsRef<std::path::Path>) -> Result<Self, std::io::Error> {
let content = std::fs::read_to_string(path)?;
Self::from_bind_str(&content)
}
pub fn from_bind_str(s: &str) -> Result<Self, std::io::Error> {
let mut name: Option<String> = None;
let mut algorithm: Option<TsigAlgo> = None;
let mut secret: Option<Vec<u8>> = None;
for line in s.lines() {
let line = line.trim();
if let Some(rest) = line.strip_prefix("key ") {
if let Some((quoted, _)) = rest.split_once(' ') {
name = Some(quoted.trim_matches('"').to_string());
}
} else if let Some(rest) = line.strip_prefix("algorithm ") {
let v = rest.trim_end_matches(';').trim();
algorithm = match v {
"hmac-sha256" => Some(TsigAlgo::HmacSha256),
"hmac-sha512" => Some(TsigAlgo::HmacSha512),
_ => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("unsupported TSIG algorithm: {v}"),
));
}
};
} else if let Some(rest) = line.strip_prefix("secret ") {
let v = rest.trim_end_matches(';').trim().trim_matches('"');
let bytes = base64::engine::general_purpose::STANDARD
.decode(v)
.map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("base64 decode of TSIG secret failed: {e}"),
)
})?;
secret = Some(bytes);
}
}
let name = name.ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "missing 'key' line")
})?;
let algorithm = algorithm.ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "missing 'algorithm'")
})?;
let secret = secret.ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "missing 'secret'")
})?;
Ok(Self {
name,
algorithm,
secret,
})
}
}
#[derive(Debug, Clone)]
pub struct DnsRegistration {
pub server: SocketAddr,
pub zone: String,
pub instance: String,
pub host: String,
pub port: u16,
pub txt: Vec<(String, String)>,
pub ttl: Duration,
pub keepalive: Duration,
pub tsig: Option<TsigKey>,
}
impl Default for DnsRegistration {
fn default() -> Self {
Self {
server: "127.0.0.1:53".parse().unwrap(),
zone: "local.".to_string(),
instance: "ioc".to_string(),
host: "localhost".to_string(),
port: 5064,
txt: Vec::new(),
ttl: Duration::from_secs(60),
keepalive: Duration::from_secs(30),
tsig: None,
}
}
}
pub struct DnsUpdater {
shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
handle: Option<tokio::task::JoinHandle<()>>,
}
impl DnsUpdater {
pub async fn register(reg: DnsRegistration) -> Result<Self, ClientError> {
send_update(®, UpdateOp::Create).await?;
tracing::info!(zone = %reg.zone, instance = %reg.instance,
server = %reg.server, "DNS UPDATE: registered");
metrics::counter!("ca_server_dns_update_register_total").increment(1);
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>();
let reg_clone = reg.clone();
let handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(reg_clone.keepalive);
interval.tick().await; loop {
tokio::select! {
_ = interval.tick() => {
if let Err(e) = send_update(®_clone, UpdateOp::Refresh).await {
tracing::warn!(error = %e, "DNS UPDATE refresh failed");
metrics::counter!("ca_server_dns_update_refresh_failures_total").increment(1);
}
}
_ = &mut shutdown_rx => break,
}
}
if let Err(e) = send_update(®_clone, UpdateOp::Delete).await {
tracing::warn!(error = %e, "DNS UPDATE delete on shutdown failed");
} else {
tracing::info!(zone = %reg_clone.zone, instance = %reg_clone.instance,
"DNS UPDATE: unregistered on shutdown");
}
});
Ok(Self {
shutdown_tx: Some(shutdown_tx),
handle: Some(handle),
})
}
}
impl Drop for DnsUpdater {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(h) = self.handle.take() {
std::mem::drop(h);
}
}
}
#[derive(Debug, Clone, Copy)]
enum UpdateOp {
Create,
Refresh,
Delete,
}
async fn send_update(reg: &DnsRegistration, op: UpdateOp) -> Result<(), ClientError> {
let zone =
Name::from_str(®.zone).map_err(|e| ClientError::from(format!("bad zone: {e}")))?;
let svc_type = parse_or_err(&format!("_epics-ca._tcp.{}", reg.zone))?;
let instance_fqdn = parse_or_err(&format!("{}._epics-ca._tcp.{}", reg.instance, reg.zone))?;
let host_fqdn = if reg.host.ends_with('.') {
parse_or_err(®.host)?
} else {
parse_or_err(&format!("{}.{}", reg.host, reg.zone))?
};
let (stream, sender) = TcpClientStream::<AsyncIoTokioAsStd<TokioTcpStream>>::new(reg.server);
let signer: Option<Arc<Signer>> = match ®.tsig {
Some(key) => {
let signer_name = parse_or_err(&key.name)?;
let tsigner = TSigner::new(
key.secret.clone(),
key.algorithm.to_proto(),
signer_name,
300, )
.map_err(|e| ClientError::from(format!("TSIG init: {e}")))?;
Some(Arc::new(Signer::TSIG(tsigner)))
}
None => None,
};
let (mut client, bg) = AsyncClient::new(stream, sender, signer).await?;
tokio::spawn(bg);
let ttl = reg.ttl.as_secs() as u32;
let srv_rdata = RData::SRV(SRV::new(0, 0, reg.port, host_fqdn));
let srv = Record::from_rdata(instance_fqdn.clone(), ttl, srv_rdata);
let txt_strs: Vec<String> = reg.txt.iter().map(|(k, v)| format!("{k}={v}")).collect();
let txt_rdata = RData::TXT(TXT::new(txt_strs));
let txt = Record::from_rdata(instance_fqdn.clone(), ttl, txt_rdata);
let ptr_rdata = RData::PTR(hickory_client::rr::rdata::PTR(instance_fqdn.clone()));
let ptr = Record::from_rdata(svc_type, ttl, ptr_rdata);
match op {
UpdateOp::Create | UpdateOp::Refresh => {
client.append(srv, zone.clone(), false).await.map_err(|e| {
tracing::debug!(error = %e, "SRV append failed");
e
})?;
client.append(txt, zone.clone(), false).await?;
client.append(ptr, zone, false).await?;
}
UpdateOp::Delete => {
let _ = client.delete_by_rdata(srv.clone(), zone.clone()).await;
let _ = client.delete_by_rdata(txt, zone.clone()).await;
let _ = client.delete_by_rdata(ptr, zone).await;
}
}
Ok(())
}
fn parse_or_err(s: &str) -> Result<Name, ClientError> {
Name::from_str(s).map_err(|e| ClientError::from(format!("bad name {s:?}: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_bind_key_file() {
let content = r#"
key "epics-key" {
algorithm hmac-sha256;
secret "dGVzdC1zZWNyZXQ=";
};
"#;
let key = TsigKey::from_bind_str(content).expect("parse");
assert_eq!(key.name, "epics-key");
assert!(matches!(key.algorithm, TsigAlgo::HmacSha256));
assert_eq!(key.secret, b"test-secret");
}
#[test]
fn parse_bind_key_rejects_bad_algo() {
let content = r#"
key "k" { algorithm foo-bar; secret "AAAA"; };
"#;
let err = TsigKey::from_bind_str(content).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
}