1use proc_macro::TokenStream;
7use quote::quote;
8use syn::{
9 DeriveInput, Expr, ExprLit, Fields, ItemStruct, Lit, Meta,
10 parse::{Parse, ParseStream, Result},
11 parse_macro_input,
12};
13
14enum SmbMsgType {
15 Request,
16 Response,
17 Both,
18}
19
20impl SmbMsgType {
21 fn get_attr(&self) -> proc_macro2::TokenStream {
25 match self {
26 SmbMsgType::Request => quote! {
27 #[cfg_attr(all(feature = "server", feature = "client"), ::binrw::binrw)]
28 #[cfg_attr(all(feature = "server", not(feature = "client")), ::binrw::binread)]
29 #[cfg_attr(all(not(feature = "server"), feature = "client"), ::binrw::binwrite)]
30 },
31 SmbMsgType::Response => quote! {
32 #[cfg_attr(all(feature = "server", feature = "client"), ::binrw::binrw)]
33 #[cfg_attr(all(feature = "server", not(feature = "client")), ::binrw::binwrite)]
34 #[cfg_attr(all(not(feature = "server"), feature = "client"), ::binrw::binread)]
35 },
36 SmbMsgType::Both => quote! {
37 #[::binrw::binrw]
38 },
39 }
40 }
41}
42
43#[derive(Debug)]
44struct SmbReqResAttr {
45 value: u16,
46}
47
48impl Parse for SmbReqResAttr {
49 fn parse(input: ParseStream) -> Result<Self> {
50 let meta: Meta = input.parse()?;
51
52 match meta {
53 Meta::NameValue(nv) if nv.path.is_ident("size") => {
54 if let Expr::Lit(ExprLit {
55 lit: Lit::Int(lit), ..
56 }) = nv.value
57 {
58 let value: u16 = lit.base10_parse()?;
59 Ok(SmbReqResAttr { value })
60 } else {
61 Err(syn::Error::new_spanned(
62 nv.value,
63 "expected integer literal",
64 ))
65 }
66 }
67 _ => Err(syn::Error::new_spanned(meta, "expected `size = <u16>`")),
68 }
69 }
70}
71
72fn make_size_field(size: u16) -> syn::Field {
73 syn::Field {
78 attrs: vec![
79 syn::parse_quote! {
80 #[bw(calc = #size)]
81 },
82 syn::parse_quote! {
83 #[br(temp)]
84 },
85 syn::parse_quote! {
86 #[br(assert(_structure_size == #size))]
87 },
88 ],
89 vis: syn::Visibility::Inherited,
90 ident: Some(syn::Ident::new(
91 "_structure_size",
92 proc_macro2::Span::call_site(),
93 )),
94 colon_token: Some(syn::token::Colon {
95 spans: [proc_macro2::Span::call_site()],
96 }),
97 ty: syn::parse_quote! { u16 },
98 mutability: syn::FieldMutability::None,
99 }
100}
101
102fn modify_smb_msg(msg_type: SmbMsgType, item: TokenStream, attr: TokenStream) -> TokenStream {
108 let item = common_struct_changes(msg_type, item);
109
110 let mut item = parse_macro_input!(item as ItemStruct);
111 let attr = parse_macro_input!(attr as SmbReqResAttr);
112
113 let size_field = make_size_field(attr.value);
114 match item.fields {
115 Fields::Named(ref mut fields) => {
116 fields.named.insert(0, size_field);
117 }
118 _ => {
119 return syn::Error::new_spanned(
120 &item.fields,
121 "Expected named fields for smb request/response",
122 )
123 .to_compile_error()
124 .into();
125 }
126 }
127
128 TokenStream::from(quote! {
129 #item
130 })
131}
132
133fn common_struct_changes(msg_type: SmbMsgType, item: TokenStream) -> TokenStream {
139 let input = parse_macro_input!(item as DeriveInput);
140
141 let is_struct = matches!(input.data, syn::Data::Struct(_));
142
143 let cfg_attrs = msg_type.get_attr();
144 let output_all = TokenStream::from(quote! {
145 #cfg_attrs
146 #[derive(Debug, PartialEq, Eq)]
147 #input
148 });
149
150 if !is_struct {
151 return output_all;
152 }
153
154 let mut item = parse_macro_input!(output_all as ItemStruct);
155
156 if let Fields::Named(ref mut fields) = item.fields {
157 for field in fields.named.iter_mut() {
158 if field.ident.as_ref().is_some_and(|id| *id == "reserved") {
159 if field.vis != syn::Visibility::Inherited {
160 return syn::Error::new_spanned(
161 &field.vis,
162 "reserved field must have no visibility defined",
163 )
164 .to_compile_error()
165 .into();
166 }
167
168 let line_number = proc_macro2::Span::call_site().start().line;
170 field.ident = Some(syn::Ident::new(
171 &format!("_reserved{}", line_number),
172 proc_macro2::Span::call_site(),
173 ));
174
175 field.attrs.push(syn::parse_quote! {
177 #[br(temp)]
178 });
179
180 let default_bw_calc = if let syn::Type::Array(arr) = &field.ty {
182 let len = arr.len.clone();
183 syn::parse_quote! {
184 #[bw(calc = [0; #len])]
185 }
186 } else {
187 syn::parse_quote! {
188 #[bw(calc = Default::default())]
189 }
190 };
191
192 field.attrs.push(default_bw_calc);
193 }
194 }
195 }
196
197 TokenStream::from(quote! {
198 #item
199 })
200}
201
202#[proc_macro_attribute]
206pub fn smb_request(attr: TokenStream, input: TokenStream) -> TokenStream {
207 modify_smb_msg(SmbMsgType::Request, input, attr)
208}
209
210#[proc_macro_attribute]
214pub fn smb_response(attr: TokenStream, input: TokenStream) -> TokenStream {
215 modify_smb_msg(SmbMsgType::Response, input, attr)
216}
217
218#[proc_macro_attribute]
222pub fn smb_request_response(attr: TokenStream, input: TokenStream) -> TokenStream {
223 modify_smb_msg(SmbMsgType::Both, input, attr)
224}
225
226#[proc_macro_attribute]
230pub fn smb_request_binrw(_attr: TokenStream, input: TokenStream) -> TokenStream {
231 common_struct_changes(SmbMsgType::Request, input)
232}
233
234#[proc_macro_attribute]
238pub fn smb_response_binrw(_attr: TokenStream, input: TokenStream) -> TokenStream {
239 common_struct_changes(SmbMsgType::Response, input)
240}
241
242#[proc_macro_attribute]
246pub fn smb_message_binrw(_attr: TokenStream, input: TokenStream) -> TokenStream {
247 common_struct_changes(SmbMsgType::Both, input)
248}