mlua-mathlib 0.1.0

Math library for mlua — RNG, distributions, and descriptive statistics
Documentation
use mlua::prelude::*;
use rand_distr::{Beta, Distribution, Exp, Gamma, Normal, Poisson, Uniform};

use crate::rng::LuaRng;

pub(crate) fn register(lua: &Lua, t: &LuaTable) -> LuaResult<()> {
    t.set(
        "normal_sample",
        lua.create_function(
            |_, (rng, mean, stddev): (LuaUserDataRef<LuaRng>, f64, f64)| {
                let dist = Normal::new(mean, stddev)
                    .map_err(|e| LuaError::runtime(format!("normal_sample: {e}")))?;
                Ok(dist.sample(&mut *rng.0.borrow_mut()))
            },
        )?,
    )?;

    t.set(
        "beta_sample",
        lua.create_function(
            |_, (rng, alpha, beta): (LuaUserDataRef<LuaRng>, f64, f64)| {
                let dist = Beta::new(alpha, beta)
                    .map_err(|e| LuaError::runtime(format!("beta_sample: {e}")))?;
                Ok(dist.sample(&mut *rng.0.borrow_mut()))
            },
        )?,
    )?;

    t.set(
        "gamma_sample",
        lua.create_function(
            |_, (rng, shape, scale): (LuaUserDataRef<LuaRng>, f64, f64)| {
                let dist = Gamma::new(shape, scale)
                    .map_err(|e| LuaError::runtime(format!("gamma_sample: {e}")))?;
                Ok(dist.sample(&mut *rng.0.borrow_mut()))
            },
        )?,
    )?;

    t.set(
        "exp_sample",
        lua.create_function(|_, (rng, lambda): (LuaUserDataRef<LuaRng>, f64)| {
            let dist =
                Exp::new(lambda).map_err(|e| LuaError::runtime(format!("exp_sample: {e}")))?;
            Ok(dist.sample(&mut *rng.0.borrow_mut()))
        })?,
    )?;

    t.set(
        "poisson_sample",
        lua.create_function(|_, (rng, lambda): (LuaUserDataRef<LuaRng>, f64)| {
            let dist = Poisson::new(lambda)
                .map_err(|e| LuaError::runtime(format!("poisson_sample: {e}")))?;
            let val: f64 = dist.sample(&mut *rng.0.borrow_mut());
            Ok(val as u64)
        })?,
    )?;

    t.set(
        "uniform_sample",
        lua.create_function(|_, (rng, low, high): (LuaUserDataRef<LuaRng>, f64, f64)| {
            let dist = Uniform::new(low, high)
                .map_err(|e| LuaError::runtime(format!("uniform_sample: {e}")))?;
            Ok(dist.sample(&mut *rng.0.borrow_mut()))
        })?,
    )?;

    Ok(())
}