1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#![allow(clippy::cargo_common_metadata)]

use std::time::Duration;

use mlua::prelude::*;
use mlua_luau_scheduler::Functions;

use tokio::time::{sleep, Instant};

use lune_utils::TableBuilder;

/**
    Creates the `task` standard library module.

    # Errors

    Errors when out of memory, or if default Lua globals are missing.
*/
pub fn module(lua: &Lua) -> LuaResult<LuaTable> {
    let fns = Functions::new(lua)?;

    // Create wait & delay functions
    let task_wait = lua.create_async_function(wait)?;
    let task_delay_env = TableBuilder::new(lua)?
        .with_value("select", lua.globals().get::<_, LuaFunction>("select")?)?
        .with_value("spawn", fns.spawn.clone())?
        .with_value("defer", fns.defer.clone())?
        .with_value("wait", task_wait.clone())?
        .build_readonly()?;
    let task_delay = lua
        .load(DELAY_IMPL_LUA)
        .set_name("task.delay")
        .set_environment(task_delay_env)
        .into_function()?;

    // Overwrite resume & wrap functions on the coroutine global
    // with ones that are compatible with our scheduler
    let co = lua.globals().get::<_, LuaTable>("coroutine")?;
    co.set("resume", fns.resume.clone())?;
    co.set("wrap", fns.wrap.clone())?;

    TableBuilder::new(lua)?
        .with_value("cancel", fns.cancel)?
        .with_value("defer", fns.defer)?
        .with_value("delay", task_delay)?
        .with_value("spawn", fns.spawn)?
        .with_value("wait", task_wait)?
        .build_readonly()
}

const DELAY_IMPL_LUA: &str = r"
return defer(function(...)
    wait(select(1, ...))
    spawn(select(2, ...))
end, ...)
";

async fn wait(_: &Lua, secs: Option<f64>) -> LuaResult<f64> {
    let duration = Duration::from_secs_f64(secs.unwrap_or_default());

    let before = Instant::now();
    sleep(duration).await;
    let after = Instant::now();

    Ok((after - before).as_secs_f64())
}