Skip to main content

assay/lua/builtins/
core.rs

1use data_encoding::BASE64;
2use mlua::{Lua, Value};
3use std::os::unix::fs::PermissionsExt;
4use std::time::{SystemTime, UNIX_EPOCH};
5use tracing::{error, info, warn};
6
7static TEMPDIR_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
8
9pub fn register_log(lua: &Lua) -> mlua::Result<()> {
10    let log_table = lua.create_table()?;
11
12    let info_fn = lua.create_function(|_, msg: String| {
13        info!(target: "lua", "{}", msg);
14        Ok(())
15    })?;
16    log_table.set("info", info_fn)?;
17
18    let warn_fn = lua.create_function(|_, msg: String| {
19        warn!(target: "lua", "{}", msg);
20        Ok(())
21    })?;
22    log_table.set("warn", warn_fn)?;
23
24    let error_fn = lua.create_function(|_, msg: String| {
25        error!(target: "lua", "{}", msg);
26        Ok(())
27    })?;
28    log_table.set("error", error_fn)?;
29
30    lua.globals().set("log", log_table)?;
31    Ok(())
32}
33
34pub fn register_env(lua: &Lua) -> mlua::Result<()> {
35    let env_table = lua.create_table()?;
36
37    let process_get_fn = lua.create_function(|_, name: String| match std::env::var(&name) {
38        Ok(val) => Ok(Some(val)),
39        Err(_) => Ok(None),
40    })?;
41    env_table.set("_process_get", process_get_fn)?;
42    env_table.set("_check_env", lua.create_table()?)?;
43
44    lua.globals().set("env", env_table)?;
45
46    lua.load(
47        r#"
48        function env.get(name)
49            local val = env._check_env[name]
50            if val ~= nil then return val end
51            return env._process_get(name)
52        end
53        "#,
54    )
55    .exec()?;
56
57    // env.set(key, val) — set env var (nil val = unset)
58    let set_fn = lua.create_function(|_, (key, val): (String, Option<String>)| {
59        match val {
60            Some(v) => unsafe { std::env::set_var(&key, &v) },
61            None => unsafe { std::env::remove_var(&key) },
62        }
63        Ok(())
64    })?;
65    lua.globals()
66        .get::<mlua::Table>("env")?
67        .set("set", set_fn)?;
68
69    // env.list() — returns table of {key, value} for all env vars
70    let list_fn = lua.create_function(|lua, ()| {
71        let results = lua.create_table()?;
72        for (i, (key, val)) in (1..).zip(std::env::vars()) {
73            let entry = lua.create_table()?;
74            entry.set("key", key)?;
75            entry.set("value", val)?;
76            results.set(i, entry)?;
77        }
78        Ok(results)
79    })?;
80    lua.globals()
81        .get::<mlua::Table>("env")?
82        .set("list", list_fn)?;
83
84    Ok(())
85}
86
87pub fn register_sleep(lua: &Lua) -> mlua::Result<()> {
88    let sleep_fn = lua.create_async_function(|_, seconds: f64| async move {
89        let duration = std::time::Duration::from_secs_f64(seconds);
90        tokio::time::sleep(duration).await;
91        Ok(())
92    })?;
93    lua.globals().set("sleep", sleep_fn)?;
94    Ok(())
95}
96
97pub fn register_time(lua: &Lua) -> mlua::Result<()> {
98    let time_fn = lua.create_function(|_, ()| {
99        let secs = SystemTime::now()
100            .duration_since(UNIX_EPOCH)
101            .map_err(|e| mlua::Error::runtime(format!("time(): {e}")))?
102            .as_secs_f64();
103        Ok(secs)
104    })?;
105    lua.globals().set("time", time_fn)?;
106    Ok(())
107}
108
109pub fn register_fs(lua: &Lua) -> mlua::Result<()> {
110    use crate::lua::file_source::FileSourceHandle;
111
112    let fs_table = lua.create_table()?;
113
114    let read_fn = lua.create_function(|lua, path: String| -> mlua::Result<String> {
115        let bytes = match lua.app_data_ref::<FileSourceHandle>() {
116            Some(source) => source.read(&path).ok_or_else(|| {
117                mlua::Error::runtime(format!(
118                    "fs.read: failed to read {path:?}: not found in file source"
119                ))
120            })?,
121            None => std::fs::read(&path).map_err(|e| {
122                mlua::Error::runtime(format!("fs.read: failed to read {path:?}: {e}"))
123            })?,
124        };
125        String::from_utf8(bytes).map_err(|e| {
126            mlua::Error::runtime(format!("fs.read: invalid UTF-8 in {path:?}: {e}"))
127        })
128    })?;
129    fs_table.set("read", read_fn)?;
130
131    // fs.read_bytes(path) → string (binary-safe; Lua strings can hold any bytes)
132    let read_bytes_fn = lua.create_function(|lua, path: String| {
133        let bytes = match lua.app_data_ref::<FileSourceHandle>() {
134            Some(source) => source.read(&path).ok_or_else(|| {
135                mlua::Error::runtime(format!(
136                    "fs.read_bytes: failed to read {path:?}: not found in file source"
137                ))
138            })?,
139            None => std::fs::read(&path).map_err(|e| {
140                mlua::Error::runtime(format!(
141                    "fs.read_bytes: failed to read {path:?}: {e}"
142                ))
143            })?,
144        };
145        lua.create_string(&bytes)
146    })?;
147    fs_table.set("read_bytes", read_bytes_fn)?;
148
149    let write_fn = lua.create_function(|_, (path, content): (String, String)| {
150        let p = std::path::Path::new(&path);
151        if let Some(parent) = p.parent() {
152            std::fs::create_dir_all(parent).map_err(|e| {
153                mlua::Error::runtime(format!(
154                    "fs.write: failed to create directories for {path:?}: {e}"
155                ))
156            })?;
157        }
158        std::fs::write(&path, &content)
159            .map_err(|e| mlua::Error::runtime(format!("fs.write: failed to write {path:?}: {e}")))
160    })?;
161    fs_table.set("write", write_fn)?;
162
163    // fs.write_bytes(path, data) → write binary data (Lua string with arbitrary bytes)
164    let write_bytes_fn = lua.create_function(|_, (path, data): (String, mlua::String)| {
165        let p = std::path::Path::new(&path);
166        if let Some(parent) = p.parent() {
167            std::fs::create_dir_all(parent).map_err(|e| {
168                mlua::Error::runtime(format!(
169                    "fs.write_bytes: failed to create directories for {path:?}: {e}"
170                ))
171            })?;
172        }
173        std::fs::write(&path, data.as_bytes())
174            .map_err(|e| mlua::Error::runtime(format!("fs.write_bytes: failed to write {path:?}: {e}")))
175    })?;
176    fs_table.set("write_bytes", write_bytes_fn)?;
177
178    let remove_fn = lua.create_function(|_, path: String| {
179        let p = std::path::Path::new(&path);
180        // Use symlink_metadata to detect symlinks without following them.
181        // A symlink to a directory should be removed as a file (unlink),
182        // not recursively delete the target directory.
183        let is_dir = match std::fs::symlink_metadata(&path) {
184            Ok(m) => m.file_type().is_dir(),
185            Err(_) => p.is_dir(),
186        };
187        if is_dir {
188            std::fs::remove_dir_all(&path).map_err(|e| {
189                mlua::Error::runtime(format!(
190                    "fs.remove: failed to remove directory {path:?}: {e}"
191                ))
192            })
193        } else {
194            std::fs::remove_file(&path).map_err(|e| {
195                mlua::Error::runtime(format!("fs.remove: failed to remove {path:?}: {e}"))
196            })
197        }
198    })?;
199    fs_table.set("remove", remove_fn)?;
200
201    let list_fn =
202        lua.create_function(|lua, path: String| {
203            let entries = lua.create_table()?;
204            for (i, entry) in (1..).zip(std::fs::read_dir(&path).map_err(|e| {
205                mlua::Error::runtime(format!("fs.list: failed to list {path:?}: {e}"))
206            })?) {
207                let entry = entry.map_err(|e| {
208                    mlua::Error::runtime(format!("fs.list: error reading entry in {path:?}: {e}"))
209                })?;
210                let info = lua.create_table()?;
211                let name = entry.file_name().to_string_lossy().to_string();
212                info.set("name", name)?;
213                let file_type = entry.file_type().map_err(|e| {
214                    mlua::Error::runtime(format!("fs.list: failed to get file type: {e}"))
215                })?;
216                if file_type.is_dir() {
217                    info.set("type", "directory")?;
218                } else if file_type.is_symlink() {
219                    info.set("type", "symlink")?;
220                } else {
221                    info.set("type", "file")?;
222                }
223                entries.set(i, info)?;
224            }
225            Ok(entries)
226        })?;
227    fs_table.set("list", list_fn)?;
228
229    let stat_fn = lua.create_function(|lua, path: String| {
230        let metadata = std::fs::metadata(&path)
231            .map_err(|e| mlua::Error::runtime(format!("fs.stat: failed to stat {path:?}: {e}")))?;
232        // Use symlink_metadata separately to correctly detect symlinks,
233        // since std::fs::metadata follows symlinks (is_symlink always false).
234        let is_symlink = std::fs::symlink_metadata(&path)
235            .map(|m| m.file_type().is_symlink())
236            .unwrap_or(false);
237        let info = lua.create_table()?;
238        info.set("size", metadata.len())?;
239        info.set("is_file", metadata.is_file())?;
240        info.set("is_dir", metadata.is_dir())?;
241        info.set("is_symlink", is_symlink)?;
242        if let Ok(modified) = metadata.modified()
243            && let Ok(duration) = modified.duration_since(std::time::UNIX_EPOCH)
244        {
245            info.set("modified", duration.as_secs_f64())?;
246        }
247        if let Ok(created) = metadata.created()
248            && let Ok(duration) = created.duration_since(std::time::UNIX_EPOCH)
249        {
250            info.set("created", duration.as_secs_f64())?;
251        }
252        Ok(info)
253    })?;
254    fs_table.set("stat", stat_fn)?;
255
256    let mkdir_fn = lua.create_function(|_, path: String| {
257        std::fs::create_dir_all(&path)
258            .map_err(|e| mlua::Error::runtime(format!("fs.mkdir: failed to create {path:?}: {e}")))
259    })?;
260    fs_table.set("mkdir", mkdir_fn)?;
261
262    let exists_fn =
263        lua.create_function(|_, path: String| Ok(std::path::Path::new(&path).exists()))?;
264    fs_table.set("exists", exists_fn)?;
265
266    // fs.copy(src, dst) — copy file, returns bytes copied
267    let copy_fn = lua.create_function(|_, (src, dst): (String, String)| {
268        let bytes = std::fs::copy(&src, &dst).map_err(|e| {
269            mlua::Error::runtime(format!("fs.copy: failed to copy {src:?} to {dst:?}: {e}"))
270        })?;
271        Ok(bytes)
272    })?;
273    fs_table.set("copy", copy_fn)?;
274
275    // fs.rename(src, dst) — atomic rename
276    let rename_fn = lua.create_function(|_, (src, dst): (String, String)| {
277        std::fs::rename(&src, &dst).map_err(|e| {
278            mlua::Error::runtime(format!(
279                "fs.rename: failed to rename {src:?} to {dst:?}: {e}"
280            ))
281        })
282    })?;
283    fs_table.set("rename", rename_fn)?;
284
285    // fs.glob(pattern) — glob pattern matching, returns array of path strings
286    let glob_fn = lua.create_function(|lua, pattern: String| {
287        let paths = glob::glob(&pattern).map_err(|e| {
288            mlua::Error::runtime(format!("fs.glob: invalid pattern {pattern:?}: {e}"))
289        })?;
290        let results = lua.create_table()?;
291        for (i, entry) in (1..).zip(paths) {
292            let path = entry
293                .map_err(|e| mlua::Error::runtime(format!("fs.glob: error reading entry: {e}")))?;
294            results.set(i, path.to_string_lossy().to_string())?;
295        }
296        Ok(results)
297    })?;
298    fs_table.set("glob", glob_fn)?;
299
300    // fs.tempdir() — create a temporary directory, returns path string
301    let tempdir_fn = lua.create_function(|_, ()| {
302        let base = std::env::temp_dir();
303        let nanos: u64 = std::time::SystemTime::now()
304            .duration_since(std::time::UNIX_EPOCH)
305            .unwrap_or_default()
306            .as_nanos() as u64;
307        let seq = TEMPDIR_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
308        let dir = base.join(format!("assay-{nanos:x}-{seq}"));
309        std::fs::create_dir_all(&dir).map_err(|e| {
310            mlua::Error::runtime(format!("fs.tempdir: failed to create {dir:?}: {e}"))
311        })?;
312        Ok(dir.to_string_lossy().to_string())
313    })?;
314    fs_table.set("tempdir", tempdir_fn)?;
315
316    // fs.chmod(path, mode) — set file permissions (octal integer, e.g. 493 = 0o755)
317    let chmod_fn = lua.create_function(|_, (path, mode): (String, u32)| {
318        let perms = std::fs::Permissions::from_mode(mode);
319        std::fs::set_permissions(&path, perms)
320            .map_err(|e| mlua::Error::runtime(format!("fs.chmod: failed to chmod {path:?}: {e}")))
321    })?;
322    fs_table.set("chmod", chmod_fn)?;
323
324    // fs.readdir(path, opts?) — recursive directory listing
325    // opts: { depth = N } — max recursion depth (nil = unlimited)
326    let readdir_fn = lua.create_function(|lua, args: mlua::MultiValue| {
327        let mut args_iter = args.into_iter();
328        let path: String = args_iter
329            .next()
330            .ok_or_else(|| mlua::Error::runtime("fs.readdir: path required"))
331            .and_then(|v| lua.unpack(v))?;
332
333        let max_depth: Option<usize> = if let Some(Value::Table(opts)) = args_iter.next() {
334            opts.get::<Option<usize>>("depth")?
335        } else {
336            None
337        };
338
339        let results = lua.create_table()?;
340        let mut i = 1u64;
341        let base = std::path::PathBuf::from(&path);
342
343        fn walk(
344            base: &std::path::Path,
345            dir: &std::path::Path,
346            results: &mlua::Table,
347            lua: &Lua,
348            i: &mut u64,
349            depth: usize,
350            max_depth: Option<usize>,
351        ) -> mlua::Result<()> {
352            let entries = std::fs::read_dir(dir).map_err(|e| {
353                mlua::Error::runtime(format!("fs.readdir: failed to read {dir:?}: {e}"))
354            })?;
355            for entry in entries {
356                let entry = entry.map_err(|e| {
357                    mlua::Error::runtime(format!("fs.readdir: error reading entry: {e}"))
358                })?;
359                let file_type = entry.file_type().map_err(|e| {
360                    mlua::Error::runtime(format!("fs.readdir: failed to get file type: {e}"))
361                })?;
362                let rel_path = entry
363                    .path()
364                    .strip_prefix(base)
365                    .unwrap_or(&entry.path())
366                    .to_string_lossy()
367                    .to_string();
368                let info = lua.create_table()?;
369                info.set("path", rel_path)?;
370                if file_type.is_dir() {
371                    info.set("type", "directory")?;
372                } else if file_type.is_symlink() {
373                    info.set("type", "symlink")?;
374                } else {
375                    info.set("type", "file")?;
376                }
377                results.set(*i, info)?;
378                *i += 1;
379                if file_type.is_dir() && (max_depth.is_none() || depth < max_depth.unwrap()) {
380                    walk(base, &entry.path(), results, lua, i, depth + 1, max_depth)?;
381                }
382            }
383            Ok(())
384        }
385
386        walk(&base, &base, &results, lua, &mut i, 1, max_depth)?;
387        Ok(results)
388    })?;
389    fs_table.set("readdir", readdir_fn)?;
390
391    // fs.lines(path) — stateful iterator yielding one line per call.
392    // Designed for `for line in fs.lines(path) do ... end`. Streams via
393    // BufReader so multi-GB files don't land in memory. Lines are
394    // stripped of their trailing `\n` (and `\r\n` on Windows files).
395    // Returns an iterator function; Lua's for-loop calls it until nil.
396    let lines_fn = lua.create_function(|lua, path: String| {
397        use std::io::BufRead;
398        let file = std::fs::File::open(&path).map_err(|e| {
399            mlua::Error::runtime(format!("fs.lines: failed to open {path:?}: {e}"))
400        })?;
401        let iter = std::sync::Arc::new(std::sync::Mutex::new(
402            std::io::BufReader::new(file).lines(),
403        ));
404        lua.create_function(move |_, ()| {
405            let mut it = iter
406                .lock()
407                .map_err(|e| mlua::Error::runtime(format!("fs.lines: lock poisoned: {e}")))?;
408            match it.next() {
409                Some(Ok(line)) => Ok(Some(line)),
410                Some(Err(e)) => Err(mlua::Error::runtime(format!("fs.lines: read error: {e}"))),
411                None => Ok(None),
412            }
413        })
414    })?;
415    fs_table.set("lines", lines_fn)?;
416
417    // fs.sub_in_file(path, pattern, repl) — `sed -i` equivalent.
418    // In-place search-and-replace using Lua's native pattern engine
419    // (same semantics as string.gsub, including %0-%9 backreferences
420    // and function replacements). Reads the file, substitutes, and
421    // only writes back if at least one match was made — so repeated
422    // calls on an already-substituted file are a no-op on disk.
423    // Returns the count of substitutions.
424    let sub_in_file_fn =
425        lua.create_function(|lua, (path, pattern, repl): (String, String, mlua::Value)| {
426            let content = std::fs::read_to_string(&path).map_err(|e| {
427                mlua::Error::runtime(format!(
428                    "fs.sub_in_file: failed to read {path:?}: {e}"
429                ))
430            })?;
431            let string_table: mlua::Table = lua.globals().get("string")?;
432            let gsub: mlua::Function = string_table.get("gsub")?;
433            let (new_content, count): (String, u64) = gsub.call((content, pattern, repl))?;
434            if count > 0 {
435                std::fs::write(&path, &new_content).map_err(|e| {
436                    mlua::Error::runtime(format!(
437                        "fs.sub_in_file: failed to write {path:?}: {e}"
438                    ))
439                })?;
440            }
441            Ok(count)
442        })?;
443    fs_table.set("sub_in_file", sub_in_file_fn)?;
444
445    lua.globals().set("fs", fs_table)?;
446    Ok(())
447}
448
449pub fn register_string_helpers(lua: &Lua) -> mlua::Result<()> {
450    let string_table: mlua::Table = lua.globals().get("string")?;
451
452    // string.split(s, sep?) — awk-style field split.
453    // When `sep` is nil or empty: splits on any run of whitespace and
454    // skips leading/trailing empty fields (matches awk default FS and
455    // Python's str.split() with no arg). When `sep` is provided: splits
456    // on the literal string (not a Lua pattern — use string.gmatch if
457    // you need pattern semantics). Returns a 1-indexed array table.
458    let split_fn = lua.create_function(|lua, args: mlua::MultiValue| {
459        let mut args_iter = args.into_iter();
460        let s: String = args_iter
461            .next()
462            .ok_or_else(|| mlua::Error::runtime("string.split: string required"))
463            .and_then(|v| lua.unpack(v))?;
464        let sep: Option<String> = match args_iter.next() {
465            Some(mlua::Value::Nil) | None => None,
466            Some(v) => Some(lua.unpack(v)?),
467        };
468        let results = lua.create_table()?;
469        match sep {
470            Some(ref sep_str) if !sep_str.is_empty() => {
471                for (i, part) in (1..).zip(s.split(sep_str.as_str())) {
472                    results.set(i, part)?;
473                }
474            }
475            _ => {
476                for (i, part) in (1..).zip(s.split_whitespace()) {
477                    results.set(i, part)?;
478                }
479            }
480        }
481        Ok(results)
482    })?;
483    string_table.set("split", split_fn)?;
484
485    Ok(())
486}
487
488pub fn register_base64(lua: &Lua) -> mlua::Result<()> {
489    let b64_table = lua.create_table()?;
490
491    let encode_fn = lua.create_function(|_, input: String| Ok(BASE64.encode(input.as_bytes())))?;
492    b64_table.set("encode", encode_fn)?;
493
494    let decode_fn = lua.create_function(|_, input: String| {
495        let bytes = BASE64
496            .decode(input.as_bytes())
497            .map_err(|e| mlua::Error::runtime(format!("base64.decode: {e}")))?;
498        String::from_utf8(bytes)
499            .map_err(|e| mlua::Error::runtime(format!("base64.decode: invalid UTF-8: {e}")))
500    })?;
501    b64_table.set("decode", decode_fn)?;
502
503    lua.globals().set("base64", b64_table)?;
504    Ok(())
505}
506
507pub fn register_regex(lua: &Lua) -> mlua::Result<()> {
508    let regex_table = lua.create_table()?;
509
510    let match_fn = lua.create_function(|_, (text, pattern): (String, String)| {
511        let re = regex_lite::Regex::new(&pattern)
512            .map_err(|e| mlua::Error::runtime(format!("regex.match: invalid pattern: {e}")))?;
513        Ok(re.is_match(&text))
514    })?;
515    regex_table.set("match", match_fn)?;
516
517    let find_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
518        let re = regex_lite::Regex::new(&pattern)
519            .map_err(|e| mlua::Error::runtime(format!("regex.find: invalid pattern: {e}")))?;
520        match re.captures(&text) {
521            Some(caps) => {
522                let result = lua.create_table()?;
523                let full_match = caps.get(0).map(|m| m.as_str()).unwrap_or("");
524                result.set("match", full_match.to_string())?;
525                let groups = lua.create_table()?;
526                for i in 1..caps.len() {
527                    if let Some(m) = caps.get(i) {
528                        groups.set(i, m.as_str().to_string())?;
529                    }
530                }
531                result.set("groups", groups)?;
532                Ok(Value::Table(result))
533            }
534            None => Ok(Value::Nil),
535        }
536    })?;
537    regex_table.set("find", find_fn)?;
538
539    let find_all_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
540        let re = regex_lite::Regex::new(&pattern)
541            .map_err(|e| mlua::Error::runtime(format!("regex.find_all: invalid pattern: {e}")))?;
542        let results = lua.create_table()?;
543        for (i, m) in re.find_iter(&text).enumerate() {
544            results.set(i + 1, m.as_str().to_string())?;
545        }
546        Ok(results)
547    })?;
548    regex_table.set("find_all", find_all_fn)?;
549
550    let replace_fn = lua.create_function(
551        |_, (text, pattern, replacement): (String, String, String)| {
552            let re = regex_lite::Regex::new(&pattern).map_err(|e| {
553                mlua::Error::runtime(format!("regex.replace: invalid pattern: {e}"))
554            })?;
555            Ok(re.replace_all(&text, replacement.as_str()).into_owned())
556        },
557    )?;
558    regex_table.set("replace", replace_fn)?;
559
560    lua.globals().set("regex", regex_table)?;
561    Ok(())
562}
563
564pub fn register_async(lua: &Lua) -> mlua::Result<()> {
565    let async_table = lua.create_table()?;
566
567    let spawn_fn = lua.create_async_function(|lua, func: mlua::Function| async move {
568        let thread = lua.create_thread(func)?;
569        let async_thread = thread.into_async::<mlua::MultiValue>(())?;
570        let join_handle: tokio::task::JoinHandle<Result<Vec<Value>, String>> =
571            tokio::task::spawn_local(async move {
572                let values = async_thread.await.map_err(|e| e.to_string())?;
573                Ok(values.into_vec())
574            });
575
576        let handle = lua.create_table()?;
577        let cell = std::rc::Rc::new(std::cell::RefCell::new(Some(join_handle)));
578        let cell_clone = cell.clone();
579
580        let await_fn = lua.create_async_function(move |lua, ()| {
581            let cell = cell_clone.clone();
582            async move {
583                let join_handle = cell
584                    .borrow_mut()
585                    .take()
586                    .ok_or_else(|| mlua::Error::runtime("async handle already awaited"))?;
587                let result = join_handle.await.map_err(|e| {
588                    mlua::Error::runtime(format!("async.spawn: task panicked: {e}"))
589                })?;
590                match result {
591                    Ok(values) => {
592                        let tbl = lua.create_table()?;
593                        for (i, v) in values.into_iter().enumerate() {
594                            tbl.set(i + 1, v)?;
595                        }
596                        Ok(Value::Table(tbl))
597                    }
598                    Err(msg) => Err(mlua::Error::runtime(msg)),
599                }
600            }
601        })?;
602        handle.set("await", await_fn)?;
603
604        Ok(handle)
605    })?;
606    async_table.set("spawn", spawn_fn)?;
607
608    let spawn_interval_fn =
609        lua.create_async_function(|lua, (seconds, func): (f64, mlua::Function)| async move {
610            if seconds <= 0.0 {
611                return Err(mlua::Error::runtime(
612                    "async.spawn_interval: interval must be positive",
613                ));
614            }
615
616            let cancel = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
617            let cancel_clone = cancel.clone();
618
619            tokio::task::spawn_local({
620                let cancel = cancel_clone.clone();
621                async move {
622                    let mut interval =
623                        tokio::time::interval(std::time::Duration::from_secs_f64(seconds));
624                    interval.tick().await;
625                    loop {
626                        interval.tick().await;
627                        if cancel.load(std::sync::atomic::Ordering::Relaxed) {
628                            break;
629                        }
630                        if let Err(e) = func.call_async::<()>(()).await {
631                            error!("async.spawn_interval: callback error: {e}");
632                            break;
633                        }
634                    }
635                }
636            });
637
638            let handle = lua.create_table()?;
639            let cancel_fn = lua.create_function(move |_, ()| {
640                cancel.store(true, std::sync::atomic::Ordering::Relaxed);
641                Ok(())
642            })?;
643            handle.set("cancel", cancel_fn)?;
644
645            Ok(handle)
646        })?;
647    async_table.set("spawn_interval", spawn_interval_fn)?;
648
649    lua.globals().set("async", async_table)?;
650    Ok(())
651}
652
653#[cfg(test)]
654mod tests {
655    use data_encoding::BASE64;
656
657    #[test]
658    fn test_base64_roundtrip() {
659        let input = "hello world";
660        let encoded = BASE64.encode(input.as_bytes());
661        assert_eq!(encoded, "aGVsbG8gd29ybGQ=");
662        let decoded = BASE64.decode(encoded.as_bytes()).unwrap();
663        assert_eq!(String::from_utf8(decoded).unwrap(), input);
664    }
665
666    #[test]
667    fn test_base64_empty() {
668        let encoded = BASE64.encode(b"");
669        assert_eq!(encoded, "");
670        let decoded = BASE64.decode(b"").unwrap();
671        assert!(decoded.is_empty());
672    }
673}