use std::{
convert::Infallible,
fmt,
future::IntoFuture,
io,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
ops::{Deref, Range},
sync::Arc,
};
use axum::{body::Body, response::IntoResponse, routing::any_service};
use http::{HeaderValue, Method, Request, StatusCode, header};
use matrix_sdk_base::{boxed_into_future, locks::Mutex};
use matrix_sdk_common::executor::spawn;
use rand::{Rng, thread_rng};
use tokio::{net::TcpListener, sync::oneshot};
use tower::service_fn;
use url::Url;
const DEFAULT_PORT_RANGE: Range<u16> = 20000..30000;
const DEFAULT_BIND_TRIES: u8 = 10;
#[derive(Debug, Default, Clone)]
pub struct LocalServerBuilder {
ip_address: Option<LocalServerIpAddress>,
port_range: Option<Range<u16>>,
bind_tries: Option<u8>,
response: Option<LocalServerResponse>,
}
impl LocalServerBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn ip_address(mut self, ip_address: LocalServerIpAddress) -> Self {
self.ip_address = Some(ip_address);
self
}
pub fn port_range(mut self, range: Range<u16>) -> Self {
self.port_range = Some(range);
self
}
pub fn bind_tries(mut self, tries: u8) -> Self {
self.bind_tries = Some(tries);
self
}
pub fn response(mut self, response: LocalServerResponse) -> Self {
self.response = Some(response);
self
}
pub async fn spawn(self) -> Result<(Url, LocalServerRedirectHandle), io::Error> {
let Self { ip_address, port_range, bind_tries, response } = self;
let listener = {
let ip_addresses = ip_address.unwrap_or_default().ip_addresses();
let port_range = port_range.unwrap_or(DEFAULT_PORT_RANGE);
let bind_tries = bind_tries.unwrap_or(DEFAULT_BIND_TRIES);
let mut n = 0u8;
loop {
let port = thread_rng().gen_range(port_range.clone());
let socket_addresses =
ip_addresses.iter().map(|ip| SocketAddr::new(*ip, port)).collect::<Vec<_>>();
match TcpListener::bind(socket_addresses.as_slice()).await {
Ok(l) => {
break l;
}
Err(_) if n < bind_tries => {
n += 1;
}
Err(e) => {
return Err(e);
}
}
}
};
let socket_address =
listener.local_addr().expect("bound TCP listener should have an address");
let uri = Url::parse(&format!("http://{socket_address}/"))
.expect("socket address should parse as a URI host");
let (shutdown_signal_sender, shutdown_signal_receiver) = oneshot::channel::<()>();
let (data_sender, data_receiver) = oneshot::channel::<Option<QueryString>>();
let data_sender_mutex = Arc::new(Mutex::new(Some(data_sender)));
let router = any_service(service_fn(move |request: Request<_>| {
let data_sender_mutex = data_sender_mutex.clone();
let response = response.clone();
async move {
if request.method() != Method::HEAD && request.method() != Method::GET {
return Ok::<_, Infallible>(StatusCode::METHOD_NOT_ALLOWED.into_response());
}
if let Some(data_sender) = data_sender_mutex.lock().take() {
let _ =
data_sender.send(request.uri().query().map(|s| QueryString(s.to_owned())));
}
Ok(response.unwrap_or_default().into_response())
}
}));
let server = axum::serve(listener, router)
.with_graceful_shutdown(async {
shutdown_signal_receiver.await.ok();
})
.into_future();
spawn(server);
Ok((
uri,
LocalServerRedirectHandle {
data_receiver: Some(data_receiver),
shutdown_signal_sender: Arc::new(Mutex::new(Some(shutdown_signal_sender))),
},
))
}
}
#[allow(missing_debug_implementations)]
pub struct LocalServerRedirectHandle {
data_receiver: Option<oneshot::Receiver<Option<QueryString>>>,
shutdown_signal_sender: Arc<Mutex<Option<oneshot::Sender<()>>>>,
}
impl LocalServerRedirectHandle {
pub fn shutdown_handle(&self) -> LocalServerShutdownHandle {
LocalServerShutdownHandle(self.shutdown_signal_sender.clone())
}
}
impl Drop for LocalServerRedirectHandle {
fn drop(&mut self) {
if let Some(sender) = self.shutdown_signal_sender.lock().take() {
let _ = sender.send(());
}
}
}
impl IntoFuture for LocalServerRedirectHandle {
type Output = Option<QueryString>;
boxed_into_future!();
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let mut this = self;
let data_receiver =
this.data_receiver.take().expect("data receiver is set during construction");
data_receiver.await.ok().flatten()
})
}
}
#[cfg(not(tarpaulin_include))]
impl fmt::Debug for LocalServerRedirectHandle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LocalServerRedirectHandle").finish_non_exhaustive()
}
}
#[derive(Clone)]
#[allow(missing_debug_implementations)]
pub struct LocalServerShutdownHandle(Arc<Mutex<Option<oneshot::Sender<()>>>>);
impl LocalServerShutdownHandle {
pub fn shutdown(self) {
if let Some(sender) = self.0.lock().take() {
let _ = sender.send(());
}
}
}
#[cfg(not(tarpaulin_include))]
impl fmt::Debug for LocalServerShutdownHandle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LocalServerShutdownHandle").finish_non_exhaustive()
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum LocalServerIpAddress {
Localhostv4,
Localhostv6,
#[default]
LocalhostAny,
Custom(IpAddr),
}
impl LocalServerIpAddress {
fn ip_addresses(self) -> Vec<IpAddr> {
match self {
Self::Localhostv4 => vec![Ipv4Addr::LOCALHOST.into()],
Self::Localhostv6 => vec![Ipv6Addr::LOCALHOST.into()],
Self::LocalhostAny => vec![Ipv4Addr::LOCALHOST.into(), Ipv6Addr::LOCALHOST.into()],
Self::Custom(ip) => vec![ip],
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LocalServerResponse {
PlainText(String),
Html(String),
}
impl LocalServerResponse {
fn into_response(self) -> http::Response<Body> {
let (content_type, body) = match self {
Self::PlainText(body) => {
(HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()), body)
}
Self::Html(body) => (HeaderValue::from_static(mime::TEXT_HTML_UTF_8.as_ref()), body),
};
let mut response = Body::from(body).into_response();
response.headers_mut().insert(header::CONTENT_TYPE, content_type);
response
}
}
impl Default for LocalServerResponse {
fn default() -> Self {
LocalServerResponse::PlainText(
"The authorization step is complete. You can close this page.".to_owned(),
)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct QueryString(pub String);
impl AsRef<str> for QueryString {
fn as_ref(&self) -> &str {
&self.0
}
}
impl Deref for QueryString {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use http::header;
use matrix_sdk_test::async_test;
use crate::{
assert_let_timeout,
utils::local_server::{LocalServerBuilder, LocalServerIpAddress, LocalServerResponse},
};
#[async_test]
async fn test_local_server_builder_no_query() {
let (uri, server_handle) = LocalServerBuilder::new().spawn().await.unwrap();
let http_client = reqwest::Client::new();
http_client.get(uri.as_str()).send().await.unwrap();
assert_let_timeout!(None = server_handle);
}
#[async_test]
async fn test_local_server_builder_with_query() {
let (mut uri, server_handle) = LocalServerBuilder::new().spawn().await.unwrap();
uri.set_query(Some("foo=bar"));
let http_client = reqwest::Client::new();
http_client.get(uri.as_str()).send().await.unwrap();
assert_let_timeout!(Some(query) = server_handle);
assert_eq!(query.0, "foo=bar");
}
#[async_test]
async fn test_local_server_builder_with_ipv4_and_port() {
let (mut uri, server_handle) = LocalServerBuilder::new()
.ip_address(LocalServerIpAddress::Localhostv4)
.port_range(3000..3001)
.bind_tries(1)
.spawn()
.await
.unwrap();
uri.set_query(Some("foo=bar"));
assert_eq!(uri.host_str(), Some("127.0.0.1"));
assert_eq!(uri.port(), Some(3000));
let http_client = reqwest::Client::new();
http_client.get(uri.as_str()).send().await.unwrap();
assert_let_timeout!(Some(query) = server_handle);
assert_eq!(query.0, "foo=bar");
}
#[async_test]
async fn test_local_server_builder_with_ipv6_and_port() {
let (mut uri, server_handle) = LocalServerBuilder::new()
.ip_address(LocalServerIpAddress::Localhostv6)
.port_range(10000..10001)
.bind_tries(1)
.spawn()
.await
.unwrap();
uri.set_query(Some("foo=bar"));
assert_eq!(uri.host_str(), Some("[::1]"));
assert_eq!(uri.port(), Some(10000));
let http_client = reqwest::Client::new();
http_client.get(uri.as_str()).send().await.unwrap();
assert_let_timeout!(Some(query) = server_handle);
assert_eq!(query.0, "foo=bar");
}
#[async_test]
async fn test_local_server_builder_with_custom_ip_and_port() {
let (mut uri, server_handle) = LocalServerBuilder::new()
.ip_address(LocalServerIpAddress::Custom(Ipv4Addr::new(127, 0, 0, 1).into()))
.port_range(10040..10041)
.bind_tries(1)
.spawn()
.await
.unwrap();
uri.set_query(Some("foo=bar"));
assert_eq!(uri.host_str(), Some("127.0.0.1"));
assert_eq!(uri.port(), Some(10040));
let http_client = reqwest::Client::new();
http_client.get(uri.as_str()).send().await.unwrap();
assert_let_timeout!(Some(query) = server_handle);
assert_eq!(query.0, "foo=bar");
}
#[async_test]
async fn test_local_server_builder_with_custom_plain_text_response() {
let text = "Hello world!";
let (mut uri, server_handle) = LocalServerBuilder::new()
.response(LocalServerResponse::PlainText(text.to_owned()))
.spawn()
.await
.unwrap();
uri.set_query(Some("foo=bar"));
let http_client = reqwest::Client::new();
let response = http_client.get(uri.as_str()).send().await.unwrap();
let content_type = response.headers().get(header::CONTENT_TYPE).unwrap();
assert_eq!(content_type, "text/plain; charset=utf-8");
assert_eq!(response.text().await.unwrap(), text);
assert_let_timeout!(Some(query) = server_handle);
assert_eq!(query.0, "foo=bar");
}
#[async_test]
async fn test_local_server_builder_with_custom_html_response() {
let html = "<html><body><h1>Hello world!</h1></body></html>";
let (mut uri, server_handle) = LocalServerBuilder::new()
.response(LocalServerResponse::Html(html.to_owned()))
.spawn()
.await
.unwrap();
uri.set_query(Some("foo=bar"));
let http_client = reqwest::Client::new();
let response = http_client.get(uri.as_str()).send().await.unwrap();
let content_type = response.headers().get(header::CONTENT_TYPE).unwrap();
assert_eq!(content_type, "text/html; charset=utf-8");
assert_eq!(response.text().await.unwrap(), html);
assert_let_timeout!(Some(query) = server_handle);
assert_eq!(query.0, "foo=bar");
}
#[async_test]
async fn test_local_server_builder_early_shutdown() {
let (mut uri, server_handle) = LocalServerBuilder::new().spawn().await.unwrap();
uri.set_query(Some("foo=bar"));
server_handle.shutdown_handle().shutdown();
let http_client = reqwest::Client::new();
http_client.get(uri.as_str()).send().await.unwrap_err();
assert_let_timeout!(None = server_handle);
}
}