use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use thiserror::Error;
pub(crate) struct CompiledTemplate {
pub src_tokens: Vec<TokenStream2>,
pub dst_tokens: Vec<TokenStream2>,
pub static_len: usize,
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub(crate) enum RewriteTemplateError {
#[error("empty parameter `{{}}` in route path")]
EmptyRouteParameter,
#[error("duplicate parameter `{name}` in route path")]
DuplicateRouteParameter { name: String },
#[error("unclosed `{{` in `to`")]
UnclosedToParameter,
#[error("empty parameter `{{}}` in `to`")]
EmptyToParameter,
#[error("parameter `{name}` in `to` not found in route path")]
UnknownToParameter { name: String },
#[error("unmatched `}}` in `to`")]
UnmatchedToCloseBrace,
#[error("too many path parameters in route path (max {max})")]
TooManyRouteParameters { max: usize },
#[error("rewrite capture index overflow")]
CaptureIndexOverflow,
}
enum SrcSegIr {
Lit(String),
Param,
}
enum DstChunkIr {
Lit(String),
Capture(u8),
}
struct RewriteIr {
src: Vec<SrcSegIr>,
dst: Vec<DstChunkIr>,
static_len: usize,
}
pub(crate) fn compile_rewrite_template(
apigate_path: &TokenStream2,
route_path: &str,
to: &str,
) -> Result<CompiledTemplate, RewriteTemplateError> {
let (src, params) = parse_route_source(route_path)?;
let (dst, static_len) = parse_rewrite_target(to, ¶ms)?;
Ok(emit(
apigate_path,
RewriteIr {
src,
dst,
static_len,
},
))
}
fn parse_route_source(
route_path: &str,
) -> Result<(Vec<SrcSegIr>, Vec<String>), RewriteTemplateError> {
const MAX_CAPTURE_PARAMS: usize = u8::MAX as usize + 1;
let mut src = Vec::new();
let mut params = Vec::new();
for seg in route_path.split('/').filter(|s| !s.is_empty()) {
if seg.starts_with('{') && seg.ends_with('}') {
let name = &seg[1..seg.len() - 1];
if name.is_empty() {
return Err(RewriteTemplateError::EmptyRouteParameter);
}
if params.iter().any(|p| p == name) {
return Err(RewriteTemplateError::DuplicateRouteParameter {
name: name.to_owned(),
});
}
if params.len() >= MAX_CAPTURE_PARAMS {
return Err(RewriteTemplateError::TooManyRouteParameters {
max: MAX_CAPTURE_PARAMS,
});
}
params.push(name.to_owned());
src.push(SrcSegIr::Param);
} else {
src.push(SrcSegIr::Lit(seg.to_owned()));
}
}
Ok((src, params))
}
fn parse_rewrite_target(
to: &str,
params: &[String],
) -> Result<(Vec<DstChunkIr>, usize), RewriteTemplateError> {
let mut dst = Vec::new();
let mut static_len = 0usize;
let mut lit_start = 0;
let bytes = to.as_bytes();
let mut i = 0;
while i < bytes.len() {
match bytes[i] {
b'{' => {
if lit_start < i {
let lit = &to[lit_start..i];
static_len += lit.len();
dst.push(DstChunkIr::Lit(lit.to_owned()));
}
let close = to[i + 1..]
.find('}')
.ok_or(RewriteTemplateError::UnclosedToParameter)?;
let name = &to[i + 1..i + 1 + close];
if name.is_empty() {
return Err(RewriteTemplateError::EmptyToParameter);
}
let src_index = params.iter().position(|p| p == name).ok_or_else(|| {
RewriteTemplateError::UnknownToParameter {
name: name.to_owned(),
}
})?;
let src_index = u8::try_from(src_index)
.map_err(|_| RewriteTemplateError::CaptureIndexOverflow)?;
dst.push(DstChunkIr::Capture(src_index));
i = i + 1 + close + 1;
lit_start = i;
}
b'}' => return Err(RewriteTemplateError::UnmatchedToCloseBrace),
_ => i += 1,
}
}
if lit_start < to.len() {
let lit = &to[lit_start..];
static_len += lit.len();
dst.push(DstChunkIr::Lit(lit.to_owned()));
}
Ok((dst, static_len))
}
fn emit(apigate_path: &TokenStream2, ir: RewriteIr) -> CompiledTemplate {
let src_tokens = ir
.src
.into_iter()
.map(|seg| match seg {
SrcSegIr::Lit(lit) => quote!(#apigate_path::SrcSeg::Lit(#lit)),
SrcSegIr::Param => quote!(#apigate_path::SrcSeg::Param),
})
.collect();
let dst_tokens = ir
.dst
.into_iter()
.map(|chunk| match chunk {
DstChunkIr::Lit(lit) => quote!(#apigate_path::DstChunk::Lit(#lit)),
DstChunkIr::Capture(idx) => {
quote!(#apigate_path::DstChunk::Capture { src_index: #idx })
}
})
.collect();
CompiledTemplate {
src_tokens,
dst_tokens,
static_len: ir.static_len,
}
}
#[cfg(test)]
mod tests {
use super::*;
use quote::quote;
#[test]
fn compiles_rewrite_template_into_source_and_destination_tokens() {
let compiled = compile_rewrite_template(
"e!(::apigate),
"/sales/{sale_id}/items/{item_id}",
"/internal/{item_id}/sales/{sale_id}",
)
.unwrap();
assert_eq!(compiled.src_tokens.len(), 4);
assert_eq!(compiled.dst_tokens.len(), 4);
assert_eq!(compiled.static_len, "/internal//sales/".len());
let dst = compiled
.dst_tokens
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(" ");
assert!(dst.contains("src_index : 1"));
assert!(dst.contains("src_index : 0"));
}
#[test]
fn rejects_invalid_route_parameters() {
let apigate = quote!(::apigate);
assert_eq!(
compile_rewrite_template(&apigate, "/sales/{}", "/internal")
.err()
.expect("empty route parameter"),
RewriteTemplateError::EmptyRouteParameter
);
assert_eq!(
compile_rewrite_template(&apigate, "/sales/{id}/{id}", "/internal")
.err()
.expect("duplicate route parameter"),
RewriteTemplateError::DuplicateRouteParameter {
name: "id".to_owned()
}
);
}
#[test]
fn rejects_invalid_destination_template() {
let apigate = quote!(::apigate);
assert_eq!(
compile_rewrite_template(&apigate, "/sales/{id}", "/internal/{id")
.err()
.expect("unclosed destination parameter"),
RewriteTemplateError::UnclosedToParameter
);
assert_eq!(
compile_rewrite_template(&apigate, "/sales/{id}", "/internal/{}")
.err()
.expect("empty destination parameter"),
RewriteTemplateError::EmptyToParameter
);
assert_eq!(
compile_rewrite_template(&apigate, "/sales/{id}", "/internal/{missing}")
.err()
.expect("unknown destination parameter"),
RewriteTemplateError::UnknownToParameter {
name: "missing".to_owned()
}
);
assert_eq!(
compile_rewrite_template(&apigate, "/sales/{id}", "/internal/{id}}")
.err()
.expect("unmatched destination close brace"),
RewriteTemplateError::UnmatchedToCloseBrace
);
}
}