fix_rs_macros/
lib.rs

1// Copyright 2017 James Bendig. See the COPYRIGHT file at the top-level
2// directory of this distribution.
3//
4// Licensed under:
5//   the MIT license
6//     <LICENSE-MIT or https://opensource.org/licenses/MIT>
7//   or the Apache License, Version 2.0
8//     <LICENSE-APACHE or https://www.apache.org/licenses/LICENSE-2.0>,
9// at your option. This file may not be copied, modified, or distributed
10// except according to those terms.
11
12#![feature(proc_macro)]
13#![crate_type = "proc-macro"]
14#![recursion_limit = "256"]
15
16extern crate proc_macro;
17#[macro_use]
18extern crate quote;
19extern crate syn;
20
21use proc_macro::TokenStream;
22use quote::Tokens;
23
24fn str_to_tokens(input: &str) -> Tokens {
25    let mut tokens = Tokens::new();
26    tokens.append(input);
27    tokens
28}
29
30enum ExtractAttributeError {
31    BodyNotStruct,
32    FieldNotFound,
33    AttributeNotFound,
34    AttributeNotNameValue,
35    AttributeValueWrongType,
36}
37
38fn extract_attribute_value(ast: &syn::DeriveInput,field_ident: &'static str,attr_ident: &'static str) -> Result<syn::Lit,ExtractAttributeError> {
39    if let syn::Body::Struct(ref data) = ast.body {
40        for field in data.fields() {
41            if field.ident.as_ref().expect("Field must have an identifier") != field_ident {
42                continue;
43            }
44
45            for attr in &field.attrs {
46                if attr.name() != attr_ident {
47                    continue;
48                }
49
50                if let syn::MetaItem::NameValue(_,ref lit) = attr.value {
51                    return Ok(lit.clone());
52                }
53                else {
54                    return Err(ExtractAttributeError::AttributeNotNameValue);
55                }
56            }
57
58            return Err(ExtractAttributeError::AttributeNotFound);
59        }
60
61        return Err(ExtractAttributeError::FieldNotFound);
62    }
63
64    Err(ExtractAttributeError::BodyNotStruct)
65}
66
67fn extract_attribute_byte_str(ast: &syn::DeriveInput,field_ident: &'static str,attr_ident: &'static str) -> Result<Vec<u8>,ExtractAttributeError> {
68    let lit = try!(extract_attribute_value(ast,field_ident,attr_ident));
69
70    if let syn::Lit::ByteStr(ref bytes,_) = lit {
71       return Ok(bytes.clone());
72    }
73
74    Err(ExtractAttributeError::AttributeValueWrongType)
75}
76
77fn extract_attribute_int(ast: &syn::DeriveInput,field_ident: &'static str,attr_ident: &'static str) -> Result<u64,ExtractAttributeError> {
78    let lit = try!(extract_attribute_value(ast,field_ident,attr_ident));
79
80    if let syn::Lit::Int(value,_) = lit {
81       return Ok(value);
82    }
83
84    Err(ExtractAttributeError::AttributeValueWrongType)
85}
86
87#[proc_macro_derive(BuildMessage,attributes(message_type))]
88pub fn build_message(input: TokenStream) -> TokenStream {
89    let source = input.to_string();
90    let ast = syn::parse_derive_input(&source[..]).unwrap();
91
92    let message_type = match extract_attribute_byte_str(&ast,"_message_type_gen","message_type") {
93        Ok(bytes) => bytes,
94        Err(ExtractAttributeError::BodyNotStruct) => panic!("#[derive(BuildMessage)] can only be used with structs"),
95        Err(ExtractAttributeError::FieldNotFound) => panic!("#[derive(BuildMessage)] requires a _message_type_gen field to be specified"),
96        Err(ExtractAttributeError::AttributeNotFound) => Vec::new(),
97        Err(ExtractAttributeError::AttributeNotNameValue) |
98        Err(ExtractAttributeError::AttributeValueWrongType) => panic!("#[derive(BuildMessage)] message_type attribute must be a byte string value like #[message_type=b\"1234\"]"),
99    };
100    let is_fixt_message = source.contains("sender_comp_id") && source.contains("target_comp_id");
101
102    //Setup symbols.
103    let message_name = ast.ident;
104    let build_message_name = String::from("Build") + &message_name.to_string()[..];
105    let mut message_type_header = "b\"35=".to_string();
106    message_type_header += &String::from_utf8_lossy(&message_type[..]).into_owned()[..];
107    message_type_header += "\\x01\"";
108
109    //Convert symbols into tokens so quote's ToTokens trait doesn't quote them.
110    let build_message_name = str_to_tokens(&build_message_name[..]);
111    let message_type_header = str_to_tokens(&message_type_header[..]);
112
113    let tokens = quote! {
114        impl #message_name {
115            fn msg_type_header() -> &'static [u8] {
116                #message_type_header
117            }
118        }
119
120        pub struct #build_message_name {
121            cache: message::BuildMessageInternalCache,
122        }
123
124        impl #build_message_name {
125            fn new() -> #build_message_name {
126                #build_message_name {
127                    cache: message::BuildMessageInternalCache {
128                        fields_fix40: None,
129                        fields_fix41: None,
130                        fields_fix42: None,
131                        fields_fix43: None,
132                        fields_fix44: None,
133                        fields_fix50: None,
134                        fields_fix50sp1: None,
135                        fields_fix50sp2: None,
136                    },
137                }
138            }
139
140            fn new_into_box() -> Box<message::BuildMessage + Send> {
141                Box::new(#build_message_name::new())
142            }
143        }
144
145        impl message::BuildMessage for #build_message_name {
146            fn first_field(&self,version: message_version::MessageVersion) -> field_tag::FieldTag {
147                #message_name::first_field(version)
148            }
149
150            fn field_count(&self,version: message_version::MessageVersion) -> usize {
151                #message_name::field_count(version)
152            }
153
154            fn fields(&mut self,version: message_version::MessageVersion) -> message::FieldHashMap {
155                fn get_or_set_fields(option_fields: &mut Option<message::FieldHashMap>,
156                                     version: message_version::MessageVersion) -> message::FieldHashMap {
157                    if option_fields.is_none() {
158                        let fields = #message_name::fields(version);
159                        *option_fields = Some(fields);
160                    }
161
162                    option_fields.as_ref().unwrap().clone()
163                }
164
165                match version {
166                    message_version::MessageVersion::FIX40 => get_or_set_fields(&mut self.cache.fields_fix40,version),
167                    message_version::MessageVersion::FIX41 => get_or_set_fields(&mut self.cache.fields_fix41,version),
168                    message_version::MessageVersion::FIX42 => get_or_set_fields(&mut self.cache.fields_fix42,version),
169                    message_version::MessageVersion::FIX43 => get_or_set_fields(&mut self.cache.fields_fix43,version),
170                    message_version::MessageVersion::FIX44 => get_or_set_fields(&mut self.cache.fields_fix44,version),
171                    message_version::MessageVersion::FIX50 => get_or_set_fields(&mut self.cache.fields_fix50,version),
172                    message_version::MessageVersion::FIX50SP1 => get_or_set_fields(&mut self.cache.fields_fix50sp1,version),
173                    message_version::MessageVersion::FIX50SP2 => get_or_set_fields(&mut self.cache.fields_fix50sp2,version),
174                }
175            }
176
177            fn required_fields(&self,version: message_version::MessageVersion) -> message::FieldHashSet {
178                #message_name::required_fields(version)
179            }
180
181            fn new_into_box(&self) -> Box<message::BuildMessage + Send> {
182                #build_message_name::new_into_box()
183            }
184
185            fn build(&self) -> Box<message::Message + Send> {
186                Box::new(#message_name::new())
187            }
188        }
189
190
191        impl message::MessageBuildable for #message_name {
192            fn builder(&self) -> Box<message::BuildMessage + Send> {
193                #build_message_name::new_into_box()
194            }
195
196            fn builder_func(&self) -> fn() -> Box<message::BuildMessage + Send> {
197                #build_message_name::new_into_box
198            }
199        }
200    };
201    let mut result = String::from(tokens.as_str());
202
203    if is_fixt_message {
204        let tokens = quote! {
205            impl fixt::message::BuildFIXTMessage for #build_message_name {
206                fn new_into_box(&self) -> Box<fixt::message::BuildFIXTMessage + Send> {
207                    Box::new(#build_message_name::new())
208                }
209
210                fn build(&self) -> Box<fixt::message::FIXTMessage + Send> {
211                    Box::new(#message_name::new())
212                }
213            }
214
215            impl fixt::message::FIXTMessageBuildable for #message_name {
216                fn builder(&self) -> Box<fixt::message::BuildFIXTMessage + Send> {
217                    Box::new(#build_message_name::new())
218                }
219            }
220        };
221        result += tokens.as_str();
222    }
223
224    result.parse().unwrap()
225}
226
227#[proc_macro_derive(BuildField,attributes(tag))]
228pub fn build_field(input: TokenStream) -> TokenStream {
229    let source = input.to_string();
230    let ast = syn::parse_derive_input(&source[..]).unwrap();
231
232    let tag = match extract_attribute_int(&ast,"_tag_gen","tag") {
233        Ok(bytes) => bytes,
234        Err(ExtractAttributeError::BodyNotStruct) => panic!("#[derive(BuildField)] can only be used with structs"),
235        Err(ExtractAttributeError::FieldNotFound) => panic!("#[derive(BuildField)] requires a _tag_gen field to be specified"),
236        Err(ExtractAttributeError::AttributeNotFound) => panic!("#[derive(BuildField)] requires the _tag_gen field to have the tag attribute"),
237        Err(ExtractAttributeError::AttributeNotNameValue) |
238        Err(ExtractAttributeError::AttributeValueWrongType) => panic!("#[derive(BuildField)] tag attribute must be as an unsigned integer like #[tag=1234]"),
239    };
240    let tag = tag.to_string();
241
242    let mut tag_bytes = "b\"".to_string();
243    tag_bytes += &tag[..];
244    tag_bytes += "\"";
245
246    let field_name = ast.ident;
247    let tag = str_to_tokens(&tag[..]);
248    let tag_bytes = str_to_tokens(&tag_bytes[..]);
249
250    let tokens = quote! {
251        impl #field_name {
252            fn tag_bytes() -> &'static [u8] {
253                #tag_bytes
254            }
255
256            fn tag() -> field_tag::FieldTag {
257                field_tag::FieldTag(#tag)
258            }
259        }
260    };
261    tokens.parse().unwrap()
262}
263