use once_cell::sync::Lazy;
use std::thread;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::sync::Mutex;
use rhai::{ASTNode, Dynamic, Engine, EvalAltResult, Expr, ParseError, Scope, AST};
use std::{collections::{HashSet, HashMap}, ops::Deref, time::Duration};
use crate::cell_value::CellValue;
static ACCESSOR_THREADS: Lazy<Mutex<HashSet<thread::ThreadId>>> = Lazy::new(|| HashSet::new().into());
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
#[serde(untagged)]
pub enum CellArgument {
Matrix(Vec<Vec<CellValue>>),
Vector(Vec<CellValue>),
Value(CellValue),
}
pub struct CellExpr {
engine: Engine,
ast: Result<AST, ParseError>,
}
#[derive(Debug, PartialEq, Eq)]
pub enum CellExprEvalError {
VariableDependsOnError,
}
impl CellExpr {
pub fn new(expr: &str) -> Self {
let engine = Engine::new();
let ast = engine.compile_expression(expr);
Self { engine, ast }
}
pub fn find_variable_names(&self) -> Vec<String> {
static RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"^[A-Z]+[0-9]+(_[A-Z]+[0-9]+)?$").unwrap());
let mut variables = Vec::new();
if let Ok(ast) = &self.ast {
ast.walk(&mut |nodes| {
for node in nodes {
match node {
ASTNode::Expr(Expr::Variable(variable, _, _)) => {
let (_var_index, _namespace, _namespace_hash, name) = variable.deref();
if RE.is_match(name) {
variables.push(name.to_string());
}
}
_ => {}
}
}
true
});
}
variables
}
fn check_thread_allowed() {
let mut accessors = ACCESSOR_THREADS.lock().unwrap();
let current_thread = thread::current().id();
if accessors.len() > 6 {
eprintln!("Too many threads accessing the cell expressions.");
}
accessors.insert(current_thread);
}
pub fn evaluate(
mut self,
variables: &HashMap<String, CellArgument>,
) -> Result<CellValue, CellExprEvalError> {
Self::check_thread_allowed();
let any_error_variables =
variables
.iter()
.any(|(_variable, cell_argument)| match cell_argument {
CellArgument::Value(cell_value) => cell_value.is_error(),
CellArgument::Vector(vector) => vector.iter().any(CellValue::is_error),
CellArgument::Matrix(matrix) => {
matrix.iter().flatten().any(CellValue::is_error)
}
});
if any_error_variables {
return Err(CellExprEvalError::VariableDependsOnError);
}
let ast = match self.ast.as_mut() {
Ok(ast) => ast,
Err(e) => return Ok(CellValue::Error(e.to_string())),
};
let mut scope = Scope::new();
for (name, value) in variables {
let value = rhai::serde::to_dynamic(value);
match value {
Ok(value) => {
scope.push(name, value);
}
Err(_) => {
return Ok(CellValue::Error(format!(
"Unable to convert value {value:?} to Rhai."
)))
}
}
}
fn summer(vector: Vec<Dynamic>) -> Result<i64, Box<EvalAltResult>> {
let mut total = 0;
for item in vector {
if let Ok(i) = item.as_int() {
total += i;
} else if let Ok(l) = item.clone().into_array() {
total += summer(l)?;
} else {
return Err(format!("Unknown value: {:?}", item).into());
}
}
Ok(total)
}
fn sleep_then(millis: i64, value: Dynamic) -> Dynamic {
std::thread::sleep(Duration::from_millis(millis as u64));
value
}
self.engine.register_fn("sum", summer);
self.engine.register_fn("sleep_then", sleep_then);
let result = self
.engine
.eval_ast_with_scope::<rhai::Dynamic>(&mut scope, ast);
Ok(match result {
Ok(d) => match rhai::serde::from_dynamic(&d) {
Ok(ret) => ret,
Err(_) => CellValue::Error(String::from(
"Could not cast Rhai return back to Cell Value.",
)),
},
Err(e) => CellValue::Error(e.to_string()),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_run_cell_none() {
let result = CellExpr::new("()").evaluate(&HashMap::new());
assert_eq!(result, Ok(CellValue::None));
}
#[test]
fn test_run_cell_vector() {
let vector = CellArgument::Vector(vec![
CellValue::Int(1),
CellValue::Int(2),
CellValue::Int(3),
]);
let result =
CellExpr::new("sum(A1_A3)").evaluate(&HashMap::from([("A1_A3".to_string(), vector)]));
assert_eq!(result, Ok(CellValue::Int(6)));
}
#[test]
fn test_run_cell_value() {
let values = HashMap::from([
("A1".to_string(), CellArgument::Value(CellValue::Int(1))),
("A2".to_string(), CellArgument::Value(CellValue::Int(2))),
("A3".to_string(), CellArgument::Value(CellValue::Int(3))),
]);
let result = CellExpr::new("A1 + A2 + A3").evaluate(&values);
assert_eq!(
CellExpr::new("A1 + A2 + A3").find_variable_names(),
vec!["A1".to_string(), "A2".to_string(), "A3".to_string()]
);
assert_eq!(result, Ok(CellValue::Int(6)));
}
#[test]
fn test_run_cell_matrix() {
let matrix = CellArgument::Matrix(vec![
vec![CellValue::Int(1), CellValue::Int(2)],
vec![CellValue::Int(3), CellValue::Int(4)],
]);
let result =
CellExpr::new("sum(A1_B2)").evaluate(&HashMap::from([("A1_B2".to_string(), matrix)]));
assert_eq!(result, Ok(CellValue::Int(10)));
}
#[test]
fn test_run_cell_error() {
let result = CellExpr::new("asdf").evaluate(&HashMap::new());
assert!(matches!(result, Ok(CellValue::Error(_))));
}
#[test]
fn test_depend_on_error() {
let values = HashMap::from([
("A1".to_string(), CellArgument::Value(CellValue::Int(1))),
(
"A2".to_string(),
CellArgument::Value(CellValue::Error("some existing error".to_string())),
),
]);
let result = CellExpr::new("A1 + A2").evaluate(&values);
assert!(matches!(
result,
Err(CellExprEvalError::VariableDependsOnError)
));
}
#[test]
fn test_depend_on_error_vector() {
let values = HashMap::from([
("A1".to_string(), CellArgument::Value(CellValue::Int(1))),
(
"A2_A3".to_string(),
CellArgument::Vector(vec![
CellValue::Int(10),
CellValue::Error("some existing error".to_string()),
]),
),
]);
let result = CellExpr::new("A1 + sum(A2_A3)").evaluate(&values);
assert!(matches!(
result,
Err(CellExprEvalError::VariableDependsOnError)
));
}
#[test]
fn test_depend_on_error_matrix() {
let values = HashMap::from([
("A1".to_string(), CellArgument::Value(CellValue::Int(1))),
(
"A2_B3".to_string(),
CellArgument::Matrix(vec![
vec![
CellValue::Int(10),
CellValue::Error("some existing error".to_string()),
],
vec![CellValue::Int(20), CellValue::Int(50)],
]),
),
]);
let result = CellExpr::new("A1 + sum(A2_B3)").evaluate(&values);
assert!(matches!(
result,
Err(CellExprEvalError::VariableDependsOnError)
));
}
}