luallaby 0.1.0

**Work in progress** A pure-Rust Lua interpreter/compiler
Documentation
use std::iter::once;
use std::rc::Rc;

use super::StdLib;
use crate::error::Result;
use crate::vm::{FuncBuiltin, Value, VM};
use crate::{Error, LuaError};

const MAX: u32 = 0x7FFFFFFF;

pub(super) fn module(stdlib: &mut StdLib) -> Result<()> {
    stdlib
        .module("utf8")
        .func("char", char)?
        .cons(
            "charpattern",
            Value::str_bytes(vec![
                91, 0, 45, 127, 194, 45, 253, 93, 91, 128, 45, 191, 93, 42,
            ]),
        )?
        .func("codes", codes)?
        .func("codepoint", codepoint)?
        .func("len", len)?
        .func("offset", offset)?;
    Ok(())
}

#[inline]
fn is_cont_byte(b: u8) -> bool {
    (b & 0xC0) == 0x80
}

#[inline]
fn index(val: Value, len: usize, default: i64) -> Result<i64> {
    let num = match val {
        Value::Nil => default,
        val => val.to_number_coerce()?.coerce_int()?,
    };

    Ok(if num >= 0 {
        num - 1
    } else {
        num.wrapping_add_unsigned(len as u64)
    })
}

fn next(s: &[u8], lax: bool) -> Option<(u32, usize)> {
    const MINS: [u32; 7] = [0, 0, 0x80, 0x800, 0x10000, 0x200000, 0x4000000];
    if s.is_empty() {
        return None;
    }
    let c = s[0];
    if c < 0x80 {
        Some((c as u32, 1))
    } else {
        if is_cont_byte(s[0]) {
            return None;
        }

        let mut u = 0u32;
        let mut count = 1;
        let mut need = c;

        while need & 0x40 > 0 {
            // While need another continuation byte
            if count >= s.len() || !is_cont_byte(s[count]) {
                return None;
            }
            u = (u << 6) | (s[count] as u32 & 0x3F);

            need <<= 1;
            count += 1;
        }
        let c: u8 = c << (count + 1);
        u |= (c as u32) << ((count.saturating_sub(2)) * 5 + 3);

        if u > MAX || count > 6 || u < MINS[count] || (!lax && char::from_u32(u).is_none()) {
            None
        } else {
            Some((u, count))
        }
    }
}

fn char(vm: &mut VM) -> Result<Value> {
    let mut s = Vec::new();
    for arg in vm.arg_split(0) {
        let code = arg.to_int_coerce()?;
        if code > MAX as i64 || code < 0 {
            return err!(LuaError::ValueRange);
        }
        let mut buf = [0u8; 6];
        let mut code = code as u32;
        let count = if code < 0x80 {
            buf[buf.len() - 1] = code as u8;
            1
        } else {
            let mut count = 1;
            let mut mfb = 0x3F; // Max that fits in first byte
            while code > mfb {
                buf[buf.len() - count] = (0x80 | (code & 0x3F)) as u8;
                code >>= 6;
                mfb >>= 1;
                count += 1;
            }
            buf[buf.len() - count] = ((!mfb << 1) | code) as u8;
            count
        };
        s.extend(buf.into_iter().skip(buf.len() - count));
    }
    Ok(Value::str_bytes(s))
}

