rib/inferred_type/
mod.rs

1// Copyright 2024-2025 Golem Cloud
2//
3// Licensed under the Golem Source License v1.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://license.golem.cloud/LICENSE
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub use type_internal::*;
16
17pub(crate) use all_of::*;
18pub(crate) use type_origin::*;
19pub(crate) use unification::*;
20
21mod all_of;
22mod type_internal;
23mod type_origin;
24mod unification;
25
26use crate::instance_type::InstanceType;
27use crate::rib_source_span::SourceSpan;
28use crate::type_inference::GetTypeHint;
29use crate::TypeName;
30use bigdecimal::BigDecimal;
31use golem_wasm_ast::analysis::*;
32use std::fmt::{Display, Formatter};
33use std::hash::{Hash, Hasher};
34use std::ops::Deref;
35
36#[derive(Debug, Clone, Eq, PartialOrd, Ord)]
37pub struct InferredType {
38    pub inner: Box<TypeInternal>,
39    pub origin: TypeOrigin,
40}
41
42impl InferredType {
43    pub fn originated_at(&self, source_span: &SourceSpan) -> InferredType {
44        self.add_origin(TypeOrigin::OriginatedAt(source_span.clone()))
45    }
46
47    pub fn origin(&self) -> TypeOrigin {
48        self.origin.clone()
49    }
50
51    pub fn source_span(&self) -> Option<SourceSpan> {
52        let origin = self.origin();
53
54        match origin {
55            TypeOrigin::Default(_) => None,
56            TypeOrigin::NoOrigin => None,
57            TypeOrigin::Declared(_) => None,
58            TypeOrigin::Multiple(origins) => {
59                let mut source_span = None;
60                for origin in origins {
61                    if let TypeOrigin::OriginatedAt(loc) = origin {
62                        source_span = Some(loc.clone());
63                        break;
64                    }
65                }
66                source_span
67            }
68            TypeOrigin::OriginatedAt(_) => None,
69        }
70    }
71
72    pub fn as_number(&self) -> Result<InferredNumber, String> {
73        fn go(with_origin: &InferredType, found: &mut Vec<InferredNumber>) -> Result<(), String> {
74            match with_origin.inner.deref() {
75                TypeInternal::S8 => {
76                    found.push(InferredNumber::S8);
77                    Ok(())
78                }
79                TypeInternal::U8 => {
80                    found.push(InferredNumber::U8);
81                    Ok(())
82                }
83                TypeInternal::S16 => {
84                    found.push(InferredNumber::S16);
85                    Ok(())
86                }
87                TypeInternal::U16 => {
88                    found.push(InferredNumber::U16);
89                    Ok(())
90                }
91                TypeInternal::S32 => {
92                    found.push(InferredNumber::S32);
93                    Ok(())
94                }
95                TypeInternal::U32 => {
96                    found.push(InferredNumber::U32);
97                    Ok(())
98                }
99                TypeInternal::S64 => {
100                    found.push(InferredNumber::S64);
101                    Ok(())
102                }
103                TypeInternal::U64 => {
104                    found.push(InferredNumber::U64);
105                    Ok(())
106                }
107                TypeInternal::F32 => {
108                    found.push(InferredNumber::F32);
109                    Ok(())
110                }
111                TypeInternal::F64 => {
112                    found.push(InferredNumber::F64);
113                    Ok(())
114                }
115                TypeInternal::AllOf(all_variables) => {
116                    let mut previous: Option<InferredNumber> = None;
117                    for variable in all_variables {
118                        go(variable, found)?;
119
120                        if let Some(current) = found.first() {
121                            match &previous {
122                                None => {
123                                    previous = Some(current.clone());
124                                    found.push(current.clone());
125                                }
126                                Some(previous) => {
127                                    if previous != current {
128                                        return Err(format!(
129                                            "expected the same type of number. But found {current}, {previous}"
130                                        ));
131                                    }
132
133                                    found.push(current.clone());
134                                }
135                            }
136                        } else {
137                            return Err("failed to get a number".to_string());
138                        }
139                    }
140
141                    Ok(())
142                }
143                TypeInternal::Range { .. } => Err("used as range".to_string()),
144                TypeInternal::Bool => Err(format!("used as {}", "bool")),
145                TypeInternal::Chr => Err(format!("used as {}", "char")),
146                TypeInternal::Str => Err(format!("used as {}", "string")),
147                TypeInternal::List(_) => Err(format!("used as {}", "list")),
148                TypeInternal::Tuple(_) => Err(format!("used as {}", "tuple")),
149                TypeInternal::Record(_) => Err(format!("used as {}", "record")),
150                TypeInternal::Flags(_) => Err(format!("used as {}", "flags")),
151                TypeInternal::Enum(_) => Err(format!("used as {}", "enum")),
152                TypeInternal::Option(_) => Err(format!("used as {}", "option")),
153                TypeInternal::Result { .. } => Err(format!("used as {}", "result")),
154                TypeInternal::Variant(_) => Err(format!("used as {}", "variant")),
155                TypeInternal::Unknown => Err("found unknown".to_string()),
156                TypeInternal::Sequence(_) => {
157                    Err(format!("used as {}", "function-multi-parameter-return"))
158                }
159                TypeInternal::Resource { .. } => Err(format!("used as {}", "resource")),
160                TypeInternal::Instance { .. } => Err(format!("used as {}", "instance")),
161            }
162        }
163
164        let mut found: Vec<InferredNumber> = vec![];
165        go(self, &mut found)?;
166        found.first().cloned().ok_or("Failed".to_string())
167    }
168
169    pub fn bool() -> InferredType {
170        InferredType {
171            inner: Box::new(TypeInternal::Bool),
172            origin: TypeOrigin::NoOrigin,
173        }
174    }
175
176    pub fn char() -> InferredType {
177        InferredType {
178            inner: Box::new(TypeInternal::Chr),
179            origin: TypeOrigin::NoOrigin,
180        }
181    }
182
183    pub fn contains_only_number(&self) -> bool {
184        match self.inner.deref() {
185            TypeInternal::S8
186            | TypeInternal::U8
187            | TypeInternal::S16
188            | TypeInternal::U16
189            | TypeInternal::S32
190            | TypeInternal::U32
191            | TypeInternal::S64
192            | TypeInternal::U64
193            | TypeInternal::F32
194            | TypeInternal::F64 => true,
195            TypeInternal::Bool => false,
196            TypeInternal::Chr => false,
197            TypeInternal::Str => false,
198            TypeInternal::List(_) => false,
199            TypeInternal::Tuple(_) => false,
200            TypeInternal::Record(_) => false,
201            TypeInternal::Flags(_) => false,
202            TypeInternal::Enum(_) => false,
203            TypeInternal::Option(_) => false,
204            TypeInternal::Result { .. } => false,
205            TypeInternal::Variant(_) => false,
206            TypeInternal::Resource { .. } => false,
207            TypeInternal::Range { .. } => false,
208            TypeInternal::Instance { .. } => false,
209            TypeInternal::Unknown => false,
210            TypeInternal::Sequence(_) => false,
211            TypeInternal::AllOf(types) => types.iter().all(|t| t.contains_only_number()),
212        }
213    }
214
215    pub fn declared_at(&self, source_span: SourceSpan) -> InferredType {
216        self.add_origin(TypeOrigin::Declared(source_span.clone()))
217    }
218
219    pub fn as_default(&self, default_type: DefaultType) -> InferredType {
220        let new_origin = TypeOrigin::Default(default_type);
221
222        InferredType {
223            inner: self.inner.clone(),
224            origin: self.origin.add_origin(new_origin),
225        }
226    }
227
228    pub fn enum_(cases: Vec<String>) -> InferredType {
229        InferredType {
230            inner: Box::new(TypeInternal::Enum(cases)),
231            origin: TypeOrigin::NoOrigin,
232        }
233    }
234
235    pub fn f32() -> InferredType {
236        InferredType {
237            inner: Box::new(TypeInternal::F32),
238            origin: TypeOrigin::NoOrigin,
239        }
240    }
241
242    pub fn f64() -> InferredType {
243        InferredType {
244            inner: Box::new(TypeInternal::F64),
245            origin: TypeOrigin::NoOrigin,
246        }
247    }
248
249    pub fn flags(flags: Vec<String>) -> InferredType {
250        InferredType {
251            inner: Box::new(TypeInternal::Flags(flags)),
252            origin: TypeOrigin::NoOrigin,
253        }
254    }
255
256    pub fn instance(instance_type: InstanceType) -> InferredType {
257        InferredType {
258            inner: Box::new(TypeInternal::Instance {
259                instance_type: Box::new(instance_type),
260            }),
261            origin: TypeOrigin::NoOrigin,
262        }
263    }
264
265    pub fn internal_type(&self) -> &TypeInternal {
266        self.inner.as_ref()
267    }
268
269    pub fn internal_type_mut(&mut self) -> &mut TypeInternal {
270        self.inner.as_mut()
271    }
272
273    pub fn list(inner: InferredType) -> InferredType {
274        InferredType {
275            inner: Box::new(TypeInternal::List(inner)),
276            origin: TypeOrigin::NoOrigin,
277        }
278    }
279
280    pub fn new(inferred_type: TypeInternal, origin: TypeOrigin) -> InferredType {
281        InferredType {
282            inner: Box::new(inferred_type),
283            origin,
284        }
285    }
286
287    pub fn option(inner: InferredType) -> InferredType {
288        InferredType {
289            inner: Box::new(TypeInternal::Option(inner)),
290            origin: TypeOrigin::NoOrigin,
291        }
292    }
293
294    pub fn range(from: InferredType, to: Option<InferredType>) -> InferredType {
295        InferredType {
296            inner: Box::new(TypeInternal::Range { from, to }),
297            origin: TypeOrigin::NoOrigin,
298        }
299    }
300
301    pub fn eliminate_default(inferred_types: Vec<&InferredType>) -> Vec<&InferredType> {
302        inferred_types
303            .into_iter()
304            .filter(|&t| !t.origin.is_default())
305            .collect::<Vec<_>>()
306    }
307
308    pub fn record(fields: Vec<(String, InferredType)>) -> InferredType {
309        InferredType {
310            inner: Box::new(TypeInternal::Record(fields)),
311            origin: TypeOrigin::NoOrigin,
312        }
313    }
314
315    pub fn resolved(inferred_type: TypeInternal) -> InferredType {
316        InferredType {
317            inner: Box::new(inferred_type),
318            origin: TypeOrigin::NoOrigin,
319        }
320    }
321
322    pub fn resource(resource_id: u64, resource_mode: u8) -> InferredType {
323        InferredType {
324            inner: Box::new(TypeInternal::Resource {
325                resource_id,
326                resource_mode,
327            }),
328            origin: TypeOrigin::NoOrigin,
329        }
330    }
331
332    pub fn result(ok: Option<InferredType>, error: Option<InferredType>) -> InferredType {
333        InferredType {
334            inner: Box::new(TypeInternal::Result { ok, error }),
335            origin: TypeOrigin::NoOrigin,
336        }
337    }
338
339    pub fn sequence(inferred_types: Vec<InferredType>) -> InferredType {
340        InferredType {
341            inner: Box::new(TypeInternal::Sequence(inferred_types)),
342            origin: TypeOrigin::NoOrigin,
343        }
344    }
345
346    pub fn string() -> InferredType {
347        InferredType {
348            inner: Box::new(TypeInternal::Str),
349            origin: TypeOrigin::NoOrigin,
350        }
351    }
352
353    pub fn s8() -> InferredType {
354        InferredType {
355            inner: Box::new(TypeInternal::S8),
356            origin: TypeOrigin::NoOrigin,
357        }
358    }
359
360    pub fn s16() -> InferredType {
361        InferredType {
362            inner: Box::new(TypeInternal::S16),
363            origin: TypeOrigin::NoOrigin,
364        }
365    }
366
367    pub fn s32() -> InferredType {
368        InferredType {
369            inner: Box::new(TypeInternal::S32),
370            origin: TypeOrigin::NoOrigin,
371        }
372    }
373
374    pub fn s64() -> InferredType {
375        InferredType {
376            inner: Box::new(TypeInternal::S64),
377            origin: TypeOrigin::NoOrigin,
378        }
379    }
380
381    pub fn tuple(inner: Vec<InferredType>) -> InferredType {
382        InferredType {
383            inner: Box::new(TypeInternal::Tuple(inner)),
384            origin: TypeOrigin::NoOrigin,
385        }
386    }
387
388    pub fn u8() -> InferredType {
389        InferredType {
390            inner: Box::new(TypeInternal::U8),
391            origin: TypeOrigin::NoOrigin,
392        }
393    }
394
395    pub fn unit() -> InferredType {
396        InferredType::tuple(vec![])
397    }
398
399    pub fn unknown() -> InferredType {
400        InferredType {
401            inner: Box::new(TypeInternal::Unknown),
402            origin: TypeOrigin::NoOrigin,
403        }
404    }
405
406    pub fn u16() -> InferredType {
407        InferredType {
408            inner: Box::new(TypeInternal::U16),
409            origin: TypeOrigin::NoOrigin,
410        }
411    }
412
413    pub fn u32() -> InferredType {
414        InferredType {
415            inner: Box::new(TypeInternal::U32),
416            origin: TypeOrigin::NoOrigin,
417        }
418    }
419
420    pub fn u64() -> InferredType {
421        InferredType {
422            inner: Box::new(TypeInternal::U64),
423            origin: TypeOrigin::NoOrigin,
424        }
425    }
426
427    pub fn variant(fields: Vec<(String, Option<InferredType>)>) -> InferredType {
428        InferredType {
429            inner: Box::new(TypeInternal::Variant(fields)),
430            origin: TypeOrigin::NoOrigin,
431        }
432    }
433
434    pub fn override_origin(&self, origin: TypeOrigin) -> InferredType {
435        InferredType {
436            inner: self.inner.clone(),
437            origin,
438        }
439    }
440
441    pub fn add_origin(&self, origin: TypeOrigin) -> InferredType {
442        let mut inferred_type = self.clone();
443        inferred_type.add_origin_mut(origin.clone());
444        inferred_type
445    }
446
447    pub fn add_origin_mut(&mut self, origin: TypeOrigin) {
448        self.origin = self.origin.add_origin(origin);
449    }
450
451    pub fn without_origin(inferred_type: TypeInternal) -> InferredType {
452        InferredType {
453            inner: Box::new(inferred_type),
454            origin: TypeOrigin::NoOrigin,
455        }
456    }
457
458    pub fn printable(&self) -> String {
459        // Try a fully blown type name or if it fails,
460        // get the `kind` of inferred type
461        TypeName::try_from(self.clone())
462            .map(|tn| tn.to_string())
463            .unwrap_or(self.get_type_hint().to_string())
464    }
465
466    pub fn all_of(types: Vec<InferredType>) -> InferredType {
467        get_merge_task(&types).complete()
468    }
469
470    pub fn is_unit(&self) -> bool {
471        match self.inner.deref() {
472            TypeInternal::Sequence(types) => types.is_empty(),
473            _ => false,
474        }
475    }
476    pub fn is_unknown(&self) -> bool {
477        matches!(self.inner.deref(), TypeInternal::Unknown)
478    }
479
480    pub fn is_valid_wit_type(&self) -> bool {
481        AnalysedType::try_from(self).is_ok()
482    }
483
484    pub fn is_all_of(&self) -> bool {
485        matches!(self.inner.deref(), TypeInternal::AllOf(_))
486    }
487
488    pub fn is_number(&self) -> bool {
489        matches!(
490            self.inner.deref(),
491            TypeInternal::S8
492                | TypeInternal::U8
493                | TypeInternal::S16
494                | TypeInternal::U16
495                | TypeInternal::S32
496                | TypeInternal::U32
497                | TypeInternal::S64
498                | TypeInternal::U64
499                | TypeInternal::F32
500                | TypeInternal::F64
501        )
502    }
503
504    pub fn is_string(&self) -> bool {
505        matches!(self.inner.deref(), TypeInternal::Str)
506    }
507
508    pub fn flatten_all_of_inferred_types(types: &Vec<InferredType>) -> Vec<InferredType> {
509        flatten_all_of_list(types)
510    }
511
512    // Here unification returns an inferred type, but it doesn't necessarily imply
513    // its valid type, which can be converted to a wasm type.
514    pub fn unify(&self) -> Result<InferredType, UnificationFailureInternal> {
515        unify(self).map(|x| x.inferred_type())
516    }
517
518    // There is only one way to merge types. If they are different, they are merged into AllOf
519    pub fn merge(&self, new_inferred_type: InferredType) -> InferredType {
520        match (self.inner.deref(), new_inferred_type.inner.deref()) {
521            (TypeInternal::Unknown, _) => new_inferred_type.add_origin(self.origin.clone()),
522
523            (TypeInternal::AllOf(existing_types), TypeInternal::AllOf(new_types)) => {
524                let mut all_types = new_types.clone();
525                all_types.extend(existing_types.clone());
526
527                InferredType::all_of(all_types)
528            }
529
530            (TypeInternal::AllOf(existing_types), _) => {
531                let mut all_types = existing_types.clone();
532                all_types.push(new_inferred_type);
533
534                InferredType::all_of(all_types)
535            }
536
537            (_, TypeInternal::AllOf(new_types)) => {
538                let mut all_types = new_types.clone();
539                all_types.push(self.clone());
540
541                InferredType::all_of(all_types)
542            }
543
544            (_, _) => {
545                if self != &new_inferred_type && !new_inferred_type.is_unknown() {
546                    InferredType::all_of(vec![self.clone(), new_inferred_type])
547                } else {
548                    self.clone().add_origin(new_inferred_type.origin.clone())
549                }
550            }
551        }
552    }
553
554    pub fn from_type_variant(type_variant: &TypeVariant) -> InferredType {
555        let cases = type_variant
556            .cases
557            .iter()
558            .map(|name_type_pair| {
559                (
560                    name_type_pair.name.clone(),
561                    name_type_pair.typ.as_ref().map(|t| t.into()),
562                )
563            })
564            .collect();
565
566        InferredType::from_variant_cases(cases)
567    }
568
569    pub fn from_variant_cases(cases: Vec<(String, Option<InferredType>)>) -> InferredType {
570        InferredType::without_origin(TypeInternal::Variant(cases))
571    }
572
573    pub fn from_enum_cases(type_enum: &TypeEnum) -> InferredType {
574        InferredType::without_origin(TypeInternal::Enum(type_enum.cases.clone()))
575    }
576}
577
578impl PartialEq for InferredType {
579    fn eq(&self, other: &Self) -> bool {
580        self.inner == other.inner
581    }
582}
583
584impl Hash for InferredType {
585    fn hash<H: Hasher>(&self, state: &mut H) {
586        self.inner.hash(state);
587    }
588}
589
590#[derive(PartialEq, Clone, Debug)]
591pub enum InferredNumber {
592    S8,
593    U8,
594    S16,
595    U16,
596    S32,
597    U32,
598    S64,
599    U64,
600    F32,
601    F64,
602}
603
604impl From<&InferredNumber> for InferredType {
605    fn from(inferred_number: &InferredNumber) -> Self {
606        match inferred_number {
607            InferredNumber::S8 => InferredType::s8(),
608            InferredNumber::U8 => InferredType::u8(),
609            InferredNumber::S16 => InferredType::s16(),
610            InferredNumber::U16 => InferredType::u16(),
611            InferredNumber::S32 => InferredType::s32(),
612            InferredNumber::U32 => InferredType::u32(),
613            InferredNumber::S64 => InferredType::s64(),
614            InferredNumber::U64 => InferredType::u64(),
615            InferredNumber::F32 => InferredType::f32(),
616            InferredNumber::F64 => InferredType::f64(),
617        }
618    }
619}
620
621impl From<&DefaultType> for InferredType {
622    fn from(default_type: &DefaultType) -> Self {
623        match default_type {
624            DefaultType::String => InferredType::string().as_default(default_type.clone()),
625            DefaultType::F64 => InferredType::f64().as_default(default_type.clone()),
626            DefaultType::S32 => InferredType::s32().as_default(default_type.clone()),
627        }
628    }
629}
630
631impl From<&BigDecimal> for InferredType {
632    fn from(value: &BigDecimal) -> Self {
633        if value.fractional_digit_count() <= 0 {
634            // Rust inspired
635            // https://github.com/rust-lang/rfcs/blob/master/text/0212-restore-int-fallback.md#rationale-for-the-choice-of-defaulting-to-i32
636            InferredType::s32()
637        } else {
638            // more precision, almost same perf as f32
639            InferredType::f64()
640        }
641    }
642}
643
644#[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
645pub struct RangeType {
646    from: Box<TypeInternal>,
647    to: Option<Box<TypeInternal>>,
648}
649
650impl Display for InferredNumber {
651    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
652        let type_name = TypeName::from(self);
653        write!(f, "{type_name}")
654    }
655}
656
657impl From<&AnalysedType> for InferredType {
658    fn from(analysed_type: &AnalysedType) -> Self {
659        match analysed_type {
660            AnalysedType::Bool(_) => InferredType::bool(),
661            AnalysedType::S8(_) => InferredType::s8(),
662            AnalysedType::U8(_) => InferredType::u8(),
663            AnalysedType::S16(_) => InferredType::s16(),
664            AnalysedType::U16(_) => InferredType::u16(),
665            AnalysedType::S32(_) => InferredType::s32(),
666            AnalysedType::U32(_) => InferredType::u32(),
667            AnalysedType::S64(_) => InferredType::s64(),
668            AnalysedType::U64(_) => InferredType::u64(),
669            AnalysedType::F32(_) => InferredType::f32(),
670            AnalysedType::F64(_) => InferredType::f64(),
671            AnalysedType::Chr(_) => InferredType::char(),
672            AnalysedType::Str(_) => InferredType::string(),
673            AnalysedType::List(t) => InferredType::list(t.inner.as_ref().into()),
674            AnalysedType::Tuple(ts) => {
675                InferredType::tuple(ts.items.iter().map(|t| t.into()).collect())
676            }
677            AnalysedType::Record(fs) => InferredType::record(
678                fs.fields
679                    .iter()
680                    .map(|name_type| (name_type.name.clone(), (&name_type.typ).into()))
681                    .collect(),
682            ),
683            AnalysedType::Flags(vs) => InferredType::flags(vs.names.clone()),
684            AnalysedType::Enum(vs) => InferredType::from_enum_cases(vs),
685            AnalysedType::Option(t) => InferredType::option(t.inner.as_ref().into()),
686            AnalysedType::Result(golem_wasm_ast::analysis::TypeResult { ok, err, .. }) => {
687                InferredType::result(
688                    ok.as_ref().map(|t| t.as_ref().into()),
689                    err.as_ref().map(|t| t.as_ref().into()),
690                )
691            }
692            AnalysedType::Variant(vs) => InferredType::from_type_variant(vs),
693            AnalysedType::Handle(golem_wasm_ast::analysis::TypeHandle { resource_id, mode }) => {
694                InferredType::resource(
695                    resource_id.0,
696                    match mode {
697                        AnalysedResourceMode::Owned => 0,
698                        AnalysedResourceMode::Borrowed => 1,
699                    },
700                )
701            }
702        }
703    }
704}