roto 0.10.0

a statically-typed, compiled, embedded scripting language
Documentation
//! Information generated by the type checker

use std::{collections::HashMap, fmt::Debug};

use crate::{
    ast::Identifier,
    ice,
    parser::meta::MetaId,
    runtime::{
        Rt, RuntimeFunctionRef,
        layout::{Layout, LayoutBuilder},
    },
    value::ErasedList,
};

use super::{
    expr::ResolvedPath,
    scope::{
        DeclarationKind, ResolvedName, ScopeGraph, ScopeRef, TypeOrStub,
    },
    types::{
        Function, IntKind, IntSize, Primitive, Signature, Type,
        TypeDefinition, TypeName,
    },
    unionfind::UnionFind,
};

/// The output of the type checker that is used for lowering
#[derive(Clone, Default)]
pub struct TypeInfo {
    /// The unionfind structure that maps type variables to types
    pub(super) unionfind: UnionFind,

    /// All declarations in the program, extracted from the scope graph
    pub scope_graph: ScopeGraph,

    /// Map from type names to types
    pub(super) types: HashMap<ResolvedName, TypeDefinition>,

    /// The types we inferred for each Expr
    ///
    /// This might not be fully resolved yet.
    pub(super) expr_types: HashMap<MetaId, Type>,

    /// The fully qualified (and hence unique) name for each identifier.
    pub(super) resolved_names: HashMap<MetaId, ResolvedName>,

    /// Scopes of functions
    pub(super) function_scopes: HashMap<MetaId, ScopeRef>,

    /// The function that is called on each function call
    pub(super) function_calls: HashMap<MetaId, Function>,

    pub(super) function_signatures: HashMap<MetaId, Signature>,

    pub(super) runtime_function_signatures:
        HashMap<RuntimeFunctionRef, Signature>,

    /// The ids of all the `Expr::Access` nodes that should be interpreted
    /// as enum variant constructors.
    pub(super) path_kinds: HashMap<MetaId, ResolvedPath>,

    pub(super) diverges: HashMap<MetaId, bool>,

    /// Type for return/accept/reject that it constructs and returns.
    pub(super) return_types: HashMap<MetaId, Type>,

    pub(super) type_ids: HashMap<Type, usize>,
}

impl TypeInfo {
    pub fn new() -> Self {
        Self::default()
    }
}

impl TypeInfo {
    pub fn type_id(&mut self, ty: &Type) -> usize {
        let ty = self.resolve(ty);
        let len = self.type_ids.len();
        *self.type_ids.entry(ty).or_insert(len + 1)
    }

    pub fn resolved_name(
        &self,
        x: impl Into<MetaId> + Debug,
    ) -> ResolvedName {
        self.resolved_names[&x.into()]
    }

    pub fn function_signature(&self, x: impl Into<MetaId>) -> Signature {
        self.function_signatures[&x.into()].clone()
    }

    pub fn type_of(&mut self, x: impl Into<MetaId> + Debug) -> Type {
        let ty = self.expr_types[&x.into()].clone();
        self.resolve(&ty)
    }

    pub fn diverges(&mut self, x: impl Into<MetaId>) -> bool {
        self.diverges[&x.into()]
    }

    pub fn return_type_of(&mut self, x: impl Into<MetaId>) -> Type {
        let ty = self.return_types[&x.into()].clone();
        self.resolve(&ty)
    }

    pub fn function(&self, x: impl Into<MetaId>) -> &Function {
        &self.function_calls[&x.into()]
    }

    pub fn function_scope(&self, x: impl Into<MetaId>) -> ScopeRef {
        self.function_scopes[&x.into()]
    }

    pub fn path_kind(&self, x: impl Into<MetaId>) -> &ResolvedPath {
        &self.path_kinds[&x.into()]
    }

    pub fn full_name(&self, name: &ResolvedName) -> Identifier {
        let mut s = self.scope_graph.print_scope(name.scope);
        s.push('.');
        s.push_str(name.ident.as_str());
        s.into()
    }

    pub fn runtime_function_signature(
        &self,
        func_ref: RuntimeFunctionRef,
    ) -> Signature {
        self.runtime_function_signatures
            .get(&func_ref)
            .unwrap()
            .clone()
    }

