numbat 1.23.0

A statically typed programming language for scientific computations with first class support for physical dimensions and units.
Documentation
//! Parser for simple quantity literals.
//!
//! This module provides a restricted parser that only accepts simple quantity literals
//! of the form `<number>` or `<number> <unit>`, not arbitrary expressions.

use crate::Type;
use crate::ast::{BinaryOperator, Expression, Statement, UnaryOperator};
use crate::number::Number;
use crate::parser::{ParseError, parse};
use crate::prefix_transformer::Transformer;
use crate::quantity::Quantity;
use crate::typechecker::TypeChecker;
use crate::typechecker::type_scheme::TypeScheme;
use crate::unit::Unit;

/// Error type for parsing quantity literals
#[derive(Debug, Clone, PartialEq)]
pub enum QuantityLiteralError {
    /// Input could not be parsed at all
    ParseError(String),
    /// Input was parsed but doesn't match the expected `<number> [<unit>]` pattern
    InvalidPattern(String),
    /// Name resolution error (unknown unit)
    NameResolutionError(String),
}

impl std::fmt::Display for QuantityLiteralError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            QuantityLiteralError::ParseError(msg) => write!(f, "Parse error: {msg}"),
            QuantityLiteralError::InvalidPattern(msg) => write!(f, "Invalid pattern: {msg}"),
            QuantityLiteralError::NameResolutionError(msg) => {
                write!(f, "Name resolution error: {msg}")
            }
        }
    }
}

/// Parse and evaluate a quantity literal, returning a Quantity and its corresponding type.
///
/// Valid examples:
/// - `1.5`
/// - `-3.14`
/// - `1.5 km`
/// - `1.5 * km`
/// - `100 m`
pub fn parse_quantity_literal(
    input: &str,
    transformer: &Transformer,
    typechecker: &TypeChecker,
    unit_lookup: impl Fn(&str) -> Option<Unit>,
) -> Result<(Quantity, TypeScheme), QuantityLiteralError> {
    let mut expr = parse_quantity_ast(input)?;
    transformer.transform_expression(&mut expr);
    let type_scheme = get_expression_type(&expr, typechecker)?;
    let quantity = evaluate_quantity_expression(&expr, &unit_lookup)?;

    Ok((quantity, type_scheme))
}

/// Infer the type of a quantity literal expression.
///
/// For scalars, returns the scalar (dimensionless) type.
/// For scalar × unit, returns the unit's type scheme.
fn get_expression_type(
    expr: &Expression,
    typechecker: &TypeChecker,
) -> Result<TypeScheme, QuantityLiteralError> {
    match expr {
        Expression::Scalar(_, _) => Ok(TypeScheme::concrete(Type::scalar())),

        Expression::BinaryOperator {
            op: BinaryOperator::Mul,
            rhs,
            ..
        } => {
            if let Expression::UnitIdentifier { name, .. } = rhs.as_ref() {
                typechecker.lookup_identifier_type(name).ok_or_else(|| {
                    QuantityLiteralError::NameResolutionError(format!("Unknown unit: {name}"))
                })
            } else {
                Err(QuantityLiteralError::InvalidPattern(
                    "Expected unit identifier".to_string(),
                ))
            }
        }

        Expression::UnaryOperator {
            op: UnaryOperator::Negate,
            expr: inner,
            ..
        } => get_expression_type(inner, typechecker),

        _ => Err(QuantityLiteralError::InvalidPattern(
            "Unexpected expression type".to_string(),
        )),
    }
}

/// Evaluate a quantity literal expression to produce a Quantity.
fn evaluate_quantity_expression(
    expr: &Expression,
    unit_lookup: &impl Fn(&str) -> Option<Unit>,
) -> Result<Quantity, QuantityLiteralError> {
    match expr {
        // Plain scalar
        Expression::Scalar(_, n) => Ok(Quantity::from_scalar(n.to_f64())),

        // Scalar × Unit
        Expression::BinaryOperator {
            op: BinaryOperator::Mul,
            lhs,
            rhs,
            ..
        } => {
            let scalar = extract_scalar(lhs)?;
            let unit = extract_unit(rhs, unit_lookup)?;
            Ok(Quantity::new(Number::from_f64(scalar), unit))
        }

        // Negation
        Expression::UnaryOperator {
            op: UnaryOperator::Negate,
            expr: inner,
            ..
        } => {
            let quantity = evaluate_quantity_expression(inner, unit_lookup)?;
            Ok(-quantity)
        }

        _ => Err(QuantityLiteralError::InvalidPattern(
            "Unexpected expression type".to_string(),
        )),
    }
}

