use std::collections::BTreeMap;
use std::sync::Arc;
use axum::body::Body;
use axum::extract::{Path, Query, Request, State, WebSocketUpgrade};
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
use axum::response::{IntoResponse, Response};
use axum::routing::{any, get};
use axum::{Json, Router};
use futures::future::BoxFuture;
use serde::Deserialize;
use serde_json::Value as JsonValue;
use uuid::Uuid;
use super::core::Core;
use super::render::{self, Hrefs};
use crate::core::TransitionInput;
use crate::siren::SIREN_CONTENT_TYPE;
pub(crate) type PeerHandler =
Arc<dyn Fn(String, Uuid, hyper::upgrade::Upgraded) -> BoxFuture<'static, ()> + Send + Sync>;
#[async_trait::async_trait]
pub trait PeerSenders: Send + Sync + 'static {
async fn sender(
&self,
name: &str,
) -> Option<hyper::client::conn::http2::SendRequest<axum::body::Body>>;
async fn names(&self) -> Vec<String>;
async fn has_active_peer(&self, name: &str) -> bool {
self.names().await.iter().any(|n| n == name)
}
}
#[derive(Debug, Clone, Default)]
pub(crate) struct DeviceRegistration {
pub type_: String,
pub name: Option<String>,
pub id: Option<Uuid>,
pub fields: std::collections::HashMap<String, String>,
}
pub(crate) type DeviceRegistrar = Arc<
dyn Fn(DeviceRegistration) -> BoxFuture<'static, Result<Uuid, crate::core::DeviceError>>
+ Send
+ Sync,
>;
#[derive(Clone, Default)]
pub(crate) struct PeerInitState {
inner: Arc<std::sync::Mutex<std::collections::HashMap<Uuid, ()>>>,
}
impl PeerInitState {
pub fn register(&self, id: Uuid) {
self.inner.lock().unwrap().insert(id, ());
}
pub fn consume(&self, id: &Uuid) -> bool {
self.inner.lock().unwrap().remove(id).is_some()
}
}
#[derive(Clone)]
pub(crate) struct AppState {
pub core: Arc<Core>,
pub peer_handler: Option<PeerHandler>,
pub peer_init: PeerInitState,
pub peer_senders: Option<Arc<dyn PeerSenders>>,
pub peer_streams: super::peer_streams::PeerStreamHub,
pub device_registrar: Option<DeviceRegistrar>,
}
pub fn router(core: Arc<Core>) -> Router {
router_with(AppState {
core,
peer_handler: None,
peer_init: PeerInitState::default(),
peer_senders: None,
peer_streams: super::peer_streams::PeerStreamHub::new(),
device_registrar: None,
})
}
pub(crate) fn router_with(state: AppState) -> Router {
Router::new()
.route("/", get(root))
.route("/servers/{name}", get(server_get))
.route(
"/servers/{name}/devices",
get(devices_get).post(devices_post),
)
.route(
"/servers/{name}/devices/{id}",
get(device_get).post(device_post),
)
.route("/servers/{name}/meta", get(meta_get))
.route("/servers/{name}/meta/{type}", get(meta_type_get))
.route("/servers/{name}/events", get(server_events_stream))
.route(
"/servers/{name}/events/unsubscribe",
axum::routing::post(events_unsubscribe),
)
.route("/peer-management", get(peer_management_get))
.route("/events", get(events_ws))
.route("/peers/{name}", any(peers_upgrade))
.route("/_initiate_peer/{id}", get(initiate_peer))
.with_state(state)
}
async fn maybe_forward_or_404(
state: &AppState,
target_name: &str,
method: Method,
uri: &Uri,
headers: &HeaderMap,
body: Body,
) -> Option<Response> {
if target_name == state.core.name {
return None;
}
let senders = state.peer_senders.as_ref()?;
let mut sender = senders.sender(target_name).await?;
let path_and_query = uri
.path_and_query()
.map(|p| p.as_str())
.unwrap_or(uri.path());
let target_uri = format!(
"http://{}.peer.boardwalk.invalid{}",
urlencoding::encode(target_name),
path_and_query
);
let mut builder = http::Request::builder().method(method).uri(target_uri);
for (name, value) in headers.iter() {
if name == http::header::HOST {
continue;
}
builder = builder.header(name.clone(), value.clone());
}
let req = match builder.body(body) {
Ok(r) => r,
Err(e) => {
return Some(
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("forward build: {e}"),
)
.into_response(),
);
}
};
tracing::debug!(
peer = %target_name,
method = %req.method(),
path = %path_and_query,
"forwarding request to peer"
);
let resp = match sender.send_request(req).await {
Ok(r) => r,
Err(e) => {
tracing::warn!(
peer = %target_name,
error = %e,
"peer forward send failed"
);
return Some((StatusCode::BAD_GATEWAY, format!("peer forward: {e}")).into_response());
}
};
let (parts, incoming) = resp.into_parts();
let mut out = Response::builder().status(parts.status);
for (name, value) in parts.headers.iter() {
if name == http::header::TRANSFER_ENCODING {
continue;
}
out = out.header(name.clone(), value.clone());
}
match out.body(Body::new(incoming)) {
Ok(r) => Some(r),
Err(e) => Some((StatusCode::INTERNAL_SERVER_ERROR, format!("{e}")).into_response()),
}
}
async fn maybe_forward_get_or_404(
state: &AppState,
target_name: &str,
uri: &Uri,
headers: &HeaderMap,
) -> Option<Response> {
maybe_forward_or_404(state, target_name, Method::GET, uri, headers, Body::empty()).await
}
fn build_hrefs(headers: &HeaderMap, uri: &Uri, server: &str) -> Hrefs {
let host = headers
.get(axum::http::header::HOST)
.and_then(|v| v.to_str().ok())
.unwrap_or("localhost");
let scheme = headers
.get("x-forwarded-proto")
.and_then(|v| v.to_str().ok())
.or_else(|| uri.scheme_str())
.unwrap_or("http");
let http_base: url::Url = format!("{scheme}://{host}/").parse().unwrap();
let ws_scheme = if scheme == "https" { "wss" } else { "ws" };
let ws_base: url::Url = format!("{ws_scheme}://{host}/").parse().unwrap();
Hrefs {
http: http_base,
ws: ws_base,
server: server.to_string(),
}
}
fn siren_response(entity: crate::siren::Entity) -> Response {
let body = serde_json::to_vec(&entity).unwrap();
let mut resp = Response::new(axum::body::Body::from(body));
resp.headers_mut().insert(
axum::http::header::CONTENT_TYPE,
HeaderValue::from_static(SIREN_CONTENT_TYPE),
);
resp
}
#[derive(Debug, Deserialize)]
struct QueryParams {
ql: Option<String>,
server: Option<String>,
}
async fn root(
State(state): State<AppState>,
headers: HeaderMap,
uri: Uri,
Query(params): Query<QueryParams>,
) -> Response {
let core = state.core.clone();
let h = build_hrefs(&headers, &uri, &core.name);
if let Some(ql) = params.ql {
let devices = core.list_devices().await;
let filtered = filter_by_ql(&devices, &ql);
return siren_response(render::render_search_results(&h, &ql, &filtered));
}
let _ = params.server;
let peers = match &state.peer_senders {
Some(p) => p.names().await,
None => Vec::new(),
};
siren_response(render::render_root(&core, &h, &peers))
}
async fn server_get(
State(state): State<AppState>,
Path(name): Path<String>,
headers: HeaderMap,
uri: Uri,
Query(params): Query<QueryParams>,
) -> Response {
if let Some(r) = maybe_forward_get_or_404(&state, &name, &uri, &headers).await {
return r;
}
let core = state.core.clone();
let h = build_hrefs(&headers, &uri, &core.name);
let devices = core.list_devices().await;
if let Some(ql) = params.ql {
let filtered = filter_by_ql(&devices, &ql);
return siren_response(render::render_search_results(&h, &ql, &filtered));
}
siren_response(render::render_server(&h, &devices))
}
async fn devices_get(
State(state): State<AppState>,
Path(name): Path<String>,
headers: HeaderMap,
uri: Uri,
) -> Response {
if let Some(r) = maybe_forward_get_or_404(&state, &name, &uri, &headers).await {
return r;
}
let core = state.core.clone();
let h = build_hrefs(&headers, &uri, &core.name);
let devices = core.list_devices().await;
siren_response(render::render_server(&h, &devices))
}
async fn devices_post(
State(state): State<AppState>,
Path(name): Path<String>,
headers: HeaderMap,
uri: Uri,
body_bytes: bytes::Bytes,
) -> Response {
if name != state.core.name {
let body = Body::from(body_bytes);
return maybe_forward_or_404(&state, &name, Method::POST, &uri, &headers, body)
.await
.unwrap_or_else(|| (StatusCode::NOT_FOUND, "unknown server").into_response());
}
let Some(registrar) = state.device_registrar.clone() else {
return (
StatusCode::NOT_IMPLEMENTED,
"no factories registered; call Boardwalk::register_factory(...)",
)
.into_response();
};
let pairs: Vec<(String, String)> = match serde_urlencoded::from_bytes(&body_bytes) {
Ok(p) => p,
Err(e) => return (StatusCode::BAD_REQUEST, format!("bad form: {e}")).into_response(),
};
let mut reg = DeviceRegistration::default();
for (k, v) in pairs {
match k.as_str() {
"type" => reg.type_ = v,
"name" => reg.name = Some(v),
"id" => reg.id = Uuid::parse_str(&v).ok(),
_ => {
reg.fields.insert(k, v);
}
}
}
if reg.type_.is_empty() {
return (StatusCode::BAD_REQUEST, "missing `type` field").into_response();
}
let new_id = match registrar(reg).await {
Ok(id) => id,
Err(crate::core::DeviceError::Invalid(msg)) => {
return (StatusCode::BAD_REQUEST, msg).into_response();
}
Err(crate::core::DeviceError::Conflict(msg)) => {
return (StatusCode::CONFLICT, msg).into_response();
}
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("{e}")).into_response(),
};
let h = build_hrefs(&headers, &uri, &state.core.name);
let Some(snap) = state.core.get_device(&new_id).await else {
return (
StatusCode::INTERNAL_SERVER_ERROR,
"device missing after register",
)
.into_response();
};
let mut resp = siren_response(render::render_device(&h, &snap));
*resp.status_mut() = StatusCode::CREATED;
resp
}
async fn device_get(
State(state): State<AppState>,
Path((name, id)): Path<(String, String)>,
headers: HeaderMap,
uri: Uri,
) -> Response {
if let Some(r) = maybe_forward_get_or_404(&state, &name, &uri, &headers).await {
return r;
}
let core = state.core.clone();
let h = build_hrefs(&headers, &uri, &core.name);
let id = match uuid::Uuid::parse_str(&id) {
Ok(id) => id,
Err(_) => return (StatusCode::BAD_REQUEST, "invalid device id").into_response(),
};
match core.get_device(&id).await {
Some(d) => siren_response(render::render_device(&h, &d)),
None => (StatusCode::NOT_FOUND, "unknown device").into_response(),
}
}
async fn device_post(
State(state): State<AppState>,
Path((name, id)): Path<(String, String)>,
headers: HeaderMap,
uri: Uri,
body: String,
) -> Response {
if name != state.core.name {
if state.peer_senders.is_some() {
return maybe_forward_or_404(
&state,
&name,
Method::POST,
&uri,
&headers,
Body::from(body),
)
.await
.unwrap_or_else(|| (StatusCode::NOT_FOUND, "unknown server").into_response());
}
return (StatusCode::NOT_FOUND, "unknown server").into_response();
}
let core = state.core.clone();
let id = match uuid::Uuid::parse_str(&id) {
Ok(id) => id,
Err(_) => return (StatusCode::BAD_REQUEST, "invalid device id").into_response(),
};
let h = build_hrefs(&headers, &uri, &core.name);
let pairs: Vec<(String, String)> = match serde_urlencoded::from_str(&body) {
Ok(v) => v,
Err(e) => return (StatusCode::BAD_REQUEST, format!("bad form body: {e}")).into_response(),
};
let mut map: BTreeMap<String, JsonValue> = BTreeMap::new();
let mut action_name = None;
for (k, v) in pairs {
if k == "action" {
action_name = Some(v);
} else {
map.insert(k, JsonValue::String(v));
}
}
let action_name = match action_name {
Some(n) => n,
None => return (StatusCode::BAD_REQUEST, "missing `action` field").into_response(),
};
let input = TransitionInput { fields: map };
match core.run_transition(&id, &action_name, input).await {
Ok(snap) => siren_response(render::render_device(&h, &snap)),
Err(crate::core::DeviceError::NotAllowed(_)) => (
StatusCode::CONFLICT,
"transition not allowed in current state",
)
.into_response(),
Err(crate::core::DeviceError::Invalid(msg)) => {
(StatusCode::BAD_REQUEST, msg).into_response()
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("{e}")).into_response(),
}
}
#[derive(Debug, Deserialize)]
struct EventsQuery {
topic: Option<String>,
}
async fn events_unsubscribe(
State(state): State<AppState>,
Path(name): Path<String>,
headers: HeaderMap,
uri: Uri,
body: Body,
) -> Response {
if name != state.core.name {
return maybe_forward_or_404(&state, &name, Method::POST, &uri, &headers, body)
.await
.unwrap_or_else(|| (StatusCode::NOT_FOUND, "unknown server").into_response());
}
(StatusCode::ACCEPTED, "").into_response()
}
async fn server_events_stream(
State(state): State<AppState>,
Path(name): Path<String>,
headers: HeaderMap,
uri: Uri,
Query(q): Query<EventsQuery>,
) -> Response {
if let Some(r) = maybe_forward_get_or_404(&state, &name, &uri, &headers).await {
return r;
}
let topic = match q.topic {
Some(t) => t,
None => return (StatusCode::BAD_REQUEST, "missing ?topic=").into_response(),
};
let pattern = match crate::events::TopicPattern::parse(&topic) {
Ok(p) => p,
Err(e) => return (StatusCode::BAD_REQUEST, format!("topic: {e}")).into_response(),
};
let sub = state
.core
.bus
.subscribe(pattern, crate::events::SubscribeOpts::default());
let mut rx = sub.rx;
let stream = async_stream::stream! {
while let Some(ev) = rx.recv().await {
let line = match serde_json::to_string(&serde_json::json!({
"topic": ev.topic,
"timestamp": ev.timestamp_ms,
"data": ev.data,
})) {
Ok(s) => s,
Err(_) => continue,
};
yield Ok::<_, std::convert::Infallible>(format!("{line}\n"));
}
};
let body = Body::from_stream(stream);
let mut resp = Response::new(body);
resp.headers_mut().insert(
http::header::CONTENT_TYPE,
HeaderValue::from_static("application/x-ndjson"),
);
resp
}
async fn meta_get(
State(state): State<AppState>,
Path(name): Path<String>,
headers: HeaderMap,
uri: Uri,
) -> Response {
if let Some(r) = maybe_forward_get_or_404(&state, &name, &uri, &headers).await {
return r;
}
let core = state.core.clone();
let h = build_hrefs(&headers, &uri, &core.name);
let devices = core.list_devices().await;
siren_response(render::render_meta(&h, &devices))
}
async fn meta_type_get(
State(state): State<AppState>,
Path((name, ty)): Path<(String, String)>,
headers: HeaderMap,
uri: Uri,
) -> Response {
if let Some(r) = maybe_forward_get_or_404(&state, &name, &uri, &headers).await {
return r;
}
let core = state.core.clone();
let h = build_hrefs(&headers, &uri, &core.name);
let devices = core.list_devices().await;
let dev = devices.iter().find(|d| d.type_ == ty);
match dev {
Some(d) => siren_response(
crate::siren::Entity::new()
.with_class("type")
.with_property("type", JsonValue::String(d.type_.clone()))
.with_link(crate::siren::Link::new(
crate::siren::rels::SELF,
h.meta_type_url(&d.type_),
)),
),
None => (StatusCode::NOT_FOUND, "unknown type").into_response(),
}
}
async fn peer_management_get() -> Response {
Json(serde_json::json!({
"class": ["peer-management"],
"actions": [],
"entities": [],
"links": [],
}))
.into_response()
}
pub const EVENTS_SUBPROTOCOL: &str = "boardwalk-events/1";
async fn events_ws(State(state): State<AppState>, ws: WebSocketUpgrade) -> Response {
ws.protocols([EVENTS_SUBPROTOCOL])
.on_upgrade(move |socket| super::ws::handle_socket(socket, state))
}
#[derive(Debug, Deserialize)]
struct PeerQuery {
#[serde(rename = "connectionId")]
connection_id: Option<Uuid>,
}
async fn peers_upgrade(
State(state): State<AppState>,
Path(peer_name): Path<String>,
Query(query): Query<PeerQuery>,
mut req: Request<Body>,
) -> Response {
let connection_id = match query.connection_id {
Some(id) => id,
None => return (StatusCode::BAD_REQUEST, "missing connectionId").into_response(),
};
let handler = match state.peer_handler.clone() {
Some(h) => h,
None => return (StatusCode::SERVICE_UNAVAILABLE, "peering disabled").into_response(),
};
if let Some(senders) = &state.peer_senders
&& senders.has_active_peer(&peer_name).await
{
return (
StatusCode::CONFLICT,
format!("peer `{peer_name}` is already connected"),
)
.into_response();
}
let upgrade_response = match boardwalk_tunnel_upgrade_response(req.headers()) {
Ok(r) => r,
Err(e) => return (StatusCode::BAD_REQUEST, format!("upgrade: {e}")).into_response(),
};
let on_upgrade = hyper::upgrade::on(&mut req);
tokio::spawn(async move {
match on_upgrade.await {
Ok(upgraded) => handler(peer_name, connection_id, upgraded).await,
Err(e) => tracing::warn!(%e, "peer upgrade failed"),
}
});
upgrade_response
}
fn boardwalk_tunnel_upgrade_response(headers: &HeaderMap) -> Result<Response, String> {
let resp = crate::tunnel::build_upgrade_response(headers).map_err(|e| format!("{e}"))?;
let (parts, _) = resp.into_parts();
let mut r = Response::builder().status(parts.status);
for (name, value) in parts.headers.iter() {
r = r.header(name.clone(), value.clone());
}
r.body(Body::empty()).map_err(|e| format!("{e}"))
}
async fn initiate_peer(State(state): State<AppState>, Path(id): Path<String>) -> Response {
let id = match Uuid::parse_str(&id) {
Ok(id) => id,
Err(_) => return (StatusCode::BAD_REQUEST, "invalid id").into_response(),
};
if state.peer_init.consume(&id) {
(StatusCode::OK, "ok").into_response()
} else {
(StatusCode::NOT_FOUND, "unknown connection id").into_response()
}
}
fn filter_by_ql(
devices: &[super::core::DeviceSnapshot],
ql: &str,
) -> Vec<super::core::DeviceSnapshot> {
let q = match crate::caql::parse(ql) {
Ok(q) => q,
Err(_) => return Vec::new(),
};
devices
.iter()
.filter(|d| {
let target = serde_json::json!({
"id": d.id.to_string(),
"type": d.type_,
"name": d.name,
"state": d.state,
});
crate::caql::matches(&q, &target).unwrap_or(false)
})
.cloned()
.collect()
}