    pub fn is_numeric_type(&mut self, ty: &Type) -> bool {
        let ty = self.resolve(ty);
        match ty {
            Type::FloatVar(_) | Type::IntVar(_, _) => true,
            Type::Name(name) => {
                let type_def = self.resolve_type_name(&name);
                type_def.is_float() || type_def.is_int()
            }
            _ => false,
        }
    }

    pub fn is_float_type(&mut self, ty: &Type) -> bool {
        let ty = self.resolve(ty);
        match ty {
            Type::FloatVar(_) => true,
            Type::Name(name) => {
                let type_def = self.resolve_type_name(&name);
                type_def.is_float()
            }
            _ => false,
        }
    }

    pub fn is_int_type(&mut self, ty: &Type) -> bool {
        self.get_int_type(ty).is_some()
    }

    pub fn get_int_type(&mut self, ty: &Type) -> Option<(IntKind, IntSize)> {
        let ty = self.resolve(ty);
        match ty {
            Type::IntVar(_, _) => Some((IntKind::Signed, IntSize::I32)),
            Type::Name(name) => {
                let type_def = self.resolve_type_name(&name);
                if let TypeDefinition::Primitive(Primitive::Int(kind, size)) =
                    type_def
                {
                    Some((kind, size))
                } else {
                    None
                }
            }
            _ => None,
        }
    }

    pub fn is_asn_type(&mut self, ty: &Type) -> bool {
        let ty = self.resolve(ty);
        match ty {
            Type::Name(name) => {
                let type_def = self.resolve_type_name(&name);
                matches!(type_def, TypeDefinition::Primitive(Primitive::Asn))
            }
            _ => false,
        }
    }

    pub fn is_list_type(&mut self, ty: &Type) -> bool {
        let ty = self.resolve(ty);
        match ty {
            Type::Name(name) => {
                let type_def = self.resolve_type_name(&name);
                matches!(type_def, TypeDefinition::List(_))
            }
            _ => false,
        }
    }

    /// Whether or not the type is passed around by reference or by value
    ///
    /// Roto always has by-value semantics, but we still have types that we
    /// store in stack slots and then operate on by pointer. That is what
    /// we mean here with a reference type.
    ///
    /// Registered types, enums, records, ip addrs, prefixes and strings are all
    /// reference types. Integers, floats, booleans and AS numbers are not.
    ///
    /// This returns `None` if the type is uninhabited (e.g. `!`)
    pub(crate) fn is_reference_type(
        &mut self,
        ty: &Type,
        rt: &Rt,
    ) -> Option<bool> {
        let ty = self.resolve(ty);
        if self.layout_of(&ty, rt)?.size() == 0 {
            return Some(false);
        }
        match ty {
            Type::Record(..) | Type::RecordVar(..) => Some(true),
            Type::Name(name) => {
                let type_def = self.resolve_type_name(&name);
                let is_ref = matches!(
                    type_def,
                    TypeDefinition::Enum(..)
                        | TypeDefinition::List(..)
                        | TypeDefinition::Record(..)
                        | TypeDefinition::Runtime(..)
                        | TypeDefinition::Primitive(
                            Primitive::IpAddr
                                | Primitive::Prefix
                                | Primitive::String,
                        )
                );
                Some(is_ref)
            }
            _ => Some(false),
        }
    }

    pub fn resolve_type_name(&mut self, ty: &TypeName) -> TypeDefinition {
        let name = ty.name;
        let dec = self.scope_graph.get_declaration(name);
        let DeclarationKind::Type(TypeOrStub::Type(ty)) = dec.kind else {
            ice!()
        };
        ty
    }

