lutra-compiler 0.6.0

Compiler for Lutra query language
Documentation
use lutra_bin::ident;

use crate::pr;
use crate::printer::common::{PrintSourceExt, Separated};
use crate::printer::{PrintSource, Printer};

impl PrintSource for pr::Ty {
    #[tracing::instrument(name = "t", skip_all)]
    fn print<'c>(&self, p: &mut Printer<'c>) -> Option<()> {
        tracing::trace!("ty {}", self.kind.as_ref());

        match &self.kind {
            pr::TyKind::Ident(ident) => {
                if ident.starts_with_part("std") {
                    for (i, s) in ident.as_steps()[1..].iter().enumerate() {
                        if i > 0 {
                            p.push("::")?;
                        }
                        p.push(lutra_bin::ident::display(s))?;
                    }
                } else {
                    p.push(ident.to_string())?;
                }
            }
            pr::TyKind::Primitive(prim) => p.push(prim.to_string())?,
            pr::TyKind::Tuple(fields) => {
                return Separated {
                    nodes: fields,
                    sep_inline: ", ",
                    sep_line_end: ",",
                }
                .between("{", "}", self.span)
                .single_line_end(",")
                .print(p);
            }
            pr::TyKind::Array(item) => {
                return item.as_ref().between("[", "]", self.span).print(p);
            }
            pr::TyKind::Option(inner) => {
                inner.print(p)?;
                p.push("?")?;
            }
            pr::TyKind::Enum(variants) => {
                p.push("enum ")?;

                Separated {
                    nodes: variants,
                    sep_inline: ", ",
                    sep_line_end: ",",
                }
                .between("{", "}", self.span)
                .single_line_end(",")
                .print(p)?;
            }
            pr::TyKind::Func(func) => return print_ty_func(func, None, p),
            pr::TyKind::TupleComprehension(comp) => {
                p.push("{for ")?;
                p.push(ident::display(&comp.variable_name))?;
                p.push(": ")?;
                p.push(ident::display(&comp.variable_ty))?;
                p.push(" in ")?;
                comp.tuple.print(p)?;
                p.push(" do ")?;
                if let Some(name) = &comp.body_name {
                    p.push(ident::display(name))?;
                    p.push(": ")?;
                }
                comp.body_ty.print(p)?;
                p.push("}")?;
            }
        };
        Some(())
    }

    fn span(&self) -> Option<crate::Span> {
        self.span
    }
}

pub(super) fn print_ty_func<'c>(
    func: &pr::TyFunc,
    name: Option<&str>,
    p: &mut Printer<'c>,
) -> Option<()> {
    p.push("func ")?;
    if let Some(name) = name {
        p.push(ident::display(name))?;
    }

    Separated {
        nodes: &func.params,
        sep_inline: ", ",
        sep_line_end: ",",
    }
    .between("(", ")", None)
    .print(p)?;

    if let Some(return_ty) = &func.body {
        p.push(": ")?;
        return_ty.print(p)?;
    }
    if !func.ty_params.is_empty() {
        p.new_line();
        p.push("where ")?;

        p.indent();

        Separated {
            nodes: &func.ty_params,
            sep_inline: ", ",
            sep_line_end: ",",
        }
        .print(p)?;

        p.dedent();
    }

    Some(())
}

impl PrintSource for pr::TyFuncParam {
    fn print<'c>(&self, p: &mut Printer<'c>) -> Option<()> {
        if self.constant {
            p.push("const ")?;
        }

        if let Some(label) = &self.label {
            p.push(ident::display(label))?;
            p.push(": ")?;
        }

        if let Some(ty) = &self.ty {
            ty.print(p)?;
        }

        Some(())
    }

    fn span(&self) -> Option<crate::Span> {
        self.ty.as_ref().and_then(|t| t.span)
    }
}

impl PrintSource for pr::TyTupleField {
    fn print<'c>(&self, p: &mut Printer<'c>) -> Option<()> {
        if let Some(name) = &self.name {
            p.push(ident::display(name))?;
            p.push(": ")?;
        }
        if self.unpack {
            p.push("..")?;
        }
        self.ty.print(p)
    }

    fn span(&self) -> Option<crate::Span> {
        self.ty.span
    }
}
impl PrintSource for pr::TyEnumVariant {
    fn print<'c>(&self, p: &mut Printer<'c>) -> Option<()> {
        p.push(ident::display(&self.name))?;

        let is_unit = self.ty.kind.as_tuple().is_some_and(|f| f.is_empty());
        if is_unit {
            return Some(());
        }

        p.push(": ")?;
        self.ty.print(p)
    }

    fn span(&self) -> Option<crate::Span> {
        self.ty.span
    }
}

impl PrintSource for pr::TyParam {
    fn print<'c>(&self, p: &mut Printer<'c>) -> Option<()> {
        p.push(ident::display(&self.name))?;
        match &self.domain {
            pr::TyDomain::Open => {}
            pr::TyDomain::OneOf(tys) => {
                p.push(": ")?;
                let (shorthand, remaining) = compact_domain_shorthand(tys);
                if let Some(s) = shorthand {
                    p.push(s)?;
                }
                for (i, ty) in remaining.iter().enumerate() {
                    if i > 0 || shorthand.is_some() {
                        p.push(" | ")?;
                    }
                    ty.print(p)?;
                }
            }
            pr::TyDomain::TupleHasFields(fields) => {
                p.push(": {")?;
                for (i, field) in fields.iter().enumerate() {
                    match &field.location {
                        pr::Lookup::Name(name) => {
                            p.push(ident::display(name))?;
                            p.push(": ")?;
                        }
                        pr::Lookup::Position(p) => {
                            assert_eq!(i, *p as usize); // TODO: print these fields when they are out of order
                        }
                    }
                    field.ty.print(p)?;
                    p.push(", ")?;
                }
                p.push("..}")?;
            }
            pr::TyDomain::TupleLen { n } => {
                // not an actual Lutra syntax (there is no syntax for this)
                p.push(": {.. ")?;
                p.push(n.to_string())?;
                p.push(" ..}")?;
            }
            pr::TyDomain::EnumVariants(variants) => {
                p.push(": enum {")?;
                for (i, variant) in variants.iter().enumerate() {
                    if i > 0 {
                        p.push(", ")?;
                    }
                    p.push(ident::display(&variant.name))?;
                    if variant.ty.kind.as_tuple().is_some_and(|f| f.is_empty()) {
                        p.push(": ")?;
                        variant.ty.print(p)?;
                    }
                }
                p.push(", ..}")?;
            }
        }
        Some(())
    }

    fn span(&self) -> Option<crate::Span> {
        self.span
    }
}

fn compact_domain_shorthand(tys: &[pr::Ty]) -> (Option<&'static str>, Vec<&pr::Ty>) {
    use crate::resolver;
    if let Some(remaining) = remove_shorthand(tys, resolver::DOMAIN_ANY_NUMBER) {
        return (Some("AnyNumber"), remaining);
    }
    (None, tys.iter().collect())
}

fn remove_shorthand<'t>(tys: &'t [pr::Ty], names: &[&str]) -> Option<Vec<&'t pr::Ty>> {
    let mut remaining: Vec<&_> = tys.iter().collect();
    for name in names {
        // HACK: comparing just the last ident is error-prone. But we want to support
        // both unresolved and resolved ASTs, as well as std and non-std namespaces.
        let removed = remaining
            .extract_if(.., |t| t.kind.as_ident().is_some_and(|p| p.last() == *name))
            .count();
        if removed == 0 {
            return None;
        }
    }
    Some(remaining)
}