use axum::{
body::Body,
extract::State,
http::{Request, StatusCode},
middleware::{self, Next},
response::{Html, IntoResponse, Response},
routing::{get, post},
Json, Router,
};
use serde::{Deserialize, Serialize};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug, Clone, Serialize)]
pub struct VariantInfo {
pub name: &'static str,
pub value: u64,
}
#[derive(Debug, Clone, Serialize)]
pub struct BitfieldInfo {
pub name: &'static str,
pub doc: &'static str,
pub lo: u32,
pub hi: u32,
pub field_type: &'static str,
pub variants: Vec<VariantInfo>,
}
#[derive(Debug, Clone, Serialize)]
pub struct RegisterInfo {
pub name: String,
pub doc: &'static str,
pub offset: usize,
pub access: &'static str,
pub width: usize,
pub bitfields: Vec<BitfieldInfo>,
}
pub trait RegisterMapInfo {
fn map_name(&self) -> &'static str;
fn bus_width(&self) -> usize;
fn base_address(&self) -> usize;
fn registers(&self) -> Vec<RegisterInfo>;
fn read_register(&self, offset: usize) -> Option<u64>;
fn write_register(&mut self, offset: usize, value: u64) -> Option<()>;
}
#[derive(Serialize)]
struct RegisterMapDescription {
name: &'static str,
bus_width: usize,
base_address: usize,
registers: Vec<RegisterInfo>,
}
#[derive(Deserialize)]
struct ReadReq {
offset: usize,
}
#[derive(Serialize)]
struct ReadResp {
value: u64,
}
#[derive(Deserialize)]
struct WriteReq {
offset: usize,
value: u64,
}
pub fn ct_eq(a: &str, b: &str) -> bool {
use subtle::ConstantTimeEq;
a.as_bytes().ct_eq(b.as_bytes()).into()
}
pub type AuthFuture = Pin<Box<dyn Future<Output = bool> + Send>>;
type AuthFn = Arc<dyn Fn(String, String) -> AuthFuture + Send + Sync>;
fn extract_basic_credentials(req: &Request<Body>) -> Option<(String, String)> {
let header = req.headers().get("Authorization")?.to_str().ok()?;
let b64 = header.strip_prefix("Basic ")?;
use base64::Engine;
let bytes = base64::engine::general_purpose::STANDARD.decode(b64).ok()?;
let decoded = String::from_utf8(bytes).ok()?;
let (user, pass) = decoded.split_once(':')?;
Some((user.to_owned(), pass.to_owned()))
}
fn unauthorized_response() -> Response {
(
StatusCode::UNAUTHORIZED,
[("WWW-Authenticate", "Basic realm=\"ddevmem register map\"")],
"Unauthorized",
)
.into_response()
}
async fn index_page() -> Html<&'static str> {
Html(include_str!(concat!(env!("OUT_DIR"), "/web_ui.min.html")))
}
type DynMap = Arc<Mutex<dyn RegisterMapInfo + Send>>;
struct WebUiState {
maps: Vec<(String, DynMap)>,
auth: Option<AuthFn>,
title: Option<String>,
}
impl Clone for WebUiState {
fn clone(&self) -> Self {
Self {
maps: self.maps.clone(),
auth: self.auth.clone(),
title: self.title.clone(),
}
}
}
pub struct WebUi {
maps: Vec<(String, DynMap)>,
auth: Option<AuthFn>,
title: Option<String>,
}
impl Default for WebUi {
fn default() -> Self {
Self::new()
}
}
impl WebUi {
pub fn new() -> Self {
Self {
maps: Vec::new(),
auth: None,
title: None,
}
}
pub fn add<T: RegisterMapInfo + Send + 'static>(
mut self,
slug: &str,
regs: Arc<Mutex<T>>,
) -> Self {
assert!(
!slug.is_empty()
&& slug
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-'),
"slug must be non-empty ASCII [a-zA-Z0-9_-], got: {slug:?}"
);
self.maps.push((slug.to_owned(), regs as DynMap));
self
}
pub fn with_auth<F, Fut>(mut self, check: F) -> Self
where
F: Fn(String, String) -> Fut + Send + Sync + 'static,
Fut: Future<Output = bool> + Send + 'static,
{
self.auth = Some(Arc::new(move |u, p| Box::pin(check(u, p)) as AuthFuture));
self
}
pub fn with_title(mut self, title: impl Into<String>) -> Self {
self.title = Some(title.into());
self
}
pub fn build(self) -> Router {
let state = WebUiState {
maps: self.maps,
auth: self.auth,
title: self.title,
};
let api = Router::new()
.route("/maps", get(api_list))
.route("/{slug}/info", get(api_info))
.route("/{slug}/read", post(api_read))
.route("/{slug}/write", post(api_write));
Router::new()
.route("/", get(index_page))
.nest("/api", api)
.layer(middleware::from_fn_with_state(
state.clone(),
auth_middleware,
))
.with_state(state)
}
}
async fn auth_middleware(
State(state): State<WebUiState>,
req: Request<Body>,
next: Next,
) -> Response {
if let Some(check) = state.auth.clone() {
let creds = extract_basic_credentials(&req);
let allowed = match creds {
Some((u, p)) => check(u, p).await,
None => false,
};
if !allowed {
return unauthorized_response();
}
}
next.run(req).await
}
#[derive(Serialize)]
struct MapEntry {
slug: String,
name: String,
}
#[derive(Serialize)]
struct MapList {
#[serde(skip_serializing_if = "Option::is_none")]
title: Option<String>,
maps: Vec<MapEntry>,
}
async fn api_list(State(state): State<WebUiState>) -> Json<MapList> {
let mut entries = Vec::with_capacity(state.maps.len());
for (slug, regs) in &state.maps {
let regs = regs.lock().await;
entries.push(MapEntry {
slug: slug.clone(),
name: regs.map_name().to_owned(),
});
}
Json(MapList {
title: state.title.clone(),
maps: entries,
})
}
fn find_map<'a>(maps: &'a [(String, DynMap)], slug: &str) -> Result<&'a DynMap, StatusCode> {
maps.iter()
.find(|(s, _)| *s == slug)
.map(|(_, regs)| regs)
.ok_or(StatusCode::NOT_FOUND)
}
async fn api_info(
State(state): State<WebUiState>,
axum::extract::Path(slug): axum::extract::Path<String>,
) -> Result<Json<RegisterMapDescription>, StatusCode> {
let regs = find_map(&state.maps, &slug)?;
let regs = regs.lock().await;
Ok(Json(RegisterMapDescription {
name: regs.map_name(),
bus_width: regs.bus_width(),
base_address: regs.base_address(),
registers: regs.registers(),
}))
}
async fn api_read(
State(state): State<WebUiState>,
axum::extract::Path(slug): axum::extract::Path<String>,
Json(req): Json<ReadReq>,
) -> Result<Json<ReadResp>, StatusCode> {
let regs = find_map(&state.maps, &slug)?;
let regs = regs.lock().await;
regs.read_register(req.offset)
.map(|value| Json(ReadResp { value }))
.ok_or(StatusCode::BAD_REQUEST)
}
async fn api_write(
State(state): State<WebUiState>,
axum::extract::Path(slug): axum::extract::Path<String>,
Json(req): Json<WriteReq>,
) -> Result<StatusCode, StatusCode> {
let regs = find_map(&state.maps, &slug)?;
let mut regs = regs.lock().await;
regs.write_register(req.offset, req.value)
.map(|()| StatusCode::OK)
.ok_or(StatusCode::BAD_REQUEST)
}