ezlua 0.5.4

Ergonomic, efficient and Zero-cost rust bindings to Lua5.4
Documentation
use super::*;

use core::cell::RefCell;
#[cfg(not(target_os = "windows"))]
use std::os::unix::thread::{JoinHandleExt, RawPthread as RawHandle};
#[cfg(target_os = "windows")]
use std::os::windows::io::{AsRawHandle, RawHandle};
use std::sync::{mpsc::*, *};
use std::thread::{self, JoinHandle};
use std::time::Duration;

struct LuaThread {
    handle: RawHandle,
    join: JoinHandle<LuaResult<CoroutineWithRef>>,
}

impl UserData for LuaThread {
    const TYPE_NAME: &'static str = "LLuaThread";

    fn getter(fields: UserdataRegistry<Self>) -> Result<()> {
        fields.set_closure("handle", |this: &Self| this.handle as usize)?;
        fields.add_method("name", |s, this, ()| this.join.thread().name())?;
        fields.set_closure("id", |this: &Self| this.join.thread().id().as_u64().get())?;

        Ok(())
    }

    fn methods(mt: UserdataRegistry<Self>) -> Result<()> {
        mt.set(
            "join",
            mt.state()
                .new_closure1(|lua, OwnedUserdata::<Self>(this)| {
                    this.join.join().lua_result()??.take(lua)
                })?,
        )?;
        mt.set_closure("unpark", |this: &Self| this.join.thread().unpark())?;

        Ok(())
    }
}

#[cfg(target_os = "windows")]
const RAW_NULL: RawHandle = std::ptr::null_mut();
#[cfg(not(target_os = "windows"))]
const RAW_NULL: RawHandle = 0;

#[derive(Default, Deref, AsRef)]
struct LuaMutex(Mutex<()>);
struct LuaMutexGaurd<'a>(MutexGuard<'a, ()>);

impl<'a> UserData for LuaMutexGaurd<'a> {
    const TYPE_NAME: &'static str = "LuaMutexGaurd";

    fn methods(methods: UserdataRegistry<Self>) -> LuaResult<()> {
        methods.set_closure("unlock", ValRef::close_and_remove_metatable)?;
        Ok(())
    }
}

impl UserData for LuaMutex {
    const TYPE_NAME: &'static str = "LuaMutex";

    fn methods(mt: UserdataRegistry<Self>) -> Result<()> {
        mt.as_deref()
            .add_deref("is_poisoned", Mutex::<()>::is_poisoned)?;
        mt.add_method("lock", |s, this, ()| {
            s.new_userdata_with_values(LuaMutexGaurd(this.0.lock().lua_result()?), ArgRef(1))
        })?;
        mt.add_method("try_lock", |s, this, ()| {
            NilError(
                this.0
                    .try_lock()
                    .lua_result()
                    .map(LuaMutexGaurd)
                    .map(|guard| s.new_userdata_with_values(guard, ArgRef(1))),
            )
        })?;

        Ok(())
    }
}

#[derive(Debug)]
struct LuaCondVar {
    lock: Mutex<Reference>,
    cvar: Condvar,
}

impl Default for LuaCondVar {
    fn default() -> Self {
        Self {
            lock: Reference(0).into(),
            cvar: Default::default(),
        }
    }
}

impl UserData for LuaCondVar {
    const TYPE_NAME: &'static str = "thread.CondVar";

    fn methods(mt: UserdataRegistry<Self>) -> Result<()> {
        mt.add_method("wait", |s, this, tm| this.wait(s, tm))?;
        mt.add_method("notify_one", |s, this, val: ValRef| {
            this.push_some(s, val).map(|_| this.cvar.notify_one())
        })?;
        mt.add_method("notify_all", |s, this, val: ValRef| {
            this.push_some(s, val).map(|_| this.cvar.notify_all())
        })?;

        Ok(())
    }

    fn metatable(mt: UserdataRegistry<Self>) -> LuaResult<()> {
        mt.set_closure("__close", |lua: &LuaState, this: OwnedUserdata<Self>| {
            lua.registry().unreference(*this.0.lock.lock().unwrap());
        })?;

        Ok(())
    }
}

impl LuaCondVar {
    fn push_some(&self, s: &LuaState, val: ValRef) -> Result<()> {
        let mut i = self.lock.lock().lua_result()?;
        let creg = s.registry();
        creg.unreference(core::mem::take(&mut *i));
        *i = creg.reference(val)?;

        Ok(())
    }

