use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct NamedRoute {
pub name: &'static str,
pub pattern: &'static str,
}
inventory::collect!(NamedRoute);
#[macro_export]
macro_rules! register_url {
($name:expr, $pattern:expr) => {
$crate::inventory::submit! {
$crate::urls::NamedRoute {
name: $name,
pattern: $pattern,
}
}
};
}
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum ReverseError {
#[error("no URL registered for name `{0}`")]
UnknownName(String),
#[error("URL `{name}` requires placeholder `{{{param}}}` — pass it in `params`")]
MissingParam { name: String, param: String },
#[error("URL `{name}` doesn't have a `{{{param}}}` placeholder")]
UnexpectedParam { name: String, param: String },
#[error("URL `{name}` has a malformed pattern: {detail}")]
MalformedPattern { name: String, detail: String },
}
#[must_use]
pub fn all_routes() -> Vec<&'static NamedRoute> {
inventory::iter::<NamedRoute>.into_iter().collect()
}
#[must_use]
pub fn duplicates() -> Vec<&'static str> {
let mut seen: std::collections::HashMap<&'static str, usize> = std::collections::HashMap::new();
for r in inventory::iter::<NamedRoute> {
*seen.entry(r.name).or_insert(0) += 1;
}
let mut out: Vec<&'static str> = seen
.into_iter()
.filter_map(|(name, count)| (count > 1).then_some(name))
.collect();
out.sort_unstable();
out
}
pub fn reverse(name: &str, params: &HashMap<&str, String>) -> Result<String, ReverseError> {
let route = inventory::iter::<NamedRoute>
.into_iter()
.find(|r| r.name == name)
.ok_or_else(|| ReverseError::UnknownName(name.to_owned()))?;
substitute(name, route.pattern, params)
}
pub fn reverse_owned(name: &str, params: &HashMap<String, String>) -> Result<String, ReverseError> {
let route = inventory::iter::<NamedRoute>
.into_iter()
.find(|r| r.name == name)
.ok_or_else(|| ReverseError::UnknownName(name.to_owned()))?;
let borrowed: HashMap<&str, String> = params
.iter()
.map(|(k, v)| (k.as_str(), v.clone()))
.collect();
substitute(name, route.pattern, &borrowed)
}
fn substitute(
name: &str,
pattern: &str,
params: &HashMap<&str, String>,
) -> Result<String, ReverseError> {
let mut out = String::with_capacity(pattern.len() + 16);
let mut used: HashSet<String> = HashSet::new();
let mut chars = pattern.chars().peekable();
while let Some(c) = chars.next() {
if c != '{' {
out.push(c);
continue;
}
let mut placeholder = String::new();
let mut closed = false;
for nc in chars.by_ref() {
if nc == '}' {
closed = true;
break;
}
placeholder.push(nc);
}
if !closed {
return Err(ReverseError::MalformedPattern {
name: name.to_owned(),
detail: format!("unclosed placeholder starting at `{{{placeholder}`"),
});
}
let key = placeholder.split(':').next_back().unwrap_or(&placeholder);
let value = params.get(key).ok_or_else(|| ReverseError::MissingParam {
name: name.to_owned(),
param: key.to_owned(),
})?;
out.push_str(&crate::url_codec::url_encode(value));
used.insert(key.to_owned());
}
for k in params.keys() {
if !used.contains(*k) {
return Err(ReverseError::UnexpectedParam {
name: name.to_owned(),
param: (*k).to_owned(),
});
}
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
register_url!("__test_home", "/");
register_url!("__test_post_detail", "/posts/{id}");
register_url!("__test_two_args", "/users/{user_id}/posts/{post_id}");
register_url!("__test_typed_placeholder", "/items/{int:id}");
fn params(pairs: &[(&'static str, &str)]) -> HashMap<&'static str, String> {
pairs.iter().map(|(k, v)| (*k, (*v).to_owned())).collect()
}
#[test]
fn reverse_resolves_static_pattern() {
assert_eq!(reverse("__test_home", &HashMap::new()).unwrap(), "/");
}
#[test]
fn reverse_substitutes_single_placeholder() {
let p = params(&[("id", "42")]);
assert_eq!(reverse("__test_post_detail", &p).unwrap(), "/posts/42");
}
#[test]
fn reverse_substitutes_multiple_placeholders() {
let p = params(&[("user_id", "5"), ("post_id", "10")]);
assert_eq!(reverse("__test_two_args", &p).unwrap(), "/users/5/posts/10");
}
#[test]
fn reverse_percent_encodes_param_values() {
let p = params(&[("id", "hello world")]);
let url = reverse("__test_post_detail", &p).unwrap();
assert_eq!(url, "/posts/hello%20world");
}
#[test]
fn reverse_unknown_name_errors() {
let err = reverse("nope_doesnt_exist", &HashMap::new()).unwrap_err();
assert!(matches!(err, ReverseError::UnknownName(ref n) if n == "nope_doesnt_exist"));
}
#[test]
fn reverse_missing_param_errors_with_param_name() {
let err = reverse("__test_post_detail", &HashMap::new()).unwrap_err();
match err {
ReverseError::MissingParam { name, param } => {
assert_eq!(name, "__test_post_detail");
assert_eq!(param, "id");
}
other => panic!("expected MissingParam, got: {other:?}"),
}
}
#[test]
fn reverse_unexpected_param_errors() {
let p = params(&[("id", "1"), ("typo_extra", "x")]);
let err = reverse("__test_post_detail", &p).unwrap_err();
assert!(
matches!(err, ReverseError::UnexpectedParam { ref param, .. } if param == "typo_extra"),
"got: {err:?}"
);
}
#[test]
fn reverse_accepts_axum_style_typed_placeholder() {
let p = params(&[("id", "7")]);
assert_eq!(reverse("__test_typed_placeholder", &p).unwrap(), "/items/7");
}
#[test]
fn reverse_owned_takes_string_keyed_params() {
let mut p: HashMap<String, String> = HashMap::new();
p.insert("id".into(), "99".into());
assert_eq!(
reverse_owned("__test_post_detail", &p).unwrap(),
"/posts/99"
);
}
register_url!("__test_malformed", "/items/{unclosed");
#[test]
fn reverse_malformed_pattern_surfaces_dedicated_error() {
let err = reverse("__test_malformed", &HashMap::new()).unwrap_err();
match err {
ReverseError::MalformedPattern { name, detail } => {
assert_eq!(name, "__test_malformed");
assert!(detail.contains("unclosed"), "detail: {detail}");
}
other => panic!("expected MalformedPattern, got: {other:?}"),
}
}
#[test]
fn all_routes_returns_at_least_registered_test_routes() {
let names: Vec<&str> = all_routes().iter().map(|r| r.name).collect();
for required in [
"__test_home",
"__test_post_detail",
"__test_two_args",
"__test_typed_placeholder",
"__test_malformed",
] {
assert!(names.contains(&required), "missing {required}: {names:?}");
}
}
#[test]
fn duplicates_helper_is_callable() {
let dups = duplicates();
for w in dups.windows(2) {
assert!(w[0] <= w[1], "duplicates() must return sorted: {dups:?}");
}
}
}
#[cfg(feature = "template_views")]
pub fn register_url_tag(tera: &mut tera::Tera) {
tera.register_function("url", url_tag_fn);
}
#[cfg(feature = "template_views")]
fn url_tag_fn(args: &std::collections::HashMap<String, tera::Value>) -> tera::Result<tera::Value> {
let name = match args.get("name") {
Some(tera::Value::String(s)) => s.clone(),
Some(other) => {
return Err(tera::Error::msg(format!(
"url(): `name` must be a string, got: {other:?}"
)));
}
None => return Err(tera::Error::msg("url(): missing required `name` argument")),
};
let mut params: HashMap<String, String> = HashMap::new();
for (k, v) in args {
if k == "name" {
continue;
}
let s = match v {
tera::Value::String(s) => s.clone(),
tera::Value::Number(n) => n.to_string(),
tera::Value::Bool(b) => b.to_string(),
tera::Value::Null => {
return Err(tera::Error::msg(format!(
"url(): argument `{k}` is null — likely an undefined template variable"
)));
}
other => {
return Err(tera::Error::msg(format!(
"url(): argument `{k}` must be a scalar (string / number / bool), got: {other:?}"
)));
}
};
params.insert(k.clone(), s);
}
reverse_owned(&name, ¶ms)
.map(tera::Value::String)
.map_err(|e| tera::Error::msg(e.to_string()))
}
#[cfg(all(test, feature = "template_views"))]
mod tera_tests {
use super::*;
register_url!("__test_tag_home", "/");
register_url!("__test_tag_post", "/posts/{id}");
register_url!("__test_tag_users_posts", "/users/{user_id}/posts/{post_id}");
fn setup() -> tera::Tera {
let mut tera = tera::Tera::default();
register_url_tag(&mut tera);
tera
}
fn render(tera: &tera::Tera, src: &str) -> String {
let mut t = tera.clone();
t.add_raw_template("_", src).unwrap();
t.render("_", &tera::Context::new()).unwrap()
}
#[test]
fn url_tag_resolves_static_route() {
let tera = setup();
assert_eq!(render(&tera, "{{ url(name='__test_tag_home') }}"), "/");
}
#[test]
fn url_tag_substitutes_int_param_via_display() {
let tera = setup();
assert_eq!(
render(&tera, "{{ url(name='__test_tag_post', id=42) }}"),
"/posts/42"
);
}
#[test]
fn url_tag_substitutes_string_param() {
let tera = setup();
assert_eq!(
render(&tera, "{{ url(name='__test_tag_post', id='hello') }}"),
"/posts/hello"
);
}
#[test]
fn url_tag_substitutes_multiple_params() {
let tera = setup();
assert_eq!(
render(
&tera,
"{{ url(name='__test_tag_users_posts', user_id=5, post_id=10) }}"
),
"/users/5/posts/10"
);
}
#[test]
fn url_tag_set_capture_works_via_tera_set() {
let tera = setup();
let src = "{% set u = url(name='__test_tag_post', id=7) %}<a href='{{ u }}'>x</a>";
assert_eq!(render(&tera, src), "<a href='/posts/7'>x</a>");
}
fn full_error_chain(e: &tera::Error) -> String {
use std::error::Error as _;
let mut out = format!("{e}");
let mut cur: Option<&dyn std::error::Error> = e.source();
while let Some(c) = cur {
out.push_str(" | ");
out.push_str(&c.to_string());
cur = c.source();
}
out
}
#[test]
fn url_tag_missing_name_arg_errors() {
let mut tera = setup();
tera.add_raw_template("_", "{{ url(id=1) }}").unwrap();
let err = tera.render("_", &tera::Context::new()).unwrap_err();
let msg = full_error_chain(&err).to_lowercase();
assert!(
msg.contains("name") || msg.contains("url()"),
"expected error about missing `name`, got: {msg}"
);
}
#[test]
fn url_tag_unknown_route_propagates_reverse_error() {
let mut tera = setup();
tera.add_raw_template("_", "{{ url(name='nope_nope_nope') }}")
.unwrap();
let err = tera.render("_", &tera::Context::new()).unwrap_err();
let msg = full_error_chain(&err).to_lowercase();
assert!(
msg.contains("no url registered") || msg.contains("nope_nope_nope"),
"expected unknown-name error, got: {msg}"
);
}
#[test]
fn url_tag_non_string_name_errors_clearly() {
let mut tera = setup();
tera.add_raw_template("_", "{{ url(name=42) }}").unwrap();
let err = tera.render("_", &tera::Context::new()).unwrap_err();
let msg = full_error_chain(&err).to_lowercase();
assert!(
msg.contains("name") && msg.contains("string"),
"expected `name must be a string` error, got: {msg}"
);
}
#[test]
fn url_tag_null_param_errors_instead_of_emitting_empty_segment() {
let mut tera = setup();
let mut ctx = tera::Context::new();
ctx.insert("v", &serde_json::Value::Null);
tera.add_raw_template("_", "{{ url(name='__test_tag_post', id=v) }}")
.unwrap();
let err = tera.render("_", &ctx).unwrap_err();
let msg = full_error_chain(&err).to_lowercase();
assert!(
msg.contains("null") || msg.contains("undefined"),
"expected null/undefined error, got: {msg}"
);
}
}
#[cfg(feature = "template_views")]
pub fn register_querystring_filter(tera: &mut tera::Tera) {
tera.register_filter("querystring", querystring_filter);
}
#[cfg(feature = "template_views")]
fn querystring_filter(
value: &tera::Value,
args: &HashMap<String, tera::Value>,
) -> tera::Result<tera::Value> {
let current = value.as_str().unwrap_or("");
let mut pairs = parse_query_pairs(current);
for (k, v) in args {
if matches!(v, tera::Value::Null) {
pairs.retain(|(pk, _)| pk != k);
continue;
}
let s = match v {
tera::Value::String(s) => s.clone(),
tera::Value::Number(n) => n.to_string(),
tera::Value::Bool(b) => b.to_string(),
other => {
return Err(tera::Error::msg(format!(
"querystring(): argument `{k}` must be a scalar (string / number / bool / null), got: {other:?}"
)));
}
};
let mut found = false;
let mut i = 0;
while i < pairs.len() {
if pairs[i].0 == *k {
if found {
pairs.remove(i);
continue;
}
pairs[i].1 = s.clone();
found = true;
}
i += 1;
}
if !found {
pairs.push((k.clone(), s));
}
}
if pairs.is_empty() {
return Ok(tera::Value::String(String::new()));
}
let encoded: Vec<String> = pairs
.iter()
.map(|(k, v)| {
format!(
"{}={}",
crate::url_codec::url_encode(k),
crate::url_codec::url_encode(v)
)
})
.collect();
Ok(tera::Value::String(format!("?{}", encoded.join("&"))))
}
#[cfg(feature = "template_views")]
fn parse_query_pairs(s: &str) -> Vec<(String, String)> {
let s = s.trim_start_matches('?');
if s.is_empty() {
return Vec::new();
}
s.split('&')
.filter(|chunk| !chunk.is_empty())
.map(|chunk| match chunk.split_once('=') {
Some((k, v)) => (
crate::url_codec::url_decode(k),
crate::url_codec::url_decode(v),
),
None => (crate::url_codec::url_decode(chunk), String::new()),
})
.collect()
}
#[cfg(all(test, feature = "template_views"))]
mod querystring_tests {
use super::*;
fn setup() -> tera::Tera {
let mut tera = tera::Tera::default();
register_querystring_filter(&mut tera);
tera
}
fn render(tera: &tera::Tera, src: &str, ctx: tera::Context) -> String {
let mut t = tera.clone();
t.add_raw_template("_", src).unwrap();
t.render("_", &ctx).unwrap()
}
#[test]
fn empty_input_with_overrides_emits_new_qs() {
let tera = setup();
let mut ctx = tera::Context::new();
ctx.insert("q", "");
assert_eq!(
render(&tera, "{{ q | querystring(page=2) | safe }}", ctx),
"?page=2"
);
}
#[test]
fn empty_input_with_no_overrides_emits_empty_string() {
let tera = setup();
let mut ctx = tera::Context::new();
ctx.insert("q", "");
assert_eq!(render(&tera, "{{ q | querystring() | safe }}", ctx), "");
}
#[test]
fn override_replaces_existing_key() {
let tera = setup();
let mut ctx = tera::Context::new();
ctx.insert("q", "page=1");
assert_eq!(
render(&tera, "{{ q | querystring(page=2) | safe }}", ctx),
"?page=2"
);
}
#[test]
fn override_preserves_other_keys_and_position() {
let tera = setup();
let mut ctx = tera::Context::new();
ctx.insert("q", "q=hello&page=1&sort=asc");
let out = render(&tera, "{{ q | querystring(page=2) | safe }}", ctx);
assert_eq!(out, "?q=hello&page=2&sort=asc");
}
#[test]
fn override_collapses_duplicate_existing_keys() {
let tera = setup();
let mut ctx = tera::Context::new();
ctx.insert("q", "tag=a&tag=b&tag=c");
assert_eq!(
render(&tera, "{{ q | querystring(tag='x') | safe }}", ctx),
"?tag=x"
);
}
#[test]
fn override_appends_new_key() {
let tera = setup();
let mut ctx = tera::Context::new();
ctx.insert("q", "q=hello");
assert_eq!(
render(&tera, "{{ q | querystring(filter='active') | safe }}", ctx),
"?q=hello&filter=active"
);
}
#[test]
fn null_override_removes_key() {
let tera = setup();
let mut ctx = tera::Context::new();
ctx.insert("q", "q=hello&filter=active");
ctx.insert("v", &serde_json::Value::Null);
assert_eq!(
render(&tera, "{{ q | querystring(filter=v) | safe }}", ctx),
"?q=hello"
);
}
#[test]
fn percent_encodes_special_chars_in_output() {
let tera = setup();
let mut ctx = tera::Context::new();
ctx.insert("q", "");
let out = render(
&tera,
"{{ q | querystring(name='hello world', special='a/b?c') | safe }}",
ctx,
);
assert!(out.contains("hello%20world"), "got: {out}");
assert!(out.contains("a%2Fb%3Fc"), "got: {out}");
}
#[test]
fn input_with_leading_question_mark_is_stripped() {
let tera = setup();
let mut ctx = tera::Context::new();
ctx.insert("q", "?q=hello&page=1");
assert_eq!(
render(&tera, "{{ q | querystring(page=2) | safe }}", ctx),
"?q=hello&page=2"
);
}
#[test]
fn bool_and_number_args_stringify_via_display() {
let tera = setup();
let mut ctx = tera::Context::new();
ctx.insert("q", "");
let out = render(
&tera,
"{{ q | querystring(page=2, active=true) | safe }}",
ctx,
);
assert!(out.contains("page=2"), "got: {out}");
assert!(out.contains("active=true"), "got: {out}");
}
#[test]
fn parse_pairs_handles_trailing_ampersand_and_empty_chunks() {
let pairs = parse_query_pairs("?q=hello&&page=1&");
assert_eq!(
pairs,
vec![
("q".to_owned(), "hello".to_owned()),
("page".to_owned(), "1".to_owned()),
]
);
}
#[test]
fn parse_pairs_percent_decodes_keys_and_values() {
let pairs = parse_query_pairs("q=hello%20world&page=1");
assert_eq!(pairs[0].1, "hello world");
}
}