use self::{
backend_state::{handle_backend_status, handle_backend_status_stream},
cluster_state::handle_cluster_state,
connect::handle_revoke,
dns::handle_dns_socket,
drain::handle_drain,
error::IntoApiError,
proxy::handle_proxy_socket,
};
use crate::{
cleanup,
controller::{connect::handle_connect, core::Controller, drone::handle_drone_socket},
database::{connect_and_migrate, PlaneDatabase},
heartbeat_consts::HEARTBEAT_INTERVAL,
signals::wait_for_shutdown_signal,
util::GuardHandle,
};
use anyhow::{Context, Result};
use axum::{
extract::State,
http::{header, Method},
response::Response,
routing::{get, post},
Json, Router,
};
use forward_auth::forward_layer;
use futures_util::never::Never;
use plane_common::{
names::ControllerName,
protocol::StatusResponse,
types::ClusterName,
version::{PLANE_GIT_HASH, PLANE_VERSION},
PlaneClient,
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::net::SocketAddr;
use tokio::{
net::TcpListener,
sync::oneshot::{self},
task::JoinHandle,
};
use tower_http::trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer};
use tower_http::{
cors::{Any, CorsLayer},
trace::DefaultOnFailure,
};
use tracing::Level;
use url::Url;
mod backend_state;
mod cluster_state;
pub mod command;
mod connect;
mod core;
mod dns;
mod drain;
mod drone;
pub mod error;
mod forward_auth;
mod proxy;
mod terminate;
const TERMINATE_TIMEOUT_DURATION: std::time::Duration = std::time::Duration::from_secs(2);
pub async fn status(
State(controller): State<Controller>,
) -> Result<Json<StatusResponse>, Response> {
controller
.db
.health_check()
.await
.or_internal_error("Database health check failed")?;
Ok(Json(StatusResponse {
status: "ok".to_string(),
version: PLANE_VERSION.to_string(),
hash: PLANE_GIT_HASH.to_string(),
}))
}
pub async fn health(State(controller): State<Controller>) -> Result<Json<Value>, Response> {
controller
.db
.health_check()
.await
.or_internal_error("Database health check failed")?;
Ok(Json(json!({
"status": "ok"
})))
}
struct HeartbeatSender {
handle: JoinHandle<Never>,
db: PlaneDatabase,
controller_id: ControllerName,
}
impl HeartbeatSender {
pub async fn start(db: PlaneDatabase, controller_id: ControllerName) -> Result<Self> {
db.controller().heartbeat(&controller_id, true).await?;
let db_clone = db.clone();
let controller_id_clone = controller_id.clone();
let handle: JoinHandle<Never> = tokio::spawn(async move {
loop {
tokio::time::sleep(HEARTBEAT_INTERVAL).await;
if let Err(err) = db_clone
.controller()
.heartbeat(&controller_id_clone, true)
.await
{
tracing::error!(?err, "Failed to send heartbeat");
}
}
});
Ok(Self {
handle,
db,
controller_id,
})
}
pub async fn terminate(&self) {
self.handle.abort();
if let Err(err) = self
.db
.controller()
.heartbeat(&self.controller_id, false)
.await
{
tracing::error!(?err, "Failed to send offline heartbeat");
}
}
}
pub struct ControllerServer {
bind_addr: SocketAddr,
controller_id: ControllerName,
graceful_terminate_sender: Option<oneshot::Sender<()>>,
heartbeat_handle: HeartbeatSender,
server_handle: Option<JoinHandle<Result<(), std::io::Error>>>,
_cleanup_handle: GuardHandle,
}
impl ControllerServer {
pub async fn run(config: ControllerConfig) -> Result<Self> {
let listener = TcpListener::bind(config.bind_addr).await?;
tracing::info!("Attempting to connect to database...");
let db = connect_and_migrate(&config.db_url)
.await
.context("Failed to connect to database and run migrations.")?;
tracing::info!("Connected to database. Listening for connections.");
Self::run_with_listener(
db,
listener,
config.id,
config.controller_url,
config.default_cluster,
config.cleanup_min_age_days,
config.cleanup_batch_size,
config.forward_auth,
)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn run_with_listener(
db: PlaneDatabase,
listener: TcpListener,
id: ControllerName,
controller_url: Url,
default_cluster: Option<ClusterName>,
cleanup_min_age_days: Option<i32>,
cleanup_batch_size: Option<i32>,
forward_auth: Option<Url>,
) -> Result<Self> {
let bind_addr = listener.local_addr()?;
let cleanup_handle = {
let db = db.clone();
GuardHandle::new(async move {
cleanup::run_cleanup_loop(db.clone(), cleanup_min_age_days, cleanup_batch_size)
.await
})
};
let (graceful_terminate_sender, graceful_terminate_receiver) =
tokio::sync::oneshot::channel::<()>();
let controller =
Controller::new(db.clone(), id.clone(), controller_url, default_cluster).await;
let trace_layer = TraceLayer::new_for_http()
.make_span_with(DefaultMakeSpan::new().level(Level::INFO))
.on_request(DefaultOnRequest::new().level(Level::DEBUG))
.on_failure(DefaultOnFailure::new().level(Level::WARN))
.on_response(DefaultOnResponse::new().level(Level::DEBUG));
let heartbeat_handle = HeartbeatSender::start(db.clone(), id.clone()).await?;
let mut control_routes = Router::new()
.route("/status", get(status))
.route("/c/:cluster/state", get(handle_cluster_state))
.route("/c/:cluster/drone-socket", get(handle_drone_socket))
.route("/c/:cluster/proxy-socket", get(handle_proxy_socket))
.route("/dns-socket", get(handle_dns_socket))
.route("/connect", post(handle_connect))
.route("/c/:cluster/d/:drone/drain", post(handle_drain))
.route(
"/b/:backend/soft-terminate",
post(terminate::handle_soft_terminate),
)
.route(
"/b/:backend/hard-terminate",
post(terminate::handle_hard_terminate),
)
.route(
"/b/revoke",
post(handle_revoke), );
if let Some(forward_auth_url) = forward_auth {
tracing::info!(?forward_auth_url, "Forward auth enabled");
let forward_url = forward_auth_url.clone();
control_routes = control_routes.layer(axum::middleware::from_fn_with_state(
forward_url.clone(),
forward_layer,
));
}
let cors_public = CorsLayer::new()
.allow_methods(vec![Method::GET, Method::POST])
.allow_headers(vec![header::CONTENT_TYPE])
.allow_origin(Any);
let public_routes = Router::new()
.route("/b/:backend/status", get(handle_backend_status))
.route(
"/b/:backend/status-stream",
get(handle_backend_status_stream),
)
.route("/health", get(health))
.layer(cors_public.clone());
let app = Router::new()
.nest("/pub", public_routes)
.nest("/ctrl", control_routes)
.layer(trace_layer)
.with_state(controller);
let server_handle = tokio::spawn(async {
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(async {
graceful_terminate_receiver.await.ok();
})
.await
});
Ok(Self {
graceful_terminate_sender: Some(graceful_terminate_sender),
heartbeat_handle,
server_handle: Some(server_handle),
controller_id: id,
bind_addr,
_cleanup_handle: cleanup_handle,
})
}
pub async fn terminate(&mut self) {
self.heartbeat_handle.terminate().await;
tracing::info!("Initiating graceful shutdown of server");
let Some(graceful_terminate_sender) = self.graceful_terminate_sender.take() else {
return;
};
if let Err(err) = graceful_terminate_sender.send(()) {
tracing::error!(?err, "Failed to send graceful terminate signal");
} else {
let Some(server_handle) = self.server_handle.take() else {
return;
};
match tokio::time::timeout(TERMINATE_TIMEOUT_DURATION, server_handle).await {
Ok(Ok(Ok(()))) => {
tracing::info!("Server gracefully terminated");
}
Ok(Ok(Err(err))) => {
tracing::error!(?err, "Server error");
}
Ok(Err(err)) => {
tracing::error!(?err, "Server error");
}
Err(_) => {
tracing::warn!("Server did not terminate gracefully in time, forcing shutdown");
}
}
}
}
pub fn id(&self) -> &ControllerName {
&self.controller_id
}
pub fn url(&self) -> Url {
let base_url: Url = format!("http://{}", self.bind_addr)
.parse()
.expect("Generated URI is always valid.");
base_url
}
pub fn client(&self) -> PlaneClient {
let base_url: Url = self.url();
PlaneClient::new(base_url)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ControllerConfig {
pub db_url: String,
pub bind_addr: SocketAddr,
pub id: ControllerName,
pub controller_url: Url,
pub default_cluster: Option<ClusterName>,
pub cleanup_min_age_days: Option<i32>,
pub cleanup_batch_size: Option<i32>,
pub forward_auth: Option<Url>,
}
pub async fn run_controller(config: ControllerConfig) -> Result<()> {
let mut server = ControllerServer::run(config).await?;
wait_for_shutdown_signal().await;
server.terminate().await;
Ok(())
}