    /// Compute the layout of a Roto type
    ///
    /// The layout of Roto types match the C representation of Rust types,
    /// because we cannot rely on the Rust representation.
    ///
    /// The C representation is described in the [Rust reference].
    ///
    /// The general rules are as follows:
    ///
    ///  - The minimum layout of any type is a size of 0 and an alignment of 1
    ///  - Each primitive has a size and alignment equal to itself.
    ///  - Each composite type has the alignment of the most-aligned field in it.
    ///  - Fields are laid out in order, each padded to their alignment.
    ///  - The size **must** be a multiple of the alignment.
    ///
    /// For enums we use the `#[repr(C, u8)]` representation, because other the
    /// other representations are platform-specific. This means that the tag for
    /// enums is a `u8` and therefore 1 byte.
    ///
    /// To implement these rules, we rely on the [`Layout`] struct from the Rust
    /// standard library. This also allows to get the layout of some Rust types
    /// we rely on.
    ///
    /// This function returns `None` if the type is uninhabited.
    ///
    /// [Rust reference]: https://doc.rust-lang.org/reference/type-layout.html
    pub(crate) fn layout_of(&mut self, ty: &Type, rt: &Rt) -> Option<Layout> {
        let ty = self.resolve(ty);
        let layout = match ty {
            Type::ExplicitVar(_) => {
                ice!("Can't get the layout of an unconcrete type: {:?}", ty)
            }
            Type::Function(_, _) => {
                ice!("Can't get the layout of a function type")
            }
            Type::Unit => Layout::new(0, 1),
            Type::Var(_) | Type::Never => return None,
            Type::IntVar(_, _) => Primitive::i32().layout(),
            Type::FloatVar(_) => Primitive::f64().layout(),
            Type::RecordVar(_, fields) | Type::Record(fields) => {
                let layouts = fields
                    .iter()
                    .map(|(_, f)| self.layout_of(f, rt))
                    .collect::<Option<Vec<_>>>()?;
                Layout::concat(layouts)
            }
            Type::Name(type_name) => {
                let type_def = self.resolve_type_name(&type_name);
                match type_def {
                    TypeDefinition::List(_) => Layout::of::<ErasedList>(),
                    TypeDefinition::Enum(type_constructor, variants) => {
                        let subs: Vec<_> = type_constructor
                            .arguments
                            .iter()
                            .zip(&type_name.arguments)
                            .collect();

                        let mut layout = None;
                        for variant in &variants {
                            let mut builder = LayoutBuilder::new();
                            builder.add(&Layout::of::<u8>());

                            let builder = variant.fields.iter().try_fold(
                                builder,
                                |mut b, t| {
                                    let t = t.substitute_many(&subs);
                                    let layout = self.layout_of(&t, rt)?;
                                    b.add(&layout);
                                    Some(b)
                                },
                            );

                            // If the variant contains uninhabited fields, the
                            // entire variant is uninhabited, so we don't need
                            // to consider it.
                            let Some(builder) = builder else {
                                continue;
                            };

                            let variant_layout = builder.finish();

                            layout = Some(layout.map_or(
                                variant_layout.clone(),
                                |l: Layout| l.union(&variant_layout),
                            ));
                        }

                        layout?
                    }
                    TypeDefinition::Record(type_constructor, fields) => {
                        let subs: Vec<_> = type_constructor
                            .arguments
                            .iter()
                            .zip(&type_name.arguments)
                            .collect();

                        // If any of the fields of the record are uninhabited
                        // the entire record is uninhabited.
                        let mut builder = LayoutBuilder::new();
                        for (_, t) in fields {
                            let t = t.substitute_many(&subs);
                            builder.add(&self.layout_of(&t, rt)?);
                        }
                        builder.finish()
                    }
                    TypeDefinition::Runtime(_, type_id) => {
                        rt.get_runtime_type(type_id).unwrap().layout()
                    }
                    TypeDefinition::Primitive(primitive) => {
                        primitive.layout()
                    }
                }
            }
        };
        Some(layout)
    }

    pub fn resolve_ref<'a>(&'a self, mut t: &'a Type) -> &'a Type {
        if let Type::Var(x)
        | Type::IntVar(x, _)
        | Type::FloatVar(x)
        | Type::RecordVar(x, _) = t
        {
            t = self.unionfind.find_ref(*x);
        }

        t
    }

    pub fn resolve(&mut self, t: &Type) -> Type {
        let mut t = t.clone();

        if let Type::Var(x)
        | Type::IntVar(x, _)
        | Type::RecordVar(x, _)
        | Type::FloatVar(x) = t
        {
            t = self.unionfind.find(x).clone();
        }

        t
    }
}