rsheet_lib 0.2.0

Libraries to help implementing cs6991-24T1-ass2
Documentation
//! This module supports the evaluation of a single cell expression using
//! Rhaiscript using the `CellExpr::evaluate` function.

use once_cell::sync::Lazy;
use std::thread;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::sync::Mutex;

use rhai::{ASTNode, Dynamic, Engine, EvalAltResult, Expr, ParseError, Scope, AST};
use std::{collections::{HashSet, HashMap}, ops::Deref, time::Duration};

use crate::cell_value::CellValue;

static ACCESSOR_THREADS: Lazy<Mutex<HashSet<thread::ThreadId>>> = Lazy::new(|| HashSet::new().into());


/// The `CellArgument` enum represents the inputs to the calculation of a cell.
/// These are generally represented by a cell range. For example, `A3` indicates
/// a single cell; `A3:A5` represents a vector of values, and `A3:C5` represents
/// a matrix of values.
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
#[serde(untagged)]
pub enum CellArgument {
    Matrix(Vec<Vec<CellValue>>),
    Vector(Vec<CellValue>),
    Value(CellValue),
}

/// Each cell name is assigned to a cell expression. Evaluating this expression
/// gives the value (`CellValue`) of the cell at the current point in time,
/// i.e. given the current values of all other cells this cell depends on.
pub struct CellExpr {
    engine: Engine,
    ast: Result<AST, ParseError>,
}

#[derive(Debug, PartialEq, Eq)]
pub enum CellExprEvalError {
    /// The cell expression couldn't be evaluated because it contains a variable
    /// whose value is an error.
    VariableDependsOnError,
}

impl CellExpr {
    pub fn new(expr: &str) -> Self {
        let engine = Engine::new();
        let ast = engine.compile_expression(expr);

        Self { engine, ast }
    }

    /// Finds all variable names within a provided cell expression.
    ///
    /// # Example
    ///
    /// ```
    /// # use rsheet_lib::cell_expr::CellExpr;
    /// let cell_expr = CellExpr::new("A1 + A2 * A3_A4");
    /// assert_eq!(cell_expr.find_variable_names(), vec!["A1", "A2", "A3_A4"]);
    /// ```
    pub fn find_variable_names(&self) -> Vec<String> {
        static RE: Lazy<Regex> =
            Lazy::new(|| Regex::new(r"^[A-Z]+[0-9]+(_[A-Z]+[0-9]+)?$").unwrap());

        let mut variables = Vec::new();
        if let Ok(ast) = &self.ast {
            ast.walk(&mut |nodes| {
                for node in nodes {
                    match node {
                        ASTNode::Expr(Expr::Variable(variable, _, _)) => {
                            let (_var_index, _namespace, _namespace_hash, name) = variable.deref();

                            if RE.is_match(name) {
                                variables.push(name.to_string());
                            }
                        }
                        _ => {}
                    }
                }

                true
            });
        }

        variables
    }

    /// This function checks that we've only been called from a maximum of a certain
    /// number of threads. We've hard coded that number to 6, since there's one test case
    /// with 5 seperate threads; and we want to allow for one more thread for the worker thread.
    fn check_thread_allowed() {
        let mut accessors = ACCESSOR_THREADS.lock().unwrap();
        let current_thread = thread::current().id();
        if accessors.len() > 6 {
            eprintln!("Too many threads accessing the cell expressions.");
        }
        accessors.insert(current_thread);
    }