fn codes(vm: &mut VM) -> Result<Value> {
    let s = vm.arg_string_coerce(0)?;
    let lax = vm.arg_or_nil(1).is_truthy();

    Ok(Value::Mult(vec![
        Value::Func(vm.alloc_builtin(FuncBuiltin {
            module: "utf8",
            name: "__utf8_codes",
            func: Rc::new(move |vm| {
                let ctrl = vm.arg_int_coerce(1)?;
                let start = if ctrl <= 0 {
                    0
                } else if ctrl as usize >= s.len() {
                    return Ok(Value::Nil);
                } else {
                    let ctrl = (ctrl - 1) as usize;
                    let (_, read) = next(&s[ctrl..], lax)
                        .ok_or_else(|| Error::from_lua(LuaError::Utf8Invalid))?;
                    ctrl + read
                };
                Ok(Value::Mult(if start == s.len() {
                    vec![Value::Nil]
                } else {
                    let (code, _) = next(&s[start..], lax)
                        .ok_or_else(|| Error::from_lua(LuaError::Utf8Invalid))?;
                    vec![Value::int((start + 1) as i64), Value::int(code as i64)]
                }))
            }),
        })),
        Value::Nil,
        Value::int(-1),
    ]))
}

fn codepoint(vm: &mut VM) -> Result<Value> {
    let s = vm.arg_string_coerce(0)?;
    let mut i = index(vm.arg_or_nil(1), s.len(), 1)?;
    let j = index(vm.arg_or_nil(2), s.len(), i as i64 + 1)?;
    let lax = vm.arg_or_nil(3).is_truthy();

    if i < 0 || (j > 0 && (j as usize) + 1 > s.len()) {
        return err!(LuaError::Utf8OutOfBounds);
    }

    let mut codes = Vec::new();
    while i <= j {
        let (code, read) =
            next(&s[i as usize..], lax).ok_or_else(|| Error::from_lua(LuaError::Utf8Invalid))?;
        codes.push(Value::int(code as i64));
        i += read as i64;
    }

    Ok(Value::Mult(codes))
}

fn len(vm: &mut VM) -> Result<Value> {
    let s = vm.arg_string_coerce(0)?;
    if s.is_empty() {
        return Ok(Value::int(0));
    }
    let mut i = index(vm.arg_or_nil(1), s.len(), 1)?;
    let j = index(vm.arg_or_nil(2), s.len(), -1)?;
    let lax = vm.arg_or_nil(3).is_truthy();

    if i < 0 || i as usize > s.len() || j < 0 || (j as usize) + 1 > s.len() {
        return err!(LuaError::Utf8OutOfBounds);
    }

    let mut chars = 0;
    let mut bytes = 0;
    while i <= j {
        match next(&s[i as usize..], lax) {
            Some((_, read)) => {
                i += read as i64;
                chars += 1;
                bytes += read as i64;
            }
            None => return Ok(Value::Mult(vec![Value::Nil, Value::int(bytes + 1)])),
        }
    }

    Ok(Value::int(chars))
}

fn offset(vm: &mut VM) -> Result<Value> {
    let s = vm.arg_string_coerce(0)?;

    let n = vm.arg_int_coerce(1)?;
    let i = index(
        vm.arg_or_nil(2),
        s.len(),
        if n >= 0 { 1 } else { s.len() as i64 + 1 },
    )?;
    if i < 0 || i as usize > s.len() {
        return err!(LuaError::Utf8OutOfBounds);
    }
    let i = i as usize;

    let res = if n == 0 {
        Some(
            s.into_iter()
                .enumerate()
                .filter_map(|(idx, b)| (!is_cont_byte(b) && idx <= i).then_some(idx))
                .last()
                .unwrap_or(0),
        )
    } else {
        if i < s.len() && is_cont_byte(s[i]) {
            return err!(LuaError::Utf8IsContByte);
        }
        if n < 0 {
            s.into_iter()
                .enumerate()
                .filter_map(|(idx, b)| (!is_cont_byte(b) && idx < i).then_some(idx))
                .rev()
                .nth((n.abs() - 1) as usize)
        } else {
            s.into_iter()
                .chain(once(b'\0')) // Mimic NULL-terminated string
                .enumerate()
                .skip(i)
                .filter_map(|(idx, b)| (!is_cont_byte(b)).then_some(idx))
                .nth((n.abs() - 1) as usize)
        }
    };
    Ok(match res {
        Some(idx) => Value::int(idx as i64 + 1),
        None => Value::Nil,
    })
}