Skip to main content

rusteron_code_gen/
parser.rs

1use crate::generator::{parse_custom_methods, CBinding, CWrapper, Method};
2use crate::{Arg, ArgProcessing, CHandler};
3use itertools::Itertools;
4use quote::ToTokens;
5use std::collections::{BTreeMap, BTreeSet};
6use std::fs;
7use std::path::PathBuf;
8use syn::{Attribute, Item, ItemForeignMod, ItemStruct, ItemType, Lit, Meta, MetaNameValue};
9
10pub fn parse_bindings(out: &PathBuf) -> CBinding {
11    let file_content = fs::read_to_string(out.clone()).expect("Unable to read file");
12    let syntax_tree = syn::parse_file(&file_content).expect("Unable to parse file");
13    let mut wrappers = BTreeMap::new();
14    let mut methods = Vec::new();
15    let mut handlers = Vec::new();
16
17    let items = syntax_tree.items;
18
19    for item in &items {
20        if let Item::Type(ty) = item {
21            process_type(&mut wrappers, &mut handlers, ty);
22        }
23    }
24
25    let handler_names = handlers
26        .iter()
27        .filter(|h| {
28            !["aeron_udp_channel", "aeron_udp_transport"]
29                .iter()
30                .any(|&filter| h.type_name.starts_with(filter))
31        })
32        .map(|handler| handler.type_name.clone())
33        .collect();
34
35    for item in &items {
36        if let Item::Struct(s) = item {
37            process_struct(&mut wrappers, s, &handler_names);
38        }
39    }
40
41    for item in &items {
42        if let Item::ForeignMod(fm) = item {
43            process_c_method(&mut wrappers, &mut methods, fm, &handler_names);
44        }
45    }
46
47    let mut bindings = CBinding {
48        wrappers: wrappers
49            .into_iter()
50            .filter(|(_, wrapper)| {
51                // these are from media driver and do not follow convention
52                ![
53                    "aeron_thread",
54                    "aeron_command",
55                    "aeron_executor",
56                    "aeron_name_resolver",
57                    "aeron_udp_channel_transport", // this one I have issues with handlers
58                    "aeron_udp_transport",         // this one I have issues with handlers
59                ]
60                .iter()
61                .any(|&filter| wrapper.type_name.starts_with(filter))
62            })
63            .collect(),
64        methods,
65        handlers: handlers
66            .into_iter()
67            .filter(|h| {
68                !["aeron_udp_channel", "aeron_udp_transport"]
69                    .iter()
70                    .any(|&filter| h.type_name.starts_with(filter))
71            })
72            .collect(),
73    };
74
75    let mismatched_types = bindings
76        .wrappers
77        .iter()
78        .filter(|(key, w)| key.as_str() != w.type_name)
79        .map(|(a, b)| (a.clone(), b.clone()))
80        .collect_vec();
81    assert_eq!(Vec::<(String, CWrapper)>::new(), mismatched_types);
82
83    let custom = parse_custom_methods(crate::CUSTOM_AERON_CODE);
84    for wrapper in bindings.wrappers.values_mut() {
85        if let Some(methods) = custom.get(&wrapper.class_name) {
86            wrapper.skipped_methods = methods.clone();
87        }
88    }
89
90    bindings
91}
92
93fn process_c_method(
94    wrappers: &mut BTreeMap<String, CWrapper>,
95    methods: &mut Vec<Method>,
96    fm: &ItemForeignMod,
97    handler_names: &BTreeSet<String>,
98) {
99    // Extract functions inside extern "C" blocks
100    if fm.abi.name.is_some() && fm.abi.name.as_ref().unwrap().value() == "C" {
101        for foreign_item in &fm.items {
102            if let syn::ForeignItem::Fn(f) = foreign_item {
103                let docs = get_doc_comments(&f.attrs);
104                let fn_name = f.sig.ident.to_string();
105
106                // Get function arguments and return type as Rust code
107                let args = extract_function_arguments(&f.sig.inputs);
108                let ret = extract_return_type(&f.sig.output);
109
110                let option = if let Some(arg) = args
111                    .iter()
112                    .skip_while(|a| a.is_mut_pointer() && a.is_primitive())
113                    .next()
114                {
115                    let ty = &arg.c_type;
116                    let ty = ty.split(' ').last().map(|t| t.to_string()).unwrap();
117                    if wrappers.contains_key(&ty) {
118                        Some(ty)
119                    } else {
120                        find_closest_wrapper_from_method_name(wrappers, &fn_name)
121                    }
122                } else {
123                    find_closest_wrapper_from_method_name(wrappers, &fn_name)
124                };
125
126                match option {
127                    Some(key) => {
128                        let wrapper = wrappers.get_mut(&key).unwrap();
129                        wrapper.methods.push(Method {
130                            fn_name: fn_name.clone(),
131                            struct_method_name: fn_name
132                                .replace(&wrapper.type_name[..wrapper.type_name.len() - 1], "")
133                                .to_string(),
134                            return_type: Arg {
135                                name: "".to_string(),
136                                c_type: ret.clone(),
137                                processing: ArgProcessing::Default,
138                            },
139                            arguments: process_types(args.clone(), Some(handler_names)),
140                            docs: docs.clone(),
141                        });
142                    }
143                    None => methods.push(Method {
144                        fn_name: fn_name.clone(),
145                        struct_method_name: "".to_string(),
146                        return_type: Arg {
147                            name: "".to_string(),
148                            c_type: ret.clone(),
149                            processing: ArgProcessing::Default,
150                        },
151                        arguments: process_types(args.clone(), Some(handler_names)),
152                        docs: docs.clone(),
153                    }),
154                }
155            }
156        }
157    }
158}
159
160fn find_closest_wrapper_from_method_name(
161    wrappers: &mut BTreeMap<String, CWrapper>,
162    fn_name: &String,
163) -> Option<String> {
164    let type_names = get_possible_wrappers(&fn_name);
165
166    let mut value = None;
167    for ty in type_names {
168        if wrappers.contains_key(&ty) {
169            value = Some(ty);
170            break;
171        }
172    }
173
174    value
175}
176
177pub fn get_possible_wrappers(fn_name: &str) -> Vec<String> {
178    fn_name
179        .char_indices()
180        .filter(|(_, c)| *c == '_')
181        .map(|(i, _)| format!("{}_t", &fn_name[..i]))
182        .rev()
183        .collect_vec()
184}
185
186fn process_type(
187    wrappers: &mut BTreeMap<String, CWrapper>,
188    handlers: &mut Vec<CHandler>,
189    ty: &ItemType,
190) {
191    // Handle type definitions and get docs
192    let docs = get_doc_comments(&ty.attrs);
193
194    let type_name = ty.ident.to_string();
195    let class_name = snake_to_pascal_case(&type_name);
196
197    if is_struct_typedef(&ty.ty) {
198        wrappers
199            .entry(type_name.clone())
200            .or_insert(CWrapper {
201                class_name,
202                without_name: type_name[..type_name.len() - 2].to_string(),
203                type_name,
204                ..Default::default()
205            })
206            .docs
207            .extend(docs);
208    } else {
209        // Parse the function pointer type -> it is typically used for handlers/callbacks
210        if let syn::Type::Path(type_path) = &*ty.ty {
211            if let Some(segment) = type_path.path.segments.last() {
212                if segment.ident.to_string() == "Option" {
213                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
214                        if let Some(syn::GenericArgument::Type(syn::Type::BareFn(bare_fn))) =
215                            args.args.first()
216                        {
217                            let args: Vec<Arg> = bare_fn
218                                .inputs
219                                .iter()
220                                .map(|arg| {
221                                    let arg_name = match &arg.name {
222                                        Some((ident, _)) => ident.to_string(),
223                                        None => "".to_string(),
224                                    };
225                                    let arg_type = arg.ty.to_token_stream().to_string();
226                                    (arg_name, arg_type)
227                                })
228                                .map(|(field_name, field_type)| Arg {
229                                    name: field_name,
230                                    c_type: field_type,
231                                    processing: ArgProcessing::Default,
232                                })
233                                .collect();
234                            let string = bare_fn.output.to_token_stream().to_string();
235                            let mut return_type = string.trim();
236
237                            if return_type.starts_with("-> ") {
238                                return_type = &return_type[3..];
239                            }
240
241                            if return_type.is_empty() {
242                                return_type = "()";
243                            }
244
245                            if is_handler_typedef(&args) {
246                                let value = CHandler {
247                                    type_name: ty.ident.to_string(),
248                                    args: process_types(args, None),
249                                    return_type: Arg {
250                                        name: "".to_string(),
251                                        c_type: return_type.to_string(),
252                                        processing: ArgProcessing::Default,
253                                    },
254                                    docs: docs.clone(),
255                                    fn_mut_signature: Default::default(),
256                                    closure_type_name: Default::default(),
257                                };
258                                handlers.push(value);
259                            }
260                        }
261                    }
262                }
263            }
264        }
265    }
266}
267
268fn is_handler_typedef(args: &[Arg]) -> bool {
269    args.iter().filter(|arg| arg.is_c_void()).count() == 1
270        || args
271            .first()
272            .map(|arg| arg.is_c_void() && is_client_data_arg(&arg.name))
273            .unwrap_or(false)
274}
275
276fn is_client_data_arg(name: &str) -> bool {
277    name == "clientd"
278        || name == "state"
279        || name == "task_clientd"
280        || name.ends_with("_clientd")
281        || name.ends_with("_state")
282}
283
284fn is_struct_typedef(ty: &syn::Type) -> bool {
285    if let syn::Type::Path(type_path) = ty {
286        if let Some(segment) = type_path.path.segments.last() {
287            return segment.ident.to_string().ends_with("_stct");
288        }
289    }
290
291    false
292}
293
294fn process_struct(
295    wrappers: &mut BTreeMap<String, CWrapper>,
296    s: &ItemStruct,
297    handler_names: &BTreeSet<String>,
298) {
299    // Print the struct name and its doc comments
300    let docs = get_doc_comments(&s.attrs);
301    let type_name = s.ident.to_string().replace("_stct", "_t");
302    let class_name = snake_to_pascal_case(&type_name);
303
304    let fields: Vec<Arg> = s
305        .fields
306        .iter()
307        .map(|f| {
308            let field_name = f.ident.as_ref().unwrap().to_string();
309            let field_type = f.ty.to_token_stream().to_string();
310            (field_name, field_type)
311        })
312        .map(|(field_name, field_type)| Arg {
313            name: field_name,
314            c_type: field_type,
315            processing: ArgProcessing::Default,
316        })
317        .collect();
318
319    let w = wrappers.entry(type_name.to_string()).or_insert(CWrapper {
320        class_name,
321        without_name: type_name[..type_name.len() - 2].to_string(),
322        type_name,
323        ..Default::default()
324    });
325    w.docs.extend(docs);
326    w.fields = process_types(fields, Some(handler_names));
327}
328
329fn process_types(
330    mut name_and_type: Vec<Arg>,
331    handler_names: Option<&BTreeSet<String>>,
332) -> Vec<Arg> {
333    // now mark arguments which can be reduced
334    for i in 1..name_and_type.len() {
335        let param1 = &name_and_type[i - 1];
336        let param2 = &name_and_type[i];
337
338        let is_int = param2.c_type == "usize" || param2.c_type == "i32";
339        let length_field = param2.name == "length"
340            || param2.name == "len"
341            || (param2.name.ends_with("_length") && param2.name.starts_with(&param1.name));
342        if param2.is_c_void()
343            && !param1.is_mut_pointer()
344            && param1.c_type.ends_with("_t")
345            && handler_names
346                .map(|handler_names| handler_names.contains(&param1.c_type))
347                .unwrap_or(false)
348        {
349            // closures
350            //         handler: aeron_on_available_counter_t,
351            //         clientd: *mut ::std::os::raw::c_void,
352            let processing = ArgProcessing::Handler(vec![param1.clone(), param2.clone()]);
353            name_and_type[i - 1].processing = processing.clone();
354            name_and_type[i].processing = processing.clone();
355        } else if param1.is_c_string_any() && !param1.is_mut_pointer() && is_int && length_field {
356            //     pub stripped_channel: *mut ::std::os::raw::c_char,
357            //     pub stripped_channel_length: usize,
358            let processing = ArgProcessing::StringWithLength(vec![param1.clone(), param2.clone()]);
359            name_and_type[i - 1].processing = processing.clone();
360            name_and_type[i].processing = processing.clone();
361        } else if param1.is_byte_array()
362            // && !param1.is_mut_pointer()
363            && is_int
364            && length_field
365        {
366            //         key_buffer: *const u8,
367            //         key_buffer_length: usize,
368            let processing =
369                ArgProcessing::ByteArrayWithLength(vec![param1.clone(), param2.clone()]);
370            name_and_type[i - 1].processing = processing.clone();
371            name_and_type[i].processing = processing.clone();
372        }
373
374        //
375    }
376
377    name_and_type
378}
379
380// Helper function to extract doc comments
381fn get_doc_comments(attrs: &[Attribute]) -> BTreeSet<String> {
382    attrs
383        .iter()
384        .filter_map(|attr| {
385            // Parse the attribute meta to check if it is a `Meta::NameValue`
386            if let Meta::NameValue(MetaNameValue {
387                path,
388                value: syn::Expr::Lit(expr_lit),
389                ..
390            }) = &attr.meta
391            {
392                // Check if the path is "doc"
393                if path.is_ident("doc") {
394                    // Check if the literal is a string and return its value
395                    if let Lit::Str(lit_str) = &expr_lit.lit {
396                        return Some(lit_str.value().trim().to_string());
397                    }
398                }
399            }
400            None
401        })
402        .collect()
403}
404
405pub fn snake_to_pascal_case(mut snake: &str) -> String {
406    if snake.ends_with("_t") {
407        snake = &snake[..snake.len() - 2];
408    }
409    snake
410        .split('_')
411        .filter(|x| *x != "on") // Split the string by underscores
412        .map(|word| {
413            let mut chars = word.chars();
414            // Capitalize the first letter and collect the rest of the letters
415            match chars.next() {
416                Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
417                None => String::new(),
418            }
419        })
420        .collect()
421}
422
423// Helper function to extract function arguments as Rust code
424fn extract_function_arguments(
425    inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::token::Comma>,
426) -> Vec<Arg> {
427    inputs
428        .iter()
429        .map(|arg| match arg {
430            syn::FnArg::Receiver(_) => "self".to_string(), // Handle self receiver
431            syn::FnArg::Typed(pat_type) => pat_type.to_token_stream().to_string(), // Convert the pattern and type to Rust code
432        })
433        .map(|arg| {
434            arg.splitn(2, ':')
435                .map(|s| s.trim().to_string())
436                .collect_tuple()
437                .unwrap()
438        })
439        .map(|(name, ty)| Arg {
440            name,
441            c_type: ty,
442            processing: ArgProcessing::Default,
443        })
444        .collect_vec()
445}
446
447// Helper function to extract return type as Rust code
448fn extract_return_type(output: &syn::ReturnType) -> String {
449    match output {
450        syn::ReturnType::Default => "()".to_string(), // No return type, equivalent to ()
451        syn::ReturnType::Type(_, ty) => ty.to_token_stream().to_string(), // Convert the type to Rust code
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use crate::parser::parse_bindings;
458    use std::path::PathBuf;
459
460    fn running_under_valgrind() -> bool {
461        std::env::var_os("RUSTERON_VALGRIND").is_some()
462    }
463
464    #[test]
465    fn media_driver() {
466        if running_under_valgrind() {
467            return;
468        }
469
470        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
471            .join("bindings")
472            .join("media-driver.rs");
473        let bindings = parse_bindings(&path);
474        assert_eq!(
475            "AeronImageFragmentAssembler",
476            bindings
477                .wrappers
478                .get("aeron_image_fragment_assembler_t")
479                .unwrap()
480                .class_name
481        );
482    }
483    #[test]
484    fn client() {
485        if running_under_valgrind() {
486            return;
487        }
488
489        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
490            .join("bindings")
491            .join("client.rs");
492        let bindings = parse_bindings(&path);
493        assert_eq!(
494            "AeronImageFragmentAssembler",
495            bindings
496                .wrappers
497                .get("aeron_image_fragment_assembler_t")
498                .unwrap()
499                .class_name
500        );
501        assert!(bindings.handlers.len() > 1);
502    }
503}