rib/inferred_type/
mod.rs

1// Copyright 2024-2025 Golem Cloud
2//
3// Licensed under the Golem Source License v1.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://license.golem.cloud/LICENSE
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub use type_internal::*;
16
17pub(crate) use all_of::*;
18pub(crate) use type_origin::*;
19pub(crate) use unification::*;
20
21mod all_of;
22mod type_internal;
23mod type_origin;
24mod unification;
25
26use crate::instance_type::InstanceType;
27use crate::rib_source_span::SourceSpan;
28use crate::type_inference::GetTypeHint;
29use crate::TypeName;
30use bigdecimal::BigDecimal;
31use golem_wasm_ast::analysis::*;
32use std::fmt::{Display, Formatter};
33use std::hash::{Hash, Hasher};
34use std::ops::Deref;
35
36#[derive(Debug, Clone, Eq, PartialOrd, Ord)]
37pub struct InferredType {
38    pub inner: Box<TypeInternal>,
39    pub origin: TypeOrigin,
40}
41
42impl InferredType {
43    pub fn originated_at(&self, source_span: &SourceSpan) -> InferredType {
44        self.add_origin(TypeOrigin::OriginatedAt(source_span.clone()))
45    }
46
47    pub fn origin(&self) -> TypeOrigin {
48        self.origin.clone()
49    }
50
51    pub fn source_span(&self) -> Option<SourceSpan> {
52        let origin = self.origin();
53
54        match origin {
55            TypeOrigin::Default(_) => None,
56            TypeOrigin::NoOrigin => None,
57            TypeOrigin::Declared(_) => None,
58            TypeOrigin::Multiple(origins) => {
59                let mut source_span = None;
60                for origin in origins {
61                    if let TypeOrigin::OriginatedAt(loc) = origin {
62                        source_span = Some(loc.clone());
63                        break;
64                    }
65                }
66                source_span
67            }
68            TypeOrigin::OriginatedAt(_) => None,
69        }
70    }
71
72    pub fn as_number(&self) -> Result<InferredNumber, String> {
73        fn go(with_origin: &InferredType, found: &mut Vec<InferredNumber>) -> Result<(), String> {
74            match with_origin.inner.deref() {
75                TypeInternal::S8 => {
76                    found.push(InferredNumber::S8);
77                    Ok(())
78                }
79                TypeInternal::U8 => {
80                    found.push(InferredNumber::U8);
81                    Ok(())
82                }
83                TypeInternal::S16 => {
84                    found.push(InferredNumber::S16);
85                    Ok(())
86                }
87                TypeInternal::U16 => {
88                    found.push(InferredNumber::U16);
89                    Ok(())
90                }
91                TypeInternal::S32 => {
92                    found.push(InferredNumber::S32);
93                    Ok(())
94                }
95                TypeInternal::U32 => {
96                    found.push(InferredNumber::U32);
97                    Ok(())
98                }
99                TypeInternal::S64 => {
100                    found.push(InferredNumber::S64);
101                    Ok(())
102                }
103                TypeInternal::U64 => {
104                    found.push(InferredNumber::U64);
105                    Ok(())
106                }
107                TypeInternal::F32 => {
108                    found.push(InferredNumber::F32);
109                    Ok(())
110                }
111                TypeInternal::F64 => {
112                    found.push(InferredNumber::F64);
113                    Ok(())
114                }
115                TypeInternal::AllOf(all_variables) => {
116                    let mut previous: Option<InferredNumber> = None;
117                    for variable in all_variables {
118                        go(variable, found)?;
119
120                        if let Some(current) = found.first() {
121                            match &previous {
122                                None => {
123                                    previous = Some(current.clone());
124                                    found.push(current.clone());
125                                }
126                                Some(previous) => {
127                                    if previous != current {
128                                        return Err(format!(
129                                            "expected the same type of number. But found {}, {}",
130                                            current, previous
131                                        ));
132                                    }
133
134                                    found.push(current.clone());
135                                }
136                            }
137                        } else {
138                            return Err("failed to get a number".to_string());
139                        }
140                    }
141
142                    Ok(())
143                }
144                TypeInternal::Range { .. } => Err("used as range".to_string()),
145                TypeInternal::Bool => Err(format!("used as {}", "bool")),
146                TypeInternal::Chr => Err(format!("used as {}", "char")),
147                TypeInternal::Str => Err(format!("used as {}", "string")),
148                TypeInternal::List(_) => Err(format!("used as {}", "list")),
149                TypeInternal::Tuple(_) => Err(format!("used as {}", "tuple")),
150                TypeInternal::Record(_) => Err(format!("used as {}", "record")),
151                TypeInternal::Flags(_) => Err(format!("used as {}", "flags")),
152                TypeInternal::Enum(_) => Err(format!("used as {}", "enum")),
153                TypeInternal::Option(_) => Err(format!("used as {}", "option")),
154                TypeInternal::Result { .. } => Err(format!("used as {}", "result")),
155                TypeInternal::Variant(_) => Err(format!("used as {}", "variant")),
156                TypeInternal::Unknown => Err("found unknown".to_string()),
157                TypeInternal::Sequence(_) => {
158                    Err(format!("used as {}", "function-multi-parameter-return"))
159                }
160                TypeInternal::Resource { .. } => Err(format!("used as {}", "resource")),
161                TypeInternal::Instance { .. } => Err(format!("used as {}", "instance")),
162            }
163        }
164
165        let mut found: Vec<InferredNumber> = vec![];
166        go(self, &mut found)?;
167        found.first().cloned().ok_or("Failed".to_string())
168    }
169
170    pub fn bool() -> InferredType {
171        InferredType {
172            inner: Box::new(TypeInternal::Bool),
173            origin: TypeOrigin::NoOrigin,
174        }
175    }
176
177    pub fn char() -> InferredType {
178        InferredType {
179            inner: Box::new(TypeInternal::Chr),
180            origin: TypeOrigin::NoOrigin,
181        }
182    }
183
184    pub fn contains_only_number(&self) -> bool {
185        match self.inner.deref() {
186            TypeInternal::S8
187            | TypeInternal::U8
188            | TypeInternal::S16
189            | TypeInternal::U16
190            | TypeInternal::S32
191            | TypeInternal::U32
192            | TypeInternal::S64
193            | TypeInternal::U64
194            | TypeInternal::F32
195            | TypeInternal::F64 => true,
196            TypeInternal::Bool => false,
197            TypeInternal::Chr => false,
198            TypeInternal::Str => false,
199            TypeInternal::List(_) => false,
200            TypeInternal::Tuple(_) => false,
201            TypeInternal::Record(_) => false,
202            TypeInternal::Flags(_) => false,
203            TypeInternal::Enum(_) => false,
204            TypeInternal::Option(_) => false,
205            TypeInternal::Result { .. } => false,
206            TypeInternal::Variant(_) => false,
207            TypeInternal::Resource { .. } => false,
208            TypeInternal::Range { .. } => false,
209            TypeInternal::Instance { .. } => false,
210            TypeInternal::Unknown => false,
211            TypeInternal::Sequence(_) => false,
212            TypeInternal::AllOf(types) => types.iter().all(|t| t.contains_only_number()),
213        }
214    }
215
216    pub fn declared_at(&self, source_span: SourceSpan) -> InferredType {
217        self.add_origin(TypeOrigin::Declared(source_span.clone()))
218    }
219
220    pub fn as_default(&self, default_type: DefaultType) -> InferredType {
221        let new_origin = TypeOrigin::Default(default_type);
222
223        InferredType {
224            inner: self.inner.clone(),
225            origin: self.origin.add_origin(new_origin),
226        }
227    }
228
229    pub fn enum_(cases: Vec<String>) -> InferredType {
230        InferredType {
231            inner: Box::new(TypeInternal::Enum(cases)),
232            origin: TypeOrigin::NoOrigin,
233        }
234    }
235
236    pub fn f32() -> InferredType {
237        InferredType {
238            inner: Box::new(TypeInternal::F32),
239            origin: TypeOrigin::NoOrigin,
240        }
241    }
242
243    pub fn f64() -> InferredType {
244        InferredType {
245            inner: Box::new(TypeInternal::F64),
246            origin: TypeOrigin::NoOrigin,
247        }
248    }
249
250    pub fn flags(flags: Vec<String>) -> InferredType {
251        InferredType {
252            inner: Box::new(TypeInternal::Flags(flags)),
253            origin: TypeOrigin::NoOrigin,
254        }
255    }
256
257    pub fn instance(instance_type: InstanceType) -> InferredType {
258        InferredType {
259            inner: Box::new(TypeInternal::Instance {
260                instance_type: Box::new(instance_type),
261            }),
262            origin: TypeOrigin::NoOrigin,
263        }
264    }
265
266    pub fn internal_type(&self) -> &TypeInternal {
267        self.inner.as_ref()
268    }
269
270    pub fn internal_type_mut(&mut self) -> &mut TypeInternal {
271        self.inner.as_mut()
272    }
273
274    pub fn list(inner: InferredType) -> InferredType {
275        InferredType {
276            inner: Box::new(TypeInternal::List(inner)),
277            origin: TypeOrigin::NoOrigin,
278        }
279    }
280
281    pub fn new(inferred_type: TypeInternal, origin: TypeOrigin) -> InferredType {
282        InferredType {
283            inner: Box::new(inferred_type),
284            origin,
285        }
286    }
287
288    pub fn option(inner: InferredType) -> InferredType {
289        InferredType {
290            inner: Box::new(TypeInternal::Option(inner)),
291            origin: TypeOrigin::NoOrigin,
292        }
293    }
294
295    pub fn range(from: InferredType, to: Option<InferredType>) -> InferredType {
296        InferredType {
297            inner: Box::new(TypeInternal::Range { from, to }),
298            origin: TypeOrigin::NoOrigin,
299        }
300    }
301
302    pub fn eliminate_default(inferred_types: Vec<&InferredType>) -> Vec<&InferredType> {
303        inferred_types
304            .into_iter()
305            .filter(|&t| !t.origin.is_default())
306            .collect::<Vec<_>>()
307    }
308
309    pub fn record(fields: Vec<(String, InferredType)>) -> InferredType {
310        InferredType {
311            inner: Box::new(TypeInternal::Record(fields)),
312            origin: TypeOrigin::NoOrigin,
313        }
314    }
315
316    pub fn resolved(inferred_type: TypeInternal) -> InferredType {
317        InferredType {
318            inner: Box::new(inferred_type),
319            origin: TypeOrigin::NoOrigin,
320        }
321    }
322
323    pub fn resource(resource_id: u64, resource_mode: u8) -> InferredType {
324        InferredType {
325            inner: Box::new(TypeInternal::Resource {
326                resource_id,
327                resource_mode,
328            }),
329            origin: TypeOrigin::NoOrigin,
330        }
331    }
332
333    pub fn result(ok: Option<InferredType>, error: Option<InferredType>) -> InferredType {
334        InferredType {
335            inner: Box::new(TypeInternal::Result { ok, error }),
336            origin: TypeOrigin::NoOrigin,
337        }
338    }
339
340    pub fn sequence(inferred_types: Vec<InferredType>) -> InferredType {
341        InferredType {
342            inner: Box::new(TypeInternal::Sequence(inferred_types)),
343            origin: TypeOrigin::NoOrigin,
344        }
345    }
346
347    pub fn string() -> InferredType {
348        InferredType {
349            inner: Box::new(TypeInternal::Str),
350            origin: TypeOrigin::NoOrigin,
351        }
352    }
353
354    pub fn s8() -> InferredType {
355        InferredType {
356            inner: Box::new(TypeInternal::S8),
357            origin: TypeOrigin::NoOrigin,
358        }
359    }
360
361    pub fn s16() -> InferredType {
362        InferredType {
363            inner: Box::new(TypeInternal::S16),
364            origin: TypeOrigin::NoOrigin,
365        }
366    }
367
368    pub fn s32() -> InferredType {
369        InferredType {
370            inner: Box::new(TypeInternal::S32),
371            origin: TypeOrigin::NoOrigin,
372        }
373    }
374
375    pub fn s64() -> InferredType {
376        InferredType {
377            inner: Box::new(TypeInternal::S64),
378            origin: TypeOrigin::NoOrigin,
379        }
380    }
381
382    pub fn tuple(inner: Vec<InferredType>) -> InferredType {
383        InferredType {
384            inner: Box::new(TypeInternal::Tuple(inner)),
385            origin: TypeOrigin::NoOrigin,
386        }
387    }
388
389    pub fn u8() -> InferredType {
390        InferredType {
391            inner: Box::new(TypeInternal::U8),
392            origin: TypeOrigin::NoOrigin,
393        }
394    }
395
396    pub fn unit() -> InferredType {
397        InferredType::tuple(vec![])
398    }
399
400    pub fn unknown() -> InferredType {
401        InferredType {
402            inner: Box::new(TypeInternal::Unknown),
403            origin: TypeOrigin::NoOrigin,
404        }
405    }
406
407    pub fn u16() -> InferredType {
408        InferredType {
409            inner: Box::new(TypeInternal::U16),
410            origin: TypeOrigin::NoOrigin,
411        }
412    }
413
414    pub fn u32() -> InferredType {
415        InferredType {
416            inner: Box::new(TypeInternal::U32),
417            origin: TypeOrigin::NoOrigin,
418        }
419    }
420
421    pub fn u64() -> InferredType {
422        InferredType {
423            inner: Box::new(TypeInternal::U64),
424            origin: TypeOrigin::NoOrigin,
425        }
426    }
427
428    pub fn variant(fields: Vec<(String, Option<InferredType>)>) -> InferredType {
429        InferredType {
430            inner: Box::new(TypeInternal::Variant(fields)),
431            origin: TypeOrigin::NoOrigin,
432        }
433    }
434
435    pub fn override_origin(&self, origin: TypeOrigin) -> InferredType {
436        InferredType {
437            inner: self.inner.clone(),
438            origin,
439        }
440    }
441
442    pub fn add_origin(&self, origin: TypeOrigin) -> InferredType {
443        let mut inferred_type = self.clone();
444        inferred_type.add_origin_mut(origin.clone());
445        inferred_type
446    }
447
448    pub fn add_origin_mut(&mut self, origin: TypeOrigin) {
449        self.origin = self.origin.add_origin(origin);
450    }
451
452    pub fn without_origin(inferred_type: TypeInternal) -> InferredType {
453        InferredType {
454            inner: Box::new(inferred_type),
455            origin: TypeOrigin::NoOrigin,
456        }
457    }
458
459    pub fn printable(&self) -> String {
460        // Try a fully blown type name or if it fails,
461        // get the `kind` of inferred type
462        TypeName::try_from(self.clone())
463            .map(|tn| tn.to_string())
464            .unwrap_or(self.get_type_hint().to_string())
465    }
466
467    pub fn all_of(types: Vec<InferredType>) -> InferredType {
468        get_merge_task(&types).complete()
469    }
470
471    pub fn is_unit(&self) -> bool {
472        match self.inner.deref() {
473            TypeInternal::Sequence(types) => types.is_empty(),
474            _ => false,
475        }
476    }
477    pub fn is_unknown(&self) -> bool {
478        matches!(self.inner.deref(), TypeInternal::Unknown)
479    }
480
481    pub fn is_valid_wit_type(&self) -> bool {
482        AnalysedType::try_from(self).is_ok()
483    }
484
485    pub fn is_all_of(&self) -> bool {
486        matches!(self.inner.deref(), TypeInternal::AllOf(_))
487    }
488
489    pub fn is_number(&self) -> bool {
490        matches!(
491            self.inner.deref(),
492            TypeInternal::S8
493                | TypeInternal::U8
494                | TypeInternal::S16
495                | TypeInternal::U16
496                | TypeInternal::S32
497                | TypeInternal::U32
498                | TypeInternal::S64
499                | TypeInternal::U64
500                | TypeInternal::F32
501                | TypeInternal::F64
502        )
503    }
504
505    pub fn is_string(&self) -> bool {
506        matches!(self.inner.deref(), TypeInternal::Str)
507    }
508
509    pub fn flatten_all_of_inferred_types(types: &Vec<InferredType>) -> Vec<InferredType> {
510        flatten_all_of_list(types)
511    }
512
513    // Here unification returns an inferred type, but it doesn't necessarily imply
514    // its valid type, which can be converted to a wasm type.
515    pub fn unify(&self) -> Result<InferredType, UnificationFailureInternal> {
516        unify(self).map(|x| x.inferred_type())
517    }
518
519    // There is only one way to merge types. If they are different, they are merged into AllOf
520    pub fn merge(&self, new_inferred_type: InferredType) -> InferredType {
521        match (self.inner.deref(), new_inferred_type.inner.deref()) {
522            (TypeInternal::Unknown, _) => new_inferred_type.add_origin(self.origin.clone()),
523
524            (TypeInternal::AllOf(existing_types), TypeInternal::AllOf(new_types)) => {
525                let mut all_types = new_types.clone();
526                all_types.extend(existing_types.clone());
527
528                InferredType::all_of(all_types)
529            }
530
531            (TypeInternal::AllOf(existing_types), _) => {
532                let mut all_types = existing_types.clone();
533                all_types.push(new_inferred_type);
534
535                InferredType::all_of(all_types)
536            }
537
538            (_, TypeInternal::AllOf(new_types)) => {
539                let mut all_types = new_types.clone();
540                all_types.push(self.clone());
541
542                InferredType::all_of(all_types)
543            }
544
545            (_, _) => {
546                if self != &new_inferred_type && !new_inferred_type.is_unknown() {
547                    InferredType::all_of(vec![self.clone(), new_inferred_type])
548                } else {
549                    self.clone().add_origin(new_inferred_type.origin.clone())
550                }
551            }
552        }
553    }
554
555    pub fn from_type_variant(type_variant: &TypeVariant) -> InferredType {
556        let cases = type_variant
557            .cases
558            .iter()
559            .map(|name_type_pair| {
560                (
561                    name_type_pair.name.clone(),
562                    name_type_pair.typ.as_ref().map(|t| t.into()),
563                )
564            })
565            .collect();
566
567        InferredType::from_variant_cases(cases)
568    }
569
570    pub fn from_variant_cases(cases: Vec<(String, Option<InferredType>)>) -> InferredType {
571        InferredType::without_origin(TypeInternal::Variant(cases))
572    }
573
574    pub fn from_enum_cases(type_enum: &TypeEnum) -> InferredType {
575        InferredType::without_origin(TypeInternal::Enum(type_enum.cases.clone()))
576    }
577}
578
579impl PartialEq for InferredType {
580    fn eq(&self, other: &Self) -> bool {
581        self.inner == other.inner
582    }
583}
584
585impl Hash for InferredType {
586    fn hash<H: Hasher>(&self, state: &mut H) {
587        self.inner.hash(state);
588    }
589}
590
591#[derive(PartialEq, Clone, Debug)]
592pub enum InferredNumber {
593    S8,
594    U8,
595    S16,
596    U16,
597    S32,
598    U32,
599    S64,
600    U64,
601    F32,
602    F64,
603}
604
605impl From<&InferredNumber> for InferredType {
606    fn from(inferred_number: &InferredNumber) -> Self {
607        match inferred_number {
608            InferredNumber::S8 => InferredType::s8(),
609            InferredNumber::U8 => InferredType::u8(),
610            InferredNumber::S16 => InferredType::s16(),
611            InferredNumber::U16 => InferredType::u16(),
612            InferredNumber::S32 => InferredType::s32(),
613            InferredNumber::U32 => InferredType::u32(),
614            InferredNumber::S64 => InferredType::s64(),
615            InferredNumber::U64 => InferredType::u64(),
616            InferredNumber::F32 => InferredType::f32(),
617            InferredNumber::F64 => InferredType::f64(),
618        }
619    }
620}
621
622impl From<&DefaultType> for InferredType {
623    fn from(default_type: &DefaultType) -> Self {
624        match default_type {
625            DefaultType::String => InferredType::string().as_default(default_type.clone()),
626            DefaultType::F64 => InferredType::f64().as_default(default_type.clone()),
627            DefaultType::S32 => InferredType::s32().as_default(default_type.clone()),
628        }
629    }
630}
631
632impl From<&BigDecimal> for InferredType {
633    fn from(value: &BigDecimal) -> Self {
634        if value.fractional_digit_count() <= 0 {
635            // Rust inspired
636            // https://github.com/rust-lang/rfcs/blob/master/text/0212-restore-int-fallback.md#rationale-for-the-choice-of-defaulting-to-i32
637            InferredType::s32()
638        } else {
639            // more precision, almost same perf as f32
640            InferredType::f64()
641        }
642    }
643}
644
645#[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
646pub struct RangeType {
647    from: Box<TypeInternal>,
648    to: Option<Box<TypeInternal>>,
649}
650
651impl Display for InferredNumber {
652    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
653        let type_name = TypeName::from(self);
654        write!(f, "{}", type_name)
655    }
656}
657
658impl From<&AnalysedType> for InferredType {
659    fn from(analysed_type: &AnalysedType) -> Self {
660        match analysed_type {
661            AnalysedType::Bool(_) => InferredType::bool(),
662            AnalysedType::S8(_) => InferredType::s8(),
663            AnalysedType::U8(_) => InferredType::u8(),
664            AnalysedType::S16(_) => InferredType::s16(),
665            AnalysedType::U16(_) => InferredType::u16(),
666            AnalysedType::S32(_) => InferredType::s32(),
667            AnalysedType::U32(_) => InferredType::u32(),
668            AnalysedType::S64(_) => InferredType::s64(),
669            AnalysedType::U64(_) => InferredType::u64(),
670            AnalysedType::F32(_) => InferredType::f32(),
671            AnalysedType::F64(_) => InferredType::f64(),
672            AnalysedType::Chr(_) => InferredType::char(),
673            AnalysedType::Str(_) => InferredType::string(),
674            AnalysedType::List(t) => InferredType::list(t.inner.as_ref().into()),
675            AnalysedType::Tuple(ts) => {
676                InferredType::tuple(ts.items.iter().map(|t| t.into()).collect())
677            }
678            AnalysedType::Record(fs) => InferredType::record(
679                fs.fields
680                    .iter()
681                    .map(|name_type| (name_type.name.clone(), (&name_type.typ).into()))
682                    .collect(),
683            ),
684            AnalysedType::Flags(vs) => InferredType::flags(vs.names.clone()),
685            AnalysedType::Enum(vs) => InferredType::from_enum_cases(vs),
686            AnalysedType::Option(t) => InferredType::option(t.inner.as_ref().into()),
687            AnalysedType::Result(golem_wasm_ast::analysis::TypeResult { ok, err, .. }) => {
688                InferredType::result(
689                    ok.as_ref().map(|t| t.as_ref().into()),
690                    err.as_ref().map(|t| t.as_ref().into()),
691                )
692            }
693            AnalysedType::Variant(vs) => InferredType::from_type_variant(vs),
694            AnalysedType::Handle(golem_wasm_ast::analysis::TypeHandle { resource_id, mode }) => {
695                InferredType::resource(
696                    resource_id.0,
697                    match mode {
698                        AnalysedResourceMode::Owned => 0,
699                        AnalysedResourceMode::Borrowed => 1,
700                    },
701                )
702            }
703        }
704    }
705}