use super::AppState;
#[cfg(any(test, feature = "testing"))]
use crate::MockDataDir;
use crate::{
app_context::{AppContext, AppContextConversionError},
PersistentDataDir,
};
use anyhow::Result;
use futures_util::TryFutureExt;
use pubky_common::auth::AuthVerifier;
use std::net::TcpListener;
use std::path::PathBuf;
use std::time::Duration;
use axum::{
routing::{get, post},
Router,
};
use axum_server::{
tls_rustls::{RustlsAcceptor, RustlsConfig},
Handle,
};
use std::{net::SocketAddr, sync::Arc};
use tower_cookies::CookieManagerLayer;
use tower_http::cors::CorsLayer;
use super::layers::{
pubky_host::PubkyHostLayer,
rate_limiter::{BandwidthQuotaLimitLayer, RequestRateLimitLayer},
trace::with_trace_layer,
};
use super::routes::{auth, events, root, signup_tokens, tenants};
#[derive(Debug, thiserror::Error)]
pub enum ClientServerBuildError {
#[error("ICANN web server error: {0}")]
IcannWebServer(anyhow::Error),
#[error("Pubky TLS web server error: {0}")]
PubkyTlsServer(anyhow::Error),
#[error("AppContext conversion error: {0}")]
AppContext(#[from] AppContextConversionError),
#[error("Request-count rate limit configuration error: {0}")]
RequestRateLimits(String),
}
pub struct ClientServer {
context: AppContext,
pub(crate) icann_http_handle: Handle<SocketAddr>,
pub(crate) icann_http_socket: SocketAddr,
pub(crate) pubky_tls_handle: Handle<SocketAddr>,
pub(crate) pubky_tls_socket: SocketAddr,
}
impl ClientServer {
pub async fn start_with_persistent_data_dir_path(
dir_path: PathBuf,
) -> Result<Self, ClientServerBuildError> {
let data_dir = PersistentDataDir::new(dir_path);
let context = AppContext::read_from(data_dir).await?;
Self::start(context).await
}
pub async fn start_with_persistent_data_dir(
dir: PersistentDataDir,
) -> Result<Self, ClientServerBuildError> {
let context = AppContext::read_from(dir).await?;
Self::start(context).await
}
#[cfg(any(test, feature = "testing"))]
pub async fn start_with_mock_data_dir(
dir: MockDataDir,
) -> Result<Self, ClientServerBuildError> {
let context = AppContext::read_from(dir).await?;
Self::start(context).await
}
pub async fn start(context: AppContext) -> std::result::Result<Self, ClientServerBuildError> {
let router = Self::create_router(&context)?;
let (icann_http_handle, icann_http_socket) =
Self::start_icann_http_server(&context, router.clone())
.await
.map_err(ClientServerBuildError::IcannWebServer)?;
let (pubky_tls_handle, pubky_tls_socket) = Self::start_pubky_tls_server(&context, router)
.await
.map_err(ClientServerBuildError::PubkyTlsServer)?;
Ok(Self {
context,
icann_http_handle,
pubky_tls_handle,
icann_http_socket,
pubky_tls_socket,
})
}
pub(crate) fn create_router(
context: &AppContext,
) -> std::result::Result<Router, ClientServerBuildError> {
let state = AppState {
verifier: AuthVerifier::default(),
sql_db: context.sql_db.clone(),
file_service: context.file_service.clone(),
signup_mode: context.config_toml.general.signup_mode.clone(),
metrics: context.metrics.clone(),
events_service: context.events_service.clone(),
user_service: context.user_service.clone(),
default_storage_mb: context.config_toml.storage.default_quota_mb,
};
super::create_app(state.clone(), context)
}
async fn start_icann_http_server(
context: &AppContext,
router: Router,
) -> Result<(Handle<SocketAddr>, SocketAddr)> {
let http_listener = TcpListener::bind(context.config_toml.drive.icann_listen_socket)?;
http_listener.set_nonblocking(true)?;
let http_socket = http_listener.local_addr()?;
let http_handle = Handle::new();
let server = axum_server::from_tcp(http_listener)?;
tokio::spawn(
server
.handle(http_handle.clone())
.serve(router.into_make_service_with_connect_info::<SocketAddr>())
.map_err(|error| {
tracing::error!(?error, "Homeserver icann http server error");
println!("Homeserver icann http server error: {:?}", error);
}),
);
Ok((http_handle, http_socket))
}
async fn start_pubky_tls_server(
context: &AppContext,
router: Router,
) -> Result<(Handle<SocketAddr>, SocketAddr)> {
let https_listener = TcpListener::bind(context.config_toml.drive.pubky_listen_socket)?;
https_listener.set_nonblocking(true)?;
let https_socket = https_listener.local_addr()?;
let https_handle = Handle::new();
let server = axum_server::from_tcp(https_listener)?;
tokio::spawn(
server
.acceptor(RustlsAcceptor::new(RustlsConfig::from_config(Arc::new(
context.keypair.to_rpk_rustls_server_config(),
))))
.handle(https_handle.clone())
.serve(router.into_make_service_with_connect_info::<SocketAddr>())
.map_err(|error| {
tracing::error!(?error, "Homeserver pubky tls server error");
println!("Homeserver pubky tls server error: {:?}", error);
}),
);
Ok((https_handle, https_socket))
}
pub fn icann_http_url_string(&self) -> String {
format!("http://{}", self.icann_http_socket)
}
pub fn pubky_tls_dns_url_string(&self) -> String {
format!("https://{}", self.context.keypair.public_key().z32())
}
pub fn pubky_tls_ip_url_ring(&self) -> String {
format!("https://{}", self.pubky_tls_socket)
}
pub fn shutdown(&self) {
self.icann_http_handle
.graceful_shutdown(Some(Duration::from_secs(5)));
self.pubky_tls_handle
.graceful_shutdown(Some(Duration::from_secs(5)));
}
}
impl Drop for ClientServer {
fn drop(&mut self) {
self.shutdown();
}
}
fn base() -> Router<AppState> {
Router::new()
.route("/", get(root::handler))
.route("/signup", post(auth::signup))
.route("/signup_tokens/{token}", get(signup_tokens::get))
.route("/session", post(auth::signin))
.route("/events/", get(events::feed))
.route("/events-stream", get(events::feed_stream))
}
pub fn create_app(
state: AppState,
context: &AppContext,
) -> std::result::Result<Router, ClientServerBuildError> {
let request_rate_limit_layer =
RequestRateLimitLayer::from_path_limits(context.config_toml.drive.rate_limits.clone())
.map_err(ClientServerBuildError::RequestRateLimits)?;
let app = base()
.merge(tenants::router(state.clone()))
.layer(CorsLayer::very_permissive())
.layer(BandwidthQuotaLimitLayer::new(
context.user_service.clone(),
context.config_toml.default_quotas.clone(),
))
.layer(request_rate_limit_layer)
.layer(CookieManagerLayer::new())
.layer(PubkyHostLayer)
.with_state(state);
Ok(with_trace_layer(app))
}