    fn wait<'a>(&self, s: &'a LuaState, timeout: Option<Duration>) -> Result<ValRef<'a>> {
        let lock = &self.lock;
        let cvar = &self.cvar;
        if let Some(tm) = timeout {
            let (i, r) = cvar
                .wait_timeout(lock.lock().lua_result()?, tm)
                .map_err(LuaError::runtime_debug)?;
            if r.timed_out() {
                return s.new_val(());
            }
            // Should not be taken away because there may be multiple threads waiting
            s.registry().raw_geti(i.0)
        } else {
            let i = cvar
                .wait(lock.lock().lua_result()?)
                .map_err(LuaError::runtime_debug)?;
            s.registry().raw_geti(i.0)
        }
    }
}

impl UserData for Sender<Reference> {
    const TYPE_NAME: &'static str = "thread.Sender";

    fn methods(methods: UserdataRegistry<Self>) -> LuaResult<()> {
        methods.set_closure("send", |lua: &LuaState, this: &Self, val: ValRef| {
            this.send(lua.registry().reference(val)?).lua_result()
        })?;

        Ok(())
    }
}

impl UserData for Receiver<Reference> {
    const TYPE_NAME: &'static str = "thread.Receiver";
    type Trans = RefCell<Self>;

    fn methods(methods: UserdataRegistry<Self>) -> LuaResult<()> {
        methods.add_method_mut("recv", |lua: &LuaState, this, ()| {
            this.recv()
                .lua_result()
                .and_then(|r| lua.registry().take_reference(r))
        })?;
        methods.add_method_mut("try_recv", |lua: &LuaState, this, ()| {
            match this.try_recv() {
                Err(TryRecvError::Empty) => lua.new_val(()),
                err => err
                    .lua_result()
                    .and_then(|r| lua.registry().take_reference(r)),
            }
        })?;
        methods.add_method_mut("recv_timeout", |lua: &LuaState, this, tm| {
            this.recv_timeout(tm)
                .lua_result()
                .and_then(|r| lua.registry().take_reference(r))
        })?;

        Ok(())
    }

    fn metatable(mt: UserdataRegistry<Self>) -> LuaResult<()> {
        mt.set("__call", mt.get("__method")?.get("recv")?)?;

        Ok(())
    }
}

pub fn init(lua: &LuaState) -> Result<LuaTable> {
    let t = lua.new_table_with_size(0, 4)?;
    t.set_closure("spawn", |routine: Coroutine, name: Option<&str>| {
        thread::Builder::new()
            .name(name.unwrap_or("<lua>").into())
            .spawn(move || {
                let result = routine
                    .val(1)
                    .pcall::<_, ValRef>(())
                    .and_then(|res| routine.registry().reference(res));
                result.map(|refer| CoroutineWithRef(routine, refer))
            })
            .map(|join| {
                #[cfg(target_os = "windows")]
                let handle = join.as_raw_handle();
                #[cfg(not(target_os = "windows"))]
                let handle = join.as_pthread_t();

                LuaThread { join, handle }
            })
    })?;
    t.set_closure("sleep", |time: u64| {
        thread::sleep(Duration::from_millis(time))
    })?;
    t.set_closure("park", thread::park)?;
    t.set_closure("id", || thread::current().id().as_u64().get())?;
    t.set_closure("yield_now", thread::yield_now)?;
    t.set_closure("mutex", LuaMutex::default)?;
    t.set_closure("condvar", LuaCondVar::default)?;
    t.set(
        "name",
        lua.new_closure0(|s| s.new_val(thread::current().name()))?,
    )?;

    let sync = lua.new_table()?;
    sync.set_closure("channel", channel::<Reference>)?;
    t.set("sync", sync)?;

    Ok(t)
}

pub fn wrap_sender<'l, T: FromLuaMulti<'l> + Send + 'static>(
    lua: &'l LuaState,
    sender: Sender<T>,
) -> LuaResult<LuaFunction> {
    lua.new_function(move |_, val: T| sender.send(val).is_ok())
}

pub fn wrap_receiver<'l, T: ToLuaMulti + Send + 'l + 'static>(
    lua: &'l LuaState,
    recver: Receiver<T>,
) -> LuaResult<LuaFunction> {
    let recver = RwLock::new(recver);
    lua.new_function(move |l, ()| {
        if let Ok(x) = recver.write().unwrap().recv() {
            l.pushed(x)
        } else {
            l.pushed(())
        }
    })
}