1use ena::unify::UnifyValue;
2use quote::ToTokens;
3use std::collections::{HashMap, HashSet};
4use std::hash::Hash;
5use syn::punctuated::Punctuated;
6use syn::{FieldsNamed, Path, PathSegment, Type, TypeArray};
7
8pub type TypeMap = HashMap<crate::location::Loc, RustTypeSignature>;
10
11#[derive(Debug)]
12pub enum Error {
13 UnUnifiableTypes(RustType, RustType),
14}
15
16pub type ProgramTypeContext = (
17 HashMap<syn::Ident, RustType>,
18 HashMap<syn::Ident, RustStruct>,
19);
20
21#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
22pub struct TVar(pub usize);
23
24impl From<&str> for TVar {
25 fn from(s: &str) -> Self {
26 if !s.starts_with("T") {
27 panic!("invalid assumption: unknown generic var \"{}\"", s)
28 }
29 let ind = s
30 .trim_start_matches('T')
31 .parse()
32 .expect("Index for generic could not be extracted");
33 TVar(ind)
34 }
35}
36
37impl From<String> for TVar {
38 fn from(s: String) -> Self {
39 s.as_str().into()
40 }
41}
42
43impl std::fmt::Display for TVar {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 write!(f, "T{}", self.0)
46 }
47}
48
49#[derive(Clone, Debug, Hash, PartialEq, Eq)]
50pub enum RustMutability {
51 Immutable,
52 Mutable,
53}
54
55impl std::fmt::Display for RustMutability {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 RustMutability::Immutable => write!(f, "immutable"),
59 RustMutability::Mutable => write!(f, "mutable"),
60 }
61 }
62}
63
64impl From<Option<syn::token::Mut>> for RustMutability {
65 fn from(mut_: Option<syn::token::Mut>) -> Self {
66 match mut_ {
67 Some(_) => RustMutability::Mutable,
68 None => RustMutability::Immutable,
69 }
70 }
71}
72
73impl Into<Option<syn::token::Mut>> for RustMutability {
74 fn into(self) -> Option<syn::token::Mut> {
75 match self {
76 RustMutability::Mutable => Some(Default::default()),
77 RustMutability::Immutable => None,
78 }
79 }
80}
81
82#[derive(Clone, Debug, Hash, PartialEq, Eq)]
83pub enum CIntegralSize {
84 Char,
85 Short,
86 Int,
87 Long,
88 LongLong,
89}
90
91impl std::fmt::Display for CIntegralSize {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 match self {
94 CIntegralSize::Char => write!(f, "char"),
95 CIntegralSize::Short => write!(f, "short"),
96 CIntegralSize::Int => write!(f, "int"),
97 CIntegralSize::Long => write!(f, "long"),
98 CIntegralSize::LongLong => write!(f, "longlong"),
99 }
100 }
101}
102
103#[derive(Clone, Debug, Hash, PartialEq, Eq)]
104pub enum CFloatSize {
105 Float,
106 Double,
107}
108
109impl std::fmt::Display for CFloatSize {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 match self {
112 CFloatSize::Float => write!(f, "float"),
113 CFloatSize::Double => write!(f, "double"),
114 }
115 }
116}
117
118#[derive(Clone, Debug, Hash, PartialEq, Eq)]
119pub enum RustType {
120 CVoid,
121 CInt {
122 unsigned: bool,
123 size: CIntegralSize,
124 },
125 CFloat(CFloatSize),
126 CAlias(syn::Ident),
127
128 Array(Box<RustType>, usize),
129
130 Option(Box<RustType>),
131 Vec(Box<RustType>),
132 Unit,
133 I32,
134 U8,
135 SizeT,
136 Isize,
137 Usize,
138 TVar(TVar), Never,
141
142 ExternFn(Vec<Box<RustType>>, bool, Box<RustType>),
143
144 Reference(RustMutability, Box<RustType>),
146 Pointer(Box<RustType>),
148}
149
150impl RustType {
151 fn uses(&self, set: &mut HashSet<syn::Ident>) {
152 match self {
153 RustType::CAlias(id) => {
154 set.insert(id.clone());
155 ()
156 }
157 RustType::Option(ty)
158 | RustType::Vec(ty)
159 | RustType::Reference(_, ty)
160 | RustType::Pointer(ty)
161 | RustType::Array(ty, _) => ty.uses(set),
162
163 _ => (),
170 }
171 }
172
173 fn resolve_checked(
174 &mut self,
175 path: &mut HashSet<syn::Ident>,
176 ctxt: &ProgramTypeContext,
177 ) -> bool {
178 match self {
179 RustType::CAlias(id) if path.contains(id) => true,
181 RustType::CAlias(id) if !path.contains(id) => {
182 path.insert(id.clone());
184 if ctxt.1.contains_key(id) {
186 false
187 } else {
188 match ctxt.0.get(id) {
189 Some(inner) => {
190 *self = inner.clone();
192 self.resolve_checked(path, ctxt)
194 }
195 None => {
196 log::warn!("attempted to resolve type {} that has no defined alias or definition", id.to_string());
197 false
198 }
199 }
200 }
201 }
202 RustType::Option(elt)
203 | RustType::Vec(elt)
204 | RustType::Pointer(elt)
205 | RustType::Reference(_, elt)
206 | RustType::Array(elt, _) => elt.resolve_checked(path, ctxt),
207 RustType::ExternFn(args, _, out) => {
208 let mut any_rec = false;
209 let mut base_path = path.clone();
210 for arg in args.iter_mut() {
211 let mut rec_path = base_path.clone();
212 any_rec |= arg.resolve_checked(&mut rec_path, ctxt);
213 path.extend(rec_path.into_iter());
214 }
215 any_rec |= out.resolve_checked(&mut base_path, ctxt);
216 path.extend(base_path.into_iter());
217 any_rec
218 }
219 _ => false,
220 }
221 }
222
223 pub fn resolve(&mut self, ctxt: &ProgramTypeContext) -> HashSet<syn::Ident> {
225 let mut set = HashSet::new();
226 self.resolve_checked(&mut set, ctxt);
227 set
228 }
229}
230
231impl std::fmt::Display for RustType {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 match self {
234 RustType::Never => write!(f, "never"),
235 RustType::Array(box ty, size) => write!(f, "array({}, {})", ty, size),
236 RustType::Option(box ty) => write!(f, "option({})", ty),
237 RustType::Vec(box ty) => write!(f, "vec({})", ty),
238 RustType::CInt { unsigned, size } => {
239 write!(f, "c_{}{}", if *unsigned { "u" } else { "" }, size)
240 }
241 RustType::CFloat(size) => write!(f, "c_{}", size),
242 RustType::CVoid => write!(f, "c_void"),
243 RustType::CAlias(ident) => write!(f, "{}", ident),
244 RustType::Unit => write!(f, "()"),
245 RustType::SizeT => write!(f, "size_t"),
246 RustType::U8 => write!(f, "u8"),
247 RustType::I32 => write!(f, "i32"),
248 RustType::Isize => write!(f, "isize"),
249 RustType::Usize => write!(f, "usize"),
250 RustType::TVar(tvar) => write!(f, "{}", tvar),
251 RustType::Pointer(box x) => write!(f, "mut_ptr_{}", x),
252 RustType::Reference(mt, box x) => write!(f, "ref_{}_{}", mt, x),
253 RustType::ExternFn(args, variadic, body) => write!(
254 f,
255 "extern_fn_({}, {}, {})",
256 variadic,
257 args.iter()
258 .map(|v| format!("{}", v))
259 .collect::<Vec<_>>()
260 .join(","),
261 body
262 ),
263 }
264 }
265}
266
267impl Into<Type> for RustType {
268 fn into(self) -> Type {
269 match self {
270 RustType::Never => syn::Type::Never(syn::TypeNever {
271 bang_token: Default::default(),
272 }),
273 RustType::Array(box ty, size) => Type::Array(syn::TypeArray {
274 bracket_token: Default::default(),
275 elem: Box::new(ty.into()),
276 semi_token: Default::default(),
277 len: syn::parse_str::<syn::Expr>(&format!("{}", size)).unwrap(),
278 }),
279 RustType::Option(box ty) => Type::Path(syn::TypePath {
280 qself: None,
281 path: syn::Path {
282 leading_colon: None,
283 segments: [PathSegment {
284 ident: syn::parse_str::<syn::Ident>("Option").unwrap(),
285 arguments: syn::PathArguments::AngleBracketed(
286 syn::AngleBracketedGenericArguments {
287 colon2_token: None,
288 lt_token: Default::default(),
289 args: [syn::GenericArgument::Type(ty.into())]
290 .into_iter()
291 .collect(),
292 gt_token: Default::default(),
293 },
294 ),
295 }]
296 .into_iter()
297 .collect(),
298 },
299 }),
300 RustType::Vec(box ty) => Type::Path(syn::TypePath {
301 qself: None,
302 path: syn::Path {
303 leading_colon: None,
304 segments: [PathSegment {
305 ident: syn::parse_str::<syn::Ident>("Vec").unwrap(),
306 arguments: syn::PathArguments::AngleBracketed(
307 syn::AngleBracketedGenericArguments {
308 colon2_token: None,
309 lt_token: Default::default(),
310 args: [syn::GenericArgument::Type(ty.into())]
311 .into_iter()
312 .collect(),
313 gt_token: Default::default(),
314 },
315 ),
316 }]
317 .into_iter()
318 .collect(),
319 },
320 }),
321 RustType::Unit => syn::parse_str::<Type>("()").unwrap(),
322 ty @ RustType::CInt { .. } => syn::parse_str::<Type>(&format!("libc::{}", ty)).unwrap(),
323 ty @ RustType::CFloat(_) => syn::parse_str::<Type>(&format!("libc::{}", ty)).unwrap(),
324 RustType::CVoid => syn::parse_str::<Type>("libc::c_void").unwrap(),
325
326 RustType::CAlias(ident) => Type::Path(syn::TypePath {
327 qself: None,
328 path: syn::Path {
329 leading_colon: None,
330 segments: [syn::PathSegment::from(ident)].into_iter().collect(),
331 },
332 }),
333 RustType::SizeT => syn::parse_str::<Type>("size_t").unwrap(),
334 RustType::U8 => syn::parse_str::<Type>("u8").unwrap(),
335 RustType::I32 => syn::parse_str::<Type>("i32").unwrap(),
336 RustType::Isize => syn::parse_str::<Type>("isize").unwrap(),
337 RustType::Usize => syn::parse_str::<Type>("usize").unwrap(),
338 RustType::TVar(n) => syn::parse_str::<Type>(&format!("{}", n)).unwrap(),
339 RustType::Pointer(box v) => Type::Ptr(syn::TypePtr {
340 const_token: None,
341 mutability: Some(Default::default()),
342 elem: Box::new(v.into()),
343 star_token: Default::default(),
344 }),
345 RustType::Reference(muta, box v) => Type::Reference(syn::TypeReference {
346 and_token: Default::default(),
347 mutability: muta.into(),
348 elem: Box::new(v.into()),
349 lifetime: None,
350 }),
351 RustType::ExternFn(args, variadic, box res) => {
352 let inputs = if variadic {
353 args.into_iter()
354 .map(|f| syn::BareFnArg {
355 attrs: Default::default(),
356 name: None,
357 ty: (*f).into(),
358 })
359 .collect()
360 } else {
361 args.into_iter()
362 .map(|f| syn::BareFnArg {
363 attrs: Default::default(),
364 name: None,
365 ty: (*f).into(),
366 })
367 .chain(
368 [syn::BareFnArg {
369 attrs: Default::default(),
370 name: None,
371 ty: Type::Verbatim("...".to_token_stream()),
372 }]
373 .into_iter(),
374 )
375 .collect()
376 };
377
378 Type::BareFn(syn::TypeBareFn {
379 lifetimes: None,
380 unsafety: Some(Default::default()),
381 abi: Some(syn::Abi {
382 extern_token: Default::default(),
383 name: Some(syn::parse_str::<syn::LitStr>("\"C\"").unwrap()),
384 }),
385 fn_token: Default::default(),
386 paren_token: Default::default(),
387 inputs,
388 variadic: None,
389 output: syn::ReturnType::Type(Default::default(), Box::new(res.into())),
390 })
391 }
392 }
393 }
394}
395
396impl From<Type> for RustType {
397 fn from(ty: Type) -> Self {
398 match ty {
399 Type::Path(syn::TypePath {
400 path: Path { segments, .. },
401 ..
402 }) if segments.last().is_some_and(|segment| {
403 segment.ident == "Option" && !segment.arguments.is_empty()
404 }) =>
405 {
406 let syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments { args, .. }) =
407 &segments.last().unwrap().arguments else {
408 panic!("found use of unsupported syntactic construct {} in code", segments.to_token_stream().to_string())
409 };
410
411 let syn::GenericArgument::Type(ty) = &args[0] else {
412 panic!("found use of unsupported syntactic construct {} in code", segments.to_token_stream().to_string())
413 };
414
415 RustType::Option(Box::new(ty.clone().into()))
416 }
417 Type::Path(syn::TypePath {
418 path: Path { segments, .. },
419 ..
420 }) if segments
421 .last()
422 .is_some_and(|segment| segment.ident == "Vec" && !segment.arguments.is_empty()) =>
423 {
424 let syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments { args, .. }) =
425 &segments.last().unwrap().arguments else {
426 panic!("found use of unsupported syntactic construct {} in code", segments.to_token_stream().to_string())
427 };
428
429 let syn::GenericArgument::Type(ty) = &args[0] else {
430 panic!("found use of unsupported syntactic construct {} in code", segments.to_token_stream().to_string())
431 };
432
433 RustType::Vec(Box::new(ty.clone().into()))
434 }
435 Type::Path(syn::TypePath {
436 path: ref path @ Path { ref segments, .. },
437 ..
438 }) if segments
439 .last()
440 .is_some_and(|segment| segment.arguments.is_empty()) =>
441 {
442 let ident = &segments.last().unwrap().ident;
443
444 if !(segments.len() == 1)
445 && !(segments.len() > 1 && segments[0].ident == "libc")
446 && !(segments.len() > 3
447 && segments[0].ident == "std"
448 && segments[1].ident == "os"
449 && segments[2].ident == "raw")
450 {
451 log::warn!("found use of non libc type {:?}, optimistically assuming nothing crazy is going on",
452 path.to_token_stream().to_string()
453 )
454 }
455
456 use RustType::*;
457
458 match ident.to_string().as_str() {
459 "isize" => Isize,
460 "i32" => I32,
461 "size_t" => SizeT,
462 "u8" => U8,
463 "usize" => Usize,
464 "c_float" => CFloat(CFloatSize::Float),
465 "c_double" => CFloat(CFloatSize::Double),
466
467 "c_char" => CInt {
468 unsigned: false,
469 size: CIntegralSize::Char,
470 },
471 "c_schar" => CInt {
472 unsigned: false,
473 size: CIntegralSize::Char,
474 },
475 "c_uchar" => CInt {
476 unsigned: true,
477 size: CIntegralSize::Char,
478 },
479 "c_short" => CInt {
480 unsigned: false,
481 size: CIntegralSize::Short,
482 },
483 "c_ushort" => CInt {
484 unsigned: true,
485 size: CIntegralSize::Short,
486 },
487
488 "c_int" => CInt {
489 unsigned: false,
490 size: CIntegralSize::Int,
491 },
492 "c_uint" => CInt {
493 unsigned: true,
494 size: CIntegralSize::Int,
495 },
496
497 "c_long" => CInt {
498 unsigned: false,
499 size: CIntegralSize::Long,
500 },
501 "c_ulong" => CInt {
502 unsigned: true,
503 size: CIntegralSize::Long,
504 },
505
506 "c_longlong" => CInt {
507 unsigned: false,
508 size: CIntegralSize::Long,
509 },
510 "c_ulonglong" => CInt {
511 unsigned: true,
512 size: CIntegralSize::Long,
513 },
514
515 "c_void" => CVoid,
516
517 _txt => RustType::CAlias(ident.clone()),
518 }
519 }
520 Type::Ptr(syn::TypePtr {
521 const_token,
522 mutability,
523 elem: box ty,
524 ..
525 }) => {
526 match (const_token, mutability) {
527 (None, Some(_)) => RustType::Pointer(Box::new(ty.into())), (Some(_), None) => RustType::Pointer(Box::new(ty.into())), (_, _) => {
530 todo!()
531 }
532 }
533 }
534 Type::Reference(syn::TypeReference {
535 lifetime: None,
536 mutability,
537 elem: box elem,
538 ..
539 }) => RustType::Reference(mutability.into(), Box::new(elem.into())),
540
541 Type::Tuple(syn::TypeTuple { elems, .. }) if elems.len() == 0 => RustType::Unit,
542
543 Type::BareFn(syn::TypeBareFn {
544 unsafety: Some(_),
545 abi: Some(_),
546 inputs,
547 output,
548 variadic: None,
549 ..
550 }) => {
551 let mut variadic = false;
552 let inputs = inputs
553 .into_iter()
554 .map(|v| match v.ty {
555 Type::Verbatim(_v) => {
556 variadic = true;
557 None
558 }
559 _ => Some(Box::new(v.ty.into())),
560 })
561 .flatten()
562 .collect();
563
564 let output = match output {
565 syn::ReturnType::Default => RustType::Unit,
566 syn::ReturnType::Type(_, box ty) => ty.into(),
567 };
568 RustType::ExternFn(inputs, variadic, Box::new(output))
569 }
570
571 Type::Array(TypeArray {
572 elem: box ty,
573 len:
574 syn::Expr::Lit(syn::ExprLit {
575 lit: syn::Lit::Int(i),
576 ..
577 }),
578 ..
579 }) => RustType::Array(Box::new(ty.into()), i.base10_parse().unwrap()),
580 Type::Never(_) => RustType::Never,
581 _ => panic!("unsupported type {:?}", ty.to_token_stream().to_string()),
582 }
583 }
584}
585
586#[derive(Clone, Debug, Hash)]
587pub struct RustStruct {
588 name: syn::Ident,
589 fields: Vec<(syn::Ident, RustType)>,
590}
591
592impl std::fmt::Display for RustStruct {
593 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
594 write!(
595 f,
596 "struct {} {{ {} }}",
597 self.name,
598 self.fields
599 .iter()
600 .map(|(name, ty)| format!("{}: {}", name, ty))
601 .collect::<Vec<_>>()
602 .join(",")
603 )
604 }
605}
606
607impl RustStruct {
608 pub fn name(&self) -> &syn::Ident {
609 &self.name
610 }
611
612 pub fn fields(&self) -> &Vec<(syn::Ident, RustType)> {
613 &self.fields
614 }
615
616 pub fn uses(&self) -> HashSet<syn::Ident> {
618 let mut uses = HashSet::new();
619 for (_, ty) in self.fields.iter() {
620 ty.uses(&mut uses);
621 }
622 uses
623 }
624
625 pub fn resolve(&mut self, ctxt: &ProgramTypeContext) -> HashSet<syn::Ident> {
627 let mut acc = HashSet::new();
628 for (_, ty) in self.fields.iter_mut() {
629 acc.extend(ty.resolve(ctxt).into_iter())
630 }
631 acc
632 }
633}
634
635impl From<syn::ItemStruct> for RustStruct {
636 fn from(i: syn::ItemStruct) -> Self {
637 if i.attrs.len() == 0 {
638 log::warn!("skipping unknown attributes {:?}", i.attrs)
639 }
640 if i.generics.params.len() > 0 {
641 panic!("found unsupported use of generic struct {:?} - @Bryan, if you are implementing lifetimes or something, talk to me first.",
642 i.generics)
643 }
644
645 let syn::Fields::Named(FieldsNamed { named, .. }) = i.fields else {
646 panic!("found unsupported struct declaration {}", i.to_token_stream().to_string())
647 };
648 let name = i.ident;
649 let fields = named
650 .into_iter()
651 .map(|v| (v.ident.unwrap(), v.ty.into()))
652 .collect();
653 RustStruct { name, fields }
654 }
655}
656
657impl RustStruct {}
658
659#[derive(Clone, Debug, Hash)]
660pub enum RustTypeConstraint {
661 Index(RustType, RustType),
663 IndexMut(RustType, RustType),
665}
666
667impl Into<syn::TypeParamBound> for RustTypeConstraint {
668 fn into(self) -> syn::TypeParamBound {
669 match self {
670 RustTypeConstraint::Index(t1, t2) => {
672 let mut path = syn::Path {
673 leading_colon: None,
674 segments: Punctuated::new(),
675 };
676 let mut args = Punctuated::new();
677 args.push(syn::GenericArgument::Type(t1.into()));
678 args.push(syn::GenericArgument::AssocType(syn::AssocType {
679 ident: syn::parse_str::<syn::Ident>("Output").unwrap(),
680 generics: Default::default(),
681 eq_token: Default::default(),
682 ty: t2.into(),
683 }));
684
685 path.segments.push(syn::PathSegment {
686 ident: syn::parse_str::<syn::Ident>("Index").unwrap(),
687 arguments: syn::PathArguments::AngleBracketed(
688 syn::AngleBracketedGenericArguments {
689 colon2_token: None,
690 lt_token: Default::default(),
691 args,
692 gt_token: Default::default(),
693 },
694 ),
695 });
696
697 syn::TypeParamBound::Trait(syn::TraitBound {
698 paren_token: None,
699 modifier: syn::TraitBoundModifier::None,
700 lifetimes: None,
701 path,
702 })
703 }
704 RustTypeConstraint::IndexMut(t1, t2) => {
705 let mut path = syn::Path {
706 leading_colon: None,
707 segments: Punctuated::new(),
708 };
709 let mut args = Punctuated::new();
710 args.push(syn::GenericArgument::Type(t1.into()));
711 args.push(syn::GenericArgument::AssocType(syn::AssocType {
712 ident: syn::parse_str::<syn::Ident>("Output").unwrap(),
713 generics: Default::default(),
714 eq_token: Default::default(),
715 ty: t2.into(),
716 }));
717
718 path.segments.push(syn::PathSegment {
719 ident: syn::parse_str::<syn::Ident>("IndexMut").unwrap(),
720 arguments: syn::PathArguments::AngleBracketed(
721 syn::AngleBracketedGenericArguments {
722 colon2_token: None,
723 lt_token: Default::default(),
724 args,
725 gt_token: Default::default(),
726 },
727 ),
728 });
729
730 syn::TypeParamBound::Trait(syn::TraitBound {
731 paren_token: None,
732 modifier: syn::TraitBoundModifier::None,
733 lifetimes: None,
734 path,
735 })
736 }
737 }
738 }
739}
740
741impl From<syn::TypeParamBound> for RustTypeConstraint {
742 fn from(ty: syn::TypeParamBound) -> Self {
743 match ty {
744 syn::TypeParamBound::Trait(syn::TraitBound {
745 path: syn::Path { segments, .. },
746 ..
747 }) if segments.len() == 1 => {
748 let segment = segments[0].clone();
749 let trait_name = segment.ident.to_string();
750 match (trait_name.as_str(), segment.arguments) {
751 ("Index", syn::PathArguments::AngleBracketed(args)) => {
752 let in_ty = match args.args[0].clone() {
753 syn::GenericArgument::Type(ty) => ty,
754 _ => panic!("invalid constraint structure {:#?}", segments),
755 };
756 let out_ty = match args.args[1].clone() {
757 syn::GenericArgument::AssocType(binding) => binding.ty,
758 _ => panic!("invalid constraint structure {:#?}", segments),
759 };
760 RustTypeConstraint::Index(in_ty.into(), out_ty.into())
761 }
762 ("IndexMut", syn::PathArguments::AngleBracketed(args)) => {
763 let in_ty = match args.args[0].clone() {
764 syn::GenericArgument::Type(ty) => ty,
765 _ => panic!("invalid constraint structure {:#?}", segments),
766 };
767 let out_ty = match args.args[1].clone() {
768 syn::GenericArgument::AssocType(binding) => binding.ty,
769 _ => panic!("invalid constraint structure {:#?}", segments),
770 };
771 RustTypeConstraint::IndexMut(in_ty.into(), out_ty.into())
772 }
773 _ => panic!("unsupported type constraint {:#?}", segments),
774 }
775 }
776 ty => panic!("unsupported type constraint {:#?}", ty),
777 }
778 }
779}
780
781impl std::fmt::Display for RustTypeConstraint {
782 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
783 match self {
784 RustTypeConstraint::Index(ind_ty, out_ty) => write!(f, "Index<{},{}>", ind_ty, out_ty),
785 RustTypeConstraint::IndexMut(ind_ty, out_ty) => {
786 write!(f, "IndexMut<{},{}>", ind_ty, out_ty)
787 }
788 }
789 }
790}
791
792#[derive(Clone)]
793pub struct RustTypeSignature {
794 name: String,
795 constraints: Vec<(TVar, Vec<RustTypeConstraint>)>,
796 args: Vec<(String, RustType)>,
797 out_ty: Option<RustType>,
798}
799
800impl RustTypeSignature {
801 pub fn constraints(&self) -> &Vec<(TVar, Vec<RustTypeConstraint>)> {
802 &self.constraints
803 }
804
805 pub fn args(&self) -> &Vec<(String, RustType)> {
806 &self.args
807 }
808}
809
810impl From<syn::Signature> for RustTypeSignature {
811 fn from(sig: syn::Signature) -> Self {
812 let name = sig.ident.to_string();
813 let constraints = sig
814 .generics
815 .type_params()
816 .into_iter()
817 .map(|param| {
818 let tvar = param.ident.to_string().into();
819 let bounds = param
820 .bounds
821 .clone()
822 .into_pairs()
823 .map(|v| v.into_value())
824 .map(|v| v.into())
825 .collect();
826 (tvar, bounds)
827 })
828 .collect::<Vec<_>>();
829 let args = sig
830 .inputs
831 .into_pairs()
832 .map(|v| match v.into_value() {
833 syn::FnArg::Typed(syn::PatType {
834 pat: box syn::Pat::Ident(syn::PatIdent { ident, .. }),
835 ty: box ty,
836 ..
837 }) => (ident.to_string(), ty.into()),
838 v => panic!("unsupported pattern {:?} in signature", v),
839 })
840 .collect();
841 let out_ty = match sig.output {
842 syn::ReturnType::Default => None,
843 syn::ReturnType::Type(_, box ty) => Some(ty.into()),
844 };
845 RustTypeSignature {
846 name,
847 constraints,
848 args,
849 out_ty,
850 }
851 }
852}
853
854impl std::fmt::Display for RustTypeSignature {
855 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
856 write!(f, "fn {}<", self.name)?;
857
858 {
859 let constraints = self.constraints.iter().map(|v| Some(v)).intersperse(None);
860 for constraint in constraints {
861 match constraint {
862 None => write!(f, ", ")?,
863 Some((tvar, constraints)) => {
864 write!(f, "{}: ", tvar)?;
865 let constraints = constraints
866 .iter()
867 .map(|v| format!("{}", v))
868 .intersperse(" + ".into());
869 for opt in constraints {
870 write!(f, "{}", opt)?
871 }
872 }
873 }
874 }
875 }
876 write!(f, ">(")?;
877 {
878 let mut args = self.args.iter();
879 let mut next = args.next();
880 while next.is_some() {
881 let (name, ty) = next.unwrap();
882 write!(f, "{}: {}", name, ty)?;
883 next = args.next();
884 if next.is_some() {
885 write!(f, ",")?;
886 }
887 }
888 }
889 write!(f, ")")?;
890 match &self.out_ty {
891 None => (),
892 Some(ty) => writeln!(f, " -> {}", ty)?,
893 }
894
895 Ok(())
896 }
897}
898
899#[derive(Debug, Default, Clone)]
900pub struct CTypeContextCollector {
901 aliases: HashMap<syn::Ident, RustType>,
902 structs: HashMap<syn::Ident, RustStruct>,
903}
904
905impl CTypeContextCollector {
906 pub fn to_type_context(self) -> ProgramTypeContext {
907 (self.aliases, self.structs)
908 }
909}
910
911impl<'ast> syn::visit::Visit<'ast> for CTypeContextCollector {
912 fn visit_item_type(&mut self, i: &'ast syn::ItemType) {
913 let typ: RustType = (&*i.ty).clone().into();
914
915 self.aliases.insert(i.ident.clone(), typ);
916 }
917
918 fn visit_item_struct(&mut self, i: &'ast syn::ItemStruct) {
919 self.structs.insert(i.ident.clone(), i.clone().into());
920 }
921
922 fn visit_item_enum(&mut self, i: &'ast syn::ItemEnum) {
923 panic!(
924 "found use of unsupported enum {:?} declaration",
925 i.to_token_stream().to_string()
926 )
927 }
928}
929
930fn check_recursive(
931 checking: &syn::Ident,
932 current: syn::Ident,
933 mut path: HashSet<syn::Ident>,
934 usage_map: &HashMap<syn::Ident, HashSet<syn::Ident>>,
935) -> bool {
936 if path.contains(¤t) {
938 false
939 } else {
940 match usage_map.get(¤t) {
941 None => false,
942 Some(usages) => {
943 path.insert(current);
944 if usages.contains(checking) {
945 true
946 } else {
947 usages
948 .iter()
949 .any(|id| check_recursive(checking, id.clone(), path.clone(), usage_map))
950 }
951 }
952 }
953 }
954}
955
956pub fn normalize_type_context(ctxt: &mut ProgramTypeContext) -> HashSet<syn::Ident> {
957 let mut usage_map = HashMap::new();
958 let ref_ctx = (ctxt.0.clone(), ctxt.1.clone());
959 for (_name, st) in ctxt.0.iter_mut() {
960 st.resolve(&ref_ctx);
961 }
962 for (name, st) in ctxt.1.iter_mut() {
963 st.resolve(&ref_ctx);
964 usage_map.insert(name.clone(), st.uses());
965 }
966
967 let mut recursive = HashSet::new();
968
969 for (name, uses) in usage_map.iter() {
970 if uses.contains(name) || check_recursive(name, name.clone(), HashSet::new(), &usage_map) {
971 recursive.insert(name.clone());
972 }
973 }
974 for (id, alias) in ctxt.0.iter().clone() {
975 let mut uses = HashSet::new();
976 alias.uses(&mut uses);
977 if recursive.intersection(&uses).any(|_| true) {
978 recursive.insert(id.clone());
979 }
980 }
981 recursive
982}
983
984impl UnifyValue for RustType {
985 type Error = Error;
986
987 fn unify_values(value1: &Self, value2: &Self) -> Result<Self, Self::Error> {
988 match (value1, value2) {
989 (t1, t2) if t1 == t2 => Ok(t1.clone()),
990 (RustType::TVar(t1), RustType::TVar(t2)) if t1 == t2 => Ok(RustType::TVar(*t1)),
991 (RustType::Pointer(box x), RustType::Pointer(box y)) => {
992 let contents = Self::unify_values(x, y)?;
993 Ok(RustType::Pointer(Box::new(contents)))
994 }
995 (t1, t2) => Err(Error::UnUnifiableTypes(t1.clone(), t2.clone())),
996 }
997 }
998}