use core::time::Duration;
use std::{
fmt,
time::{SystemTime, UNIX_EPOCH},
};
use bytes::Bytes;
use chrono::{DateTime, Utc};
use ts_capabilityversion::CapabilityVersion;
use ts_control_serde::{HostInfo, RegisterRequest};
use ts_http_util::{BytesBody, ClientExt, Http2, ResponseExt, StatusCode};
use url::Url;
use crate::tokio::connect::ConnectionError;
const LOAD_BALANCER_HEADER_KEY: &str = "Ts-Lb";
const LOGOUT_TIMEOUT: Duration = Duration::from_secs(30);
const EXPIRY_BACKDATE_SECS: u64 = 10;
fn past_expiry() -> DateTime<Utc> {
let secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
.saturating_sub(EXPIRY_BACKDATE_SECS);
DateTime::<Utc>::from_timestamp(secs as i64, 0).unwrap_or(DateTime::<Utc>::UNIX_EPOCH)
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum LogoutInternalErrorKind {
Url,
SerDe,
Http,
}
impl fmt::Display for LogoutInternalErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LogoutInternalErrorKind::Url => write!(f, "URL parsing error"),
LogoutInternalErrorKind::SerDe => write!(f, "serialization error"),
LogoutInternalErrorKind::Http => write!(f, "unsuccessful HTTP request"),
}
}
}
#[derive(Debug, thiserror::Error, Clone, Eq, PartialEq)]
pub enum LogoutError {
#[error("network error logging out")]
NetworkError,
#[error("error logging out: {0}")]
Internal(LogoutInternalErrorKind),
}
impl From<url::ParseError> for LogoutError {
fn from(error: url::ParseError) -> Self {
tracing::error!(%error, "bad URL building logout request");
LogoutError::Internal(LogoutInternalErrorKind::Url)
}
}
impl From<serde_json::Error> for LogoutError {
fn from(error: serde_json::Error) -> Self {
tracing::error!(%error, "serde error in logout request");
LogoutError::Internal(LogoutInternalErrorKind::SerDe)
}
}
impl From<ts_http_util::Error> for LogoutError {
fn from(error: ts_http_util::Error) -> Self {
tracing::error!(%error, "http error in logout request");
if crate::http_error_is_recoverable(error) {
LogoutError::NetworkError
} else {
LogoutError::Internal(LogoutInternalErrorKind::Http)
}
}
}
impl From<ConnectionError> for LogoutError {
fn from(error: ConnectionError) -> Self {
use crate::tokio::connect::InternalErrorKind as Conn;
match error {
ConnectionError::NetworkError => LogoutError::NetworkError,
ConnectionError::Internal(k) => LogoutError::Internal(match k {
Conn::Url => LogoutInternalErrorKind::Url,
Conn::SerDe => LogoutInternalErrorKind::SerDe,
Conn::Http
| Conn::MessageFormat
| Conn::Io
| Conn::ChallengeLength
| Conn::NoiseHandshake => LogoutInternalErrorKind::Http,
}),
}
}
}
pub async fn logout(
config: &crate::Config,
node_keystate: &ts_keys::NodeState,
) -> Result<(), LogoutError> {
let control_url = &config.server_url;
let rpc = async {
let http2_conn = crate::tokio::connect(
control_url,
&node_keystate.machine_keys,
config.allow_http_key_fetch,
)
.await?;
logout_with(config, control_url, node_keystate, &http2_conn).await
};
match tokio::time::timeout(LOGOUT_TIMEOUT, rpc).await {
Ok(result) => result,
Err(_elapsed) => {
tracing::error!(timeout = ?LOGOUT_TIMEOUT, "logout request timed out");
Err(LogoutError::NetworkError)
}
}
}
pub(crate) async fn logout_with(
config: &crate::Config,
control_url: &Url,
node_keystate: &ts_keys::NodeState,
http2_conn: &Http2<BytesBody>,
) -> Result<(), LogoutError> {
let node_public_key = node_keystate.node_keys.public;
let logout_req = RegisterRequest {
version: CapabilityVersion::CURRENT,
node_key: node_public_key,
nl_key: Some(node_keystate.network_lock_keys.public),
expiry: Some(past_expiry()),
hostinfo: HostInfo {
hostname: config.hostname.as_deref(),
app: &config.format_client_name(),
ipn_version: crate::PKG_VERSION,
..Default::default()
},
..Default::default()
};
let body = serde_json::to_string(&logout_req)?;
let url = control_url.join("machine/register")?;
tracing::debug!(url = %url.as_str(), "logging out (expiring node key) via control");
let response = http2_conn
.post(
&url,
[(
LOAD_BALANCER_HEADER_KEY.parse().unwrap(),
node_public_key.to_string().parse().unwrap(),
)],
Bytes::from(body).into(),
)
.await?;
let status = response.status();
let body = response.collect_bytes().await.unwrap_or_default();
classify_logout_response(status, &body)
}
fn classify_logout_response(status: StatusCode, body: &[u8]) -> Result<(), LogoutError> {
if !status.is_success() {
tracing::error!(%status, "logout request failed");
let mut truncated = body.to_vec();
truncated.truncate(512);
let preview = core::str::from_utf8(&truncated).unwrap_or("<invalid utf8>");
tracing::debug!(body = %preview, %status, "logout failure response body");
return Err(LogoutError::Internal(LogoutInternalErrorKind::Http));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokio::connect::{ConnectionError, InternalErrorKind as ConnKind};
#[test]
fn connection_error_network_maps_to_network() {
assert_eq!(
LogoutError::from(ConnectionError::NetworkError),
LogoutError::NetworkError
);
}
#[test]
fn connection_error_internal_kinds_map_correctly() {
use LogoutInternalErrorKind as L;
let cases = [
(ConnKind::Url, L::Url),
(ConnKind::SerDe, L::SerDe),
(ConnKind::Http, L::Http),
(ConnKind::MessageFormat, L::Http),
(ConnKind::Io, L::Http),
(ConnKind::ChallengeLength, L::Http),
(ConnKind::NoiseHandshake, L::Http),
];
for (conn, expected) in cases {
assert_eq!(
LogoutError::from(ConnectionError::Internal(conn)),
LogoutError::Internal(expected),
"ConnectionError::Internal({conn:?}) should map to Internal({expected:?})"
);
}
}
#[test]
fn http_util_error_recoverable_maps_to_network() {
assert_eq!(
LogoutError::from(ts_http_util::Error::Io),
LogoutError::NetworkError
);
}
#[test]
fn http_util_error_non_recoverable_maps_to_internal_http() {
assert_eq!(
LogoutError::from(ts_http_util::Error::InvalidResponse),
LogoutError::Internal(LogoutInternalErrorKind::Http)
);
}
#[test]
fn classify_logout_response_2xx_is_ok() {
assert!(classify_logout_response(StatusCode::OK, b"{}").is_ok());
assert!(classify_logout_response(StatusCode::NO_CONTENT, b"").is_ok());
}
#[test]
fn classify_logout_response_non_success_is_http() {
let err = classify_logout_response(StatusCode::INTERNAL_SERVER_ERROR, b"boom").unwrap_err();
assert_eq!(err, LogoutError::Internal(LogoutInternalErrorKind::Http));
}
#[test]
fn classify_logout_response_invalid_utf8_body_still_classifies() {
let err = classify_logout_response(StatusCode::BAD_GATEWAY, &[0xff, 0xfe]).unwrap_err();
assert_eq!(err, LogoutError::Internal(LogoutInternalErrorKind::Http));
}
#[test]
fn expiry_is_in_the_past() {
let now_secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
let expiry = past_expiry();
assert!(
expiry.timestamp() < now_secs,
"logout expiry ({}) must be before now ({now_secs})",
expiry.timestamp()
);
}
}