/// Extract the scalar value from a Scalar expression.
fn extract_scalar(expr: &Expression) -> Result<f64, QuantityLiteralError> {
    match expr {
        Expression::Scalar(_, n) => Ok(n.to_f64()),
        _ => Err(QuantityLiteralError::InvalidPattern(
            "Expected scalar".to_string(),
        )),
    }
}

/// Extract the Unit from a UnitIdentifier expression.
fn extract_unit(
    expr: &Expression,
    unit_lookup: &impl Fn(&str) -> Option<Unit>,
) -> Result<Unit, QuantityLiteralError> {
    match expr {
        Expression::UnitIdentifier { prefix, name, .. } => {
            let base_unit = unit_lookup(name).ok_or_else(|| {
                QuantityLiteralError::NameResolutionError(format!("Unknown unit: {name}"))
            })?;
            Ok(base_unit.with_prefix(*prefix))
        }
        _ => Err(QuantityLiteralError::InvalidPattern(
            "Expected unit identifier".to_string(),
        )),
    }
}

/// Parse a quantity literal and return the expression AST.
///
/// This is a restricted parser that only accepts simple quantity literals,
/// not arbitrary expressions. Returns the expression on success.
///
/// Valid examples:
/// - `1.5`
/// - `-3.14`
/// - `1.5 km`
/// - `1.5 * km`
/// - `100 m`
///
/// Invalid examples (will return an error):
/// - `1 + 2`
/// - `1 km + 500 m`
/// - `100 km/h` (compound units not supported)
/// - `x` (variables not supported)
pub fn parse_quantity_ast(input: &str) -> Result<Expression<'_>, QuantityLiteralError> {
    let input = input.trim();
    if input.is_empty() {
        return Err(QuantityLiteralError::ParseError("Empty input".to_string()));
    }

    let mut statements = parse(input, 0).map_err(|(_stmts, errs): (_, Vec<ParseError>)| {
        let msg = errs
            .iter()
            .map(|e| format!("{:?}", e.kind))
            .collect::<Vec<_>>()
            .join(", ");
        QuantityLiteralError::ParseError(msg)
    })?;

    if statements.len() != 1 {
        return Err(QuantityLiteralError::InvalidPattern(format!(
            "Expected a single expression, got {} statements",
            statements.len()
        )));
    }

    let statement = statements.remove(0);

    let expr = match statement {
        Statement::Expression(expr) => expr,
        _ => {
            return Err(QuantityLiteralError::InvalidPattern(
                "Expected an expression, not a declaration".to_string(),
            ));
        }
    };

    if !is_valid_quantity_literal(&expr) {
        return Err(QuantityLiteralError::InvalidPattern(
            "Expected '<number>' or '<number> <unit>'".to_string(),
        ));
    }

    Ok(expr)
}

