harn-vm 0.7.56

Async bytecode virtual machine for the Harn programming language
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap};
use std::rc::Rc;

use crate::value::{VmError, VmValue};
use crate::vm::Vm;

thread_local! {
    static REGEX_CACHE: RefCell<HashMap<String, regex::Regex>> = RefCell::new(HashMap::new());
}

fn get_cached_regex(pattern: &str, flags: &str) -> Result<regex::Regex, VmError> {
    let cache_key = format!("{flags}\0{pattern}");
    REGEX_CACHE.with(|cache| {
        let mut cache = cache.borrow_mut();
        if let Some(re) = cache.get(&cache_key) {
            return Ok(re.clone());
        }
        let re = build_regex(pattern, flags).map_err(|e| {
            VmError::Thrown(VmValue::String(Rc::from(format!("Invalid regex: {e}"))))
        })?;
        if cache.len() >= 128 {
            cache.clear();
        }
        cache.insert(cache_key, re.clone());
        Ok(re)
    })
}

fn build_regex(pattern: &str, flags: &str) -> Result<regex::Regex, String> {
    let mut builder = regex::RegexBuilder::new(pattern);
    for flag in flags.chars() {
        match flag {
            'i' => builder.case_insensitive(true),
            'm' => builder.multi_line(true),
            's' => builder.dot_matches_new_line(true),
            'x' => builder.ignore_whitespace(true),
            _ => {
                return Err(format!(
                    "unsupported regex flag '{flag}', expected one of i/m/s/x"
                ));
            }
        };
    }
    builder.build().map_err(|e| e.to_string())
}

