runmat-vm 0.4.4

RunMat virtual machine and bytecode interpreter
Documentation
use crate::bytecode::{EndExpr, UserFunction};
use crate::interpreter::errors::mex;
use runmat_builtins::Value;
use runmat_runtime::RuntimeError;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;

#[derive(Debug, Clone, Copy)]
pub struct ValueToF64Error;

pub type BuiltinEndCallback<'a> = dyn Fn(
        &'a str,
        Vec<Value>,
    ) -> Pin<Box<dyn Future<Output = Result<Option<Value>, RuntimeError>> + 'a>>
    + 'a;

pub type UserEndCallback<'a> = dyn Fn(
        &'a str,
        Vec<Value>,
        &'a HashMap<String, UserFunction>,
        &'a [Value],
    ) -> Pin<Box<dyn Future<Output = Result<Value, RuntimeError>> + 'a>>
    + 'a;

pub fn value_to_f64(v: &Value) -> Result<f64, ValueToF64Error> {
    match v {
        Value::Num(n) => Ok(*n),
        Value::Int(i) => Ok(i.to_f64()),
        Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
        Value::Tensor(t) if t.data.len() == 1 => Ok(t.data[0]),
        Value::Complex(re, im) if im.abs() < 1e-12 => Ok(*re),
        Value::ComplexTensor(ct) if ct.data.len() == 1 && ct.data[0].1.abs() < 1e-12 => {
            Ok(ct.data[0].0)
        }
        _ => Err(ValueToF64Error),
    }
}

pub fn eval_end_expr_value<'a>(
    expr: &'a EndExpr,
    end_value: f64,
    vars: &'a [Value],
    functions: &'a HashMap<String, UserFunction>,
    call_builtin: &'a BuiltinEndCallback<'a>,
    call_user: &'a UserEndCallback<'a>,
) -> Pin<Box<dyn Future<Output = Result<f64, RuntimeError>> + 'a>> {
    Box::pin(async move {
        match expr {
            EndExpr::End => Ok(end_value),
            EndExpr::Const(v) => Ok(*v),
            EndExpr::Var(i) => {
                let v = vars.get(*i).ok_or_else(|| {
                    mex("MissingNumericIndex", "missing variable for end expression")
                })?;
                value_to_f64(v)
                    .map_err(|_| mex("UnsupportedIndexType", "end expression must be numeric"))
            }
            EndExpr::Call(name, args) => {
                let mut argv: Vec<Value> = Vec::with_capacity(args.len());
                for a in args {
                    let val =
                        eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
                            .await?;
                    argv.push(Value::Num(val));
                }
                let v = if let Some(v) = call_builtin(name, argv.clone()).await? {
                    v
                } else if functions.contains_key(name) {
                    call_user(name, argv, functions, vars).await?
                } else {
                    return Err(mex(
                        "UndefinedFunction",
                        &format!("Undefined function in end expression: {name}"),
                    ));
                };
                value_to_f64(&v)
                    .map_err(|_| mex("UnsupportedIndexType", "end call must return scalar"))
            }
            EndExpr::Add(a, b) => {
                let lhs =
                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
                        .await?;
                let rhs =
                    eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
                        .await?;
                Ok(lhs + rhs)
            }
            EndExpr::Sub(a, b) => {
                let lhs =
                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
                        .await?;
                let rhs =
                    eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
                        .await?;
                Ok(lhs - rhs)
            }
            EndExpr::Mul(a, b) => {
                let lhs =
                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
                        .await?;
                let rhs =
                    eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
                        .await?;
                Ok(lhs * rhs)
            }
            EndExpr::Div(a, b) => {
                let denom =
                    eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
                        .await?;
                if denom == 0.0 {
                    return Err(mex("IndexOutOfBounds", "Index out of bounds"));
                }
                let lhs =
                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
                        .await?;
                Ok(lhs / denom)
            }
            EndExpr::LeftDiv(a, b) => {
                let denom =
                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
                        .await?;
                if denom == 0.0 {
                    return Err(mex("IndexOutOfBounds", "Index out of bounds"));
                }
                let rhs =
                    eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
                        .await?;
                Ok(rhs / denom)
            }
            EndExpr::Pow(a, b) => {
                let lhs =
                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
                        .await?;
                let rhs =
                    eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
                        .await?;
                Ok(lhs.powf(rhs))
            }
            EndExpr::Neg(a) => {
                Ok(
                    -eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
                        .await?,
                )
            }
            EndExpr::Pos(a) => {
                Ok(
                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
                        .await?,
                )
            }
            EndExpr::Floor(a) => {
                Ok(
                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
                        .await?
                        .floor(),
                )
            }
            EndExpr::Ceil(a) => {
                Ok(
                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
                        .await?
                        .ceil(),
                )
            }
            EndExpr::Round(a) => {
                Ok(
                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
                        .await?
                        .round(),
                )
            }
            EndExpr::Fix(a) => {
                let v = eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
                    .await?;
                Ok(if v >= 0.0 { v.floor() } else { v.ceil() })
            }
        }
    })
}

pub async fn resolve_range_end_index<'a>(
    dim_len: usize,
    end_expr: &'a EndExpr,
    vars: &'a [Value],
    functions: &'a HashMap<String, UserFunction>,
    call_builtin: &'a BuiltinEndCallback<'a>,
    call_user: &'a UserEndCallback<'a>,
) -> Result<i64, RuntimeError> {
    let value = eval_end_expr_value(
        end_expr,
        dim_len as f64,
        vars,
        functions,
        call_builtin,
        call_user,
    )
    .await?;
    Ok(value.floor() as i64)
}