use std::sync::{Arc, Mutex};
use ciborium::Value as CborValue;
use rhai::{Dynamic, Engine, EvalAltResult, Map as RhaiMap, Scope};
use vantage_core::{Result, error};
use vantage_types::Record;
use crate::{sort::SortDirection, vista::Vista};
#[derive(Clone)]
pub struct RhaiVista(pub Arc<Mutex<Option<Vista>>>);
impl RhaiVista {
pub fn wrap(vista: Vista) -> Self {
RhaiVista(Arc::new(Mutex::new(Some(vista))))
}
pub fn apply<F>(&self, f: F) -> std::result::Result<RhaiVista, Box<EvalAltResult>>
where
F: FnOnce(&mut Vista) -> Result<()>,
{
with_inner(self, f)
}
}
pub type TargetResolver = Arc<dyn Fn(&str) -> Result<Vista> + Send + Sync>;
pub fn register_conventional_onto(engine: &mut Engine, resolver: TargetResolver) {
engine.register_type_with_name::<RhaiVista>("Vista");
engine.register_fn(
"table",
move |name: &str| -> std::result::Result<RhaiVista, Box<EvalAltResult>> {
let vista = resolver(name).map_err(to_rhai_err)?;
Ok(RhaiVista::wrap(vista))
},
);
engine.register_fn(
"with_id",
|v: &mut RhaiVista, id: Dynamic| -> std::result::Result<RhaiVista, Box<EvalAltResult>> {
let cbor = dynamic_to_cbor(id)?;
with_inner(v, |vista| vista.with_id(cbor).map(|_| ()))
},
);
engine.register_fn(
"add_condition_eq",
|v: &mut RhaiVista,
field: &str,
value: Dynamic|
-> std::result::Result<RhaiVista, Box<EvalAltResult>> {
let cbor = dynamic_to_cbor(value)?;
let field = field.to_string();
with_inner(v, move |vista| vista.add_condition_eq(field, cbor))
},
);
engine.register_fn(
"add_order",
|v: &mut RhaiVista,
column: &str,
dir: &str|
-> std::result::Result<RhaiVista, Box<EvalAltResult>> {
let direction = parse_dir(dir)?;
let column = column.to_string();
with_inner(v, move |vista| vista.add_order(&column, direction))
},
);
engine.register_fn(
"add_order",
|v: &mut RhaiVista, column: &str| -> std::result::Result<RhaiVista, Box<EvalAltResult>> {
let column = column.to_string();
with_inner(v, move |vista| {
vista.add_order(&column, SortDirection::Ascending)
})
},
);
engine.register_fn(
"add_search",
|v: &mut RhaiVista, text: &str| -> std::result::Result<RhaiVista, Box<EvalAltResult>> {
let text = text.to_string();
with_inner(v, move |vista| vista.add_search(text))
},
);
engine.register_fn(
"set_page_size",
|v: &mut RhaiVista, size: i64| -> std::result::Result<RhaiVista, Box<EvalAltResult>> {
if size <= 0 {
return Err("set_page_size: page size must be > 0".into());
}
with_inner(v, move |vista| vista.set_page_size(size as usize))
},
);
engine.register_fn(
"get_ref",
|v: &mut RhaiVista,
relation: &str,
row: RhaiMap|
-> std::result::Result<RhaiVista, Box<EvalAltResult>> {
let record = map_to_record(row)?;
let guard = lock(v)?;
let vista = guard
.as_ref()
.ok_or_else(|| Box::<EvalAltResult>::from("get_ref: vista already consumed"))?;
let target = vista.get_ref(relation, &record).map_err(to_rhai_err)?;
Ok(RhaiVista::wrap(target))
},
);
}
pub fn eval_ref_script(engine: &Engine, code: &str, row: &Record<CborValue>) -> Result<Vista> {
let mut scope = Scope::new();
scope.push_dynamic("row", record_to_dynamic(row));
let result: RhaiVista = engine
.eval_with_scope(&mut scope, code)
.map_err(|e| error!(format!("rhai reference build-script failed: {e}")))?;
result
.0
.lock()
.map_err(|_| error!("rhai reference build-script: result mutex poisoned"))?
.take()
.ok_or_else(|| error!("rhai reference build-script did not return a Vista"))
}
pub fn eval_modify_script(engine: &Engine, code: &str, vista: Vista) -> Result<Vista> {
let handle = RhaiVista::wrap(vista);
let mut scope = Scope::new();
scope.push("self", handle.clone());
engine
.run_with_scope(&mut scope, code)
.map_err(|e| error!(format!("rhai modify script failed: {e}")))?;
handle
.0
.lock()
.map_err(|_| error!("rhai modify script: result mutex poisoned"))?
.take()
.ok_or_else(|| error!("rhai modify script consumed `self`"))
}
type Guard<'a> = std::sync::MutexGuard<'a, Option<Vista>>;
fn lock(v: &RhaiVista) -> std::result::Result<Guard<'_>, Box<EvalAltResult>> {
v.0.lock()
.map_err(|_| Box::<EvalAltResult>::from("RhaiVista mutex poisoned"))
}
fn with_inner<F>(v: &RhaiVista, f: F) -> std::result::Result<RhaiVista, Box<EvalAltResult>>
where
F: FnOnce(&mut Vista) -> Result<()>,
{
{
let mut guard = lock(v)?;
let vista = guard
.as_mut()
.ok_or_else(|| Box::<EvalAltResult>::from("vista already consumed in script"))?;
f(vista).map_err(to_rhai_err)?;
}
Ok(v.clone())
}
fn parse_dir(dir: &str) -> std::result::Result<SortDirection, Box<EvalAltResult>> {
match dir.to_ascii_lowercase().as_str() {
"asc" | "ascending" => Ok(SortDirection::Ascending),
"desc" | "descending" => Ok(SortDirection::Descending),
other => Err(format!("invalid sort direction '{other}' (expected 'asc' or 'desc')").into()),
}
}
fn to_rhai_err(e: vantage_core::VantageError) -> Box<EvalAltResult> {
Box::<EvalAltResult>::from(e.to_string())
}
fn dynamic_to_cbor(d: Dynamic) -> std::result::Result<CborValue, Box<EvalAltResult>> {
if d.is_unit() {
Ok(CborValue::Null)
} else if d.is::<bool>() {
Ok(CborValue::Bool(d.cast::<bool>()))
} else if d.is::<i64>() {
Ok(CborValue::Integer(d.cast::<i64>().into()))
} else if d.is::<f64>() {
Ok(CborValue::Float(d.cast::<f64>()))
} else if d.is::<String>() {
Ok(CborValue::Text(d.cast::<String>()))
} else {
Err(format!(
"cannot convert rhai value of type '{}' into a condition value",
d.type_name()
)
.into())
}
}
fn cbor_to_dynamic(v: &CborValue) -> Dynamic {
match v {
CborValue::Null => Dynamic::UNIT,
CborValue::Bool(b) => Dynamic::from_bool(*b),
CborValue::Integer(i) => {
let n: i128 = (*i).into();
Dynamic::from_int(n as i64)
}
CborValue::Float(f) => Dynamic::from_float(*f),
CborValue::Text(s) => Dynamic::from(s.clone()),
CborValue::Bytes(b) => Dynamic::from_blob(b.clone()),
CborValue::Array(a) => {
let arr: rhai::Array = a.iter().map(cbor_to_dynamic).collect();
Dynamic::from_array(arr)
}
CborValue::Map(m) => {
let mut map = RhaiMap::new();
for (k, val) in m {
if let CborValue::Text(key) = k {
map.insert(key.as_str().into(), cbor_to_dynamic(val));
}
}
Dynamic::from_map(map)
}
_ => Dynamic::UNIT,
}
}
fn record_to_dynamic(row: &Record<CborValue>) -> Dynamic {
let mut map = RhaiMap::new();
for (k, v) in row.iter() {
map.insert(k.as_str().into(), cbor_to_dynamic(v));
}
Dynamic::from_map(map)
}
fn map_to_record(map: RhaiMap) -> std::result::Result<Record<CborValue>, Box<EvalAltResult>> {
let mut out: Vec<(String, CborValue)> = Vec::with_capacity(map.len());
for (k, v) in map {
out.push((k.to_string(), dynamic_to_cbor(v)?));
}
Ok(out.into_iter().collect())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Column, VistaMetadata, mocks::MockShell};
use vantage_dataset::ReadableValueSet;
fn cbor_text(s: &str) -> CborValue {
CborValue::Text(s.into())
}
fn record(pairs: &[(&str, CborValue)]) -> Record<CborValue> {
pairs
.iter()
.map(|(k, v)| ((*k).to_string(), v.clone()))
.collect()
}
fn users_vista() -> Vista {
let source = MockShell::new()
.with_record(
"1",
record(&[
("id", cbor_text("1")),
("name", cbor_text("Alice")),
("vip_flag", CborValue::Bool(true)),
]),
)
.with_record(
"2",
record(&[
("id", cbor_text("2")),
("name", cbor_text("Bob")),
("vip_flag", CborValue::Bool(false)),
]),
)
.with_record(
"3",
record(&[
("id", cbor_text("3")),
("name", cbor_text("Carol")),
("vip_flag", CborValue::Bool(true)),
]),
);
let metadata = VistaMetadata::new()
.with_column(Column::new("id", "String").with_flag("id"))
.with_column(Column::new("name", "String").with_flag("title"))
.with_column(Column::new("vip_flag", "bool"))
.with_id_column("id");
Vista::new("users", Box::new(source.with_metadata(metadata)))
}
fn engine() -> Engine {
let resolver: TargetResolver = Arc::new(|name: &str| {
if name == "users" {
Ok(users_vista())
} else {
Err(error!("unknown table in test resolver", table = name))
}
});
let mut engine = Engine::new();
register_conventional_onto(&mut engine, resolver);
engine
}
#[tokio::test]
async fn script_narrows_target_with_literal_condition() {
let row = record(&[("id", cbor_text("1"))]);
let vista = eval_ref_script(
&engine(),
r#"table("users").add_condition_eq("vip_flag", true)"#,
&row,
)
.unwrap();
let rows = vista.list_values().await.unwrap();
assert_eq!(rows.len(), 2, "only the two VIP rows should survive");
assert!(rows.contains_key("1") && rows.contains_key("3"));
}
#[tokio::test]
async fn script_can_read_the_parent_row() {
let row = record(&[("id", cbor_text("3"))]);
let vista = eval_ref_script(
&engine(),
r#"table("users").add_condition_eq("id", row.id)"#,
&row,
)
.unwrap();
let rows = vista.list_values().await.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows["3"].get("name"), Some(&cbor_text("Carol")));
}
#[tokio::test]
async fn modify_script_tweaks_an_existing_vista() {
let vista = users_vista();
let modified = eval_modify_script(
&engine(),
r#"self.add_condition_eq("vip_flag", true)"#,
vista,
)
.unwrap();
let rows = modified.list_values().await.unwrap();
assert_eq!(rows.len(), 2);
assert!(rows.contains_key("1") && rows.contains_key("3"));
}
#[test]
fn unknown_table_surfaces_resolver_error() {
let row = record(&[]);
let err = match eval_ref_script(&engine(), r#"table("ghosts")"#, &row) {
Ok(_) => panic!("expected the resolver to reject an unknown table"),
Err(e) => e,
};
assert!(err.to_string().contains("unknown table"));
}
}