use std::time::Duration;
use boa_engine::{Context, Source};
use crate::error::{Error, Result};
const LOOP_ITERATION_LIMIT: u64 = 5_000_000;
const EVAL_TIMEOUT: Duration = Duration::from_secs(5);
fn limited_context() -> Context {
let mut context = Context::default();
context
.runtime_limits_mut()
.set_loop_iteration_limit(LOOP_ITERATION_LIMIT);
context
}
#[derive(Debug, Clone)]
pub(crate) struct JsFunction {
source: String,
name: String,
}
impl JsFunction {
pub fn compile(source: &str, name: &str) -> Result<Self> {
let mut context = limited_context();
context
.eval(Source::from_bytes(source))
.map_err(|e| Error::Cipher(format!("failed to compile function `{name}`: {e}")))?;
Ok(Self {
source: source.to_string(),
name: name.to_string(),
})
}
pub fn call_str(&self, input: &str) -> Result<String> {
let input_json = serde_json::to_string(input)
.map_err(|e| Error::Cipher(format!("failed to encode input: {e}")))?;
let script = format!(
"{src}\n{name}({arg})",
src = self.source,
name = self.name,
arg = input_json
);
let mut context = limited_context();
let value = context
.eval(Source::from_bytes(&script))
.map_err(|e| Error::Cipher(format!("failed to evaluate `{}`: {e}", self.name)))?;
let js_string = value.to_string(&mut context).map_err(|e| {
Error::Cipher(format!("result of `{}` is not a string: {e}", self.name))
})?;
Ok(js_string.to_std_string_lossy())
}
pub async fn call_str_async(&self, input: &str) -> Result<String> {
let this = self.clone();
let input = input.to_string();
let join = tokio::task::spawn_blocking(move || this.call_str(&input));
match tokio::time::timeout(EVAL_TIMEOUT, join).await {
Ok(Ok(result)) => result,
Ok(Err(join_err)) => Err(Error::Cipher(format!(
"cipher evaluation thread failed: {join_err}"
))),
Err(_elapsed) => Err(Error::Cipher(
"cipher evaluation exceeded time limit".into(),
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Error;
#[test]
fn runs_extracted_function_against_input() {
let src = r#"function decode(a){a=a.split("");a.reverse();return a.join("")}"#;
let f = JsFunction::compile(src, "decode").unwrap();
assert_eq!(f.call_str("abc").unwrap(), "cba");
}
#[test]
fn compile_error_is_cipher_error() {
assert!(matches!(
JsFunction::compile("not js ((", "f"),
Err(Error::Cipher(_))
));
}
#[test]
fn call_is_reusable_and_send() {
fn assert_send<T: Send>() {}
assert_send::<JsFunction>();
}
#[test]
fn infinite_loop_in_body_is_bounded_by_iteration_limit() {
let src = r#"function evil(a){var i=0;while(true){i=i+1}return a}"#;
let f = JsFunction::compile(src, "evil").unwrap();
assert!(matches!(f.call_str("seed"), Err(Error::Cipher(_))));
}
#[tokio::test]
async fn async_call_runs_off_executor_and_returns_result() {
let src = r#"function decode(a){a=a.split("");a.reverse();return a.join("")}"#;
let f = JsFunction::compile(src, "decode").unwrap();
assert_eq!(f.call_str_async("abc").await.unwrap(), "cba");
}
}