c3_lang_parser/
rust_class_def.rs

1use c3_lang_linearization::{Class, Fn, Var};
2use quote::format_ident;
3use syn::{
4    parse::Parse, punctuated::Punctuated, Attribute, Expr, Field, Fields, ImplItem, ImplItemConst,
5    ImplItemMethod, ItemImpl, ItemStruct, Token, Visibility,
6};
7
8#[derive(Debug, PartialEq)]
9pub struct RustClassDef {
10    pub item_struct: ItemStruct,
11    pub item_impl: Option<ItemImpl>,
12}
13
14impl Parse for RustClassDef {
15    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
16        let item_struct: ItemStruct = input.parse()?;
17        let lookahead = input.lookahead1();
18        let item_impl = if lookahead.peek(Token![impl]) || input.peek(Token![#]) {
19            let item_impl: ItemImpl = input.parse()?;
20            Some(item_impl)
21        } else {
22            None
23        };
24        // println!("asd");
25        Ok(RustClassDef {
26            item_struct,
27            item_impl,
28        })
29    }
30}
31
32impl RustClassDef {
33    pub fn class(&self) -> Class {
34        Class::from(self.item_struct.ident.to_string())
35    }
36
37    pub fn is_public(&self) -> bool {
38        matches!(self.item_struct.vis, Visibility::Public(_))
39    }
40
41    pub fn struct_attrs(&self) -> Vec<Attribute> {
42        self.item_struct.attrs.clone()
43    }
44
45    pub fn impl_attrs(&self) -> Vec<Attribute> {
46        match &self.item_impl {
47            None => Vec::new(),
48            Some(item_impl) => item_impl.attrs.clone(),
49        }
50    }
51
52    pub fn parents(&self) -> Vec<Class> {
53        if self.item_impl.is_none() {
54            return vec![];
55        }
56
57        let item_impl = self.item_impl.clone().unwrap();
58        let items: Vec<ImplItem> = item_impl.items;
59        let mut parents: Vec<Class> = vec![];
60        for item in items.iter() {
61            if let ImplItem::Const(item_const) = item {
62                let item_const: &ImplItemConst = item_const;
63                if item_const.ident == format_ident!("PARENTS") {
64                    let expr = &item_const.expr;
65                    if let Expr::Reference(expr_reference) = expr {
66                        let expr = *expr_reference.expr.clone();
67                        if let Expr::Array(expr_list) = expr {
68                            let exprs: Punctuated<Expr, Token![,]> = expr_list.elems;
69                            for expr in exprs.iter() {
70                                if let Expr::Path(expr) = expr {
71                                    let path = expr.path.clone();
72                                    let segments = path.segments;
73                                    if segments[0].ident == format_ident!("ClassName") {
74                                        parents.push(Class::from(segments[1].ident.to_string()));
75                                    }
76                                }
77                            }
78                        }
79                    }
80                }
81            }
82        }
83        parents
84    }
85
86    pub fn functions(&self) -> Vec<Fn> {
87        self.function_impls().into_iter().map(|x| x.0).collect()
88    }
89
90    pub fn function_impls(&self) -> Vec<(Fn, ImplItemMethod)> {
91        let item_impl = self.item_impl.clone().unwrap();
92        let items: Vec<ImplItem> = item_impl.items;
93        let mut functions: Vec<(Fn, ImplItemMethod)> = vec![];
94        for item in items.iter() {
95            if let ImplItem::Method(method) = item {
96                let name = method.sig.ident.to_string();
97                functions.push((Fn::from(name), method.clone()));
98            }
99        }
100        functions
101    }
102
103    pub fn variables(&self) -> Vec<Var> {
104        self.variables_impl().into_iter().map(|x| x.0).collect()
105    }
106
107    pub fn variables_impl(&self) -> Vec<(Var, Field)> {
108        let mut variables: Vec<(Var, Field)> = vec![];
109        if let Fields::Named(fields) = &self.item_struct.fields {
110            for field in &fields.named {
111                let var = Var::from(field.ident.clone().unwrap().to_string());
112                variables.push((var, field.clone()));
113            }
114        };
115        variables
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use c3_lang_linearization::Class;
122    use quote::quote;
123    use syn::parse_quote;
124
125    use super::RustClassDef;
126
127    #[test]
128    fn test_rust_class_def_without_impl() {
129        let input = quote! {
130            struct A {}
131        };
132        let result: RustClassDef = syn::parse2(input).unwrap();
133        let target = RustClassDef {
134            item_struct: parse_quote!(
135                struct A {}
136            ),
137            item_impl: None,
138        };
139        assert_eq!(result, target);
140    }
141
142    #[test]
143    fn test_rust_class_def_with_impl() {
144        let input = quote! {
145            struct A {}
146            impl A for B {}
147        };
148        let result: RustClassDef = syn::parse2(input).unwrap();
149        let target = RustClassDef {
150            item_struct: parse_quote!(
151                struct A {}
152            ),
153            item_impl: Some(parse_quote!( impl A for B {} )),
154        };
155        assert_eq!(result, target);
156    }
157
158    #[test]
159    fn test_rust_class_def_getters() {
160        let input = quote! {
161            #[derive(Default)]
162            pub struct A {
163                x: u32
164            }
165
166            #[custom_macro]
167            impl A {
168                const PARENTS: &'static [ClassName; 2usize] = &[
169                    ClassName::X,
170                    ClassName::Y
171                ];
172
173                pub fn k(&self) -> u32 { 4 }
174            }
175        };
176        let result: RustClassDef = syn::parse2(input).unwrap();
177        assert!(result.is_public());
178        assert_eq!(result.class(), Class::from("A"));
179        assert_eq!(result.parents(), vec![Class::from("X"), Class::from("Y")]);
180        assert_eq!(
181            result.struct_attrs(),
182            vec![parse_quote! { #[derive(Default)] }]
183        );
184        assert_eq!(result.impl_attrs(), vec![parse_quote! { #[custom_macro] }]);
185    }
186}