1use ena::unify::UnifyValue;
2use quote::ToTokens;
3use std::collections::{HashMap, HashSet};
4use std::hash::Hash;
5use syn::punctuated::Punctuated;
6use syn::{FieldsNamed, Path, PathSegment, Type, TypeArray};
7
8pub type TypeMap = HashMap<crate::location::Loc, RustTypeSignature>;
10
11#[derive(Debug)]
12pub enum Error {
13 UnUnifiableTypes(RustType, RustType),
14}
15
16pub type ProgramTypeContext = (
17 HashMap<syn::Ident, RustType>,
18 HashMap<syn::Ident, RustStruct>,
19);
20
21#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
22pub struct TVar(pub usize);
23
24impl From<&str> for TVar {
25 fn from(s: &str) -> Self {
26 if !s.starts_with("T") {
27 panic!("invalid assumption: unknown generic var \"{}\"", s)
28 }
29 let ind = s
30 .trim_start_matches('T')
31 .parse()
32 .expect("Index for generic could not be extracted");
33 TVar(ind)
34 }
35}
36
37impl From<String> for TVar {
38 fn from(s: String) -> Self {
39 s.as_str().into()
40 }
41}
42
43impl std::fmt::Display for TVar {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 write!(f, "T{}", self.0)
46 }
47}
48
49#[derive(Clone, Debug, Hash, PartialEq, Eq)]
50pub enum RustMutability {
51 Immutable,
52 Mutable,
53}
54
55impl std::fmt::Display for RustMutability {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 RustMutability::Immutable => write!(f, "immutable"),
59 RustMutability::Mutable => write!(f, "mutable"),
60 }
61 }
62}
63
64impl From<Option<syn::token::Mut>> for RustMutability {
65 fn from(mut_: Option<syn::token::Mut>) -> Self {
66 match mut_ {
67 Some(_) => RustMutability::Mutable,
68 None => RustMutability::Immutable,
69 }
70 }
71}
72
73impl Into<Option<syn::token::Mut>> for RustMutability {
74 fn into(self) -> Option<syn::token::Mut> {
75 match self {
76 RustMutability::Mutable => Some(Default::default()),
77 RustMutability::Immutable => None,
78 }
79 }
80}
81
82#[derive(Clone, Debug, Hash, PartialEq, Eq)]
83pub enum CIntegralSize {
84 Char,
85 Short,
86 Int,
87 Long,
88 LongLong,
89}
90
91impl std::fmt::Display for CIntegralSize {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 match self {
94 CIntegralSize::Char => write!(f, "char"),
95 CIntegralSize::Short => write!(f, "short"),
96 CIntegralSize::Int => write!(f, "int"),
97 CIntegralSize::Long => write!(f, "long"),
98 CIntegralSize::LongLong => write!(f, "longlong"),
99 }
100 }
101}
102
103#[derive(Clone, Debug, Hash, PartialEq, Eq)]
104pub enum CFloatSize {
105 Float,
106 Double,
107}
108
109impl std::fmt::Display for CFloatSize {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 match self {
112 CFloatSize::Float => write!(f, "float"),
113 CFloatSize::Double => write!(f, "double"),
114 }
115 }
116}
117
118#[derive(Clone, Debug, Hash, PartialEq, Eq)]
119pub enum RustType {
120 CVoid,
121 CInt {
122 unsigned: bool,
123 size: CIntegralSize,
124 },
125 CFloat(CFloatSize),
126 CAlias(syn::Ident),
127
128 Array(Box<RustType>, usize),
129
130 Option(Box<RustType>),
131 Vec(Box<RustType>),
132 Unit,
133 I32,
134 U8,
135 SizeT,
136 Isize,
137 Usize,
138 TVar(TVar), Never,
141
142 ExternFn(Vec<Box<RustType>>, bool, Box<RustType>),
143
144 Reference(RustMutability, Box<RustType>),
146 Pointer(Box<RustType>),
148}
149
150impl RustType {
151 fn uses(&self, set: &mut HashSet<syn::Ident>) {
152 match self {
153 RustType::CAlias(id) => {
154 set.insert(id.clone());
155 ()
156 }
157 RustType::Option(ty)
158 | RustType::Vec(ty)
159 | RustType::Reference(_, ty)
160 | RustType::Pointer(ty)
161 | RustType::Array(ty, _) => ty.uses(set),
162
163 _ => (),
170 }
171 }
172
173 fn resolve_checked(
174 &mut self,
175 path: &mut HashSet<syn::Ident>,
176 ctxt: &ProgramTypeContext,
177 ) -> bool {
178 match self {
179 RustType::CAlias(id) if path.contains(id) => true,
181 RustType::CAlias(id) if !path.contains(id) => {
182 path.insert(id.clone());
184 if ctxt.1.contains_key(id) {
186 false
187 } else {
188 match ctxt.0.get(id) {
189 Some(inner) => {
190 *self = inner.clone();
192 self.resolve_checked(path, ctxt)
194 }
195 None => {
196 log::warn!("attempted to resolve type {} that has no defined alias or definition", id.to_string());
197 false
198 }
199 }
200 }
201 }
202 RustType::Option(elt)
203 | RustType::Vec(elt)
204 | RustType::Pointer(elt)
205 | RustType::Reference(_, elt)
206 | RustType::Array(elt, _) => elt.resolve_checked(path, ctxt),
207 RustType::ExternFn(args, _, out) => {
208 let mut any_rec = false;
209 let mut base_path = path.clone();
210 for arg in args.iter_mut() {
211 let mut rec_path = base_path.clone();
212 any_rec |= arg.resolve_checked(&mut rec_path, ctxt);
213 path.extend(rec_path.into_iter());
214 }
215 any_rec |= out.resolve_checked(&mut base_path, ctxt);
216 path.extend(base_path.into_iter());
217 any_rec
218 }
219 _ => false,
220 }
221 }
222
223 pub fn resolve(&mut self, ctxt: &ProgramTypeContext) -> HashSet<syn::Ident> {
225 let mut set = HashSet::new();
226 self.resolve_checked(&mut set, ctxt);
227 set
228 }
229}
230
231impl std::fmt::Display for RustType {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 match self {
234 RustType::Never => write!(f, "never"),
235 RustType::Array(box ty, size) => write!(f, "array({}, {})", ty, size),
236 RustType::Option(box ty) => write!(f, "option({})", ty),
237 RustType::Vec(box ty) => write!(f, "vec({})", ty),
238 RustType::CInt { unsigned, size } => {
239 write!(f, "c_{}{}", if *unsigned { "u" } else { "" }, size)
240 }
241 RustType::CFloat(size) => write!(f, "c_{}", size),
242 RustType::CVoid => write!(f, "c_void"),
243 RustType::CAlias(ident) => write!(f, "{}", ident),
244 RustType::Unit => write!(f, "()"),
245 RustType::SizeT => write!(f, "size_t"),
246 RustType::U8 => write!(f, "u8"),
247 RustType::I32 => write!(f, "i32"),
248 RustType::Isize => write!(f, "isize"),
249 RustType::Usize => write!(f, "usize"),
250 RustType::TVar(tvar) => write!(f, "{}", tvar),
251 RustType::Pointer(box x) => write!(f, "mut_ptr_{}", x),
252 RustType::Reference(mt, box x) => write!(f, "ref_{}_{}", mt, x),
253 RustType::ExternFn(args, variadic, body) => write!(
254 f,
255 "extern_fn_({}, {}, {})",
256 variadic,
257 args.iter()
258 .map(|v| format!("{}", v))
259 .collect::<Vec<_>>()
260 .join(","),
261 body
262 ),
263 }
264 }
265}
266
267impl Into<Type> for RustType {
268 fn into(self) -> Type {
269 match self {
270 RustType::Never => syn::Type::Never(syn::TypeNever {
271 bang_token: Default::default(),
272 }),
273 RustType::Array(box ty, size) => Type::Array(syn::TypeArray {
274 bracket_token: Default::default(),
275 elem: Box::new(ty.into()),
276 semi_token: Default::default(),
277 len: syn::parse_str::<syn::Expr>(&format!("{}", size)).unwrap(),
278 }),
279 RustType::Option(box ty) => Type::Path(syn::TypePath {
280 qself: None,
281 path: syn::Path {
282 leading_colon: None,
283 segments: [PathSegment {
284 ident: syn::parse_str::<syn::Ident>("Option").unwrap(),
285 arguments: syn::PathArguments::AngleBracketed(
286 syn::AngleBracketedGenericArguments {
287 colon2_token: None,
288 lt_token: Default::default(),
289 args: [syn::GenericArgument::Type(ty.into())]
290 .into_iter()
291 .collect(),
292 gt_token: Default::default(),
293 },
294 ),
295 }]
296 .into_iter()
297 .collect(),
298 },
299 }),
300 RustType::Vec(box ty) => Type::Path(syn::TypePath {
301 qself: None,
302 path: syn::Path {
303 leading_colon: None,
304 segments: [PathSegment {
305 ident: syn::parse_str::<syn::Ident>("Vec").unwrap(),
306 arguments: syn::PathArguments::AngleBracketed(
307 syn::AngleBracketedGenericArguments {
308 colon2_token: None,
309 lt_token: Default::default(),
310 args: [syn::GenericArgument::Type(ty.into())]
311 .into_iter()
312 .collect(),
313 gt_token: Default::default(),
314 },
315 ),
316 }]
317 .into_iter()
318 .collect(),
319 },
320 }),
321 RustType::Unit => syn::parse_str::<Type>("()").unwrap(),
322 ty @ RustType::CInt { .. } => syn::parse_str::<Type>(&format!("libc::{}", ty)).unwrap(),
323 ty @ RustType::CFloat(_) => syn::parse_str::<Type>(&format!("libc::{}", ty)).unwrap(),
324 RustType::CVoid => syn::parse_str::<Type>("libc::c_void").unwrap(),
325
326 RustType::CAlias(ident) => Type::Path(syn::TypePath {
327 qself: None,
328 path: syn::Path {
329 leading_colon: None,
330 segments: [syn::PathSegment::from(ident)].into_iter().collect(),
331 },
332 }),
333 RustType::SizeT => syn::parse_str::<Type>("size_t").unwrap(),
334 RustType::U8 => syn::parse_str::<Type>("u8").unwrap(),
335 RustType::I32 => syn::parse_str::<Type>("i32").unwrap(),
336 RustType::Isize => syn::parse_str::<Type>("isize").unwrap(),
337 RustType::Usize => syn::parse_str::<Type>("usize").unwrap(),
338 RustType::TVar(n) => syn::parse_str::<Type>(&format!("{}", n)).unwrap(),
339 RustType::Pointer(box v) => Type::Ptr(syn::TypePtr {
340 const_token: None,
341 mutability: Some(Default::default()),
342 elem: Box::new(v.into()),
343 star_token: Default::default(),
344 }),
345 RustType::Reference(muta, box v) => Type::Reference(syn::TypeReference {
346 and_token: Default::default(),
347 mutability: muta.into(),
348 elem: Box::new(v.into()),
349 lifetime: None,
350 }),
351 RustType::ExternFn(args, variadic, box res) => {
352 let inputs = if variadic {
353 args.into_iter()
354 .map(|f| syn::BareFnArg {
355 attrs: Default::default(),
356 name: None,
357 ty: (*f).into(),
358 })
359 .collect()
360 } else {
361 args.into_iter()
362 .map(|f| syn::BareFnArg {
363 attrs: Default::default(),
364 name: None,
365 ty: (*f).into(),
366 })
367 .chain(
368 [syn::BareFnArg {
369 attrs: Default::default(),
370 name: None,
371 ty: Type::Verbatim("...".to_token_stream()),
372 }]
373 .into_iter(),
374 )
375 .collect()
376 };
377
378 Type::BareFn(syn::TypeBareFn {
379 lifetimes: None,
380 unsafety: Some(Default::default()),
381 abi: Some(syn::Abi {
382 extern_token: Default::default(),
383 name: Some(syn::parse_str::<syn::LitStr>("\"C\"").unwrap()),
384 }),
385 fn_token: Default::default(),
386 paren_token: Default::default(),
387 inputs,
388 variadic: None,
389 output: syn::ReturnType::Type(Default::default(), Box::new(res.into())),
390 })
391 }
392 }
393 }
394}
395
396impl From<Type> for RustType {
397 fn from(ty: Type) -> Self {
398 match ty {
399 Type::Path(syn::TypePath {
400 path: Path { segments, .. },
401 ..
402 }) if segments.last().is_some_and(|segment| {
403 segment.ident == "Option" && !segment.arguments.is_empty()
404 }) =>
405 {
406 let syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments { args, .. }) =
407 &segments.last().unwrap().arguments else {
408 panic!("found use of unsupported syntactic construct {} in code", segments.to_token_stream().to_string())
409 };
410
411 let syn::GenericArgument::Type(ty) = &args[0] else {
412 panic!("found use of unsupported syntactic construct {} in code", segments.to_token_stream().to_string())
413 };
414
415 RustType::Option(Box::new(ty.clone().into()))
416 }
417 Type::Path(syn::TypePath {
418 path: Path { segments, .. },
419 ..
420 }) if segments
421 .last()
422 .is_some_and(|segment| segment.ident == "Vec" && !segment.arguments.is_empty()) =>
423 {
424 let syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments { args, .. }) =
425 &segments.last().unwrap().arguments else {
426 panic!("found use of unsupported syntactic construct {} in code", segments.to_token_stream().to_string())
427 };
428
429 let syn::GenericArgument::Type(ty) = &args[0] else {
430 panic!("found use of unsupported syntactic construct {} in code", segments.to_token_stream().to_string())
431 };
432
433 RustType::Vec(Box::new(ty.clone().into()))
434 }
435 Type::Path(syn::TypePath {
436 path: ref path @ Path { ref segments, .. },
437 ..
438 }) if segments
439 .last()
440 .is_some_and(|segment| segment.arguments.is_empty()) =>
441 {
442 let ident = &segments.last().unwrap().ident;
443
444 if !(segments.len() == 1)
445 && !(segments.len() > 1 && segments[0].ident == "libc")
446 && !(segments.len() > 3
447 && segments[0].ident == "std"
448 && segments[1].ident == "os"
449 && segments[2].ident == "raw")
450 {
451 log::warn!("found use of non libc type {:?}, optimistically assuming nothing crazy is going on",
452 path.to_token_stream().to_string()
453 )
454 }
455
456 use RustType::*;
457
458 match ident.to_string().as_str() {
459 "isize" => Isize,
460 "i32" => I32,
461 "size_t" => SizeT,
462 "u8" => U8,
463 "usize" => Usize,
464 "c_float" => CFloat(CFloatSize::Float),
465 "c_double" => CFloat(CFloatSize::Double),
466
467 "c_char" => CInt {
468 unsigned: false,
469 size: CIntegralSize::Char,
470 },
471 "c_schar" => CInt {
472 unsigned: false,
473 size: CIntegralSize::Char,
474 },
475 "c_uchar" => CInt {
476 unsigned: true,
477 size: CIntegralSize::Char,
478 },
479 "c_short" => CInt {
480 unsigned: false,
481 size: CIntegralSize::Short,
482 },
483 "c_ushort" => CInt {
484 unsigned: true,
485 size: CIntegralSize::Short,
486 },
487
488 "c_int" => CInt {
489 unsigned: false,
490 size: CIntegralSize::Int,
491 },
492 "c_uint" => CInt {
493 unsigned: true,
494 size: CIntegralSize::Int,
495 },
496
497 "c_long" => CInt {
498 unsigned: false,
499 size: CIntegralSize::Long,
500 },
501 "c_ulong" => CInt {
502 unsigned: true,
503 size: CIntegralSize::Long,
504 },
505
506 "c_longlong" => CInt {
507 unsigned: false,
508 size: CIntegralSize::Long,
509 },
510 "c_ulonglong" => CInt {
511 unsigned: true,
512 size: CIntegralSize::Long,
513 },
514
515 "c_void" => CVoid,
516
517 _txt => RustType::CAlias(ident.clone()),
518 }
519 }
520 Type::Ptr(syn::TypePtr {
521 const_token,
522 mutability,
523 elem: box ty,
524 ..
525 }) => {
526 match (const_token, mutability) {
527 (None, Some(_)) => RustType::Pointer(Box::new(ty.into())), (Some(_), None) => RustType::Pointer(Box::new(ty.into())), (_, _) => {
530 todo!()
531 }
532 }
533 }
534 Type::Reference(syn::TypeReference {
535 lifetime: None,
536 mutability,
537 elem: box elem,
538 ..
539 }) => RustType::Reference(mutability.into(), Box::new(elem.into())),
540
541 Type::Tuple(syn::TypeTuple { elems, .. }) if elems.len() == 0 => RustType::Unit,
542
543 Type::BareFn(syn::TypeBareFn {
544 unsafety: Some(_),
545 abi: Some(_),
546 inputs,
547 output,
548 variadic: None,
549 ..
550 }) => {
551 let mut variadic = false;
552 let inputs = inputs
553 .into_iter()
554 .map(|v| match v.ty {
555 Type::Verbatim(_v) => {
556 variadic = true;
557 None
558 }
559 _ => Some(Box::new(v.ty.into())),
560 })
561 .flatten()
562 .collect();
563
564 let output = match output {
565 syn::ReturnType::Default => RustType::Unit,
566 syn::ReturnType::Type(_, box ty) => ty.into(),
567 };
568 RustType::ExternFn(inputs, variadic, Box::new(output))
569 }
570
571 Type::Array(TypeArray {
572 elem: box ty,
573 len:
574 syn::Expr::Lit(syn::ExprLit {
575 lit: syn::Lit::Int(i),
576 ..
577 }),
578 ..
579 }) => RustType::Array(Box::new(ty.into()), i.base10_parse().unwrap()),
580 Type::Never(_) => RustType::Never,
581 _ => panic!("unsupported type {:?}", ty.to_token_stream().to_string()),
582 }
583 }
584}
585
586#[derive(Clone, Debug, Hash)]
587pub struct RustStruct {
588 name: syn::Ident,
589 fields: Vec<(syn::Ident, RustType)>,
590}
591
592impl std::fmt::Display for RustStruct {
593 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
594 write!(
595 f,
596 "struct {} {{ {} }}",
597 self.name,
598 self.fields
599 .iter()
600 .map(|(name, ty)| format!("{}: {}", name, ty))
601 .collect::<Vec<_>>()
602 .join(",")
603 )
604 }
605}
606
607impl RustStruct {
608 pub fn name(&self) -> &syn::Ident {
609 &self.name
610 }
611
612 pub fn fields(&self) -> &Vec<(syn::Ident, RustType)> {
613 &self.fields
614 }
615
616 pub fn uses(&self) -> HashSet<syn::Ident> {
618 let mut uses = HashSet::new();
619 for (_, ty) in self.fields.iter() {
620 ty.uses(&mut uses);
621 }
622 uses
623 }
624
625 pub fn resolve(&mut self, ctxt: &ProgramTypeContext) -> HashSet<syn::Ident> {
627 let mut acc = HashSet::new();
628 for (_, ty) in self.fields.iter_mut() {
629 acc.extend(ty.resolve(ctxt).into_iter())
630 }
631 acc
632 }
633}
634
635impl From<syn::ItemStruct> for RustStruct {
636 fn from(i: syn::ItemStruct) -> Self {
637 if i.attrs.len() == 0 {
638 log::warn!("skipping unknown attributes {:?}", i.attrs)
639 }
640 if i.generics.params.len() > 0 {
641 panic!("found unsupported use of generic struct {:?} - @Bryan, if you are implementing lifetimes or something, talk to me first.",
642 i.generics)
643 }
644
645 let syn::Fields::Named(FieldsNamed { named, .. }) = i.fields else {
646 panic!("found unsupported struct declaration {}", i.to_token_stream().to_string())
647 };
648 let name = i.ident;
649 let fields = named
650 .into_iter()
651 .map(|v| (v.ident.unwrap(), v.ty.into()))
652 .collect();
653 RustStruct { name, fields }
654 }
655}
656
657impl RustStruct {}
658
659#[derive(Clone, Debug, Hash)]
660pub enum RustTypeConstraint {
661 Index(RustType, RustType),
663 IndexMut(RustType, RustType),
665}
666
667impl Into<syn::TypeParamBound> for RustTypeConstraint {
668 fn into(self) -> syn::TypeParamBound {
669 match self {
670 RustTypeConstraint::Index(t1, t2) => {
672 let mut path = syn::Path {
673 leading_colon: None,
674 segments: Punctuated::new(),
675 };
676 let mut args = Punctuated::new();
677 args.push(syn::GenericArgument::Type(t1.into()));
678 args.push(syn::GenericArgument::Binding(syn::Binding {
679 ident: syn::parse_str::<syn::Ident>("Output").unwrap(),
680 eq_token: Default::default(),
681 ty: t2.into(),
682 }));
683
684 path.segments.push(syn::PathSegment {
685 ident: syn::parse_str::<syn::Ident>("Index").unwrap(),
686 arguments: syn::PathArguments::AngleBracketed(
687 syn::AngleBracketedGenericArguments {
688 colon2_token: None,
689 lt_token: Default::default(),
690 args,
691 gt_token: Default::default(),
692 },
693 ),
694 });
695
696 syn::TypeParamBound::Trait(syn::TraitBound {
697 paren_token: None,
698 modifier: syn::TraitBoundModifier::None,
699 lifetimes: None,
700 path,
701 })
702 }
703 RustTypeConstraint::IndexMut(t1, t2) => {
704 let mut path = syn::Path {
705 leading_colon: None,
706 segments: Punctuated::new(),
707 };
708 let mut args = Punctuated::new();
709 args.push(syn::GenericArgument::Type(t1.into()));
710 args.push(syn::GenericArgument::Binding(syn::Binding {
711 ident: syn::parse_str::<syn::Ident>("Output").unwrap(),
712 eq_token: Default::default(),
713 ty: t2.into(),
714 }));
715
716 path.segments.push(syn::PathSegment {
717 ident: syn::parse_str::<syn::Ident>("IndexMut").unwrap(),
718 arguments: syn::PathArguments::AngleBracketed(
719 syn::AngleBracketedGenericArguments {
720 colon2_token: None,
721 lt_token: Default::default(),
722 args,
723 gt_token: Default::default(),
724 },
725 ),
726 });
727
728 syn::TypeParamBound::Trait(syn::TraitBound {
729 paren_token: None,
730 modifier: syn::TraitBoundModifier::None,
731 lifetimes: None,
732 path,
733 })
734 }
735 }
736 }
737}
738
739impl From<syn::TypeParamBound> for RustTypeConstraint {
740 fn from(ty: syn::TypeParamBound) -> Self {
741 match ty {
742 syn::TypeParamBound::Trait(syn::TraitBound {
743 path: syn::Path { segments, .. },
744 ..
745 }) if segments.len() == 1 => {
746 let segment = segments[0].clone();
747 let trait_name = segment.ident.to_string();
748 match (trait_name.as_str(), segment.arguments) {
749 ("Index", syn::PathArguments::AngleBracketed(args)) => {
750 let in_ty = match args.args[0].clone() {
751 syn::GenericArgument::Type(ty) => ty,
752 _ => panic!("invalid constraint structure {:#?}", segments),
753 };
754 let out_ty = match args.args[1].clone() {
755 syn::GenericArgument::Binding(binding) => binding.ty,
756 _ => panic!("invalid constraint structure {:#?}", segments),
757 };
758 RustTypeConstraint::Index(in_ty.into(), out_ty.into())
759 }
760 ("IndexMut", syn::PathArguments::AngleBracketed(args)) => {
761 let in_ty = match args.args[0].clone() {
762 syn::GenericArgument::Type(ty) => ty,
763 _ => panic!("invalid constraint structure {:#?}", segments),
764 };
765 let out_ty = match args.args[1].clone() {
766 syn::GenericArgument::Binding(binding) => binding.ty,
767 _ => panic!("invalid constraint structure {:#?}", segments),
768 };
769 RustTypeConstraint::IndexMut(in_ty.into(), out_ty.into())
770 }
771 _ => panic!("unsupported type constraint {:#?}", segments),
772 }
773 }
774 ty => panic!("unsupported type constraint {:#?}", ty),
775 }
776 }
777}
778
779impl std::fmt::Display for RustTypeConstraint {
780 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
781 match self {
782 RustTypeConstraint::Index(ind_ty, out_ty) => write!(f, "Index<{},{}>", ind_ty, out_ty),
783 RustTypeConstraint::IndexMut(ind_ty, out_ty) => {
784 write!(f, "IndexMut<{},{}>", ind_ty, out_ty)
785 }
786 }
787 }
788}
789
790#[derive(Clone)]
791pub struct RustTypeSignature {
792 name: String,
793 constraints: Vec<(TVar, Vec<RustTypeConstraint>)>,
794 args: Vec<(String, RustType)>,
795 out_ty: Option<RustType>,
796}
797
798impl RustTypeSignature {
799 pub fn constraints(&self) -> &Vec<(TVar, Vec<RustTypeConstraint>)> {
800 &self.constraints
801 }
802
803 pub fn args(&self) -> &Vec<(String, RustType)> {
804 &self.args
805 }
806}
807
808impl From<syn::Signature> for RustTypeSignature {
809 fn from(sig: syn::Signature) -> Self {
810 let name = sig.ident.to_string();
811 let constraints = sig
812 .generics
813 .type_params()
814 .into_iter()
815 .map(|param| {
816 let tvar = param.ident.to_string().into();
817 let bounds = param
818 .bounds
819 .clone()
820 .into_pairs()
821 .map(|v| v.into_value())
822 .map(|v| v.into())
823 .collect();
824 (tvar, bounds)
825 })
826 .collect::<Vec<_>>();
827 let args = sig
828 .inputs
829 .into_pairs()
830 .map(|v| match v.into_value() {
831 syn::FnArg::Typed(syn::PatType {
832 pat: box syn::Pat::Ident(syn::PatIdent { ident, .. }),
833 ty: box ty,
834 ..
835 }) => (ident.to_string(), ty.into()),
836 v => panic!("unsupported pattern {:?} in signature", v),
837 })
838 .collect();
839 let out_ty = match sig.output {
840 syn::ReturnType::Default => None,
841 syn::ReturnType::Type(_, box ty) => Some(ty.into()),
842 };
843 RustTypeSignature {
844 name,
845 constraints,
846 args,
847 out_ty,
848 }
849 }
850}
851
852impl std::fmt::Display for RustTypeSignature {
853 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
854 write!(f, "fn {}<", self.name)?;
855
856 {
857 let constraints = self.constraints.iter().map(|v| Some(v)).intersperse(None);
858 for constraint in constraints {
859 match constraint {
860 None => write!(f, ", ")?,
861 Some((tvar, constraints)) => {
862 write!(f, "{}: ", tvar)?;
863 let constraints = constraints
864 .iter()
865 .map(|v| format!("{}", v))
866 .intersperse(" + ".into());
867 for opt in constraints {
868 write!(f, "{}", opt)?
869 }
870 }
871 }
872 }
873 }
874 write!(f, ">(")?;
875 {
876 let mut args = self.args.iter();
877 let mut next = args.next();
878 while next.is_some() {
879 let (name, ty) = next.unwrap();
880 write!(f, "{}: {}", name, ty)?;
881 next = args.next();
882 if next.is_some() {
883 write!(f, ",")?;
884 }
885 }
886 }
887 write!(f, ")")?;
888 match &self.out_ty {
889 None => (),
890 Some(ty) => writeln!(f, " -> {}", ty)?,
891 }
892
893 Ok(())
894 }
895}
896
897#[derive(Debug, Default, Clone)]
898pub struct CTypeContextCollector {
899 aliases: HashMap<syn::Ident, RustType>,
900 structs: HashMap<syn::Ident, RustStruct>,
901}
902
903impl CTypeContextCollector {
904 pub fn to_type_context(self) -> ProgramTypeContext {
905 (self.aliases, self.structs)
906 }
907}
908
909impl<'ast> syn::visit::Visit<'ast> for CTypeContextCollector {
910 fn visit_item_type(&mut self, i: &'ast syn::ItemType) {
911 let typ: RustType = (&*i.ty).clone().into();
912
913 self.aliases.insert(i.ident.clone(), typ);
914 }
915
916 fn visit_item_struct(&mut self, i: &'ast syn::ItemStruct) {
917 self.structs.insert(i.ident.clone(), i.clone().into());
918 }
919
920 fn visit_item_enum(&mut self, i: &'ast syn::ItemEnum) {
921 panic!(
922 "found use of unsupported enum {:?} declaration",
923 i.to_token_stream().to_string()
924 )
925 }
926}
927
928fn check_recursive(
929 checking: &syn::Ident,
930 current: syn::Ident,
931 mut path: HashSet<syn::Ident>,
932 usage_map: &HashMap<syn::Ident, HashSet<syn::Ident>>,
933) -> bool {
934 if path.contains(¤t) {
936 false
937 } else {
938 match usage_map.get(¤t) {
939 None => false,
940 Some(usages) => {
941 path.insert(current);
942 if usages.contains(checking) {
943 true
944 } else {
945 usages
946 .iter()
947 .any(|id| check_recursive(checking, id.clone(), path.clone(), usage_map))
948 }
949 }
950 }
951 }
952}
953
954pub fn normalize_type_context(ctxt: &mut ProgramTypeContext) -> HashSet<syn::Ident> {
955 let mut usage_map = HashMap::new();
956 let ref_ctx = (ctxt.0.clone(), ctxt.1.clone());
957 for (_name, st) in ctxt.0.iter_mut() {
958 st.resolve(&ref_ctx);
959 }
960 for (name, st) in ctxt.1.iter_mut() {
961 st.resolve(&ref_ctx);
962 usage_map.insert(name.clone(), st.uses());
963 }
964
965 let mut recursive = HashSet::new();
966
967 for (name, uses) in usage_map.iter() {
968 if uses.contains(name) || check_recursive(name, name.clone(), HashSet::new(), &usage_map) {
969 recursive.insert(name.clone());
970 }
971 }
972 for (id, alias) in ctxt.0.iter().clone() {
973 let mut uses = HashSet::new();
974 alias.uses(&mut uses);
975 if recursive.intersection(&uses).any(|_| true) {
976 recursive.insert(id.clone());
977 }
978 }
979 recursive
980}
981
982impl UnifyValue for RustType {
983 type Error = Error;
984
985 fn unify_values(value1: &Self, value2: &Self) -> Result<Self, Self::Error> {
986 match (value1, value2) {
987 (t1, t2) if t1 == t2 => Ok(t1.clone()),
988 (RustType::TVar(t1), RustType::TVar(t2)) if t1 == t2 => Ok(RustType::TVar(*t1)),
989 (RustType::Pointer(box x), RustType::Pointer(box y)) => {
990 let contents = Self::unify_values(x, y)?;
991 Ok(RustType::Pointer(Box::new(contents)))
992 }
993 (t1, t2) => Err(Error::UnUnifiableTypes(t1.clone(), t2.clone())),
994 }
995 }
996}