rem_utils/
wrappers.rs

1use quote::ToTokens;
2
3use crate::{pprint_ast, typ::RustType, CHRusty_build, CHRusty_parse};
4
5/// Represents a wrapper struct that encodes the fact that a given pointer should implement indexed
6#[derive(Clone, Debug)]
7pub struct IndexWrapper {
8    /// Depth of indexing supported by the wrapper
9    /// i.e 1 => a[x]
10    ///     2 => a[x][y]
11    indirection: usize,
12    /// base expression being wrapped
13    expr: syn::Expr,
14    /// type of inner expression
15    ty: RustType,
16}
17
18impl std::fmt::Display for IndexWrapper {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        write!(
21            f,
22            "IndexWrapper {{ depth: {}, expr: {} }}",
23            self.indirection,
24            self.expr.clone().into_token_stream().to_string()
25        )
26    }
27}
28
29fn extract_path_argument(argument: syn::PathArguments) -> Option<RustType> {
30    match argument {
31        syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
32            args, ..
33        }) if args.len() == 1 => match &args[0] {
34            syn::GenericArgument::Type(ty) => Some(ty.clone().into()),
35            _ => None,
36        },
37        _ => None,
38    }
39}
40
41fn path_segments_to_str(path: syn::Path) -> Vec<(String, Option<RustType>)> {
42    path.segments
43        .into_iter()
44        .map(|v| {
45            let ident = v.ident.to_string();
46            let typ = extract_path_argument(v.arguments);
47            (ident, typ)
48        })
49        .collect()
50}
51
52fn is_tuple_struct(expr: &syn::ExprCall) -> bool {
53    expr.args.len() == 1
54}
55
56fn unwrap_tuple_struct(expr: syn::Expr) -> (syn::Path, syn::Expr) {
57    if let syn::Expr::Call(syn::ExprCall {
58        func: box syn::Expr::Path(syn::ExprPath { path, .. }),
59        args,
60        ..
61    }) = expr
62    {
63        assert!(args.len() == 1);
64        (path, args[0].clone())
65    } else {
66        unreachable!()
67    }
68}
69
70fn extract_tuple_struct(expr: &syn::Expr) -> (&syn::Path, &syn::Expr) {
71    if let syn::Expr::Call(syn::ExprCall {
72        func: box syn::Expr::Path(syn::ExprPath { path, .. }),
73        args,
74        ..
75    }) = expr
76    {
77        assert!(args.len() == 1);
78        (path, &args[0])
79    } else {
80        unreachable!()
81    }
82}
83
84fn is_tuple_call(expr: &syn::Expr) -> bool {
85    if let syn::Expr::Call(cs) = expr {
86        is_tuple_struct(cs)
87    } else {
88        false
89    }
90}
91
92impl IndexWrapper {
93    pub fn new(indirection: usize, expr: syn::Expr, ty: RustType) -> Self {
94        IndexWrapper {
95            indirection,
96            expr,
97            ty,
98        }
99    }
100
101    pub fn indirection(&self) -> usize {
102        self.indirection
103    }
104
105    pub fn base_expr(&self) -> &syn::Expr {
106        &self.expr
107    }
108
109    pub fn base_ty(&self) -> &RustType {
110        &self.ty
111    }
112
113    /// Test whether an expression is indeed an index wrapper
114    pub fn is_index_wrapper(expr: &syn::Expr) -> bool {
115        match &expr {
116            syn::Expr::Call(expr_call) if is_tuple_struct(expr_call) => {
117                let (path, _) = extract_tuple_struct(expr);
118                let path_segments = path_segments_to_str(path.clone());
119                let elts = path_segments
120                    .iter()
121                    .map(|(v, _)| v.as_str())
122                    .collect::<Vec<_>>();
123                &elts[..] == &["chrusty", "IndexWrapperFinal"]
124            }
125            _ => false,
126        }
127    }
128
129    /// Folds over the calls in a wrapper in order from IndexWrapperFinal to IndexWrapperBase
130    pub fn fold_calls<P, O>(mut f: P, expr: &syn::Expr) -> Vec<O>
131    where
132        P: FnMut(&syn::Expr) -> O,
133    {
134        let mut acc = vec![];
135        let mut expr = expr;
136        while is_tuple_call(expr) {
137            let (path, next_expr) = extract_tuple_struct(expr);
138            let slice = path_segments_to_str(path.clone());
139            let elts = slice
140                .iter()
141                .map(|(v, ty)| (v.as_str(), ty))
142                .collect::<Vec<_>>();
143            match &elts[..] {
144                [("chrusty", _), ("IndexWrapperFinal", _)] => {
145                    acc.push(f(expr));
146                    expr = next_expr
147                }
148                [("chrusty", _), ("IndexWrapper", _)] => {
149                    acc.push(f(expr));
150                    expr = next_expr
151                }
152                [("chrusty", _), ("IndexWrapperBase", _)] => {
153                    acc.push(f(expr));
154                    break;
155                }
156                elts => panic!("unexpected index wrapper structure {:?}", elts),
157            }
158        }
159        acc
160    }
161}
162
163impl Into<syn::Expr> for IndexWrapper {
164    fn into(self) -> syn::Expr {
165        fn wrap_with_constructor(name: &str, expr: syn::Expr, typ: Option<syn::Type>) -> syn::Expr {
166            let last_segment = {
167                let mut base = CHRusty_parse!((name) as syn::PathSegment);
168                if let Some(typ) = typ {
169                    base.arguments = syn::PathArguments::AngleBracketed(
170                        CHRusty_build!(syn::AngleBracketedGenericArguments {
171                            args: [syn::GenericArgument::Type(typ)].into_iter().collect(),
172                            // Note: explicitly setting the colon2
173                            // token to some is important, else the
174                            // code will not be valid rust
175                            colon2_token: Some(Default::default());
176                            default![lt_token, gt_token]
177                        }),
178                    )
179                }
180                base
181            };
182
183            syn::Expr::Call(CHRusty_build!(syn::ExprCall {
184                func: Box::new(syn::Expr::Path(CHRusty_build!(syn::ExprPath{
185                    path: syn::Path {
186                        leading_colon: None,
187                        segments: [
188                            CHRusty_parse!("chrusty" as syn::PathSegment),
189                            last_segment
190                        ].into_iter().collect()
191                    };
192                    default![attrs, qself]
193                }))),
194                args: [expr].into_iter().collect();
195                default![attrs,paren_token]
196            }))
197        }
198        let mut expr: syn::Expr =
199            wrap_with_constructor("IndexWrapperBase", self.expr, Some(self.ty.into()));
200
201        for _ in 1..self.indirection {
202            expr = wrap_with_constructor("IndexWrapper", expr, None)
203        }
204
205        wrap_with_constructor("IndexWrapperFinal", expr, None)
206    }
207}
208
209impl From<syn::Expr> for IndexWrapper {
210    fn from(expr: syn::Expr) -> Self {
211        let mut expr = match &expr {
212            syn::Expr::Call(expr_call) if is_tuple_struct(expr_call) => {
213                let (path, inner_expr) = unwrap_tuple_struct(expr);
214                let path_segments = path_segments_to_str(path);
215                let elts = path_segments
216                    .iter()
217                    .map(|(v, _)| v.as_str())
218                    .collect::<Vec<_>>();
219                assert!(&elts[..] == &["chrusty", "IndexWrapperFinal"]);
220                inner_expr
221            }
222            expr => panic!(
223                "unexpected structure for RustIndexWrapper {:?}",
224                pprint_ast!(expr)
225            ),
226        };
227
228        let base_expr;
229        let base_typ;
230        let mut indirection = 1;
231
232        loop {
233            let (path, inner_expr) = unwrap_tuple_struct(expr);
234            let slice = path_segments_to_str(path.clone());
235            let elts = slice
236                .iter()
237                .map(|(v, ty)| (v.as_str(), ty))
238                .collect::<Vec<_>>();
239            match &elts[..] {
240                [("chrusty", _), ("IndexWrapper", _)] => {
241                    indirection += 1;
242                    expr = inner_expr
243                }
244                [("chrusty", _), ("IndexWrapperBase", Some(ty))] => {
245                    base_expr = inner_expr;
246                    base_typ = ty.clone();
247                    break;
248                }
249                elts => panic!("unexpected index wrapper structure {:?}", elts),
250            }
251        }
252        let expr = base_expr;
253        let ty = base_typ;
254
255        IndexWrapper {
256            indirection,
257            expr,
258            ty,
259        }
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[test]
268    fn test_index_wrapper_conversion_works_for_1() {
269        let base_expr = CHRusty_parse!("x.as_mut_ptr()" as syn::Expr);
270        let base_ty = CHRusty_parse!("*mut i32" as syn::Type).into();
271
272        let wrapper: IndexWrapper = IndexWrapper::new(1, base_expr, base_ty);
273
274        let wrapper_expr: syn::Expr = wrapper.into();
275
276        let wrapper: IndexWrapper = wrapper_expr.into();
277
278        assert_eq!(wrapper.indirection, 1);
279        let base_ty: syn::Type = wrapper.ty.into();
280        assert_eq!(&pprint_ast!(base_ty), "* mut i32")
281    }
282
283    #[test]
284    fn test_index_wrapper_conversion_works_for_2() {
285        let base_expr = CHRusty_parse!("x.as_mut_ptr()" as syn::Expr);
286        let base_ty = CHRusty_parse!("*mut *mut i32" as syn::Type).into();
287
288        let wrapper: IndexWrapper = IndexWrapper::new(2, base_expr, base_ty);
289
290        let wrapper_expr: syn::Expr = wrapper.into();
291
292        let wrapper: IndexWrapper = wrapper_expr.into();
293
294        assert_eq!(wrapper.indirection, 2);
295        let base_ty: syn::Type = wrapper.ty.into();
296        assert_eq!(&pprint_ast!(base_ty), "* mut * mut i32")
297    }
298
299    #[test]
300    fn test_index_wrapper_conversion_works_for_3() {
301        let base_expr = CHRusty_parse!("x.as_mut_ptr()" as syn::Expr);
302        let base_ty = CHRusty_parse!("* mut * mut * mut i32" as syn::Type).into();
303
304        let wrapper: IndexWrapper = IndexWrapper::new(3, base_expr, base_ty);
305
306        let wrapper_expr: syn::Expr = wrapper.into();
307
308        let wrapper: IndexWrapper = wrapper_expr.into();
309
310        assert_eq!(wrapper.indirection, 3);
311        let base_ty: syn::Type = wrapper.ty.into();
312        assert_eq!(&pprint_ast!(base_ty), "* mut * mut * mut i32")
313    }
314
315    #[test]
316    fn test_index_wrapper_has_correct_internal_structure() {
317        let base_expr = CHRusty_parse!("x.as_mut_ptr()" as syn::Expr);
318        let base_ty = CHRusty_parse!("* mut * mut * mut i32" as syn::Type).into();
319
320        let wrapper: IndexWrapper = IndexWrapper::new(3, base_expr, base_ty);
321
322        let wrapper_expr: syn::Expr = wrapper.into();
323        assert_eq!(&pprint_ast!(wrapper_expr), "chrusty :: IndexWrapperFinal (chrusty :: IndexWrapper (chrusty :: IndexWrapper (chrusty :: IndexWrapperBase :: < * mut * mut * mut i32 > (x . as_mut_ptr ()))))");
324    }
325}