use core::time::Duration;
use std::fmt;
use bytes::Bytes;
use ts_capabilityversion::CapabilityVersion;
use ts_control_serde::{SetDnsRequest, SetDnsResponse};
use ts_http_util::{BytesBody, ClientExt, Http2, ResponseExt, StatusCode};
use url::Url;
use crate::tokio::connect::ConnectionError;
const LOAD_BALANCER_HEADER_KEY: &str = "Ts-Lb";
const SET_DNS_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum SetDnsInternalErrorKind {
Url,
SerDe,
Http,
Utf8,
}
impl fmt::Display for SetDnsInternalErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SetDnsInternalErrorKind::Url => write!(f, "URL parsing error"),
SetDnsInternalErrorKind::SerDe => write!(f, "serialization/deserialization error"),
SetDnsInternalErrorKind::Http => write!(f, "unsuccessful HTTP request"),
SetDnsInternalErrorKind::Utf8 => write!(f, "invalid UTF8"),
}
}
}
#[derive(Debug, thiserror::Error, Clone, Eq, PartialEq)]
pub enum SetDnsError {
#[error("network error publishing dns record")]
NetworkError,
#[error("error publishing dns record: {0}")]
Internal(SetDnsInternalErrorKind),
}
impl From<url::ParseError> for SetDnsError {
fn from(error: url::ParseError) -> Self {
tracing::error!(%error, "bad URL building set-dns request");
SetDnsError::Internal(SetDnsInternalErrorKind::Url)
}
}
impl From<serde_json::Error> for SetDnsError {
fn from(error: serde_json::Error) -> Self {
tracing::error!(%error, "serde error in set-dns request");
SetDnsError::Internal(SetDnsInternalErrorKind::SerDe)
}
}
impl From<core::str::Utf8Error> for SetDnsError {
fn from(error: core::str::Utf8Error) -> Self {
tracing::error!(%error, "invalid utf8 in set-dns response");
SetDnsError::Internal(SetDnsInternalErrorKind::Utf8)
}
}
impl From<ts_http_util::Error> for SetDnsError {
fn from(error: ts_http_util::Error) -> Self {
tracing::error!(%error, "http error in set-dns request");
if crate::http_error_is_recoverable(error) {
SetDnsError::NetworkError
} else {
SetDnsError::Internal(SetDnsInternalErrorKind::Http)
}
}
}
impl From<ConnectionError> for SetDnsError {
fn from(error: ConnectionError) -> Self {
use crate::tokio::connect::InternalErrorKind as Conn;
match error {
ConnectionError::NetworkError => SetDnsError::NetworkError,
ConnectionError::Internal(k) => SetDnsError::Internal(match k {
Conn::Url => SetDnsInternalErrorKind::Url,
Conn::SerDe => SetDnsInternalErrorKind::SerDe,
Conn::Http
| Conn::MessageFormat
| Conn::Io
| Conn::ChallengeLength
| Conn::NoiseHandshake => SetDnsInternalErrorKind::Http,
}),
}
}
}
pub async fn set_dns(
config: &crate::Config,
node_keystate: &ts_keys::NodeState,
name: &str,
record_type: &str,
value: &str,
) -> Result<(), SetDnsError> {
let control_url = &config.server_url;
let rpc = async {
let http2_conn = crate::tokio::connect(
control_url,
&node_keystate.machine_keys,
config.allow_http_key_fetch,
)
.await?;
set_dns_with(
control_url,
node_keystate,
name,
record_type,
value,
&http2_conn,
)
.await
};
match tokio::time::timeout(SET_DNS_TIMEOUT, rpc).await {
Ok(result) => result,
Err(_elapsed) => {
tracing::error!(timeout = ?SET_DNS_TIMEOUT, "set-dns request timed out");
Err(SetDnsError::NetworkError)
}
}
}
pub(crate) async fn set_dns_with(
control_url: &Url,
node_keystate: &ts_keys::NodeState,
name: &str,
record_type: &str,
value: &str,
http2_conn: &Http2<BytesBody>,
) -> Result<(), SetDnsError> {
let node_public_key = node_keystate.node_keys.public;
let req = SetDnsRequest {
version: CapabilityVersion::CURRENT,
node_key: node_public_key,
name: name.to_string(),
r#type: record_type.to_string(),
value: value.to_string(),
};
let body = serde_json::to_string(&req)?;
let url = control_url.join("machine/set-dns")?;
tracing::debug!(url = %url.as_str(), name, record_type, "publishing dns record via control");
let response = http2_conn
.post(
&url,
[(
LOAD_BALANCER_HEADER_KEY.parse().unwrap(),
node_public_key.to_string().parse().unwrap(),
)],
Bytes::from(body).into(),
)
.await?;
let status = response.status();
let body = response.collect_bytes().await?;
check_set_dns_status(status, &body)
}
fn check_set_dns_status(status: StatusCode, body: &[u8]) -> Result<(), SetDnsError> {
if !status.is_success() {
let mut truncated = body.to_vec();
truncated.truncate(512);
let preview = core::str::from_utf8(&truncated).unwrap_or("<invalid utf8>");
tracing::error!(body = %preview, %status, "set-dns request failed");
return Err(SetDnsError::Internal(SetDnsInternalErrorKind::Http));
}
let body = core::str::from_utf8(body)?;
if body.trim().is_empty() {
return Ok(());
}
let _resp: SetDnsResponse = serde_json::from_str(body)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokio::connect::{ConnectionError, InternalErrorKind as ConnKind};
#[test]
fn connection_error_network_maps_to_network() {
assert_eq!(
SetDnsError::from(ConnectionError::NetworkError),
SetDnsError::NetworkError
);
}
#[test]
fn connection_error_internal_kinds_map_correctly() {
use SetDnsInternalErrorKind as Sd;
let cases = [
(ConnKind::Url, Sd::Url),
(ConnKind::SerDe, Sd::SerDe),
(ConnKind::Http, Sd::Http),
(ConnKind::MessageFormat, Sd::Http),
(ConnKind::Io, Sd::Http),
(ConnKind::ChallengeLength, Sd::Http),
(ConnKind::NoiseHandshake, Sd::Http),
];
for (conn, expected) in cases {
assert_eq!(
SetDnsError::from(ConnectionError::Internal(conn)),
SetDnsError::Internal(expected),
"ConnectionError::Internal({conn:?}) should map to Internal({expected:?})"
);
}
}
#[test]
fn serde_error_maps_to_internal_serde() {
let err = serde_json::from_str::<SetDnsResponse>("not json").unwrap_err();
assert_eq!(
SetDnsError::from(err),
SetDnsError::Internal(SetDnsInternalErrorKind::SerDe)
);
}
#[test]
fn url_parse_error_maps_to_internal_url() {
let err = Url::parse("not a url").unwrap_err();
assert_eq!(
SetDnsError::from(err),
SetDnsError::Internal(SetDnsInternalErrorKind::Url)
);
}
#[test]
fn utf8_error_maps_to_internal_utf8() {
let bytes = vec![0xffu8, 0xfe];
let err = core::str::from_utf8(&bytes).unwrap_err();
assert_eq!(
SetDnsError::from(err),
SetDnsError::Internal(SetDnsInternalErrorKind::Utf8)
);
}
#[test]
fn http_util_error_non_recoverable_maps_to_internal_http() {
let err = ts_http_util::Error::InvalidResponse;
assert_eq!(
SetDnsError::from(err),
SetDnsError::Internal(SetDnsInternalErrorKind::Http)
);
}
#[test]
fn http_util_error_recoverable_maps_to_network() {
let err = ts_http_util::Error::Io;
assert_eq!(SetDnsError::from(err), SetDnsError::NetworkError);
}
#[test]
fn check_set_dns_status_ok_empty_body() {
check_set_dns_status(StatusCode::OK, b"").unwrap();
}
#[test]
fn check_set_dns_status_ok_empty_json() {
check_set_dns_status(StatusCode::OK, b"{}").unwrap();
}
#[test]
fn check_set_dns_status_self_hosted_501_is_error() {
let err =
check_set_dns_status(StatusCode::NOT_IMPLEMENTED, b"not implemented").unwrap_err();
assert_eq!(err, SetDnsError::Internal(SetDnsInternalErrorKind::Http));
}
#[test]
fn check_set_dns_status_500_is_error() {
let err =
check_set_dns_status(StatusCode::INTERNAL_SERVER_ERROR, b"upstream boom").unwrap_err();
assert_eq!(err, SetDnsError::Internal(SetDnsInternalErrorKind::Http));
}
}