use alloc::string::String;
use core::{fmt, str::FromStr};
use bytes::Bytes;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use ts_capabilityversion::CapabilityVersion;
use ts_http_util::{BytesBody, ClientExt, EmptyBody, HeaderName, HeaderValue, Http2, ResponseExt};
use ts_keys::{MachineKeyPair, MachinePublicKey};
use url::Url;
use zerocopy::network_endian::U32;
use crate::tokio::prefixed_reader::PrefixedReader;
const CHALLENGE_MAGIC: [u8; 5] = [0xFF, 0xFF, 0xFF, b'T', b'S'];
const HANDSHAKE_HEADER_KEY: &str = "X-Tailscale-Handshake";
const MAX_CHALLENGE_LENGTH: usize = 1024;
const UPGRADE_HEADER_VALUE: &str = "tailscale-control-protocol";
lazy_static::lazy_static! {
pub static ref CONTROL_PROTOCOL_VERSION: String = format!("Tailscale Control Protocol v{}", CapabilityVersion::CURRENT);
}
#[derive(Debug, thiserror::Error, Clone, Copy, Eq, PartialEq)]
pub enum ConnectionError {
#[error("internal error during connection: {0}")]
Internal(InternalErrorKind),
#[error("Network error")]
NetworkError,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum InternalErrorKind {
Url,
Http,
SerDe,
MessageFormat,
Io,
ChallengeLength,
NoiseHandshake,
}
impl fmt::Display for InternalErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
InternalErrorKind::Url => write!(f, "URL parsing error"),
InternalErrorKind::Http => write!(f, "unsuccessful HTTP request or upgrade"),
InternalErrorKind::SerDe => write!(f, "serialization/deserialization error"),
InternalErrorKind::MessageFormat => write!(f, "message format error"),
InternalErrorKind::Io => write!(f, "I/O error"),
InternalErrorKind::ChallengeLength => write!(f, "challenge too long"),
InternalErrorKind::NoiseHandshake => write!(f, "error in Noise handshake"),
}
}
}
impl ConnectionError {
fn io_error(field: &'static str, stage: &'static str, err: std::io::Error) -> Self {
tracing::error!("could not read {field} from {stage} message: {err}");
if crate::is_network_error(&err) {
ConnectionError::NetworkError
} else {
ConnectionError::Internal(InternalErrorKind::Io)
}
}
}
impl From<serde_json::Error> for ConnectionError {
fn from(error: serde_json::Error) -> Self {
tracing::error!(%error, "deserialization error");
ConnectionError::Internal(InternalErrorKind::SerDe)
}
}
impl From<ts_http_util::Error> for ConnectionError {
fn from(error: ts_http_util::Error) -> Self {
tracing::error!(%error, "http error connecting to control server");
if crate::http_error_is_recoverable(error) {
ConnectionError::NetworkError
} else {
ConnectionError::Internal(InternalErrorKind::Http)
}
}
}
impl From<url::ParseError> for ConnectionError {
fn from(error: url::ParseError) -> Self {
tracing::error!(%error, "bad URL");
ConnectionError::Internal(InternalErrorKind::Url)
}
}
impl From<ts_control_noise::Error> for ConnectionError {
fn from(error: ts_control_noise::Error) -> Self {
match error {
ts_control_noise::Error::BadFormat => {
ConnectionError::Internal(InternalErrorKind::MessageFormat)
}
ts_control_noise::Error::HandshakeFailed => {
ConnectionError::Internal(InternalErrorKind::NoiseHandshake)
}
ts_control_noise::Error::Io(error) => {
tracing::error!(%error, "IO error in Noise communication");
ConnectionError::Internal(InternalErrorKind::Io)
}
}
}
}
impl From<ConnectionError> for crate::Error {
fn from(e: ConnectionError) -> Self {
match e {
ConnectionError::Internal(k) => {
crate::Error::Internal(k.into(), crate::Operation::ConnectToControlServer)
}
ConnectionError::NetworkError => {
crate::Error::NetworkError(crate::Operation::ConnectToControlServer)
}
}
}
}
impl From<InternalErrorKind> for crate::InternalErrorKind {
fn from(e: InternalErrorKind) -> Self {
match e {
InternalErrorKind::Url => crate::InternalErrorKind::Url,
InternalErrorKind::Http => crate::InternalErrorKind::Http,
InternalErrorKind::SerDe => crate::InternalErrorKind::SerDe,
InternalErrorKind::MessageFormat => crate::InternalErrorKind::MessageFormat,
InternalErrorKind::Io => crate::InternalErrorKind::Io,
InternalErrorKind::ChallengeLength => crate::InternalErrorKind::Challenge,
InternalErrorKind::NoiseHandshake => crate::InternalErrorKind::NoiseHandshake,
}
}
}
#[derive(serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "camelCase")]
struct ControlPublicKeys {
legacy_public_key: MachinePublicKey,
public_key: MachinePublicKey,
}
impl fmt::Display for ControlPublicKeys {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.public_key)
}
}
#[tracing::instrument(skip_all, fields(%control_url), err)]
pub async fn connect(
control_url: &Url,
machine_keys: &MachineKeyPair,
) -> Result<Http2<BytesBody>, ConnectionError> {
let h1_client = connect_h1(control_url).await?;
let control_public_key = fetch_control_key(control_url).await?;
let (handshake, init_msg) = ts_control_noise::Handshake::initialize(
&CONTROL_PROTOCOL_VERSION,
&machine_keys.private,
&control_public_key,
CapabilityVersion::CURRENT,
);
let conn = upgrade_ts2021(control_url, &init_msg, handshake, h1_client).await?;
let conn = read_challenge_packet(conn).await?;
let h2_conn = ts_http_util::http2::connect(conn).await?;
Ok(h2_conn)
}
async fn connect_h1(url: &Url) -> Result<ts_http_util::Http1<EmptyBody>, ConnectionError> {
if url.scheme() == "http" {
Ok(ts_http_util::http1::connect_tcp(url).await?)
} else {
Ok(ts_http_util::http1::connect_tls(url).await?)
}
}
#[tracing::instrument(skip_all, fields(%control_url), ret, err, level = "trace")]
pub async fn fetch_control_key(control_url: &Url) -> Result<MachinePublicKey, ConnectionError> {
let mut key_url = control_url.join("/key")?;
#[cfg(not(feature = "insecure-keyfetch"))]
key_url.set_scheme("https").unwrap();
if key_url.scheme() == "http" {
tracing::warn!("fetching control key over insecure http");
}
key_url
.query_pairs_mut()
.extend_pairs([("v", CapabilityVersion::CURRENT.to_string())]);
let client = connect_h1(&key_url).await?;
let response = client.get(&key_url, None).await?;
if !response.status().is_success() {
let status = response.status();
tracing::error!(
status_code = status.as_str(),
"failed to retrieve control server machine public key"
);
return Err(ConnectionError::Internal(InternalErrorKind::Http));
}
let control_keys: ControlPublicKeys = serde_json::from_slice(&response.collect_bytes().await?)?;
let control_public_key = control_keys.public_key;
Ok(control_public_key)
}
#[tracing::instrument(skip_all, fields(%control_url, %init_msg), err)]
pub async fn upgrade_ts2021(
control_url: &Url,
init_msg: &str,
mut handshake: ts_control_noise::Handshake,
h1_client: impl ts_http_util::Client<EmptyBody>,
) -> Result<impl AsyncRead + AsyncWrite + Unpin + 'static, ConnectionError> {
let ts2021_url = control_url.join("/ts2021")?;
tracing::trace!(
%ts2021_url,
"started NoiseIK handshake, upgrading to TS2021"
);
let resp = h1_client
.send(ts_http_util::make_upgrade_req(
&ts2021_url,
UPGRADE_HEADER_VALUE,
[(
HeaderName::from_str(HANDSHAKE_HEADER_KEY).unwrap(),
HeaderValue::from_str(init_msg).expect("handshake header is valid"),
)],
)?)
.await?;
let upgraded = ts_http_util::do_upgrade(resp).await.map_err(|error| {
tracing::error!(%error, "could not upgrade control connection to TS2021 protocol");
ConnectionError::Internal(InternalErrorKind::Http)
})?;
let conn = handshake.complete(upgraded).await?;
tracing::debug!("upgraded control connection from HTTP/1.1 to TS2021");
Ok(conn)
}
#[tracing::instrument(skip_all, err, level = "trace")]
pub async fn read_challenge_packet<Conn>(
mut conn: Conn,
) -> Result<PrefixedReader<Conn>, ConnectionError>
where
Conn: AsyncRead + Unpin,
{
let mut magic = [0u8; CHALLENGE_MAGIC.len()];
conn.read_exact(&mut magic)
.await
.map_err(|err| ConnectionError::io_error("header", "early_payload", err))?;
if magic != CHALLENGE_MAGIC {
return Ok(PrefixedReader::new(conn, Bytes::copy_from_slice(&magic)));
}
let mut challenge_len: U32 = 0.into();
conn.read_exact(challenge_len.as_mut())
.await
.map_err(|err| ConnectionError::io_error("length", "challenge", err))?;
let challenge_len = challenge_len.get() as usize;
if challenge_len > MAX_CHALLENGE_LENGTH {
tracing::error!(
challenge_len,
"invalid challenge length; must be less than {MAX_CHALLENGE_LENGTH} bytes"
);
return Err(ConnectionError::Internal(
InternalErrorKind::ChallengeLength,
));
}
let mut limited = conn.take(challenge_len as _);
tokio::io::copy(&mut limited, &mut tokio::io::sink())
.await
.map_err(|err| ConnectionError::io_error("body", "challenge", err))?;
tracing::trace!(
n_bytes = challenge_len,
"read and discarded early challenge payload"
);
Ok(PrefixedReader::new(
limited.into_inner(),
Default::default(),
))
}
#[cfg(test)]
mod tests {
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use super::*;
fn make_challenge(json: &[u8]) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&CHALLENGE_MAGIC);
buf.extend_from_slice(&(json.len() as u32).to_be_bytes());
buf.extend_from_slice(json);
buf
}
#[tokio::test]
async fn challenge_present() {
let json = b"{\"nodeKeyChallenge\":\"test\"}";
let payload = b"HTTP/2 data after challenge";
let mut data = make_challenge(json);
data.extend_from_slice(payload);
let (mut writer, reader) = duplex(1024);
writer.write_all(&data).await.unwrap();
drop(writer);
let mut conn = read_challenge_packet(reader).await.unwrap();
let mut out = Vec::new();
conn.read_to_end(&mut out).await.unwrap();
assert_eq!(out, payload);
}
#[tokio::test]
async fn challenge_absent() {
let payload = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
let (mut writer, reader) = duplex(1024);
writer.write_all(payload).await.unwrap();
drop(writer);
let mut conn = read_challenge_packet(reader).await.unwrap();
let mut out = Vec::new();
conn.read_to_end(&mut out).await.unwrap();
assert_eq!(out, payload);
}
}