use std::collections::HashMap;
use std::hash::Hash;
use tabulon::{Tabula, VarResolveError, VarResolver, Parser, PreparedExpr};
struct U64Resolver {
map: HashMap<String, u64>,
}
impl VarResolver<u64> for U64Resolver {
fn resolve(&self, ident: &str) -> Result<u64, VarResolveError> {
self.map
.get(ident)
.copied()
.ok_or_else(|| VarResolveError::Unknown(ident.to_string()))
}
}
#[test]
fn u64_var_keys_work() {
let mut reg: HashMap<String, u64> = HashMap::new();
reg.insert("A".to_string(), 1);
reg.insert("B".to_string(), 2);
reg.insert("C".to_string(), 3);
let resolver = U64Resolver { map: reg };
let parser = Parser::new("(A + B) * C").unwrap();
let prepared: PreparedExpr<u64> = parser.parse_with_var_resolver(&resolver).unwrap();
let mut eng = Tabula::new();
let compiled = eng.compile_prepared_ref(&prepared).unwrap();
let mut values_by_id: HashMap<u64, f64> = HashMap::new();
values_by_id.insert(1, 100.0);
values_by_id.insert(2, 20.0);
values_by_id.insert(3, 1.5);
let ordered_values: Vec<&f64> = compiled
.vars()
.iter()
.map(|k| values_by_id.get(k).expect("missing value for key"))
.collect();
let out = compiled.eval(&ordered_values).unwrap();
assert!((out - 180.0).abs() < 1e-9);
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
enum VarKey {
Str,
BonusStr,
MulStr,
}
struct EnumResolver;
impl VarResolver<VarKey> for EnumResolver {
fn resolve(&self, ident: &str) -> Result<VarKey, VarResolveError> {
match ident {
"A" => Ok(VarKey::Str),
"B" => Ok(VarKey::BonusStr),
"C" => Ok(VarKey::MulStr),
other => Err(VarResolveError::Invalid(other.to_string())),
}
}
}
#[test]
fn enum_var_keys_work() {
let parser = Parser::new("(A + B) * C").unwrap();
let prepared: PreparedExpr<VarKey> = parser.parse_with_var_resolver(&EnumResolver).unwrap();
let mut eng = Tabula::new();
let compiled = eng.compile_prepared_ref(&prepared).unwrap();
let mut values: HashMap<VarKey, f64> = HashMap::new();
values.insert(VarKey::Str, 100.0);
values.insert(VarKey::BonusStr, 20.0);
values.insert(VarKey::MulStr, 1.5);
let ordered_values: Vec<&f64> = compiled
.vars()
.iter()
.map(|k| values.get(k).expect("missing value for enum key"))
.collect();
let out = compiled.eval(&ordered_values).unwrap();
assert!((out - 180.0).abs() < 1e-9);
}