pub(crate) fn register_regex_builtins(vm: &mut Vm) {
    vm.register_builtin("regex_match", |args, _out| {
        if args.len() >= 2 {
            let pattern = args[0].display();
            let text = args[1].display();
            let flags = args.get(2).map(VmValue::display).unwrap_or_default();
            let re = get_cached_regex(&pattern, &flags)?;
            let matches: Vec<VmValue> = re
                .find_iter(&text)
                .map(|m| VmValue::String(Rc::from(m.as_str())))
                .collect();
            if matches.is_empty() {
                return Ok(VmValue::Nil);
            }
            return Ok(VmValue::List(Rc::new(matches)));
        }
        Ok(VmValue::Nil)
    });

    // Both `regex_replace` and `regex_replace_all` replace every match via the
    // `regex` crate (supports `$1`, `${name}` backrefs). The `_all` spelling is
    // a discoverability alias on the same implementation.
    fn replace_all_impl(args: &[VmValue]) -> Result<VmValue, VmError> {
        if args.len() >= 3 {
            let pattern = args[0].display();
            let replacement = args[1].display();
            let text = args[2].display();
            let re = get_cached_regex(&pattern, "")?;
            return Ok(VmValue::String(Rc::from(
                re.replace_all(&text, replacement.as_str()).into_owned(),
            )));
        }
        Ok(VmValue::Nil)
    }
    vm.register_builtin("regex_replace", |args, _out| replace_all_impl(args));
    vm.register_builtin("regex_replace_all", |args, _out| replace_all_impl(args));

    vm.register_builtin("regex_captures", |args, _out| {
        if args.len() < 2 {
            return Ok(VmValue::List(Rc::new(Vec::new())));
        }
        let pattern = args[0].display();
        let text = args[1].display();
        let re = get_cached_regex(&pattern, "")?;

        let mut results: Vec<VmValue> = Vec::new();
        for caps in re.captures_iter(&text) {
            let mut dict = BTreeMap::new();

            dict.insert(
                "match".to_string(),
                VmValue::String(Rc::from(caps.get(0).map_or("", |m| m.as_str()))),
            );

            let groups: Vec<VmValue> = (1..caps.len())
                .map(|i| match caps.get(i) {
                    Some(m) => VmValue::String(Rc::from(m.as_str())),
                    None => VmValue::Nil,
                })
                .collect();
            dict.insert("groups".to_string(), VmValue::List(Rc::new(groups)));

            for name in re.capture_names().flatten() {
                if let Some(m) = caps.name(name) {
                    dict.insert(name.to_string(), VmValue::String(Rc::from(m.as_str())));
                }
            }

            results.push(VmValue::Dict(Rc::new(dict)));
        }
        Ok(VmValue::List(Rc::new(results)))
    });

    vm.register_builtin("regex_split", |args, _out| {
        if args.len() < 2 {
            return Ok(VmValue::Nil);
        }
        let text = args[0].display();
        let pattern = args[1].display();
        let flags = args.get(2).map(VmValue::display).unwrap_or_default();
        let re = get_cached_regex(&pattern, &flags)?;
        Ok(VmValue::List(Rc::new(
            re.split(&text)
                .map(|part| VmValue::String(Rc::from(part)))
                .collect(),
        )))
    });
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::vm::Vm;
    use std::rc::Rc;

    fn vm() -> Vm {
        let mut vm = Vm::new();
        register_regex_builtins(&mut vm);
        vm
    }

    fn call(vm: &mut Vm, name: &str, args: Vec<VmValue>) -> Result<VmValue, VmError> {
        let f = vm.builtins.get(name).unwrap().clone();
        let mut out = String::new();
        f(&args, &mut out)
    }

    fn s(v: &str) -> VmValue {
        VmValue::String(Rc::from(v))
    }

    fn unwrap_list(v: &VmValue) -> &Vec<VmValue> {
        match v {
            VmValue::List(l) => l,
            _ => panic!("expected List, got {:?}", v.display()),
        }
    }

    #[test]
    fn match_basic() {
        let mut vm = vm();
        let result = call(
            &mut vm,
            "regex_match",
            vec![s(r"\d+"), s("abc 123 def 456")],
        )
        .unwrap();
        let list = unwrap_list(&result);
        assert_eq!(list.len(), 2);
        assert_eq!(list[0].display(), "123");
        assert_eq!(list[1].display(), "456");
    }

    #[test]
    fn match_no_match_returns_nil() {
        let mut vm = vm();
        let result = call(&mut vm, "regex_match", vec![s(r"\d+"), s("no digits here")]).unwrap();
        assert!(matches!(result, VmValue::Nil));
    }

    #[test]
    fn match_empty_pattern() {
        let mut vm = vm();
        let result = call(&mut vm, "regex_match", vec![s(""), s("abc")]).unwrap();
        let list = unwrap_list(&result);
        assert_eq!(list.len(), 4);
    }

    #[test]
    fn match_missing_args_returns_nil() {
        let mut vm = vm();
        let result = call(&mut vm, "regex_match", vec![s(r"\d+")]).unwrap();
        assert!(matches!(result, VmValue::Nil));
    }

    #[test]
    fn match_invalid_regex_errors() {
        let mut vm = vm();
        let result = call(&mut vm, "regex_match", vec![s(r"[invalid"), s("text")]);
        assert!(result.is_err());
    }

    #[test]
    fn match_unicode() {
        let mut vm = vm();
        let result = call(&mut vm, "regex_match", vec![s(r"\w+"), s("café résumé")]).unwrap();
        let list = unwrap_list(&result);
        assert_eq!(list.len(), 2);
        assert_eq!(list[0].display(), "café");
        assert_eq!(list[1].display(), "résumé");
    }

    #[test]
    fn replace_basic() {
        let mut vm = vm();
        let result = call(
            &mut vm,
            "regex_replace",
            vec![s(r"\d+"), s("NUM"), s("abc 123 def 456")],
        )
        .unwrap();
        assert_eq!(result.display(), "abc NUM def NUM");
    }

    #[test]
    fn replace_no_match_returns_original() {
        let mut vm = vm();
        let result = call(
            &mut vm,
            "regex_replace",
            vec![s(r"\d+"), s("NUM"), s("no digits")],
        )
        .unwrap();
        assert_eq!(result.display(), "no digits");
    }

    #[test]
    fn replace_with_backreference() {
        let mut vm = vm();
        let result = call(
            &mut vm,
            "regex_replace",
            vec![s(r"(\w+)\s(\w+)"), s("$2 $1"), s("hello world")],
        )
        .unwrap();
        assert_eq!(result.display(), "world hello");
    }

    #[test]
    fn replace_missing_args_returns_nil() {
        let mut vm = vm();
        let result = call(&mut vm, "regex_replace", vec![s(r"\d+"), s("X")]).unwrap();
        assert!(matches!(result, VmValue::Nil));
    }

    #[test]
    fn captures_with_groups() {
        let mut vm = vm();
        let result = call(
            &mut vm,
            "regex_captures",
            vec![s(r"(\d+)-(\w+)"), s("123-abc 456-def")],
        )
        .unwrap();
        let list = unwrap_list(&result);
        assert_eq!(list.len(), 2);

        let first = list[0].as_dict().unwrap();
        assert_eq!(first.get("match").unwrap().display(), "123-abc");
        let groups = unwrap_list(first.get("groups").unwrap());
        assert_eq!(groups[0].display(), "123");
        assert_eq!(groups[1].display(), "abc");
    }

    #[test]
    fn captures_named_groups() {
        let mut vm = vm();
        let result = call(
            &mut vm,
            "regex_captures",
            vec![s(r"(?P<year>\d{4})-(?P<month>\d{2})"), s("2024-01")],
        )
        .unwrap();
        let list = unwrap_list(&result);
        assert_eq!(list.len(), 1);
        let cap = list[0].as_dict().unwrap();
        assert_eq!(cap.get("year").unwrap().display(), "2024");
        assert_eq!(cap.get("month").unwrap().display(), "01");
    }

    #[test]
    fn captures_no_match_returns_empty_list() {
        let mut vm = vm();
        let result = call(&mut vm, "regex_captures", vec![s(r"\d+"), s("no digits")]).unwrap();
        let list = unwrap_list(&result);
        assert!(list.is_empty());
    }

    #[test]
    fn captures_optional_group_nil() {
        let mut vm = vm();
        let result = call(
            &mut vm,
            "regex_captures",
            vec![s(r"(\d+)(?:-(\w+))?"), s("123")],
        )
        .unwrap();
        let list = unwrap_list(&result);
        assert_eq!(list.len(), 1);
        let groups = unwrap_list(list[0].as_dict().unwrap().get("groups").unwrap());
        assert_eq!(groups[0].display(), "123");
        assert!(matches!(groups[1], VmValue::Nil));
    }

    #[test]
    fn cache_returns_consistent_results() {
        let mut vm = vm();
        let a = call(&mut vm, "regex_match", vec![s(r"\d+"), s("42")]).unwrap();
        let b = call(&mut vm, "regex_match", vec![s(r"\d+"), s("42")]).unwrap();
        assert_eq!(a.display(), b.display());
    }

    #[test]
    fn cache_eviction_still_works() {
        for i in 0..70 {
            let pattern = format!("pat{i}");
            let _ = get_cached_regex(&pattern, "");
        }
        let re = get_cached_regex("pat0", "").unwrap();
        assert!(re.is_match("pat0"));
    }
}