#![forbid(unsafe_code)]
#![warn(missing_docs)]
use proc_macro::{Delimiter, Spacing, TokenStream, TokenTree};
#[proc_macro_derive(CanonicalEncode)]
pub fn derive_canonical_encode(input: TokenStream) -> TokenStream {
match Parsed::from_input(input) {
Ok(parsed) => parsed.canonical_encode_impl(),
Err(message) => compile_error(&message),
}
}
#[proc_macro_derive(CanonicalDecode)]
pub fn derive_canonical_decode(input: TokenStream) -> TokenStream {
match Parsed::from_input(input) {
Ok(parsed) => parsed.canonical_decode_impl(),
Err(message) => compile_error(&message),
}
}
enum Shape {
Named(Vec<String>),
Tuple(usize),
Unit,
}
struct Parsed {
name: String,
shape: Shape,
}
impl Parsed {
fn from_input(input: TokenStream) -> Result<Self, String> {
let tokens: Vec<TokenTree> = input.into_iter().collect();
let mut idx = 0;
loop {
match tokens.get(idx) {
Some(TokenTree::Ident(ident)) => {
let kw = ident.to_string();
match kw.as_str() {
"struct" => break,
"enum" => {
return Err("reliakit-derive does not support enums yet".into());
}
"union" => {
return Err("reliakit-derive does not support unions".into());
}
_ => idx += 1,
}
}
Some(_) => idx += 1,
None => return Err("reliakit-derive: expected a struct".into()),
}
}
idx += 1;
let name = match tokens.get(idx) {
Some(TokenTree::Ident(ident)) => ident.to_string(),
_ => return Err("reliakit-derive: expected a type name after `struct`".into()),
};
idx += 1;
if let Some(TokenTree::Punct(punct)) = tokens.get(idx) {
if punct.as_char() == '<' {
return Err("reliakit-derive does not support generic types yet".into());
}
}
let shape = match tokens.get(idx) {
Some(TokenTree::Group(group)) => match group.delimiter() {
Delimiter::Brace => Shape::Named(named_fields(group.stream())),
Delimiter::Parenthesis => Shape::Tuple(count_fields(group.stream())),
_ => return Err("reliakit-derive: unexpected struct body".into()),
},
Some(TokenTree::Punct(punct)) if punct.as_char() == ';' => Shape::Unit,
_ => return Err("reliakit-derive: unexpected struct body".into()),
};
Ok(Self { name, shape })
}
fn canonical_encode_impl(&self) -> TokenStream {
let mut body = String::new();
match &self.shape {
Shape::Named(fields) => {
for field in fields {
body.push_str(&format!(
"::reliakit_codec::CanonicalEncode::encode(&self.{field}, __writer)?;",
));
}
}
Shape::Tuple(count) => {
for index in 0..*count {
body.push_str(&format!(
"::reliakit_codec::CanonicalEncode::encode(&self.{index}, __writer)?;",
));
}
}
Shape::Unit => {}
}
format!(
"impl ::reliakit_codec::CanonicalEncode for {name} {{\n\
fn encode<__W: ::reliakit_codec::EncodeSink + ?Sized>(&self, __writer: &mut __W) \
-> ::core::result::Result<(), ::reliakit_codec::CodecError> {{\n\
{body}\n\
::core::result::Result::Ok(())\n\
}}\n\
}}",
name = self.name,
)
.parse()
.expect("reliakit-derive generated invalid CanonicalEncode tokens")
}
fn canonical_decode_impl(&self) -> TokenStream {
let construct = match &self.shape {
Shape::Named(fields) => {
let mut inner = String::new();
for field in fields {
inner.push_str(&format!(
"{field}: ::reliakit_codec::CanonicalDecode::decode(__reader)?,",
));
}
format!("Self {{ {inner} }}")
}
Shape::Tuple(count) => {
let mut inner = String::new();
for _ in 0..*count {
inner.push_str("::reliakit_codec::CanonicalDecode::decode(__reader)?,");
}
format!("Self({inner})")
}
Shape::Unit => "Self".to_string(),
};
format!(
"impl ::reliakit_codec::CanonicalDecode for {name} {{\n\
fn decode<__R: ::reliakit_codec::DecodeSource + ?Sized>(__reader: &mut __R) \
-> ::core::result::Result<Self, ::reliakit_codec::CodecError> {{\n\
::core::result::Result::Ok({construct})\n\
}}\n\
}}",
name = self.name,
)
.parse()
.expect("reliakit-derive generated invalid CanonicalDecode tokens")
}
}
fn named_fields(stream: TokenStream) -> Vec<String> {
let mut fields = Vec::new();
for segment in top_level_segments(stream) {
for window in segment.windows(2) {
if let (TokenTree::Ident(ident), TokenTree::Punct(punct)) = (&window[0], &window[1]) {
if punct.as_char() == ':' && punct.spacing() == Spacing::Alone {
fields.push(ident.to_string());
break;
}
}
}
}
fields
}
fn count_fields(stream: TokenStream) -> usize {
top_level_segments(stream)
.into_iter()
.filter(|segment| !segment.is_empty())
.count()
}
fn top_level_segments(stream: TokenStream) -> Vec<Vec<TokenTree>> {
let mut segments = Vec::new();
let mut current = Vec::new();
for token in stream {
match &token {
TokenTree::Punct(punct) if punct.as_char() == ',' => {
segments.push(core::mem::take(&mut current));
}
_ => current.push(token),
}
}
if !current.is_empty() {
segments.push(current);
}
segments
}
fn compile_error(message: &str) -> TokenStream {
format!("::core::compile_error!({message:?});")
.parse()
.expect("compile_error message produced invalid tokens")
}