fn is_valid_quantity_literal(expr: &Expression) -> bool {
    match expr {
        Expression::Scalar(_, _) => true,

        // Scalar times unit (implicit or explicit multiplication)
        Expression::BinaryOperator {
            op: BinaryOperator::Mul,
            lhs,
            rhs,
            ..
        } if lhs.is_scalar() && rhs.is_identifier() => true,

        // Negation of a valid quantity literal
        Expression::UnaryOperator {
            op: UnaryOperator::Negate,
            expr: inner,
            ..
        } => is_valid_quantity_literal(inner),

        _ => false,
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::ast::ReplaceSpans;
    use insta::assert_snapshot;

    /// Pretty print expression with spans normalized for comparison
    fn pp(expr: &Expression) -> String {
        format!("{:?}", expr.replace_spans())
    }

    #[test]
    fn test_parse_scalar() {
        let expr = parse_quantity_ast("1.5").unwrap();
        assert_snapshot!(pp(&expr), @"Scalar(Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 }, Number(1.5))");
    }

    #[test]
    fn test_parse_negative_scalar() {
        let expr = parse_quantity_ast("-3.14").unwrap();
        assert_snapshot!(pp(&expr), @"UnaryOperator { op: Negate, expr: Scalar(Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 }, Number(3.14)), span_op: Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 } }");
    }

    #[test]
    fn test_parse_scalar_with_unit() {
        let expr = parse_quantity_ast("1.5 km").unwrap();
        assert_snapshot!(pp(&expr), @r#"BinaryOperator { op: Mul, lhs: Scalar(Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 }, Number(1.5)), rhs: Identifier(Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 }, "km"), span_op: Some(Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 }) }"#);
    }

    #[test]
    fn test_parse_negative_scalar_with_unit() {
        let expr = parse_quantity_ast("-100 m").unwrap();
        assert_snapshot!(pp(&expr), @r#"UnaryOperator { op: Negate, expr: BinaryOperator { op: Mul, lhs: Scalar(Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 }, Number(100.0)), rhs: Identifier(Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 }, "m"), span_op: Some(Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 }) }, span_op: Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 } }"#);
    }

    #[test]
    fn test_parse_integer() {
        let expr = parse_quantity_ast("42").unwrap();
        assert_snapshot!(pp(&expr), @"Scalar(Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 }, Number(42.0))");
    }

    #[test]
    fn test_parse_double_negation() {
        let expr = parse_quantity_ast("--42").unwrap();
        assert_snapshot!(pp(&expr), @"UnaryOperator { op: Negate, expr: UnaryOperator { op: Negate, expr: Scalar(Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 }, Number(42.0)), span_op: Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 } }, span_op: Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 } }");
    }

    #[test]
    fn test_parse_triple_negation_with_unit() {
        let expr = parse_quantity_ast("---2 km").unwrap();
        assert_snapshot!(pp(&expr), @r#"UnaryOperator { op: Negate, expr: UnaryOperator { op: Negate, expr: UnaryOperator { op: Negate, expr: BinaryOperator { op: Mul, lhs: Scalar(Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 }, Number(2.0)), rhs: Identifier(Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 }, "km"), span_op: Some(Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 }) }, span_op: Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 } }, span_op: Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 } }, span_op: Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 } }"#);
    }

    #[test]
    fn test_reject_addition() {
        let result = parse_quantity_ast("1 + 2");
        assert!(matches!(
            result,
            Err(QuantityLiteralError::InvalidPattern(_))
        ));
    }

    #[test]
    fn test_reject_compound_units() {
        let result = parse_quantity_ast("100 km/h");
        assert!(matches!(
            result,
            Err(QuantityLiteralError::InvalidPattern(_))
        ));
    }

    #[test]
    fn test_reject_variable() {
        let result = parse_quantity_ast("x");
        assert!(matches!(
            result,
            Err(QuantityLiteralError::InvalidPattern(_))
        ));
    }

    #[test]
    fn test_reject_empty() {
        let result = parse_quantity_ast("");
        assert!(matches!(result, Err(QuantityLiteralError::ParseError(_))));

        let result = parse_quantity_ast("   ");
        assert!(matches!(result, Err(QuantityLiteralError::ParseError(_))));
    }

    #[test]
    fn test_parse_explicit_multiplication() {
        let expr = parse_quantity_ast("1.5 * km").unwrap();
        assert_snapshot!(pp(&expr), @r#"BinaryOperator { op: Mul, lhs: Scalar(Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 }, Number(1.5)), rhs: Identifier(Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 }, "km"), span_op: Some(Span { start: ByteIndex(0), end: ByteIndex(0), code_source_id: 0 }) }"#);
    }

    #[test]
    fn test_reject_scalar_times_scalar() {
        let result = parse_quantity_ast("2 * 3");
        assert!(matches!(
            result,
            Err(QuantityLiteralError::InvalidPattern(_))
        ));
    }

    #[test]
    fn test_reject_unit_squared() {
        let result = parse_quantity_ast("1 m^2");
        assert!(matches!(
            result,
            Err(QuantityLiteralError::InvalidPattern(_))
        ));
    }
}