#![warn(missing_docs)]
pub use async_trait::async_trait;
use bytes::Bytes;
use serde_json::Value as JsonValue;
use std::any::{Any, TypeId};
use std::collections::HashMap;
pub use actus_controller_macros::{app_routes, controller};
pub use actus_reply::prelude::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Verb {
GET,
POST,
PUT,
DELETE,
PATCH,
HEAD,
OPTIONS,
}
impl Verb {
pub fn as_str(&self) -> &'static str {
match self {
Verb::GET => "GET",
Verb::POST => "POST",
Verb::PUT => "PUT",
Verb::DELETE => "DELETE",
Verb::PATCH => "PATCH",
Verb::HEAD => "HEAD",
Verb::OPTIONS => "OPTIONS",
}
}
}
impl core::fmt::Display for Verb {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(self.as_str())
}
}
pub const DEFAULT_VERBS: &[Verb] = &[Verb::GET, Verb::POST];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ControllerMode {
Strict,
Lax,
}
#[derive(Debug, Clone, Copy)]
pub enum ParamType {
String,
Int,
U64,
U32,
F64,
Bool,
StringArray,
Json,
Bytes,
}
#[derive(Debug, Clone)]
pub enum ParamDefault {
String(&'static str),
Int(i64),
U64(u64),
U32(u32),
F64(f64),
Bool(bool),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParamSource {
Path,
Query,
Body,
}
#[derive(Debug, Clone)]
pub struct ParamDef {
pub name: &'static str,
pub ty: ParamType,
pub source: ParamSource,
pub default: Option<ParamDefault>,
}
#[derive(Debug, Clone)]
pub struct RouteDef {
pub pattern: &'static str,
pub handler_id: &'static str,
pub handler: &'static str,
pub verb: &'static [Verb],
pub params: &'static [ParamDef],
pub doc: Option<&'static str>,
}
pub struct Params {
verb: Verb,
query: HashMap<String, Vec<String>>,
body: Option<JsonValue>,
raw_body: Bytes,
headers: HashMap<String, Vec<String>>,
extensions: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
}
impl Params {
pub fn new(
verb: Verb,
query: HashMap<String, Vec<String>>,
body: Option<JsonValue>,
raw_body: Bytes,
headers: HashMap<String, Vec<String>>,
) -> Self {
Self {
verb,
query,
body,
raw_body,
headers,
extensions: HashMap::new(),
}
}
fn first(&self, name: &str) -> Option<&str> {
self.query
.get(name)
.and_then(|values| values.first())
.map(String::as_str)
}
pub fn query(&self) -> &HashMap<String, Vec<String>> {
&self.query
}
pub fn body_bytes(&self) -> &Bytes {
&self.raw_body
}
pub fn insert<T: Any + Send + Sync>(&mut self, value: T) -> Option<T> {
self.extensions
.insert(TypeId::of::<T>(), Box::new(value))
.and_then(|prev| prev.downcast::<T>().ok().map(|b| *b))
}
pub fn get<T: Any + Send + Sync>(&self) -> Option<&T> {
self.extensions
.get(&TypeId::of::<T>())
.and_then(|b| b.downcast_ref::<T>())
}
pub fn verb(&self) -> Verb {
self.verb
}
pub fn header(&self, name: &str) -> Option<&str> {
self.headers
.get(&name.to_ascii_lowercase())
.and_then(|values| values.first())
.map(String::as_str)
}
pub fn header_all(&self, name: &str) -> &[String] {
self.headers
.get(&name.to_ascii_lowercase())
.map(Vec::as_slice)
.unwrap_or(&[])
}
pub fn bearer_token(&self) -> Option<&str> {
let auth = self.header("authorization")?;
auth.strip_prefix("Bearer ")
.or_else(|| auth.strip_prefix("bearer "))
}
pub fn require(&self, name: &str) -> Result<&str, WebError> {
self.first(name)
.ok_or_else(|| WebError::BadRequest(format!("Missing required parameter: {}", name)))
}
pub fn get_optional(&self, name: &str) -> Option<&str> {
self.first(name)
}
pub fn get_int(&self, name: &str) -> Result<i64, WebError> {
self.require(name)?
.parse()
.map_err(|_| WebError::BadRequest(format!("Invalid integer: {}", name)))
}
pub fn get_int_optional(&self, name: &str) -> Result<Option<i64>, WebError> {
match self.get_optional(name) {
Some(s) => s
.parse()
.map(Some)
.map_err(|_| WebError::BadRequest(format!("Invalid integer: {}", name))),
None => Ok(None),
}
}
pub fn get_u64(&self, name: &str) -> Result<u64, WebError> {
self.require(name)?
.parse()
.map_err(|_| WebError::BadRequest(format!("Invalid u64: {}", name)))
}
pub fn get_u64_optional(&self, name: &str) -> Result<Option<u64>, WebError> {
match self.get_optional(name) {
Some(s) => s
.parse()
.map(Some)
.map_err(|_| WebError::BadRequest(format!("Invalid u64: {}", name))),
None => Ok(None),
}
}
pub fn get_u32(&self, name: &str) -> Result<u32, WebError> {
self.require(name)?
.parse()
.map_err(|_| WebError::BadRequest(format!("Invalid u32: {}", name)))
}
pub fn get_u32_optional(&self, name: &str) -> Result<Option<u32>, WebError> {
match self.get_optional(name) {
Some(s) => s
.parse()
.map(Some)
.map_err(|_| WebError::BadRequest(format!("Invalid u32: {}", name))),
None => Ok(None),
}
}
pub fn get_f64(&self, name: &str) -> Result<f64, WebError> {
self.require(name)?
.parse()
.map_err(|_| WebError::BadRequest(format!("Invalid float: {}", name)))
}
pub fn get_f64_optional(&self, name: &str) -> Result<Option<f64>, WebError> {
match self.get_optional(name) {
Some(s) => s
.parse()
.map(Some)
.map_err(|_| WebError::BadRequest(format!("Invalid float: {}", name))),
None => Ok(None),
}
}
pub fn get_bool(&self, name: &str) -> bool {
self.first(name)
.map(|s| !s.is_empty() && s != "false" && s != "0")
.unwrap_or(false)
}
pub fn get_bool_optional(&self, name: &str) -> Option<bool> {
self.first(name)
.map(|s| !s.is_empty() && s != "false" && s != "0")
}
pub fn get_all(&self, name: &str) -> Result<Vec<String>, WebError> {
Ok(self.query.get(name).cloned().unwrap_or_default())
}
pub fn get_all_optional(&self, name: &str) -> Option<Vec<String>> {
self.query.get(name).cloned()
}
pub fn json_body(&self) -> Result<JsonValue, WebError> {
self.body
.clone()
.ok_or_else(|| WebError::BadRequest("Missing JSON body".to_string()))
}
pub fn check_unexpected(&self, expected: &[&str]) -> Option<Vec<String>> {
let unexpected: Vec<String> = self
.query
.keys()
.filter(|k| !expected.contains(&k.as_str()))
.cloned()
.collect();
if unexpected.is_empty() {
None
} else {
Some(unexpected)
}
}
}
#[derive(Debug)]
pub struct ExtractedParams {
path: HashMap<String, String>,
query: HashMap<String, Vec<String>>,
body: Option<JsonValue>,
raw_body: Bytes,
}
impl ExtractedParams {
fn scalar(&self, name: &str) -> Option<&str> {
self.path.get(name).map(String::as_str).or_else(|| {
self.query
.get(name)
.and_then(|values| values.first())
.map(String::as_str)
})
}
fn require_scalar(&self, name: &str) -> Result<&str, WebError> {
self.scalar(name)
.ok_or_else(|| WebError::BadRequest(format!("Missing parameter: {}", name)))
}
pub fn get_string(&self, name: &str) -> Result<String, WebError> {
self.require_scalar(name).map(str::to_string)
}
pub fn get_i64(&self, name: &str) -> Result<i64, WebError> {
self.require_scalar(name)?
.parse()
.map_err(|_| WebError::BadRequest(format!("Invalid integer: {}", name)))
}
pub fn get_u64(&self, name: &str) -> Result<u64, WebError> {
self.require_scalar(name)?
.parse()
.map_err(|_| WebError::BadRequest(format!("Invalid u64: {}", name)))
}
pub fn get_u32(&self, name: &str) -> Result<u32, WebError> {
self.require_scalar(name)?
.parse()
.map_err(|_| WebError::BadRequest(format!("Invalid u32: {}", name)))
}
pub fn get_f64(&self, name: &str) -> Result<f64, WebError> {
self.require_scalar(name)?
.parse()
.map_err(|_| WebError::BadRequest(format!("Invalid float: {}", name)))
}
pub fn get_bool(&self, name: &str) -> Result<bool, WebError> {
Ok(self
.scalar(name)
.map(|s| !s.is_empty() && s != "false" && s != "0")
.unwrap_or(false))
}
pub fn get_string_array(&self, name: &str) -> Result<Vec<String>, WebError> {
Ok(self.query.get(name).cloned().unwrap_or_default())
}
pub fn get_json_body(&self) -> Result<JsonValue, WebError> {
self.body
.clone()
.ok_or_else(|| WebError::BadRequest("Missing JSON body".to_string()))
}
pub fn get_body_bytes(&self) -> Bytes {
self.raw_body.clone()
}
}
pub mod routing {
use super::*;
use std::collections::HashMap;
#[inline]
pub fn resolve<'a>(
routes: &'a [RouteDef],
action: &str,
params: &Params,
mode: ControllerMode,
) -> Result<(&'a RouteDef, ExtractedParams), WebError> {
let mut allowed_methods: Vec<&'static str> = Vec::new();
for route in routes {
let path_params = match match_pattern(route.pattern, action) {
Some(p) => p,
None => continue,
};
if !route.verb.contains(¶ms.verb()) {
for v in route.verb {
let token = v.as_str();
if !allowed_methods.contains(&token) {
allowed_methods.push(token);
}
}
continue;
}
{
let mut extracted = ExtractedParams {
path: path_params,
query: HashMap::new(),
body: params.body.clone(),
raw_body: params.raw_body.clone(),
};
for param_def in route.params {
match param_def.source {
ParamSource::Path => {
debug_assert!(
extracted.path.contains_key(param_def.name),
"path parameter `{}` not captured by pattern",
param_def.name
);
}
ParamSource::Query => {
if let Some(value) = params.query.get(param_def.name) {
extracted
.query
.insert(param_def.name.to_string(), value.clone());
} else if param_def.default.is_none()
&& !matches!(param_def.ty, ParamType::StringArray)
{
return Err(WebError::BadRequest(format!(
"Missing required parameter: {}",
param_def.name
)));
}
}
ParamSource::Body => {
}
}
}
if mode == ControllerMode::Strict {
let expected: Vec<&str> = route
.params
.iter()
.filter(|p| p.source == ParamSource::Query)
.map(|p| p.name)
.collect();
if let Some(unexpected) = params.check_unexpected(&expected) {
return Err(WebError::BadRequest(format!(
"Unexpected parameters: {}",
unexpected.join(", ")
)));
}
}
return Ok((route, extracted));
}
}
if allowed_methods.is_empty() {
Err(WebError::NotFound)
} else {
allowed_methods.sort_unstable();
allowed_methods.dedup();
Err(WebError::MethodNotAllowed(allowed_methods))
}
}
fn rest_token_name(segment: &str) -> Option<&str> {
segment
.strip_prefix("{...")
.and_then(|s| s.strip_suffix('}'))
.filter(|name| !name.is_empty())
}
fn match_fixed_segment(
pattern_part: &str,
path_part: &str,
params: &mut HashMap<String, String>,
) -> bool {
if let Some(param_name) = pattern_part
.strip_prefix('{')
.and_then(|s| s.strip_suffix('}'))
{
params.insert(param_name.to_string(), path_part.to_string());
true
} else {
pattern_part == path_part
}
}
fn segments(s: &str) -> Vec<&str> {
s.split('/').filter(|seg| !seg.is_empty()).collect()
}
#[inline]
pub fn match_pattern(pattern: &str, path: &str) -> Option<HashMap<String, String>> {
let pattern_parts = segments(pattern);
let path_parts = segments(path);
let mut params = HashMap::new();
if let Some(rest_name) = pattern_parts.last().and_then(|s| rest_token_name(s)) {
let fixed = &pattern_parts[..pattern_parts.len() - 1];
if path_parts.len() < fixed.len() {
return None;
}
for (pattern_part, path_part) in fixed.iter().zip(path_parts.iter()) {
if !match_fixed_segment(pattern_part, path_part, &mut params) {
return None;
}
}
params.insert(rest_name.to_string(), path_parts[fixed.len()..].join("/"));
return Some(params);
}
if pattern_parts.len() != path_parts.len() {
return None;
}
for (pattern_part, path_part) in pattern_parts.iter().zip(path_parts.iter()) {
if !match_fixed_segment(pattern_part, path_part, &mut params) {
return None;
}
}
Some(params)
}
}
#[async_trait]
pub trait Controller: Send + Sync {
async fn actus_dispatch(&self, action: &str, params: Params) -> Reply;
fn __name(&self) -> &'static str;
fn actus_describe_routes(&self) -> Vec<RouteDef> {
vec![]
}
fn actus_max_body_bytes(&self) -> Option<usize> {
None
}
fn actus_rate_limit(&self) -> Option<&'static str> {
None
}
}
pub type Routes = Vec<(
&'static str,
Box<dyn Fn() -> Box<dyn Controller> + Send + Sync>,
)>;
#[macro_export]
macro_rules! routes {
($($tokens:tt)*) => {};
}
#[cfg(test)]
mod match_pattern_tests {
use super::routing::match_pattern;
fn cap(pattern: &str, path: &str) -> Option<Vec<(String, String)>> {
match_pattern(pattern, path).map(|m| {
let mut v: Vec<_> = m.into_iter().collect();
v.sort();
v
})
}
fn pair(k: &str, v: &str) -> (String, String) {
(k.to_string(), v.to_string())
}
#[test]
fn fixed_patterns_still_work() {
assert_eq!(cap("", ""), Some(vec![]));
assert_eq!(cap("{id}", "42"), Some(vec![pair("id", "42")]));
assert_eq!(
cap("posts/{id}/comments", "posts/3/comments"),
Some(vec![pair("id", "3")])
);
assert_eq!(cap("a/b", "a/b/c"), None);
assert_eq!(cap("a/b/c", "a/b"), None);
assert_eq!(cap("posts/{id}", "users/3"), None);
}
#[test]
fn required_segments_dont_match_the_empty_action() {
assert_eq!(cap("{id}", ""), None);
assert_eq!(cap("posts", ""), None);
assert_eq!(cap("{a}/{b}", "x"), None);
assert_eq!(cap("", ""), Some(vec![]));
assert_eq!(cap("", "x"), None);
}
#[test]
fn rest_param_captures_remainder() {
assert_eq!(
cap("{folder_id}/{...path}", "abc/x/y/z"),
Some(vec![pair("folder_id", "abc"), pair("path", "x/y/z")])
);
assert_eq!(
cap("{folder_id}/{...path}", "abc"),
Some(vec![pair("folder_id", "abc"), pair("path", "")])
);
assert_eq!(cap("{folder_id}/{...path}", ""), None);
}
#[test]
fn rest_param_as_sole_token() {
assert_eq!(cap("{...path}", "a/b/c"), Some(vec![pair("path", "a/b/c")]));
assert_eq!(cap("{...path}", "a"), Some(vec![pair("path", "a")]));
assert_eq!(cap("{...path}", ""), Some(vec![pair("path", "")]));
}
#[test]
fn rest_param_after_literal_prefix() {
assert_eq!(
cap("files/{...path}", "files/x/y"),
Some(vec![pair("path", "x/y")])
);
assert_eq!(
cap("files/{...path}", "files"),
Some(vec![pair("path", "")])
);
assert_eq!(cap("files/{...path}", "other/x"), None);
assert_eq!(cap("a/b/{...path}", "a"), None);
}
}
#[cfg(test)]
mod resolve_tests {
use super::routing::resolve;
use super::*;
use bytes::Bytes;
use std::collections::HashMap;
fn params_with(verb: Verb, query: HashMap<String, Vec<String>>) -> Params {
Params::new(verb, query, None, Bytes::new(), HashMap::new())
}
#[test]
fn headers_are_a_multimap_first_value_wins_for_scalar_access() {
let mut headers = HashMap::new();
headers.insert(
"forwarded".to_string(),
vec!["for=1.2.3.4".to_string(), "for=10.0.0.1".to_string()],
);
headers.insert("x-trace-id".to_string(), vec!["abc-123".to_string()]);
let p = Params::new(Verb::GET, HashMap::new(), None, Bytes::new(), headers);
assert_eq!(p.header("Forwarded"), Some("for=1.2.3.4"));
assert_eq!(p.header("FORWARDED"), Some("for=1.2.3.4"));
assert_eq!(p.header_all("Forwarded"), ["for=1.2.3.4", "for=10.0.0.1"]);
assert_eq!(p.header("X-Trace-Id"), Some("abc-123"));
assert_eq!(p.header_all("X-Trace-Id"), ["abc-123"]);
assert_eq!(p.header("Authorization"), None);
assert!(p.header_all("Authorization").is_empty());
}
#[test]
fn params_query_exposes_the_whole_multimap() {
let mut q = HashMap::new();
q.insert("a".to_string(), vec!["1".to_string(), "2".to_string()]);
q.insert("b".to_string(), vec!["3".to_string()]);
let p = params_with(Verb::GET, q);
assert_eq!(p.query().len(), 2);
assert_eq!(
p.query().get("a").unwrap(),
&["1".to_string(), "2".to_string()]
);
assert_eq!(p.get_optional("a"), Some("1"));
}
#[test]
fn verb_mismatch_yields_405_with_sorted_deduped_allow_list() {
static ROUTES: &[RouteDef] = &[
RouteDef {
pattern: "",
handler_id: "create",
handler: "create",
verb: &[Verb::POST],
params: &[],
doc: None,
},
RouteDef {
pattern: "",
handler_id: "list",
handler: "list",
verb: &[Verb::GET],
params: &[],
doc: None,
},
];
match resolve(
ROUTES,
"",
¶ms_with(Verb::PUT, HashMap::new()),
ControllerMode::Strict,
) {
Err(WebError::MethodNotAllowed(methods)) => assert_eq!(methods, ["GET", "POST"]),
other => panic!("expected 405, got {other:?}"),
}
assert!(
resolve(
ROUTES,
"",
¶ms_with(Verb::GET, HashMap::new()),
ControllerMode::Strict
)
.is_ok()
);
}
#[test]
fn no_pattern_match_is_404_not_405() {
static ROUTES: &[RouteDef] = &[RouteDef {
pattern: "items",
handler_id: "h",
handler: "h",
verb: &[Verb::GET],
params: &[],
doc: None,
}];
match resolve(
ROUTES,
"other",
¶ms_with(Verb::DELETE, HashMap::new()),
ControllerMode::Strict,
) {
Err(WebError::NotFound) => {}
other => panic!("expected 404, got {other:?}"),
}
}
#[test]
fn vec_string_query_param_collects_all_values() {
static ROUTES: &[RouteDef] = &[RouteDef {
pattern: "",
handler_id: "h",
handler: "h",
verb: &[Verb::GET],
params: &[ParamDef {
name: "tags",
ty: ParamType::StringArray,
source: ParamSource::Query,
default: None,
}],
doc: None,
}];
let mut q = HashMap::new();
q.insert(
"tags".to_string(),
vec!["a".to_string(), "b".to_string(), "c".to_string()],
);
let (_, extracted) = resolve(
ROUTES,
"",
¶ms_with(Verb::GET, q),
ControllerMode::Strict,
)
.expect("route matches");
assert_eq!(extracted.get_string_array("tags").unwrap(), ["a", "b", "c"]);
let mut q1 = HashMap::new();
q1.insert("tags".to_string(), vec!["solo".to_string()]);
let (_, e1) = resolve(
ROUTES,
"",
¶ms_with(Verb::GET, q1),
ControllerMode::Strict,
)
.expect("route matches");
assert_eq!(e1.get_string_array("tags").unwrap(), ["solo"]);
assert_eq!(e1.get_string("tags").unwrap(), "solo");
let (_, e2) = resolve(
ROUTES,
"",
¶ms_with(Verb::GET, HashMap::new()),
ControllerMode::Strict,
)
.expect("route matches with no query");
assert!(e2.get_string_array("tags").unwrap().is_empty());
}
}