rem_utils/
typ.rs

1use ena::unify::UnifyValue;
2use quote::ToTokens;
3use std::collections::{HashMap, HashSet};
4use std::hash::Hash;
5use syn::punctuated::Punctuated;
6use syn::{FieldsNamed, Path, PathSegment, Type, TypeArray};
7
8/// Mapping of function names to type signatures
9pub type TypeMap = HashMap<crate::location::Loc, RustTypeSignature>;
10
11#[derive(Debug)]
12pub enum Error {
13    UnUnifiableTypes(RustType, RustType),
14}
15
16pub type ProgramTypeContext = (
17    HashMap<syn::Ident, RustType>,
18    HashMap<syn::Ident, RustStruct>,
19);
20
21#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
22pub struct TVar(pub usize);
23
24impl From<&str> for TVar {
25    fn from(s: &str) -> Self {
26        if !s.starts_with("T") {
27            panic!("invalid assumption: unknown generic var \"{}\"", s)
28        }
29        let ind = s
30            .trim_start_matches('T')
31            .parse()
32            .expect("Index for generic could not be extracted");
33        TVar(ind)
34    }
35}
36
37impl From<String> for TVar {
38    fn from(s: String) -> Self {
39        s.as_str().into()
40    }
41}
42
43impl std::fmt::Display for TVar {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        write!(f, "T{}", self.0)
46    }
47}
48
49#[derive(Clone, Debug, Hash, PartialEq, Eq)]
50pub enum RustMutability {
51    Immutable,
52    Mutable,
53}
54
55impl std::fmt::Display for RustMutability {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            RustMutability::Immutable => write!(f, "immutable"),
59            RustMutability::Mutable => write!(f, "mutable"),
60        }
61    }
62}
63
64impl From<Option<syn::token::Mut>> for RustMutability {
65    fn from(mut_: Option<syn::token::Mut>) -> Self {
66        match mut_ {
67            Some(_) => RustMutability::Mutable,
68            None => RustMutability::Immutable,
69        }
70    }
71}
72
73impl Into<Option<syn::token::Mut>> for RustMutability {
74    fn into(self) -> Option<syn::token::Mut> {
75        match self {
76            RustMutability::Mutable => Some(Default::default()),
77            RustMutability::Immutable => None,
78        }
79    }
80}
81
82#[derive(Clone, Debug, Hash, PartialEq, Eq)]
83pub enum CIntegralSize {
84    Char,
85    Short,
86    Int,
87    Long,
88    LongLong,
89}
90
91impl std::fmt::Display for CIntegralSize {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        match self {
94            CIntegralSize::Char => write!(f, "char"),
95            CIntegralSize::Short => write!(f, "short"),
96            CIntegralSize::Int => write!(f, "int"),
97            CIntegralSize::Long => write!(f, "long"),
98            CIntegralSize::LongLong => write!(f, "longlong"),
99        }
100    }
101}
102
103#[derive(Clone, Debug, Hash, PartialEq, Eq)]
104pub enum CFloatSize {
105    Float,
106    Double,
107}
108
109impl std::fmt::Display for CFloatSize {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        match self {
112            CFloatSize::Float => write!(f, "float"),
113            CFloatSize::Double => write!(f, "double"),
114        }
115    }
116}
117
118#[derive(Clone, Debug, Hash, PartialEq, Eq)]
119pub enum RustType {
120    CVoid,
121    CInt {
122        unsigned: bool,
123        size: CIntegralSize,
124    },
125    CFloat(CFloatSize),
126    CAlias(syn::Ident),
127
128    Array(Box<RustType>, usize),
129
130    Option(Box<RustType>),
131    Vec(Box<RustType>),
132    Unit,
133    I32,
134    U8,
135    SizeT,
136    Isize,
137    Usize,
138    TVar(TVar), //rust types
139
140    Never,
141
142    ExternFn(Vec<Box<RustType>>, bool, Box<RustType>),
143
144    /// immutable reference
145    Reference(RustMutability, Box<RustType>),
146    /// *mut T
147    Pointer(Box<RustType>),
148}
149
150impl RustType {
151    fn uses(&self, set: &mut HashSet<syn::Ident>) {
152        match self {
153            RustType::CAlias(id) => {
154                set.insert(id.clone());
155                ()
156            }
157            RustType::Option(ty)
158            | RustType::Vec(ty)
159            | RustType::Reference(_, ty)
160            | RustType::Pointer(ty)
161            | RustType::Array(ty, _) => ty.uses(set),
162
163            // RustType::ExternFn(args, _, out_ty) => {
164            //     for arg in args.iter() {
165            //         arg.uses(set)
166            //     }
167            //     out_ty.uses(set)
168            // }
169            _ => (),
170        }
171    }
172
173    fn resolve_checked(
174        &mut self,
175        path: &mut HashSet<syn::Ident>,
176        ctxt: &ProgramTypeContext,
177    ) -> bool {
178        match self {
179            // recursion check
180            RustType::CAlias(id) if path.contains(id) => true,
181            RustType::CAlias(id) if !path.contains(id) => {
182                // add the visited alias to the path
183                path.insert(id.clone());
184                // if type alias to a defined struct, then we good boys
185                if ctxt.1.contains_key(id) {
186                    false
187                } else {
188                    match ctxt.0.get(id) {
189                        Some(inner) => {
190                            // update self to be the inner type
191                            *self = inner.clone();
192                            // recursive
193                            self.resolve_checked(path, ctxt)
194                        }
195                        None => {
196                            log::warn!("attempted to resolve type {} that has no defined alias or definition", id.to_string());
197                            false
198                        }
199                    }
200                }
201            }
202            RustType::Option(elt)
203            | RustType::Vec(elt)
204            | RustType::Pointer(elt)
205            | RustType::Reference(_, elt)
206            | RustType::Array(elt, _) => elt.resolve_checked(path, ctxt),
207            RustType::ExternFn(args, _, out) => {
208                let mut any_rec = false;
209                let mut base_path = path.clone();
210                for arg in args.iter_mut() {
211                    let mut rec_path = base_path.clone();
212                    any_rec |= arg.resolve_checked(&mut rec_path, ctxt);
213                    path.extend(rec_path.into_iter());
214                }
215                any_rec |= out.resolve_checked(&mut base_path, ctxt);
216                path.extend(base_path.into_iter());
217                any_rec
218            }
219            _ => false,
220        }
221    }
222
223    /// Resolves a type according to the type context, avoiding loops, returning the list of types visited
224    pub fn resolve(&mut self, ctxt: &ProgramTypeContext) -> HashSet<syn::Ident> {
225        let mut set = HashSet::new();
226        self.resolve_checked(&mut set, ctxt);
227        set
228    }
229}
230
231impl std::fmt::Display for RustType {
232    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233        match self {
234            RustType::Never => write!(f, "never"),
235            RustType::Array(box ty, size) => write!(f, "array({}, {})", ty, size),
236            RustType::Option(box ty) => write!(f, "option({})", ty),
237            RustType::Vec(box ty) => write!(f, "vec({})", ty),
238            RustType::CInt { unsigned, size } => {
239                write!(f, "c_{}{}", if *unsigned { "u" } else { "" }, size)
240            }
241            RustType::CFloat(size) => write!(f, "c_{}", size),
242            RustType::CVoid => write!(f, "c_void"),
243            RustType::CAlias(ident) => write!(f, "{}", ident),
244            RustType::Unit => write!(f, "()"),
245            RustType::SizeT => write!(f, "size_t"),
246            RustType::U8 => write!(f, "u8"),
247            RustType::I32 => write!(f, "i32"),
248            RustType::Isize => write!(f, "isize"),
249            RustType::Usize => write!(f, "usize"),
250            RustType::TVar(tvar) => write!(f, "{}", tvar),
251            RustType::Pointer(box x) => write!(f, "mut_ptr_{}", x),
252            RustType::Reference(mt, box x) => write!(f, "ref_{}_{}", mt, x),
253            RustType::ExternFn(args, variadic, body) => write!(
254                f,
255                "extern_fn_({}, {}, {})",
256                variadic,
257                args.iter()
258                    .map(|v| format!("{}", v))
259                    .collect::<Vec<_>>()
260                    .join(","),
261                body
262            ),
263        }
264    }
265}
266
267impl Into<Type> for RustType {
268    fn into(self) -> Type {
269        match self {
270            RustType::Never => syn::Type::Never(syn::TypeNever {
271                bang_token: Default::default(),
272            }),
273            RustType::Array(box ty, size) => Type::Array(syn::TypeArray {
274                bracket_token: Default::default(),
275                elem: Box::new(ty.into()),
276                semi_token: Default::default(),
277                len: syn::parse_str::<syn::Expr>(&format!("{}", size)).unwrap(),
278            }),
279            RustType::Option(box ty) => Type::Path(syn::TypePath {
280                qself: None,
281                path: syn::Path {
282                    leading_colon: None,
283                    segments: [PathSegment {
284                        ident: syn::parse_str::<syn::Ident>("Option").unwrap(),
285                        arguments: syn::PathArguments::AngleBracketed(
286                            syn::AngleBracketedGenericArguments {
287                                colon2_token: None,
288                                lt_token: Default::default(),
289                                args: [syn::GenericArgument::Type(ty.into())]
290                                    .into_iter()
291                                    .collect(),
292                                gt_token: Default::default(),
293                            },
294                        ),
295                    }]
296                    .into_iter()
297                    .collect(),
298                },
299            }),
300            RustType::Vec(box ty) => Type::Path(syn::TypePath {
301                qself: None,
302                path: syn::Path {
303                    leading_colon: None,
304                    segments: [PathSegment {
305                        ident: syn::parse_str::<syn::Ident>("Vec").unwrap(),
306                        arguments: syn::PathArguments::AngleBracketed(
307                            syn::AngleBracketedGenericArguments {
308                                colon2_token: None,
309                                lt_token: Default::default(),
310                                args: [syn::GenericArgument::Type(ty.into())]
311                                    .into_iter()
312                                    .collect(),
313                                gt_token: Default::default(),
314                            },
315                        ),
316                    }]
317                    .into_iter()
318                    .collect(),
319                },
320            }),
321            RustType::Unit => syn::parse_str::<Type>("()").unwrap(),
322            ty @ RustType::CInt { .. } => syn::parse_str::<Type>(&format!("libc::{}", ty)).unwrap(),
323            ty @ RustType::CFloat(_) => syn::parse_str::<Type>(&format!("libc::{}", ty)).unwrap(),
324            RustType::CVoid => syn::parse_str::<Type>("libc::c_void").unwrap(),
325
326            RustType::CAlias(ident) => Type::Path(syn::TypePath {
327                qself: None,
328                path: syn::Path {
329                    leading_colon: None,
330                    segments: [syn::PathSegment::from(ident)].into_iter().collect(),
331                },
332            }),
333            RustType::SizeT => syn::parse_str::<Type>("size_t").unwrap(),
334            RustType::U8 => syn::parse_str::<Type>("u8").unwrap(),
335            RustType::I32 => syn::parse_str::<Type>("i32").unwrap(),
336            RustType::Isize => syn::parse_str::<Type>("isize").unwrap(),
337            RustType::Usize => syn::parse_str::<Type>("usize").unwrap(),
338            RustType::TVar(n) => syn::parse_str::<Type>(&format!("{}", n)).unwrap(),
339            RustType::Pointer(box v) => Type::Ptr(syn::TypePtr {
340                const_token: None,
341                mutability: Some(Default::default()),
342                elem: Box::new(v.into()),
343                star_token: Default::default(),
344            }),
345            RustType::Reference(muta, box v) => Type::Reference(syn::TypeReference {
346                and_token: Default::default(),
347                mutability: muta.into(),
348                elem: Box::new(v.into()),
349                lifetime: None,
350            }),
351            RustType::ExternFn(args, variadic, box res) => {
352                let inputs = if variadic {
353                    args.into_iter()
354                        .map(|f| syn::BareFnArg {
355                            attrs: Default::default(),
356                            name: None,
357                            ty: (*f).into(),
358                        })
359                        .collect()
360                } else {
361                    args.into_iter()
362                        .map(|f| syn::BareFnArg {
363                            attrs: Default::default(),
364                            name: None,
365                            ty: (*f).into(),
366                        })
367                        .chain(
368                            [syn::BareFnArg {
369                                attrs: Default::default(),
370                                name: None,
371                                ty: Type::Verbatim("...".to_token_stream()),
372                            }]
373                            .into_iter(),
374                        )
375                        .collect()
376                };
377
378                Type::BareFn(syn::TypeBareFn {
379                    lifetimes: None,
380                    unsafety: Some(Default::default()),
381                    abi: Some(syn::Abi {
382                        extern_token: Default::default(),
383                        name: Some(syn::parse_str::<syn::LitStr>("\"C\"").unwrap()),
384                    }),
385                    fn_token: Default::default(),
386                    paren_token: Default::default(),
387                    inputs,
388                    variadic: None,
389                    output: syn::ReturnType::Type(Default::default(), Box::new(res.into())),
390                })
391            }
392        }
393    }
394}
395
396impl From<Type> for RustType {
397    fn from(ty: Type) -> Self {
398        match ty {
399            Type::Path(syn::TypePath {
400                path: Path { segments, .. },
401                ..
402            }) if segments.last().is_some_and(|segment| {
403                segment.ident == "Option" && !segment.arguments.is_empty()
404            }) =>
405            {
406                let syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments { args, .. }) =
407                        &segments.last().unwrap().arguments else {
408                        panic!("found use of unsupported syntactic construct {} in code", segments.to_token_stream().to_string())
409                    };
410
411                let syn::GenericArgument::Type(ty) = &args[0] else {
412                        panic!("found use of unsupported syntactic construct {} in code", segments.to_token_stream().to_string())
413                    };
414
415                RustType::Option(Box::new(ty.clone().into()))
416            }
417            Type::Path(syn::TypePath {
418                path: Path { segments, .. },
419                ..
420            }) if segments
421                .last()
422                .is_some_and(|segment| segment.ident == "Vec" && !segment.arguments.is_empty()) =>
423            {
424                let syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments { args, .. }) =
425                        &segments.last().unwrap().arguments else {
426                        panic!("found use of unsupported syntactic construct {} in code", segments.to_token_stream().to_string())
427                    };
428
429                let syn::GenericArgument::Type(ty) = &args[0] else {
430                        panic!("found use of unsupported syntactic construct {} in code", segments.to_token_stream().to_string())
431                    };
432
433                RustType::Vec(Box::new(ty.clone().into()))
434            }
435            Type::Path(syn::TypePath {
436                path: ref path @ Path { ref segments, .. },
437                ..
438            }) if segments
439                .last()
440                .is_some_and(|segment| segment.arguments.is_empty()) =>
441            {
442                let ident = &segments.last().unwrap().ident;
443
444                if !(segments.len() == 1)
445                    && !(segments.len() > 1 && segments[0].ident == "libc")
446                    && !(segments.len() > 3
447                        && segments[0].ident == "std"
448                        && segments[1].ident == "os"
449                        && segments[2].ident == "raw")
450                {
451                    log::warn!("found use of non libc type {:?}, optimistically assuming nothing crazy is going on",
452                               path.to_token_stream().to_string()
453                    )
454                }
455
456                use RustType::*;
457
458                match ident.to_string().as_str() {
459                    "isize" => Isize,
460                    "i32" => I32,
461                    "size_t" => SizeT,
462                    "u8" => U8,
463                    "usize" => Usize,
464                    "c_float" => CFloat(CFloatSize::Float),
465                    "c_double" => CFloat(CFloatSize::Double),
466
467                    "c_char" => CInt {
468                        unsigned: false,
469                        size: CIntegralSize::Char,
470                    },
471                    "c_schar" => CInt {
472                        unsigned: false,
473                        size: CIntegralSize::Char,
474                    },
475                    "c_uchar" => CInt {
476                        unsigned: true,
477                        size: CIntegralSize::Char,
478                    },
479                    "c_short" => CInt {
480                        unsigned: false,
481                        size: CIntegralSize::Short,
482                    },
483                    "c_ushort" => CInt {
484                        unsigned: true,
485                        size: CIntegralSize::Short,
486                    },
487
488                    "c_int" => CInt {
489                        unsigned: false,
490                        size: CIntegralSize::Int,
491                    },
492                    "c_uint" => CInt {
493                        unsigned: true,
494                        size: CIntegralSize::Int,
495                    },
496
497                    "c_long" => CInt {
498                        unsigned: false,
499                        size: CIntegralSize::Long,
500                    },
501                    "c_ulong" => CInt {
502                        unsigned: true,
503                        size: CIntegralSize::Long,
504                    },
505
506                    "c_longlong" => CInt {
507                        unsigned: false,
508                        size: CIntegralSize::Long,
509                    },
510                    "c_ulonglong" => CInt {
511                        unsigned: true,
512                        size: CIntegralSize::Long,
513                    },
514
515                    "c_void" => CVoid,
516
517                    _txt => RustType::CAlias(ident.clone()),
518                }
519            }
520            Type::Ptr(syn::TypePtr {
521                const_token,
522                mutability,
523                elem: box ty,
524                ..
525            }) => {
526                match (const_token, mutability) {
527                    (None, Some(_)) => RustType::Pointer(Box::new(ty.into())), //case where we have *mut T
528                    (Some(_), None) => RustType::Pointer(Box::new(ty.into())), //case where we have *const T
529                    (_, _) => {
530                        todo!()
531                    }
532                }
533            }
534            Type::Reference(syn::TypeReference {
535                lifetime: None,
536                mutability,
537                elem: box elem,
538                ..
539            }) => RustType::Reference(mutability.into(), Box::new(elem.into())),
540
541            Type::Tuple(syn::TypeTuple { elems, .. }) if elems.len() == 0 => RustType::Unit,
542
543            Type::BareFn(syn::TypeBareFn {
544                unsafety: Some(_),
545                abi: Some(_),
546                inputs,
547                output,
548                variadic: None,
549                ..
550            }) => {
551                let mut variadic = false;
552                let inputs = inputs
553                    .into_iter()
554                    .map(|v| match v.ty {
555                        Type::Verbatim(_v) => {
556                            variadic = true;
557                            None
558                        }
559                        _ => Some(Box::new(v.ty.into())),
560                    })
561                    .flatten()
562                    .collect();
563
564                let output = match output {
565                    syn::ReturnType::Default => RustType::Unit,
566                    syn::ReturnType::Type(_, box ty) => ty.into(),
567                };
568                RustType::ExternFn(inputs, variadic, Box::new(output))
569            }
570
571            Type::Array(TypeArray {
572                elem: box ty,
573                len:
574                    syn::Expr::Lit(syn::ExprLit {
575                        lit: syn::Lit::Int(i),
576                        ..
577                    }),
578                ..
579            }) => RustType::Array(Box::new(ty.into()), i.base10_parse().unwrap()),
580            Type::Never(_) => RustType::Never,
581            _ => panic!("unsupported type {:?}", ty.to_token_stream().to_string()),
582        }
583    }
584}
585
586#[derive(Clone, Debug, Hash)]
587pub struct RustStruct {
588    name: syn::Ident,
589    fields: Vec<(syn::Ident, RustType)>,
590}
591
592impl std::fmt::Display for RustStruct {
593    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
594        write!(
595            f,
596            "struct {} {{ {} }}",
597            self.name,
598            self.fields
599                .iter()
600                .map(|(name, ty)| format!("{}: {}", name, ty))
601                .collect::<Vec<_>>()
602                .join(",")
603        )
604    }
605}
606
607impl RustStruct {
608    pub fn name(&self) -> &syn::Ident {
609        &self.name
610    }
611
612    pub fn fields(&self) -> &Vec<(syn::Ident, RustType)> {
613        &self.fields
614    }
615
616    /// Returns a list of all the structs that this struct references
617    pub fn uses(&self) -> HashSet<syn::Ident> {
618        let mut uses = HashSet::new();
619        for (_, ty) in self.fields.iter() {
620            ty.uses(&mut uses);
621        }
622        uses
623    }
624
625    /// Resolves a structs types, and returns the list of type names it references
626    pub fn resolve(&mut self, ctxt: &ProgramTypeContext) -> HashSet<syn::Ident> {
627        let mut acc = HashSet::new();
628        for (_, ty) in self.fields.iter_mut() {
629            acc.extend(ty.resolve(ctxt).into_iter())
630        }
631        acc
632    }
633}
634
635impl From<syn::ItemStruct> for RustStruct {
636    fn from(i: syn::ItemStruct) -> Self {
637        if i.attrs.len() == 0 {
638            log::warn!("skipping unknown attributes {:?}", i.attrs)
639        }
640        if i.generics.params.len() > 0 {
641            panic!("found unsupported use of generic struct {:?} - @Bryan, if you are implementing lifetimes or something, talk to me first.",
642            i.generics)
643        }
644
645        let syn::Fields::Named(FieldsNamed { named, .. }) = i.fields else {
646            panic!("found unsupported struct declaration {}", i.to_token_stream().to_string())
647        };
648        let name = i.ident;
649        let fields = named
650            .into_iter()
651            .map(|v| (v.ident.unwrap(), v.ty.into()))
652            .collect();
653        RustStruct { name, fields }
654    }
655}
656
657impl RustStruct {}
658
659#[derive(Clone, Debug, Hash)]
660pub enum RustTypeConstraint {
661    /// Index(T1, T2) represents Index<T1, Output=T2>
662    Index(RustType, RustType),
663    /// IndexMut(T1, T2) represents IndexMut<T1, Output=T2>
664    IndexMut(RustType, RustType),
665}
666
667impl Into<syn::TypeParamBound> for RustTypeConstraint {
668    fn into(self) -> syn::TypeParamBound {
669        match self {
670            // Index<T1, Output=T2>
671            RustTypeConstraint::Index(t1, t2) => {
672                let mut path = syn::Path {
673                    leading_colon: None,
674                    segments: Punctuated::new(),
675                };
676                let mut args = Punctuated::new();
677                args.push(syn::GenericArgument::Type(t1.into()));
678                args.push(syn::GenericArgument::AssocType(syn::AssocType {
679                    ident: syn::parse_str::<syn::Ident>("Output").unwrap(),
680                    generics: Default::default(),
681                    eq_token: Default::default(),
682                    ty: t2.into(),
683                }));
684
685                path.segments.push(syn::PathSegment {
686                    ident: syn::parse_str::<syn::Ident>("Index").unwrap(),
687                    arguments: syn::PathArguments::AngleBracketed(
688                        syn::AngleBracketedGenericArguments {
689                            colon2_token: None,
690                            lt_token: Default::default(),
691                            args,
692                            gt_token: Default::default(),
693                        },
694                    ),
695                });
696
697                syn::TypeParamBound::Trait(syn::TraitBound {
698                    paren_token: None,
699                    modifier: syn::TraitBoundModifier::None,
700                    lifetimes: None,
701                    path,
702                })
703            }
704            RustTypeConstraint::IndexMut(t1, t2) => {
705                let mut path = syn::Path {
706                    leading_colon: None,
707                    segments: Punctuated::new(),
708                };
709                let mut args = Punctuated::new();
710                args.push(syn::GenericArgument::Type(t1.into()));
711                args.push(syn::GenericArgument::AssocType(syn::AssocType {
712                    ident: syn::parse_str::<syn::Ident>("Output").unwrap(),
713                    generics: Default::default(),
714                    eq_token: Default::default(),
715                    ty: t2.into(),
716                }));
717
718                path.segments.push(syn::PathSegment {
719                    ident: syn::parse_str::<syn::Ident>("IndexMut").unwrap(),
720                    arguments: syn::PathArguments::AngleBracketed(
721                        syn::AngleBracketedGenericArguments {
722                            colon2_token: None,
723                            lt_token: Default::default(),
724                            args,
725                            gt_token: Default::default(),
726                        },
727                    ),
728                });
729
730                syn::TypeParamBound::Trait(syn::TraitBound {
731                    paren_token: None,
732                    modifier: syn::TraitBoundModifier::None,
733                    lifetimes: None,
734                    path,
735                })
736            }
737        }
738    }
739}
740
741impl From<syn::TypeParamBound> for RustTypeConstraint {
742    fn from(ty: syn::TypeParamBound) -> Self {
743        match ty {
744            syn::TypeParamBound::Trait(syn::TraitBound {
745                path: syn::Path { segments, .. },
746                ..
747            }) if segments.len() == 1 => {
748                let segment = segments[0].clone();
749                let trait_name = segment.ident.to_string();
750                match (trait_name.as_str(), segment.arguments) {
751                    ("Index", syn::PathArguments::AngleBracketed(args)) => {
752                        let in_ty = match args.args[0].clone() {
753                            syn::GenericArgument::Type(ty) => ty,
754                            _ => panic!("invalid constraint structure {:#?}", segments),
755                        };
756                        let out_ty = match args.args[1].clone() {
757                            syn::GenericArgument::AssocType(binding) => binding.ty,
758                            _ => panic!("invalid constraint structure {:#?}", segments),
759                        };
760                        RustTypeConstraint::Index(in_ty.into(), out_ty.into())
761                    }
762                    ("IndexMut", syn::PathArguments::AngleBracketed(args)) => {
763                        let in_ty = match args.args[0].clone() {
764                            syn::GenericArgument::Type(ty) => ty,
765                            _ => panic!("invalid constraint structure {:#?}", segments),
766                        };
767                        let out_ty = match args.args[1].clone() {
768                            syn::GenericArgument::AssocType(binding) => binding.ty,
769                            _ => panic!("invalid constraint structure {:#?}", segments),
770                        };
771                        RustTypeConstraint::IndexMut(in_ty.into(), out_ty.into())
772                    }
773                    _ => panic!("unsupported type constraint {:#?}", segments),
774                }
775            }
776            ty => panic!("unsupported type constraint {:#?}", ty),
777        }
778    }
779}
780
781impl std::fmt::Display for RustTypeConstraint {
782    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
783        match self {
784            RustTypeConstraint::Index(ind_ty, out_ty) => write!(f, "Index<{},{}>", ind_ty, out_ty),
785            RustTypeConstraint::IndexMut(ind_ty, out_ty) => {
786                write!(f, "IndexMut<{},{}>", ind_ty, out_ty)
787            }
788        }
789    }
790}
791
792#[derive(Clone)]
793pub struct RustTypeSignature {
794    name: String,
795    constraints: Vec<(TVar, Vec<RustTypeConstraint>)>,
796    args: Vec<(String, RustType)>,
797    out_ty: Option<RustType>,
798}
799
800impl RustTypeSignature {
801    pub fn constraints(&self) -> &Vec<(TVar, Vec<RustTypeConstraint>)> {
802        &self.constraints
803    }
804
805    pub fn args(&self) -> &Vec<(String, RustType)> {
806        &self.args
807    }
808}
809
810impl From<syn::Signature> for RustTypeSignature {
811    fn from(sig: syn::Signature) -> Self {
812        let name = sig.ident.to_string();
813        let constraints = sig
814            .generics
815            .type_params()
816            .into_iter()
817            .map(|param| {
818                let tvar = param.ident.to_string().into();
819                let bounds = param
820                    .bounds
821                    .clone()
822                    .into_pairs()
823                    .map(|v| v.into_value())
824                    .map(|v| v.into())
825                    .collect();
826                (tvar, bounds)
827            })
828            .collect::<Vec<_>>();
829        let args = sig
830            .inputs
831            .into_pairs()
832            .map(|v| match v.into_value() {
833                syn::FnArg::Typed(syn::PatType {
834                    pat: box syn::Pat::Ident(syn::PatIdent { ident, .. }),
835                    ty: box ty,
836                    ..
837                }) => (ident.to_string(), ty.into()),
838                v => panic!("unsupported pattern {:?} in signature", v),
839            })
840            .collect();
841        let out_ty = match sig.output {
842            syn::ReturnType::Default => None,
843            syn::ReturnType::Type(_, box ty) => Some(ty.into()),
844        };
845        RustTypeSignature {
846            name,
847            constraints,
848            args,
849            out_ty,
850        }
851    }
852}
853
854impl std::fmt::Display for RustTypeSignature {
855    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
856        write!(f, "fn {}<", self.name)?;
857
858        {
859            let constraints = self.constraints.iter().map(|v| Some(v)).intersperse(None);
860            for constraint in constraints {
861                match constraint {
862                    None => write!(f, ", ")?,
863                    Some((tvar, constraints)) => {
864                        write!(f, "{}: ", tvar)?;
865                        let constraints = constraints
866                            .iter()
867                            .map(|v| format!("{}", v))
868                            .intersperse(" + ".into());
869                        for opt in constraints {
870                            write!(f, "{}", opt)?
871                        }
872                    }
873                }
874            }
875        }
876        write!(f, ">(")?;
877        {
878            let mut args = self.args.iter();
879            let mut next = args.next();
880            while next.is_some() {
881                let (name, ty) = next.unwrap();
882                write!(f, "{}: {}", name, ty)?;
883                next = args.next();
884                if next.is_some() {
885                    write!(f, ",")?;
886                }
887            }
888        }
889        write!(f, ")")?;
890        match &self.out_ty {
891            None => (),
892            Some(ty) => writeln!(f, " -> {}", ty)?,
893        }
894
895        Ok(())
896    }
897}
898
899#[derive(Debug, Default, Clone)]
900pub struct CTypeContextCollector {
901    aliases: HashMap<syn::Ident, RustType>,
902    structs: HashMap<syn::Ident, RustStruct>,
903}
904
905impl CTypeContextCollector {
906    pub fn to_type_context(self) -> ProgramTypeContext {
907        (self.aliases, self.structs)
908    }
909}
910
911impl<'ast> syn::visit::Visit<'ast> for CTypeContextCollector {
912    fn visit_item_type(&mut self, i: &'ast syn::ItemType) {
913        let typ: RustType = (&*i.ty).clone().into();
914
915        self.aliases.insert(i.ident.clone(), typ);
916    }
917
918    fn visit_item_struct(&mut self, i: &'ast syn::ItemStruct) {
919        self.structs.insert(i.ident.clone(), i.clone().into());
920    }
921
922    fn visit_item_enum(&mut self, i: &'ast syn::ItemEnum) {
923        panic!(
924            "found use of unsupported enum {:?} declaration",
925            i.to_token_stream().to_string()
926        )
927    }
928}
929
930fn check_recursive(
931    checking: &syn::Ident,
932    current: syn::Ident,
933    mut path: HashSet<syn::Ident>,
934    usage_map: &HashMap<syn::Ident, HashSet<syn::Ident>>,
935) -> bool {
936    // hit a recursion, exit
937    if path.contains(&current) {
938        false
939    } else {
940        match usage_map.get(&current) {
941            None => false,
942            Some(usages) => {
943                path.insert(current);
944                if usages.contains(checking) {
945                    true
946                } else {
947                    usages
948                        .iter()
949                        .any(|id| check_recursive(checking, id.clone(), path.clone(), usage_map))
950                }
951            }
952        }
953    }
954}
955
956pub fn normalize_type_context(ctxt: &mut ProgramTypeContext) -> HashSet<syn::Ident> {
957    let mut usage_map = HashMap::new();
958    let ref_ctx = (ctxt.0.clone(), ctxt.1.clone());
959    for (_name, st) in ctxt.0.iter_mut() {
960        st.resolve(&ref_ctx);
961    }
962    for (name, st) in ctxt.1.iter_mut() {
963        st.resolve(&ref_ctx);
964        usage_map.insert(name.clone(), st.uses());
965    }
966
967    let mut recursive = HashSet::new();
968
969    for (name, uses) in usage_map.iter() {
970        if uses.contains(name) || check_recursive(name, name.clone(), HashSet::new(), &usage_map) {
971            recursive.insert(name.clone());
972        }
973    }
974    for (id, alias) in ctxt.0.iter().clone() {
975        let mut uses = HashSet::new();
976        alias.uses(&mut uses);
977        if recursive.intersection(&uses).any(|_| true) {
978            recursive.insert(id.clone());
979        }
980    }
981    recursive
982}
983
984impl UnifyValue for RustType {
985    type Error = Error;
986
987    fn unify_values(value1: &Self, value2: &Self) -> Result<Self, Self::Error> {
988        match (value1, value2) {
989            (t1, t2) if t1 == t2 => Ok(t1.clone()),
990            (RustType::TVar(t1), RustType::TVar(t2)) if t1 == t2 => Ok(RustType::TVar(*t1)),
991            (RustType::Pointer(box x), RustType::Pointer(box y)) => {
992                let contents = Self::unify_values(x, y)?;
993                Ok(RustType::Pointer(Box::new(contents)))
994            }
995            (t1, t2) => Err(Error::UnUnifiableTypes(t1.clone(), t2.clone())),
996        }
997    }
998}