    /// Evaluate a Rhai expression, given a mapping of all cell argument names to cell argument values.
    ///
    /// # Example
    ///
    /// ```
    /// # use std::collections::HashMap;
    /// # use rsheet_lib::cell_expr::{CellExpr, CellArgument};
    /// # use rsheet_lib::cell_value::CellValue;
    /// assert_eq!(
    ///     CellExpr::new("A1 + A2 * sum(A3_A4)").evaluate(&HashMap::from([
    ///         ("A1".to_string(), CellArgument::Value(CellValue::Int(5))),
    ///         ("A2".to_string(), CellArgument::Value(CellValue::Int(10))),
    ///         ("A3_A4".to_string(), CellArgument::Vector(vec![CellValue::Int(7), CellValue::Int(9)])),
    ///     ])),
    ///     Ok(CellValue::Int(165)),
    /// );
    /// ```
    pub fn evaluate(
        mut self,
        variables: &HashMap<String, CellArgument>,
    ) -> Result<CellValue, CellExprEvalError> {
        Self::check_thread_allowed();
        // Check if any variable values are errors.
        let any_error_variables =
            variables
                .iter()
                .any(|(_variable, cell_argument)| match cell_argument {
                    CellArgument::Value(cell_value) => cell_value.is_error(),
                    CellArgument::Vector(vector) => vector.iter().any(CellValue::is_error),
                    CellArgument::Matrix(matrix) => {
                        matrix.iter().flatten().any(CellValue::is_error)
                    }
                });
        if any_error_variables {
            return Err(CellExprEvalError::VariableDependsOnError);
        }

        let ast = match self.ast.as_mut() {
            Ok(ast) => ast,
            Err(e) => return Ok(CellValue::Error(e.to_string())),
        };
        let mut scope = Scope::new();

        for (name, value) in variables {
            let value = rhai::serde::to_dynamic(value);
            match value {
                Ok(value) => {
                    scope.push(name, value);
                }
                Err(_) => {
                    return Ok(CellValue::Error(format!(
                        "Unable to convert value {value:?} to Rhai."
                    )))
                }
            }
        }

        fn summer(vector: Vec<Dynamic>) -> Result<i64, Box<EvalAltResult>> {
            let mut total = 0;
            for item in vector {
                if let Ok(i) = item.as_int() {
                    total += i;
                } else if let Ok(l) = item.clone().into_array() {
                    total += summer(l)?;
                } else {
                    return Err(format!("Unknown value: {:?}", item).into());
                }
            }
            Ok(total)
        }

        /// millis is i64 for rhai compatibility
        fn sleep_then(millis: i64, value: Dynamic) -> Dynamic {
            std::thread::sleep(Duration::from_millis(millis as u64));
            value
        }

        self.engine.register_fn("sum", summer);
        self.engine.register_fn("sleep_then", sleep_then);

        let result = self
            .engine
            .eval_ast_with_scope::<rhai::Dynamic>(&mut scope, ast);

        Ok(match result {
            Ok(d) => match rhai::serde::from_dynamic(&d) {
                Ok(ret) => ret,
                Err(_) => CellValue::Error(String::from(
                    "Could not cast Rhai return back to Cell Value.",
                )),
            },
            Err(e) => CellValue::Error(e.to_string()),
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_run_cell_none() {
        let result = CellExpr::new("()").evaluate(&HashMap::new());

        assert_eq!(result, Ok(CellValue::None));
    }

    #[test]
    fn test_run_cell_vector() {
        let vector = CellArgument::Vector(vec![
            CellValue::Int(1),
            CellValue::Int(2),
            CellValue::Int(3),
        ]);
        let result =
            CellExpr::new("sum(A1_A3)").evaluate(&HashMap::from([("A1_A3".to_string(), vector)]));

        assert_eq!(result, Ok(CellValue::Int(6)));
    }

    #[test]
    fn test_run_cell_value() {
        let values = HashMap::from([
            ("A1".to_string(), CellArgument::Value(CellValue::Int(1))),
            ("A2".to_string(), CellArgument::Value(CellValue::Int(2))),
            ("A3".to_string(), CellArgument::Value(CellValue::Int(3))),
        ]);
        let result = CellExpr::new("A1 + A2 + A3").evaluate(&values);

        assert_eq!(
            CellExpr::new("A1 + A2 + A3").find_variable_names(),
            vec!["A1".to_string(), "A2".to_string(), "A3".to_string()]
        );

        assert_eq!(result, Ok(CellValue::Int(6)));
    }

    #[test]
    fn test_run_cell_matrix() {
        let matrix = CellArgument::Matrix(vec![
            vec![CellValue::Int(1), CellValue::Int(2)],
            vec![CellValue::Int(3), CellValue::Int(4)],
        ]);
        let result =
            CellExpr::new("sum(A1_B2)").evaluate(&HashMap::from([("A1_B2".to_string(), matrix)]));

        assert_eq!(result, Ok(CellValue::Int(10)));
    }

    #[test]
    fn test_run_cell_error() {
        let result = CellExpr::new("asdf").evaluate(&HashMap::new());
        assert!(matches!(result, Ok(CellValue::Error(_))));
    }

    #[test]
    fn test_depend_on_error() {
        let values = HashMap::from([
            ("A1".to_string(), CellArgument::Value(CellValue::Int(1))),
            (
                "A2".to_string(),
                CellArgument::Value(CellValue::Error("some existing error".to_string())),
            ),
        ]);
        let result = CellExpr::new("A1 + A2").evaluate(&values);
        assert!(matches!(
            result,
            Err(CellExprEvalError::VariableDependsOnError)
        ));
    }

    #[test]
    fn test_depend_on_error_vector() {
        let values = HashMap::from([
            ("A1".to_string(), CellArgument::Value(CellValue::Int(1))),
            (
                "A2_A3".to_string(),
                CellArgument::Vector(vec![
                    CellValue::Int(10),
                    CellValue::Error("some existing error".to_string()),
                ]),
            ),
        ]);
        let result = CellExpr::new("A1 + sum(A2_A3)").evaluate(&values);
        assert!(matches!(
            result,
            Err(CellExprEvalError::VariableDependsOnError)
        ));
    }

    #[test]
    fn test_depend_on_error_matrix() {
        let values = HashMap::from([
            ("A1".to_string(), CellArgument::Value(CellValue::Int(1))),
            (
                "A2_B3".to_string(),
                CellArgument::Matrix(vec![
                    vec![
                        CellValue::Int(10),
                        CellValue::Error("some existing error".to_string()),
                    ],
                    vec![CellValue::Int(20), CellValue::Int(50)],
                ]),
            ),
        ]);
        let result = CellExpr::new("A1 + sum(A2_B3)").evaluate(&values);
        assert!(matches!(
            result,
            Err(CellExprEvalError::VariableDependsOnError)
        ));
    }
}