#![deny(clippy::all)]
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use axum::{
extract::{Path, State},
http::{header, HeaderValue, StatusCode},
response::{IntoResponse, Response},
routing::{delete, get, patch, post},
Router,
};
use serde_json::Value;
use tokio::net::TcpListener;
pub const KNOWN_IDS: &[&str] = &["1", "42", "100"];
const TAKEN_EMAILS: &[&str] = &["alice@example.com", "bob@example.com"];
const VALID_STATES: &[&str] = &["active", "pending", "suspended"];
const RATE_LIMIT_THRESHOLD: u32 = 3;
const FULL_ACCESS_TOKEN: &str = "Bearer full-access";
fn is_known(id: &str) -> bool {
KNOWN_IDS.contains(&id)
}
type RateCounter = Arc<Mutex<HashMap<String, u32>>>;
async fn state_machine(
Path(id): Path<String>,
body: Option<axum::extract::Json<Value>>,
) -> Response {
if !is_known(&id) {
return StatusCode::NOT_FOUND.into_response();
}
let Some(axum::extract::Json(json)) = body else {
return StatusCode::BAD_REQUEST.into_response();
};
match json.get("status").and_then(Value::as_str) {
Some(s) if VALID_STATES.contains(&s) => StatusCode::OK.into_response(),
Some(_) => StatusCode::CONFLICT.into_response(),
None => StatusCode::BAD_REQUEST.into_response(),
}
}
async fn unique(Path(id): Path<String>, body: Option<axum::extract::Json<Value>>) -> Response {
if !is_known(&id) {
return StatusCode::NOT_FOUND.into_response();
}
let Some(axum::extract::Json(json)) = body else {
return StatusCode::BAD_REQUEST.into_response();
};
match json.get("email").and_then(Value::as_str) {
Some(email) if TAKEN_EMAILS.contains(&email) => StatusCode::CONFLICT.into_response(),
Some(_) => StatusCode::CREATED.into_response(),
None => StatusCode::BAD_REQUEST.into_response(),
}
}
async fn dependent(Path(id): Path<String>) -> Response {
if is_known(&id) {
StatusCode::CONFLICT.into_response()
} else {
StatusCode::NOT_FOUND.into_response()
}
}
async fn ratelimited(Path(id): Path<String>, State(counter): State<RateCounter>) -> Response {
if !is_known(&id) {
return StatusCode::NOT_FOUND.into_response();
}
let count = {
let mut map = counter
.lock()
.expect("rate counter mutex must not be poisoned");
let c = map.entry(id).or_insert(0);
*c += 1;
*c
};
if count > RATE_LIMIT_THRESHOLD {
let mut resp = StatusCode::TOO_MANY_REQUESTS.into_response();
resp.headers_mut()
.insert(header::RETRY_AFTER, HeaderValue::from_static("30"));
resp
} else {
StatusCode::OK.into_response()
}
}
async fn headered(Path(id): Path<String>) -> Response {
if !is_known(&id) {
return StatusCode::NOT_FOUND.into_response();
}
let mut resp = StatusCode::OK.into_response();
let h = resp.headers_mut();
h.insert("x-ratelimit-limit", HeaderValue::from_static("100"));
h.insert("x-ratelimit-remaining", HeaderValue::from_static("99"));
resp
}
async fn forbidden(Path(id): Path<String>, req: axum::http::Request<axum::body::Body>) -> Response {
if !is_known(&id) {
return StatusCode::NOT_FOUND.into_response();
}
if req.headers().contains_key(header::AUTHORIZATION) {
StatusCode::FORBIDDEN.into_response()
} else {
let mut resp = StatusCode::UNAUTHORIZED.into_response();
resp.headers_mut().insert(
header::WWW_AUTHENTICATE,
HeaderValue::from_static("Bearer realm=\"test\""),
);
resp
}
}
async fn scoped(Path(id): Path<String>, req: axum::http::Request<axum::body::Body>) -> Response {
if !is_known(&id) {
return StatusCode::NOT_FOUND.into_response();
}
match req
.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
{
Some(token) if token == FULL_ACCESS_TOKEN => StatusCode::OK.into_response(),
Some(_) => StatusCode::FORBIDDEN.into_response(),
None => {
let mut resp = StatusCode::UNAUTHORIZED.into_response();
resp.headers_mut().insert(
header::WWW_AUTHENTICATE,
HeaderValue::from_static("Bearer realm=\"test\""),
);
resp
}
}
}
pub fn router(counter: RateCounter) -> Router {
Router::new()
.route("/state/{id}", patch(state_machine).put(state_machine))
.route("/unique/{id}", post(unique).put(unique))
.route("/dependent/{id}", delete(dependent))
.route("/ratelimited/{id}", get(ratelimited).head(ratelimited))
.route("/headered/{id}", get(headered).head(headered))
.route("/forbidden/{id}", get(forbidden).head(forbidden))
.route("/scoped/{id}", get(scoped).head(scoped))
.with_state(counter)
}
pub async fn spawn() -> SocketAddr {
let counter: RateCounter = Arc::new(Mutex::new(HashMap::new()));
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("failed to bind elicit server listener");
let addr = listener
.local_addr()
.expect("failed to read local_addr from listener");
tokio::spawn(async move {
axum::serve(listener, router(counter))
.await
.expect("elicit server failed");
});
addr
}
pub async fn spawn_on(port: u16) -> SocketAddr {
let counter: RateCounter = Arc::new(Mutex::new(HashMap::new()));
let listener = TcpListener::bind(("127.0.0.1", port))
.await
.expect("failed to bind elicit server listener");
let addr = listener
.local_addr()
.expect("failed to read local_addr from listener");
tokio::spawn(async move {
axum::serve(listener, router(counter))
.await
.expect("elicit server failed");
});
addr
}