use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use hyper::Uri;
use hyper_util::{
client::legacy::connect::{Connected, Connection},
rt::TokioIo,
};
use tower_service::Service;
use crate::{Error, InternalErrorKind, loopback::OverlayDialer, netstack};
#[derive(Clone)]
pub struct TailnetConnector {
dialer: OverlayDialer,
}
impl TailnetConnector {
pub(crate) fn new(dialer: OverlayDialer) -> Self {
Self { dialer }
}
}
pub struct TailnetStream(TokioIo<netstack::TcpStream>);
impl hyper::rt::Read for TailnetStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: hyper::rt::ReadBufCursor<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
}
}
impl hyper::rt::Write for TailnetStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.get_mut().0).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.get_mut().0).poll_shutdown(cx)
}
}
impl Connection for TailnetStream {
fn connected(&self) -> Connected {
Connected::new()
}
}
impl Service<Uri> for TailnetConnector {
type Response = TailnetStream;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<TailnetStream, Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, uri: Uri) -> Self::Future {
let dialer = self.dialer.clone();
Box::pin(async move {
let (host, port) = host_port(&uri)?;
dialer
.dial_host_port(&host, port)
.await
.map(|stream| TailnetStream(TokioIo::new(stream)))
})
}
}
fn host_port(uri: &Uri) -> Result<(String, u16), Error> {
match uri.scheme_str() {
Some("http") | None => {}
_ => return Err(Error::Internal(InternalErrorKind::BadRequest)),
}
let host = uri
.host()
.ok_or(Error::Internal(InternalErrorKind::BadRequest))?;
let host = host
.strip_prefix('[')
.and_then(|h| h.strip_suffix(']'))
.unwrap_or(host)
.to_string();
let port = uri.port_u16().unwrap_or(80);
Ok((host, port))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn host_port_defaults_http_to_80() {
let (h, p) = host_port(&"http://peer/path".parse().unwrap()).unwrap();
assert_eq!((h.as_str(), p), ("peer", 80));
}
#[test]
fn host_port_rejects_https() {
for uri in [
"https://peer.tailnet.ts.net/",
"https://peer:443/",
"https://peer:8443/",
] {
assert!(
matches!(
host_port(&uri.parse().unwrap()).unwrap_err(),
Error::Internal(InternalErrorKind::BadRequest)
),
"https must be rejected: {uri}"
);
}
}
#[test]
fn host_port_rejects_wss_even_with_explicit_port() {
let err = host_port(&"wss://peer:443/".parse().unwrap()).unwrap_err();
assert!(matches!(
err,
Error::Internal(InternalErrorKind::BadRequest)
));
}
#[test]
fn host_port_explicit_port_wins() {
let (h, p) = host_port(&"http://peer:8080/".parse().unwrap()).unwrap();
assert_eq!((h.as_str(), p), ("peer", 8080));
}
#[test]
fn host_port_ipv4_literal() {
let (h, p) = host_port(&"http://100.64.0.1:9000/".parse().unwrap()).unwrap();
assert_eq!((h.as_str(), p), ("100.64.0.1", 9000));
}
#[test]
fn host_port_strips_ipv6_brackets() {
let (h, p) = host_port(&"http://[::1]:80/".parse().unwrap()).unwrap();
assert_eq!((h.as_str(), p), ("::1", 80));
}
#[test]
fn host_port_unknown_scheme_without_port_rejected() {
let err = host_port(&"ftp://peer/".parse().unwrap()).unwrap_err();
assert!(matches!(
err,
Error::Internal(InternalErrorKind::BadRequest)
));
}
}