1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
use proc_macro::{self, TokenStream};
use proc_macro2::{Ident, TokenStream as TokenStream2};
use quote::{quote, ToTokens};
use syn::{parse_macro_input, Data, DataEnum, DeriveInput, Error, Fields};

/// Serialize a message with a type prefix, in BOLT style
#[proc_macro_derive(SerBolt, attributes(message_id))]
pub fn derive_ser_bolt(input: TokenStream) -> TokenStream {
    let input1 = input.clone();
    let DeriveInput { ident, attrs, .. } = parse_macro_input!(input1);
    let message_id = attrs
        .into_iter()
        .filter(|a| a.path.is_ident("message_id"))
        .next()
        .map(|a| a.tokens)
        .unwrap_or_else(|| {
            Error::new(ident.span(), "missing message_id attribute").into_compile_error()
        });

    let output = quote! {
        impl SerBolt for #ident {
            fn as_vec(&self) -> Vec<u8> {
                let message_type = Self::TYPE;
                let mut buf = message_type.to_be_bytes().to_vec();
                let mut val_buf = to_vec(&self).expect("serialize");
                buf.append(&mut val_buf);
                buf
            }

            fn name(&self) -> &'static str {
                stringify!(#ident)
            }
        }

        impl DeBolt for #ident {
            const TYPE: u16 = #message_id;
            fn from_vec(mut ser: Vec<u8>) -> Result<Self> {
                let mut cursor = serde_bolt::io::Cursor::new(&ser);
                let message_type = cursor.read_u16_be()?;
                if message_type != Self::TYPE {
                    return Err(Error::UnexpectedType(message_type));
                }
                let res = Decodable::consensus_decode(&mut cursor)?;
                if cursor.position() as usize != ser.len() {
                    return Err(Error::TrailingBytes(cursor.position() as usize - ser.len(), Self::TYPE));
                }
                Ok(res)
            }
        }
    };
    output.into()
}

#[proc_macro_derive(ReadMessage)]
pub fn derive_read_message(input: TokenStream) -> TokenStream {
    let DeriveInput { ident, data, .. } = parse_macro_input!(input);
    let mut vs = Vec::new();
    let mut ts = Vec::new();
    let mut error: Option<Error> = None;

    if let Data::Enum(DataEnum { variants, .. }) = data {
        for v in variants {
            if v.ident == "Unknown" {
                continue;
            }
            let vident = v.ident.clone();
            let field = extract_single_type(&vident, &v.fields);
            match field {
                Ok(f) => {
                    vs.push(vident);
                    ts.push(f);
                }
                Err(e) => match error.as_mut() {
                    None => error = Some(e),
                    Some(o) => o.combine(e),
                },
            }
        }
    } else {
        unimplemented!()
    }

    if let Some(error) = error {
        return error.into_compile_error().into();
    }

    let output = quote! {
        impl #ident {
            fn read_message<R: Read + ?Sized>(mut reader: &mut R, message_type: u16) -> Result<Message> {
                let message = match message_type {
                    #(#vs::TYPE => Message::#ts(Decodable::consensus_decode(reader)?)),*,
                    _ => Message::Unknown(Unknown { message_type }),
                };
                Ok(message)
            }

            pub fn inner(&self) -> alloc::boxed::Box<&dyn SerBolt> {
                match self {
                    #(#ident::#vs(inner) => alloc::boxed::Box::new(inner)),*,
                    _ => alloc::boxed::Box::new(&UNKNOWN_PLACEHOLDER),
                }
            }
        }
    };

    output.into()
}

fn extract_single_type(vident: &Ident, fields: &Fields) -> Result<TokenStream2, Error> {
    let mut fields = fields.iter();
    let field =
        fields.next().ok_or_else(|| Error::new(vident.span(), "must have exactly one field"))?;
    if fields.next().is_some() {
        return Err(Error::new(vident.span(), "must have exactly one field"));
    }
    Ok(field.ty.clone().into_token_stream())
}