use anyhow::{anyhow, Context, Result};
use serde_json::Value as JsonValue;
use starlark::environment::{FrozenModule, Globals, Module};
use starlark::eval::{Evaluator, ReturnFileLoader};
use starlark::syntax::{AstModule, Dialect};
use starlark::values::dict::AllocDict;
use starlark::values::list::AllocList;
use starlark::values::{Heap, Value};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug)]
pub(crate) struct StarlarkTransforms {
module: FrozenModule,
}
impl StarlarkTransforms {
pub(crate) fn compile(source: &str, filename: &str, base_dir: Option<&Path>) -> Result<Self> {
let mut visiting = Vec::new();
let module = compile_module(source.to_owned(), filename, base_dir, &mut visiting)?;
Ok(StarlarkTransforms { module })
}
pub(crate) fn call(
&self,
name: &str,
value: &JsonValue,
args: &[JsonValue],
) -> Result<Option<JsonValue>> {
let Some(function) = self.module.get_option(name)? else {
return Ok(None);
};
Module::with_temp_heap(|module| {
let heap = module.heap();
let function = heap.access_owned_frozen_value(&function);
let mut positional = Vec::with_capacity(args.len() + 1);
positional.push(json_to_starlark(heap, value)?);
for arg in args {
positional.push(json_to_starlark(heap, arg)?);
}
let mut eval = Evaluator::new(&module);
let result = eval
.eval_function(function, &positional, &[])
.map_err(starlark::Error::into_anyhow)?;
let json = result
.to_json_value()
.map_err(|err| anyhow!("returned a value with no json representation: {err:#}"))?;
Ok(Some(json))
})
}
}
fn compile_module(
source: String,
filename: &str,
base_dir: Option<&Path>,
visiting: &mut Vec<String>,
) -> Result<FrozenModule> {
let ast = AstModule::parse(filename, source, &Dialect::Standard)
.map_err(starlark::Error::into_anyhow)
.with_context(|| format!("parse starlark: {filename}"))?;
let load_ids: Vec<String> = ast
.loads()
.iter()
.map(|load| load.module_id.to_string())
.collect();
let mut loaded: Vec<(String, FrozenModule)> = Vec::new();
for module_id in load_ids {
let Some(base) = base_dir else {
return Err(anyhow!(
"load(\"{module_id}\") in {filename}: load() requires a map spec loaded from a file"
));
};
let load_path = Path::new(&module_id);
if load_path.is_absolute() {
return Err(anyhow!(
"load(\"{module_id}\") in {filename}: absolute paths are not allowed; \
paths resolve relative to the spec file"
));
}
if visiting.iter().any(|seen| seen == &module_id) {
return Err(anyhow!(
"load() cycle: {} -> {module_id}",
visiting.join(" -> ")
));
}
let resolved = base.join(load_path);
let loaded_source = std::fs::read_to_string(&resolved)
.with_context(|| format!("read starlark module: {}", resolved.display()))?;
visiting.push(module_id.clone());
let frozen = compile_module(loaded_source, &module_id, base_dir, visiting)?;
visiting.pop();
loaded.push((module_id, frozen));
}
let modules: HashMap<&str, &FrozenModule> = loaded
.iter()
.map(|(id, frozen)| (id.as_str(), frozen))
.collect();
Module::with_temp_heap(|module| {
let loader = ReturnFileLoader { modules: &modules };
{
let mut eval = Evaluator::new(&module);
eval.set_loader(&loader);
eval.eval_module(ast, &Globals::standard())
.map_err(starlark::Error::into_anyhow)
.with_context(|| format!("evaluate starlark: {filename}"))?;
}
module
.freeze()
.map_err(|err| anyhow!("freeze starlark module {filename}: {err:?}"))
})
}
fn json_to_starlark<'v>(heap: Heap<'v>, value: &JsonValue) -> Result<Value<'v>> {
Ok(match value {
JsonValue::Null => Value::new_none(),
JsonValue::Bool(boolean) => Value::new_bool(*boolean),
JsonValue::Number(number) => {
if let Some(int) = number.as_i64() {
heap.alloc(int)
} else if let Some(int) = number.as_u64() {
heap.alloc(int)
} else if let Some(float) = number.as_f64() {
heap.alloc(float)
} else {
return Err(anyhow!("number {number} has no starlark representation"));
}
}
JsonValue::String(string) => heap.alloc(string.as_str()),
JsonValue::Array(items) => {
let mut values = Vec::with_capacity(items.len());
for item in items {
values.push(json_to_starlark(heap, item)?);
}
heap.alloc(AllocList(values))
}
JsonValue::Object(map) => {
let mut pairs = Vec::with_capacity(map.len());
for (key, value) in map {
pairs.push((heap.alloc(key.as_str()), json_to_starlark(heap, value)?));
}
heap.alloc(AllocDict(pairs))
}
})
}
#[cfg(test)]
mod tests {
use super::*;
fn transforms(source: &str) -> StarlarkTransforms {
StarlarkTransforms::compile(source, "test", None).unwrap()
}
#[test]
fn calls_a_simple_transform() {
let user = transforms("def cidr_host(v):\n return v.split(\"/\")[0]\n");
let result = user
.call("cidr_host", &serde_json::json!("10.0.0.1/24"), &[])
.unwrap();
assert_eq!(result, Some(serde_json::json!("10.0.0.1")));
}
#[test]
fn unknown_name_is_none() {
let user = transforms("def f(v):\n return v\n");
assert_eq!(user.call("g", &serde_json::json!(1), &[]).unwrap(), None);
}
#[test]
fn passes_literal_arguments() {
let user = transforms("def pad(v, width, fill):\n return fill * (width - len(v)) + v\n");
let result = user
.call(
"pad",
&serde_json::json!("7"),
&[serde_json::json!(3), serde_json::json!("0")],
)
.unwrap();
assert_eq!(result, Some(serde_json::json!("007")));
}
#[test]
fn fail_surfaces_as_error() {
let user = transforms("def f(v):\n fail(\"no mapping for \" + v)\n");
let err = user.call("f", &serde_json::json!("vyos"), &[]).unwrap_err();
assert!(err.to_string().contains("no mapping for vyos"), "{err:#}");
}
#[test]
fn non_callable_binding_errors() {
let user = transforms("TABLE = {\"a\": 1}\n");
let err = user
.call("TABLE", &serde_json::json!("a"), &[])
.unwrap_err();
assert!(!err.to_string().is_empty());
}
#[test]
fn round_trips_types() {
let user = transforms("def id(v):\n return v\n");
for value in [
serde_json::json!(null),
serde_json::json!(true),
serde_json::json!(i64::MIN),
serde_json::json!(i64::MAX),
serde_json::json!(u64::MAX),
serde_json::json!(2.5),
serde_json::json!("s"),
serde_json::json!([1, "a", [true]]),
serde_json::json!({"k": {"nested": [1, 2]}}),
] {
let result = user.call("id", &value, &[]).unwrap();
assert_eq!(result, Some(value));
}
}
#[test]
fn typed_returns_map_to_json() {
let user = transforms(
"def profile(platform):\n return {\"os\": platform, \"ports\": [1, 2]}\n",
);
let result = user
.call("profile", &serde_json::json!("eos"), &[])
.unwrap();
assert_eq!(
result,
Some(serde_json::json!({"os": "eos", "ports": [1, 2]}))
);
}
#[test]
fn function_return_has_no_json_representation() {
let user = transforms("def f(v):\n return lambda: v\n");
let err = user.call("f", &serde_json::json!(1), &[]).unwrap_err();
assert!(
err.to_string().contains("no json representation"),
"{err:#}"
);
}
#[test]
fn recursion_is_rejected() {
let user = transforms("def f(v):\n return f(v)\n");
let err = user.call("f", &serde_json::json!(1), &[]).unwrap_err();
assert!(!err.to_string().is_empty());
}
#[test]
fn load_requires_a_base_dir() {
let err =
StarlarkTransforms::compile("load(\"lib.star\", \"f\")\n", "test", None).unwrap_err();
assert!(
err.to_string()
.contains("load() requires a map spec loaded from a file"),
"{err:#}"
);
}
#[test]
fn load_rejects_absolute_paths() {
let dir = tempfile::tempdir().unwrap();
let err = StarlarkTransforms::compile(
"load(\"/etc/lib.star\", \"f\")\n",
"test",
Some(dir.path()),
)
.unwrap_err();
assert!(
err.to_string().contains("absolute paths are not allowed"),
"{err:#}"
);
}
#[test]
fn load_resolves_relative_to_base_dir() {
let dir = tempfile::tempdir().unwrap();
std::fs::write(
dir.path().join("lib.star"),
"def shout(v):\n return v.upper()\n",
)
.unwrap();
let user = StarlarkTransforms::compile(
"load(\"lib.star\", \"shout\")\n\ndef f(v):\n return shout(v)\n",
"test",
Some(dir.path()),
)
.unwrap();
let result = user.call("f", &serde_json::json!("quiet"), &[]).unwrap();
assert_eq!(result, Some(serde_json::json!("QUIET")));
}
#[test]
fn load_detects_cycles() {
let dir = tempfile::tempdir().unwrap();
std::fs::write(
dir.path().join("a.star"),
"load(\"b.star\", \"g\")\n\ndef f(v):\n return g(v)\n",
)
.unwrap();
std::fs::write(
dir.path().join("b.star"),
"load(\"a.star\", \"f\")\n\ndef g(v):\n return f(v)\n",
)
.unwrap();
let err =
StarlarkTransforms::compile("load(\"a.star\", \"f\")\n", "test", Some(dir.path()))
.unwrap_err();
assert!(err.to_string().contains("load() cycle"), "{err:#}");
}
}