prql-compiler 0.4.2

PRQL is a modern language for transforming data — a simple, powerful, pipelined SQL replacement.
Documentation
use std::cmp::Ordering;
use std::collections::HashSet;

use anyhow::Result;

use crate::ast::pl::*;
use crate::error::{Error, Reason, WithErrorInfo};

use super::Context;

pub fn resolve_type(node: &Expr, context: &Context) -> Result<Ty> {
    if let Some(ty) = &node.ty {
        return Ok(ty.clone());
    }

    Ok(match &node.kind {
        ExprKind::Literal(ref literal) => match literal {
            Literal::Null => Ty::Infer,
            Literal::Integer(_) => TyLit::Integer.into(),
            Literal::Float(_) => TyLit::Float.into(),
            Literal::Boolean(_) => TyLit::Bool.into(),
            Literal::String(_) => TyLit::String.into(),
            Literal::Date(_) => TyLit::Date.into(),
            Literal::Time(_) => TyLit::Time.into(),
            Literal::Timestamp(_) => TyLit::Timestamp.into(),
            Literal::ValueAndUnit(_) => Ty::Infer, // TODO
            Literal::Relation(_) => unreachable!(),
        },

        ExprKind::Ident(_) | ExprKind::Pipeline(_) | ExprKind::FuncCall(_) => Ty::Infer,

        ExprKind::SString(_) => Ty::Infer,
        ExprKind::FString(_) => TyLit::String.into(),
        ExprKind::Range(_) => Ty::Infer, // TODO

        ExprKind::TransformCall(call) => Ty::Table(call.infer_type(context)?),
        ExprKind::List(_) => Ty::Literal(TyLit::List),

        _ => Ty::Infer,
    })
}

#[allow(dead_code)]
fn too_many_arguments(call: &FuncCall, expected_len: usize, passed_len: usize) -> Error {
    let err = Error::new(Reason::Expected {
        who: Some(format!("{}", call.name)),
        expected: format!("{} arguments", expected_len),
        found: format!("{}", passed_len),
    });
    if passed_len >= 2 {
        err.with_help(format!(
            "If you are calling a function, you may want to add parentheses `{} [{:?} {:?}]`",
            call.name, call.args[0], call.args[1]
        ))
    } else {
        err
    }
}

/// Validates that found node has expected type. Returns assumed type of the node.
pub fn validate_type<F>(found: &Expr, expected: &Ty, who: F) -> Result<Ty, Error>
where
    F: FnOnce() -> Option<String>,
{
    let found_ty = found.ty.clone().unwrap();

    // infer
    if let Ty::Infer = expected {
        return Ok(found_ty);
    }
    if let Ty::Infer = found_ty {
        return Ok(if let Ty::Table(_) = expected {
            // inferred tables are needed for table s-strings
            // override the empty frame with frame of a table literal

            let input_name = (found.alias)
                .clone()
                .unwrap_or_else(|| format!("_literal_{}", found.id.unwrap()));

            Ty::Table(Frame {
                inputs: vec![FrameInput {
                    id: found.id.unwrap(),
                    name: input_name.clone(),
                    table: None,
                }],
                columns: vec![FrameColumn::All {
                    input_name,
                    except: HashSet::new(),
                }],
                ..Default::default()
            })
        } else {
            expected.clone()
        });
    }

    let expected_is_above = matches!(
        expected.partial_cmp(&found_ty),
        Some(Ordering::Equal | Ordering::Greater)
    );
    if !expected_is_above {
        let e = Err(Error::new(Reason::Expected {
            who: who(),
            expected: format!("type `{}`", expected),
            found: format!("type `{}`", found_ty),
        })
        .with_span(found.span));
        if matches!(found_ty, Ty::Function(_)) && !matches!(expected, Ty::Function(_)) {
            let func_name = found.kind.as_closure().and_then(|c| c.name.as_ref());
            let to_what = func_name
                .map(|n| format!("to function {n}"))
                .unwrap_or_else(|| "in this function call?".to_string());

            return e.with_help(format!("Have you forgotten an argument {to_what}?"));
        };
        return e;
    }
    Ok(found_ty)
}

pub fn type_of_closure(closure: &Closure) -> TyFunc {
    TyFunc {
        args: closure
            .params
            .iter()
            .map(|a| a.ty.clone().unwrap_or(Ty::Infer))
            .collect(),
        return_ty: Box::new(closure.body_ty.clone().unwrap_or(Ty::Infer)),
    }
}