Skip to main content

rib/inferred_type/
mod.rs

1pub use type_internal::*;
2
3pub(crate) use all_of::*;
4pub(crate) use type_origin::*;
5pub(crate) use unification::*;
6
7mod all_of;
8mod type_internal;
9mod type_origin;
10mod unification;
11
12use crate::instance_type::InstanceType;
13use crate::rib_source_span::SourceSpan;
14use crate::type_inference::GetTypeHint;
15use crate::wit_type::*;
16use crate::TypeName;
17use bigdecimal::BigDecimal;
18use std::fmt::{Display, Formatter};
19use std::hash::{Hash, Hasher};
20use std::ops::Deref;
21
22#[derive(Debug, Clone, Eq, PartialOrd, Ord)]
23pub struct InferredType {
24    pub inner: Box<TypeInternal>,
25    pub origin: TypeOrigin,
26}
27
28impl InferredType {
29    pub fn originated_at(&self, source_span: &SourceSpan) -> InferredType {
30        self.add_origin(TypeOrigin::OriginatedAt(source_span.clone()))
31    }
32
33    pub fn origin(&self) -> TypeOrigin {
34        self.origin.clone()
35    }
36
37    pub fn source_span(&self) -> Option<SourceSpan> {
38        let origin = self.origin();
39
40        match origin {
41            TypeOrigin::Default(_) => None,
42            TypeOrigin::NoOrigin => None,
43            TypeOrigin::Declared(_) => None,
44            TypeOrigin::Multiple(origins) => {
45                let mut source_span = None;
46                for origin in origins {
47                    if let TypeOrigin::OriginatedAt(loc) = origin {
48                        source_span = Some(loc.clone());
49                        break;
50                    }
51                }
52                source_span
53            }
54            TypeOrigin::OriginatedAt(_) => None,
55        }
56    }
57
58    pub fn as_number(&self) -> Result<InferredNumber, String> {
59        fn go(with_origin: &InferredType, found: &mut Vec<InferredNumber>) -> Result<(), String> {
60            match with_origin.inner.deref() {
61                TypeInternal::S8 => {
62                    found.push(InferredNumber::S8);
63                    Ok(())
64                }
65                TypeInternal::U8 => {
66                    found.push(InferredNumber::U8);
67                    Ok(())
68                }
69                TypeInternal::S16 => {
70                    found.push(InferredNumber::S16);
71                    Ok(())
72                }
73                TypeInternal::U16 => {
74                    found.push(InferredNumber::U16);
75                    Ok(())
76                }
77                TypeInternal::S32 => {
78                    found.push(InferredNumber::S32);
79                    Ok(())
80                }
81                TypeInternal::U32 => {
82                    found.push(InferredNumber::U32);
83                    Ok(())
84                }
85                TypeInternal::S64 => {
86                    found.push(InferredNumber::S64);
87                    Ok(())
88                }
89                TypeInternal::U64 => {
90                    found.push(InferredNumber::U64);
91                    Ok(())
92                }
93                TypeInternal::F32 => {
94                    found.push(InferredNumber::F32);
95                    Ok(())
96                }
97                TypeInternal::F64 => {
98                    found.push(InferredNumber::F64);
99                    Ok(())
100                }
101                TypeInternal::AllOf(all_variables) => {
102                    let mut previous: Option<InferredNumber> = None;
103                    for variable in all_variables {
104                        go(variable, found)?;
105
106                        if let Some(current) = found.first() {
107                            match &previous {
108                                None => {
109                                    previous = Some(current.clone());
110                                    found.push(current.clone());
111                                }
112                                Some(previous) => {
113                                    if previous != current {
114                                        return Err(format!(
115                                            "expected the same type of number. But found {current}, {previous}"
116                                        ));
117                                    }
118
119                                    found.push(current.clone());
120                                }
121                            }
122                        } else {
123                            return Err("failed to get a number".to_string());
124                        }
125                    }
126
127                    Ok(())
128                }
129                TypeInternal::Range { .. } => Err("used as range".to_string()),
130                TypeInternal::Bool => Err(format!("used as {}", "bool")),
131                TypeInternal::Chr => Err(format!("used as {}", "char")),
132                TypeInternal::Str => Err(format!("used as {}", "string")),
133                TypeInternal::List(_) => Err(format!("used as {}", "list")),
134                TypeInternal::Tuple(_) => Err(format!("used as {}", "tuple")),
135                TypeInternal::Record(_) => Err(format!("used as {}", "record")),
136                TypeInternal::Flags(_) => Err(format!("used as {}", "flags")),
137                TypeInternal::Enum(_) => Err(format!("used as {}", "enum")),
138                TypeInternal::Option(_) => Err(format!("used as {}", "option")),
139                TypeInternal::Result { .. } => Err(format!("used as {}", "result")),
140                TypeInternal::Variant(_) => Err(format!("used as {}", "variant")),
141                TypeInternal::Unknown => Err("found unknown".to_string()),
142                TypeInternal::Sequence(_) => {
143                    Err(format!("used as {}", "function-multi-parameter-return"))
144                }
145                TypeInternal::Resource { .. } => Err(format!("used as {}", "resource")),
146                TypeInternal::Instance { .. } => Err(format!("used as {}", "instance")),
147            }
148        }
149
150        let mut found: Vec<InferredNumber> = vec![];
151        go(self, &mut found)?;
152        found.first().cloned().ok_or("Failed".to_string())
153    }
154
155    pub fn bool() -> InferredType {
156        InferredType {
157            inner: Box::new(TypeInternal::Bool),
158            origin: TypeOrigin::NoOrigin,
159        }
160    }
161
162    pub fn char() -> InferredType {
163        InferredType {
164            inner: Box::new(TypeInternal::Chr),
165            origin: TypeOrigin::NoOrigin,
166        }
167    }
168
169    pub fn contains_only_number(&self) -> bool {
170        match self.inner.deref() {
171            TypeInternal::S8
172            | TypeInternal::U8
173            | TypeInternal::S16
174            | TypeInternal::U16
175            | TypeInternal::S32
176            | TypeInternal::U32
177            | TypeInternal::S64
178            | TypeInternal::U64
179            | TypeInternal::F32
180            | TypeInternal::F64 => true,
181            TypeInternal::Bool => false,
182            TypeInternal::Chr => false,
183            TypeInternal::Str => false,
184            TypeInternal::List(_) => false,
185            TypeInternal::Tuple(_) => false,
186            TypeInternal::Record(_) => false,
187            TypeInternal::Flags(_) => false,
188            TypeInternal::Enum(_) => false,
189            TypeInternal::Option(_) => false,
190            TypeInternal::Result { .. } => false,
191            TypeInternal::Variant(_) => false,
192            TypeInternal::Resource { .. } => false,
193            TypeInternal::Range { .. } => false,
194            TypeInternal::Instance { .. } => false,
195            TypeInternal::Unknown => false,
196            TypeInternal::Sequence(_) => false,
197            TypeInternal::AllOf(types) => types.iter().all(|t| t.contains_only_number()),
198        }
199    }
200
201    pub fn declared_at(&self, source_span: SourceSpan) -> InferredType {
202        self.add_origin(TypeOrigin::Declared(source_span.clone()))
203    }
204
205    pub fn as_default(&self, default_type: DefaultType) -> InferredType {
206        let new_origin = TypeOrigin::Default(default_type);
207
208        InferredType {
209            inner: self.inner.clone(),
210            origin: self.origin.add_origin(new_origin),
211        }
212    }
213
214    pub fn enum_(cases: Vec<String>) -> InferredType {
215        InferredType {
216            inner: Box::new(TypeInternal::Enum(cases)),
217            origin: TypeOrigin::NoOrigin,
218        }
219    }
220
221    pub fn f32() -> InferredType {
222        InferredType {
223            inner: Box::new(TypeInternal::F32),
224            origin: TypeOrigin::NoOrigin,
225        }
226    }
227
228    pub fn f64() -> InferredType {
229        InferredType {
230            inner: Box::new(TypeInternal::F64),
231            origin: TypeOrigin::NoOrigin,
232        }
233    }
234
235    pub fn flags(flags: Vec<String>) -> InferredType {
236        InferredType {
237            inner: Box::new(TypeInternal::Flags(flags)),
238            origin: TypeOrigin::NoOrigin,
239        }
240    }
241
242    pub fn instance(instance_type: InstanceType) -> InferredType {
243        InferredType {
244            inner: Box::new(TypeInternal::Instance {
245                instance_type: Box::new(instance_type),
246            }),
247            origin: TypeOrigin::NoOrigin,
248        }
249    }
250
251    pub fn internal_type(&self) -> &TypeInternal {
252        self.inner.as_ref()
253    }
254
255    pub fn internal_type_mut(&mut self) -> &mut TypeInternal {
256        self.inner.as_mut()
257    }
258
259    pub fn list(inner: InferredType) -> InferredType {
260        InferredType {
261            inner: Box::new(TypeInternal::List(inner)),
262            origin: TypeOrigin::NoOrigin,
263        }
264    }
265
266    pub fn new(inferred_type: TypeInternal, origin: TypeOrigin) -> InferredType {
267        InferredType {
268            inner: Box::new(inferred_type),
269            origin,
270        }
271    }
272
273    pub fn option(inner: InferredType) -> InferredType {
274        InferredType {
275            inner: Box::new(TypeInternal::Option(inner)),
276            origin: TypeOrigin::NoOrigin,
277        }
278    }
279
280    pub fn range(from: InferredType, to: Option<InferredType>) -> InferredType {
281        InferredType {
282            inner: Box::new(TypeInternal::Range { from, to }),
283            origin: TypeOrigin::NoOrigin,
284        }
285    }
286
287    pub fn eliminate_default(inferred_types: Vec<&InferredType>) -> Vec<&InferredType> {
288        inferred_types
289            .into_iter()
290            .filter(|&t| !t.origin.is_default())
291            .collect::<Vec<_>>()
292    }
293
294    pub fn record(fields: Vec<(String, InferredType)>) -> InferredType {
295        InferredType {
296            inner: Box::new(TypeInternal::Record(fields)),
297            origin: TypeOrigin::NoOrigin,
298        }
299    }
300
301    pub fn resolved(inferred_type: TypeInternal) -> InferredType {
302        InferredType {
303            inner: Box::new(inferred_type),
304            origin: TypeOrigin::NoOrigin,
305        }
306    }
307
308    pub fn resource(
309        resource_id: u64,
310        resource_mode: u8,
311        owner: Option<String>,
312        name: Option<String>,
313    ) -> InferredType {
314        InferredType {
315            inner: Box::new(TypeInternal::Resource {
316                resource_id,
317                resource_mode,
318                owner,
319                name,
320            }),
321            origin: TypeOrigin::NoOrigin,
322        }
323    }
324
325    pub fn result(ok: Option<InferredType>, error: Option<InferredType>) -> InferredType {
326        InferredType {
327            inner: Box::new(TypeInternal::Result { ok, error }),
328            origin: TypeOrigin::NoOrigin,
329        }
330    }
331
332    pub fn sequence(inferred_types: Vec<InferredType>) -> InferredType {
333        InferredType {
334            inner: Box::new(TypeInternal::Sequence(inferred_types)),
335            origin: TypeOrigin::NoOrigin,
336        }
337    }
338
339    pub fn string() -> InferredType {
340        InferredType {
341            inner: Box::new(TypeInternal::Str),
342            origin: TypeOrigin::NoOrigin,
343        }
344    }
345
346    pub fn s8() -> InferredType {
347        InferredType {
348            inner: Box::new(TypeInternal::S8),
349            origin: TypeOrigin::NoOrigin,
350        }
351    }
352
353    pub fn s16() -> InferredType {
354        InferredType {
355            inner: Box::new(TypeInternal::S16),
356            origin: TypeOrigin::NoOrigin,
357        }
358    }
359
360    pub fn s32() -> InferredType {
361        InferredType {
362            inner: Box::new(TypeInternal::S32),
363            origin: TypeOrigin::NoOrigin,
364        }
365    }
366
367    pub fn s64() -> InferredType {
368        InferredType {
369            inner: Box::new(TypeInternal::S64),
370            origin: TypeOrigin::NoOrigin,
371        }
372    }
373
374    pub fn tuple(inner: Vec<InferredType>) -> InferredType {
375        InferredType {
376            inner: Box::new(TypeInternal::Tuple(inner)),
377            origin: TypeOrigin::NoOrigin,
378        }
379    }
380
381    pub fn u8() -> InferredType {
382        InferredType {
383            inner: Box::new(TypeInternal::U8),
384            origin: TypeOrigin::NoOrigin,
385        }
386    }
387
388    pub fn unit() -> InferredType {
389        InferredType::tuple(vec![])
390    }
391
392    pub fn unknown() -> InferredType {
393        InferredType {
394            inner: Box::new(TypeInternal::Unknown),
395            origin: TypeOrigin::NoOrigin,
396        }
397    }
398
399    pub fn u16() -> InferredType {
400        InferredType {
401            inner: Box::new(TypeInternal::U16),
402            origin: TypeOrigin::NoOrigin,
403        }
404    }
405
406    pub fn u32() -> InferredType {
407        InferredType {
408            inner: Box::new(TypeInternal::U32),
409            origin: TypeOrigin::NoOrigin,
410        }
411    }
412
413    pub fn u64() -> InferredType {
414        InferredType {
415            inner: Box::new(TypeInternal::U64),
416            origin: TypeOrigin::NoOrigin,
417        }
418    }
419
420    pub fn variant(fields: Vec<(String, Option<InferredType>)>) -> InferredType {
421        InferredType {
422            inner: Box::new(TypeInternal::Variant(fields)),
423            origin: TypeOrigin::NoOrigin,
424        }
425    }
426
427    pub fn override_origin(&self, origin: TypeOrigin) -> InferredType {
428        InferredType {
429            inner: self.inner.clone(),
430            origin,
431        }
432    }
433
434    pub fn add_origin(&self, origin: TypeOrigin) -> InferredType {
435        let mut inferred_type = self.clone();
436        inferred_type.add_origin_mut(origin.clone());
437        inferred_type
438    }
439
440    pub fn add_origin_mut(&mut self, origin: TypeOrigin) {
441        self.origin = self.origin.add_origin(origin);
442    }
443
444    pub fn without_origin(inferred_type: TypeInternal) -> InferredType {
445        InferredType {
446            inner: Box::new(inferred_type),
447            origin: TypeOrigin::NoOrigin,
448        }
449    }
450
451    pub fn printable(&self) -> String {
452        // Try a fully blown type name or if it fails,
453        // get the `kind` of inferred type
454        TypeName::try_from(self.clone())
455            .map(|tn| tn.to_string())
456            .unwrap_or(self.get_type_hint().to_string())
457    }
458
459    pub fn all_of(types: Vec<InferredType>) -> InferredType {
460        get_merge_task(&types).complete()
461    }
462
463    pub fn is_unit(&self) -> bool {
464        match self.inner.deref() {
465            TypeInternal::Sequence(types) => types.is_empty(),
466            _ => false,
467        }
468    }
469    pub fn is_unknown(&self) -> bool {
470        matches!(self.inner.deref(), TypeInternal::Unknown)
471    }
472
473    pub fn is_valid_wit_type(&self) -> bool {
474        WitType::try_from(self).is_ok()
475    }
476
477    pub fn is_all_of(&self) -> bool {
478        matches!(self.inner.deref(), TypeInternal::AllOf(_))
479    }
480
481    pub fn is_number(&self) -> bool {
482        matches!(
483            self.inner.deref(),
484            TypeInternal::S8
485                | TypeInternal::U8
486                | TypeInternal::S16
487                | TypeInternal::U16
488                | TypeInternal::S32
489                | TypeInternal::U32
490                | TypeInternal::S64
491                | TypeInternal::U64
492                | TypeInternal::F32
493                | TypeInternal::F64
494        )
495    }
496
497    pub fn is_string(&self) -> bool {
498        matches!(self.inner.deref(), TypeInternal::Str)
499    }
500
501    pub fn flatten_all_of_inferred_types(types: &Vec<InferredType>) -> Vec<InferredType> {
502        flatten_all_of_list(types)
503    }
504
505    // Here unification returns an inferred type, but it doesn't necessarily imply
506    // its valid type, which can be converted to a wasm type.
507    pub fn unify(&self) -> Result<InferredType, UnificationFailureInternal> {
508        unify(self).map(|x| x.inferred_type())
509    }
510
511    // There is only one way to merge types. If they are different, they are merged into AllOf
512    pub fn merge(&self, new_inferred_type: InferredType) -> InferredType {
513        match (self.inner.deref(), new_inferred_type.inner.deref()) {
514            (TypeInternal::Unknown, _) => new_inferred_type.add_origin(self.origin.clone()),
515
516            (TypeInternal::AllOf(existing_types), TypeInternal::AllOf(new_types)) => {
517                let mut all_types = new_types.clone();
518                all_types.extend(existing_types.clone());
519
520                InferredType::all_of(all_types)
521            }
522
523            (TypeInternal::AllOf(existing_types), _) => {
524                let mut all_types = existing_types.clone();
525                all_types.push(new_inferred_type);
526
527                InferredType::all_of(all_types)
528            }
529
530            (_, TypeInternal::AllOf(new_types)) => {
531                let mut all_types = new_types.clone();
532                all_types.push(self.clone());
533
534                InferredType::all_of(all_types)
535            }
536
537            (_, _) => {
538                if self != &new_inferred_type && !new_inferred_type.is_unknown() {
539                    InferredType::all_of(vec![self.clone(), new_inferred_type])
540                } else {
541                    self.clone().add_origin(new_inferred_type.origin.clone())
542                }
543            }
544        }
545    }
546
547    pub fn from_type_variant(type_variant: &TypeVariant) -> InferredType {
548        let cases = type_variant
549            .cases
550            .iter()
551            .map(|name_type_pair| {
552                (
553                    name_type_pair.name.clone(),
554                    name_type_pair.typ.as_ref().map(|t| t.into()),
555                )
556            })
557            .collect();
558
559        InferredType::from_variant_cases(cases)
560    }
561
562    pub fn from_variant_cases(cases: Vec<(String, Option<InferredType>)>) -> InferredType {
563        InferredType::without_origin(TypeInternal::Variant(cases))
564    }
565
566    pub fn from_enum_cases(type_enum: &TypeEnum) -> InferredType {
567        InferredType::without_origin(TypeInternal::Enum(type_enum.cases.clone()))
568    }
569}
570
571impl PartialEq for InferredType {
572    fn eq(&self, other: &Self) -> bool {
573        self.inner == other.inner
574    }
575}
576
577impl Hash for InferredType {
578    fn hash<H: Hasher>(&self, state: &mut H) {
579        self.inner.hash(state);
580    }
581}
582
583#[derive(PartialEq, Clone, Debug)]
584pub enum InferredNumber {
585    S8,
586    U8,
587    S16,
588    U16,
589    S32,
590    U32,
591    S64,
592    U64,
593    F32,
594    F64,
595}
596
597impl From<&InferredNumber> for InferredType {
598    fn from(inferred_number: &InferredNumber) -> Self {
599        match inferred_number {
600            InferredNumber::S8 => InferredType::s8(),
601            InferredNumber::U8 => InferredType::u8(),
602            InferredNumber::S16 => InferredType::s16(),
603            InferredNumber::U16 => InferredType::u16(),
604            InferredNumber::S32 => InferredType::s32(),
605            InferredNumber::U32 => InferredType::u32(),
606            InferredNumber::S64 => InferredType::s64(),
607            InferredNumber::U64 => InferredType::u64(),
608            InferredNumber::F32 => InferredType::f32(),
609            InferredNumber::F64 => InferredType::f64(),
610        }
611    }
612}
613
614impl From<&DefaultType> for InferredType {
615    fn from(default_type: &DefaultType) -> Self {
616        match default_type {
617            DefaultType::String => InferredType::string().as_default(default_type.clone()),
618            DefaultType::F64 => InferredType::f64().as_default(default_type.clone()),
619            DefaultType::S32 => InferredType::s32().as_default(default_type.clone()),
620        }
621    }
622}
623
624impl From<&BigDecimal> for InferredType {
625    fn from(value: &BigDecimal) -> Self {
626        if value.fractional_digit_count() <= 0 {
627            // Rust inspired
628            // https://github.com/rust-lang/rfcs/blob/master/text/0212-restore-int-fallback.md#rationale-for-the-choice-of-defaulting-to-i32
629            InferredType::s32()
630        } else {
631            // more precision, almost same perf as f32
632            InferredType::f64()
633        }
634    }
635}
636
637#[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
638pub struct RangeType {
639    from: Box<TypeInternal>,
640    to: Option<Box<TypeInternal>>,
641}
642
643impl Display for InferredNumber {
644    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
645        let type_name = TypeName::from(self);
646        write!(f, "{type_name}")
647    }
648}
649
650impl From<&WitType> for InferredType {
651    fn from(analysed_type: &WitType) -> Self {
652        match analysed_type {
653            WitType::Bool(_) => InferredType::bool(),
654            WitType::S8(_) => InferredType::s8(),
655            WitType::U8(_) => InferredType::u8(),
656            WitType::S16(_) => InferredType::s16(),
657            WitType::U16(_) => InferredType::u16(),
658            WitType::S32(_) => InferredType::s32(),
659            WitType::U32(_) => InferredType::u32(),
660            WitType::S64(_) => InferredType::s64(),
661            WitType::U64(_) => InferredType::u64(),
662            WitType::F32(_) => InferredType::f32(),
663            WitType::F64(_) => InferredType::f64(),
664            WitType::Chr(_) => InferredType::char(),
665            WitType::Str(_) => InferredType::string(),
666            WitType::List(t) => InferredType::list(t.inner.as_ref().into()),
667            WitType::Tuple(ts) => InferredType::tuple(ts.items.iter().map(|t| t.into()).collect()),
668            WitType::Record(fs) => InferredType::record(
669                fs.fields
670                    .iter()
671                    .map(|name_type| (name_type.name.clone(), (&name_type.typ).into()))
672                    .collect(),
673            ),
674            WitType::Flags(vs) => InferredType::flags(vs.names.clone()),
675            WitType::Enum(vs) => InferredType::from_enum_cases(vs),
676            WitType::Option(t) => InferredType::option(t.inner.as_ref().into()),
677            WitType::Result(TypeResult { ok, err, .. }) => InferredType::result(
678                ok.as_ref().map(|t| t.as_ref().into()),
679                err.as_ref().map(|t| t.as_ref().into()),
680            ),
681            WitType::Variant(vs) => InferredType::from_type_variant(vs),
682            WitType::Handle(TypeHandle {
683                resource_id,
684                mode,
685                name,
686                owner,
687            }) => InferredType::resource(
688                resource_id.0,
689                match mode {
690                    AnalysedResourceMode::Owned => 0,
691                    AnalysedResourceMode::Borrowed => 1,
692                },
693                owner.clone(),
694                name.clone(),
695            ),
696        }
697    }
698}