1use data_encoding::BASE64;
2use mlua::{Lua, Value};
3use std::os::unix::fs::PermissionsExt;
4use std::time::{SystemTime, UNIX_EPOCH};
5use tracing::{error, info, warn};
6
7static TEMPDIR_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
8
9pub fn register_log(lua: &Lua) -> mlua::Result<()> {
10 let log_table = lua.create_table()?;
11
12 let info_fn = lua.create_function(|_, msg: String| {
13 info!(target: "lua", "{}", msg);
14 Ok(())
15 })?;
16 log_table.set("info", info_fn)?;
17
18 let warn_fn = lua.create_function(|_, msg: String| {
19 warn!(target: "lua", "{}", msg);
20 Ok(())
21 })?;
22 log_table.set("warn", warn_fn)?;
23
24 let error_fn = lua.create_function(|_, msg: String| {
25 error!(target: "lua", "{}", msg);
26 Ok(())
27 })?;
28 log_table.set("error", error_fn)?;
29
30 lua.globals().set("log", log_table)?;
31 Ok(())
32}
33
34pub fn register_env(lua: &Lua) -> mlua::Result<()> {
35 let env_table = lua.create_table()?;
36
37 let process_get_fn = lua.create_function(|_, name: String| match std::env::var(&name) {
38 Ok(val) => Ok(Some(val)),
39 Err(_) => Ok(None),
40 })?;
41 env_table.set("_process_get", process_get_fn)?;
42 env_table.set("_check_env", lua.create_table()?)?;
43
44 lua.globals().set("env", env_table)?;
45
46 lua.load(
47 r#"
48 function env.get(name)
49 local val = env._check_env[name]
50 if val ~= nil then return val end
51 return env._process_get(name)
52 end
53 "#,
54 )
55 .exec()?;
56
57 let set_fn = lua.create_function(|_, (key, val): (String, Option<String>)| {
59 match val {
60 Some(v) => unsafe { std::env::set_var(&key, &v) },
61 None => unsafe { std::env::remove_var(&key) },
62 }
63 Ok(())
64 })?;
65 lua.globals()
66 .get::<mlua::Table>("env")?
67 .set("set", set_fn)?;
68
69 let list_fn = lua.create_function(|lua, ()| {
71 let results = lua.create_table()?;
72 for (i, (key, val)) in (1..).zip(std::env::vars()) {
73 let entry = lua.create_table()?;
74 entry.set("key", key)?;
75 entry.set("value", val)?;
76 results.set(i, entry)?;
77 }
78 Ok(results)
79 })?;
80 lua.globals()
81 .get::<mlua::Table>("env")?
82 .set("list", list_fn)?;
83
84 Ok(())
85}
86
87pub fn register_sleep(lua: &Lua) -> mlua::Result<()> {
88 let sleep_fn = lua.create_async_function(|_, seconds: f64| async move {
89 let duration = std::time::Duration::from_secs_f64(seconds);
90 tokio::time::sleep(duration).await;
91 Ok(())
92 })?;
93 lua.globals().set("sleep", sleep_fn)?;
94 Ok(())
95}
96
97pub fn register_time(lua: &Lua) -> mlua::Result<()> {
98 let time_fn = lua.create_function(|_, ()| {
99 let secs = SystemTime::now()
100 .duration_since(UNIX_EPOCH)
101 .map_err(|e| mlua::Error::runtime(format!("time(): {e}")))?
102 .as_secs_f64();
103 Ok(secs)
104 })?;
105 lua.globals().set("time", time_fn)?;
106 Ok(())
107}
108
109pub fn register_fs(lua: &Lua) -> mlua::Result<()> {
110 use crate::lua::file_source::FileSourceHandle;
111
112 let fs_table = lua.create_table()?;
113
114 let read_fn = lua.create_function(|lua, path: String| -> mlua::Result<String> {
115 let bytes = match lua.app_data_ref::<FileSourceHandle>() {
116 Some(source) => source.read(&path).ok_or_else(|| {
117 mlua::Error::runtime(format!(
118 "fs.read: failed to read {path:?}: not found in file source"
119 ))
120 })?,
121 None => std::fs::read(&path).map_err(|e| {
122 mlua::Error::runtime(format!("fs.read: failed to read {path:?}: {e}"))
123 })?,
124 };
125 String::from_utf8(bytes).map_err(|e| {
126 mlua::Error::runtime(format!("fs.read: invalid UTF-8 in {path:?}: {e}"))
127 })
128 })?;
129 fs_table.set("read", read_fn)?;
130
131 let read_bytes_fn = lua.create_function(|lua, path: String| {
133 let bytes = match lua.app_data_ref::<FileSourceHandle>() {
134 Some(source) => source.read(&path).ok_or_else(|| {
135 mlua::Error::runtime(format!(
136 "fs.read_bytes: failed to read {path:?}: not found in file source"
137 ))
138 })?,
139 None => std::fs::read(&path).map_err(|e| {
140 mlua::Error::runtime(format!(
141 "fs.read_bytes: failed to read {path:?}: {e}"
142 ))
143 })?,
144 };
145 lua.create_string(&bytes)
146 })?;
147 fs_table.set("read_bytes", read_bytes_fn)?;
148
149 let write_fn = lua.create_function(|_, (path, content): (String, String)| {
150 let p = std::path::Path::new(&path);
151 if let Some(parent) = p.parent() {
152 std::fs::create_dir_all(parent).map_err(|e| {
153 mlua::Error::runtime(format!(
154 "fs.write: failed to create directories for {path:?}: {e}"
155 ))
156 })?;
157 }
158 std::fs::write(&path, &content)
159 .map_err(|e| mlua::Error::runtime(format!("fs.write: failed to write {path:?}: {e}")))
160 })?;
161 fs_table.set("write", write_fn)?;
162
163 let write_bytes_fn = lua.create_function(|_, (path, data): (String, mlua::String)| {
165 let p = std::path::Path::new(&path);
166 if let Some(parent) = p.parent() {
167 std::fs::create_dir_all(parent).map_err(|e| {
168 mlua::Error::runtime(format!(
169 "fs.write_bytes: failed to create directories for {path:?}: {e}"
170 ))
171 })?;
172 }
173 std::fs::write(&path, data.as_bytes())
174 .map_err(|e| mlua::Error::runtime(format!("fs.write_bytes: failed to write {path:?}: {e}")))
175 })?;
176 fs_table.set("write_bytes", write_bytes_fn)?;
177
178 let remove_fn = lua.create_function(|_, path: String| {
179 let p = std::path::Path::new(&path);
180 let is_dir = match std::fs::symlink_metadata(&path) {
184 Ok(m) => m.file_type().is_dir(),
185 Err(_) => p.is_dir(),
186 };
187 if is_dir {
188 std::fs::remove_dir_all(&path).map_err(|e| {
189 mlua::Error::runtime(format!(
190 "fs.remove: failed to remove directory {path:?}: {e}"
191 ))
192 })
193 } else {
194 std::fs::remove_file(&path).map_err(|e| {
195 mlua::Error::runtime(format!("fs.remove: failed to remove {path:?}: {e}"))
196 })
197 }
198 })?;
199 fs_table.set("remove", remove_fn)?;
200
201 let list_fn =
202 lua.create_function(|lua, path: String| {
203 let entries = lua.create_table()?;
204 for (i, entry) in (1..).zip(std::fs::read_dir(&path).map_err(|e| {
205 mlua::Error::runtime(format!("fs.list: failed to list {path:?}: {e}"))
206 })?) {
207 let entry = entry.map_err(|e| {
208 mlua::Error::runtime(format!("fs.list: error reading entry in {path:?}: {e}"))
209 })?;
210 let info = lua.create_table()?;
211 let name = entry.file_name().to_string_lossy().to_string();
212 info.set("name", name)?;
213 let file_type = entry.file_type().map_err(|e| {
214 mlua::Error::runtime(format!("fs.list: failed to get file type: {e}"))
215 })?;
216 if file_type.is_dir() {
217 info.set("type", "directory")?;
218 } else if file_type.is_symlink() {
219 info.set("type", "symlink")?;
220 } else {
221 info.set("type", "file")?;
222 }
223 entries.set(i, info)?;
224 }
225 Ok(entries)
226 })?;
227 fs_table.set("list", list_fn)?;
228
229 let stat_fn = lua.create_function(|lua, path: String| {
230 let metadata = std::fs::metadata(&path)
231 .map_err(|e| mlua::Error::runtime(format!("fs.stat: failed to stat {path:?}: {e}")))?;
232 let is_symlink = std::fs::symlink_metadata(&path)
235 .map(|m| m.file_type().is_symlink())
236 .unwrap_or(false);
237 let info = lua.create_table()?;
238 info.set("size", metadata.len())?;
239 info.set("is_file", metadata.is_file())?;
240 info.set("is_dir", metadata.is_dir())?;
241 info.set("is_symlink", is_symlink)?;
242 if let Ok(modified) = metadata.modified()
243 && let Ok(duration) = modified.duration_since(std::time::UNIX_EPOCH)
244 {
245 info.set("modified", duration.as_secs_f64())?;
246 }
247 if let Ok(created) = metadata.created()
248 && let Ok(duration) = created.duration_since(std::time::UNIX_EPOCH)
249 {
250 info.set("created", duration.as_secs_f64())?;
251 }
252 Ok(info)
253 })?;
254 fs_table.set("stat", stat_fn)?;
255
256 let mkdir_fn = lua.create_function(|_, path: String| {
257 std::fs::create_dir_all(&path)
258 .map_err(|e| mlua::Error::runtime(format!("fs.mkdir: failed to create {path:?}: {e}")))
259 })?;
260 fs_table.set("mkdir", mkdir_fn)?;
261
262 let exists_fn =
263 lua.create_function(|_, path: String| Ok(std::path::Path::new(&path).exists()))?;
264 fs_table.set("exists", exists_fn)?;
265
266 let copy_fn = lua.create_function(|_, (src, dst): (String, String)| {
268 let bytes = std::fs::copy(&src, &dst).map_err(|e| {
269 mlua::Error::runtime(format!("fs.copy: failed to copy {src:?} to {dst:?}: {e}"))
270 })?;
271 Ok(bytes)
272 })?;
273 fs_table.set("copy", copy_fn)?;
274
275 let rename_fn = lua.create_function(|_, (src, dst): (String, String)| {
277 std::fs::rename(&src, &dst).map_err(|e| {
278 mlua::Error::runtime(format!(
279 "fs.rename: failed to rename {src:?} to {dst:?}: {e}"
280 ))
281 })
282 })?;
283 fs_table.set("rename", rename_fn)?;
284
285 let glob_fn = lua.create_function(|lua, pattern: String| {
287 let paths = glob::glob(&pattern).map_err(|e| {
288 mlua::Error::runtime(format!("fs.glob: invalid pattern {pattern:?}: {e}"))
289 })?;
290 let results = lua.create_table()?;
291 for (i, entry) in (1..).zip(paths) {
292 let path = entry
293 .map_err(|e| mlua::Error::runtime(format!("fs.glob: error reading entry: {e}")))?;
294 results.set(i, path.to_string_lossy().to_string())?;
295 }
296 Ok(results)
297 })?;
298 fs_table.set("glob", glob_fn)?;
299
300 let tempdir_fn = lua.create_function(|_, ()| {
302 let base = std::env::temp_dir();
303 let nanos: u64 = std::time::SystemTime::now()
304 .duration_since(std::time::UNIX_EPOCH)
305 .unwrap_or_default()
306 .as_nanos() as u64;
307 let seq = TEMPDIR_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
308 let dir = base.join(format!("assay-{nanos:x}-{seq}"));
309 std::fs::create_dir_all(&dir).map_err(|e| {
310 mlua::Error::runtime(format!("fs.tempdir: failed to create {dir:?}: {e}"))
311 })?;
312 Ok(dir.to_string_lossy().to_string())
313 })?;
314 fs_table.set("tempdir", tempdir_fn)?;
315
316 let chmod_fn = lua.create_function(|_, (path, mode): (String, u32)| {
318 let perms = std::fs::Permissions::from_mode(mode);
319 std::fs::set_permissions(&path, perms)
320 .map_err(|e| mlua::Error::runtime(format!("fs.chmod: failed to chmod {path:?}: {e}")))
321 })?;
322 fs_table.set("chmod", chmod_fn)?;
323
324 let readdir_fn = lua.create_function(|lua, args: mlua::MultiValue| {
327 let mut args_iter = args.into_iter();
328 let path: String = args_iter
329 .next()
330 .ok_or_else(|| mlua::Error::runtime("fs.readdir: path required"))
331 .and_then(|v| lua.unpack(v))?;
332
333 let max_depth: Option<usize> = if let Some(Value::Table(opts)) = args_iter.next() {
334 opts.get::<Option<usize>>("depth")?
335 } else {
336 None
337 };
338
339 let results = lua.create_table()?;
340 let mut i = 1u64;
341 let base = std::path::PathBuf::from(&path);
342
343 fn walk(
344 base: &std::path::Path,
345 dir: &std::path::Path,
346 results: &mlua::Table,
347 lua: &Lua,
348 i: &mut u64,
349 depth: usize,
350 max_depth: Option<usize>,
351 ) -> mlua::Result<()> {
352 let entries = std::fs::read_dir(dir).map_err(|e| {
353 mlua::Error::runtime(format!("fs.readdir: failed to read {dir:?}: {e}"))
354 })?;
355 for entry in entries {
356 let entry = entry.map_err(|e| {
357 mlua::Error::runtime(format!("fs.readdir: error reading entry: {e}"))
358 })?;
359 let file_type = entry.file_type().map_err(|e| {
360 mlua::Error::runtime(format!("fs.readdir: failed to get file type: {e}"))
361 })?;
362 let rel_path = entry
363 .path()
364 .strip_prefix(base)
365 .unwrap_or(&entry.path())
366 .to_string_lossy()
367 .to_string();
368 let info = lua.create_table()?;
369 info.set("path", rel_path)?;
370 if file_type.is_dir() {
371 info.set("type", "directory")?;
372 } else if file_type.is_symlink() {
373 info.set("type", "symlink")?;
374 } else {
375 info.set("type", "file")?;
376 }
377 results.set(*i, info)?;
378 *i += 1;
379 if file_type.is_dir() && (max_depth.is_none() || depth < max_depth.unwrap()) {
380 walk(base, &entry.path(), results, lua, i, depth + 1, max_depth)?;
381 }
382 }
383 Ok(())
384 }
385
386 walk(&base, &base, &results, lua, &mut i, 1, max_depth)?;
387 Ok(results)
388 })?;
389 fs_table.set("readdir", readdir_fn)?;
390
391 let lines_fn = lua.create_function(|lua, path: String| {
397 use std::io::BufRead;
398 let file = std::fs::File::open(&path).map_err(|e| {
399 mlua::Error::runtime(format!("fs.lines: failed to open {path:?}: {e}"))
400 })?;
401 let iter = std::sync::Arc::new(std::sync::Mutex::new(
402 std::io::BufReader::new(file).lines(),
403 ));
404 lua.create_function(move |_, ()| {
405 let mut it = iter
406 .lock()
407 .map_err(|e| mlua::Error::runtime(format!("fs.lines: lock poisoned: {e}")))?;
408 match it.next() {
409 Some(Ok(line)) => Ok(Some(line)),
410 Some(Err(e)) => Err(mlua::Error::runtime(format!("fs.lines: read error: {e}"))),
411 None => Ok(None),
412 }
413 })
414 })?;
415 fs_table.set("lines", lines_fn)?;
416
417 let sub_in_file_fn =
425 lua.create_function(|lua, (path, pattern, repl): (String, String, mlua::Value)| {
426 let content = std::fs::read_to_string(&path).map_err(|e| {
427 mlua::Error::runtime(format!(
428 "fs.sub_in_file: failed to read {path:?}: {e}"
429 ))
430 })?;
431 let string_table: mlua::Table = lua.globals().get("string")?;
432 let gsub: mlua::Function = string_table.get("gsub")?;
433 let (new_content, count): (String, u64) = gsub.call((content, pattern, repl))?;
434 if count > 0 {
435 std::fs::write(&path, &new_content).map_err(|e| {
436 mlua::Error::runtime(format!(
437 "fs.sub_in_file: failed to write {path:?}: {e}"
438 ))
439 })?;
440 }
441 Ok(count)
442 })?;
443 fs_table.set("sub_in_file", sub_in_file_fn)?;
444
445 lua.globals().set("fs", fs_table)?;
446 Ok(())
447}
448
449pub fn register_string_helpers(lua: &Lua) -> mlua::Result<()> {
450 let string_table: mlua::Table = lua.globals().get("string")?;
451
452 let split_fn = lua.create_function(|lua, args: mlua::MultiValue| {
459 let mut args_iter = args.into_iter();
460 let s: String = args_iter
461 .next()
462 .ok_or_else(|| mlua::Error::runtime("string.split: string required"))
463 .and_then(|v| lua.unpack(v))?;
464 let sep: Option<String> = match args_iter.next() {
465 Some(mlua::Value::Nil) | None => None,
466 Some(v) => Some(lua.unpack(v)?),
467 };
468 let results = lua.create_table()?;
469 match sep {
470 Some(ref sep_str) if !sep_str.is_empty() => {
471 for (i, part) in (1..).zip(s.split(sep_str.as_str())) {
472 results.set(i, part)?;
473 }
474 }
475 _ => {
476 for (i, part) in (1..).zip(s.split_whitespace()) {
477 results.set(i, part)?;
478 }
479 }
480 }
481 Ok(results)
482 })?;
483 string_table.set("split", split_fn)?;
484
485 Ok(())
486}
487
488pub fn register_base64(lua: &Lua) -> mlua::Result<()> {
489 let b64_table = lua.create_table()?;
490
491 let encode_fn = lua.create_function(|_, input: String| Ok(BASE64.encode(input.as_bytes())))?;
492 b64_table.set("encode", encode_fn)?;
493
494 let decode_fn = lua.create_function(|_, input: String| {
495 let bytes = BASE64
496 .decode(input.as_bytes())
497 .map_err(|e| mlua::Error::runtime(format!("base64.decode: {e}")))?;
498 String::from_utf8(bytes)
499 .map_err(|e| mlua::Error::runtime(format!("base64.decode: invalid UTF-8: {e}")))
500 })?;
501 b64_table.set("decode", decode_fn)?;
502
503 lua.globals().set("base64", b64_table)?;
504 Ok(())
505}
506
507pub fn register_regex(lua: &Lua) -> mlua::Result<()> {
508 let regex_table = lua.create_table()?;
509
510 let match_fn = lua.create_function(|_, (text, pattern): (String, String)| {
511 let re = regex_lite::Regex::new(&pattern)
512 .map_err(|e| mlua::Error::runtime(format!("regex.match: invalid pattern: {e}")))?;
513 Ok(re.is_match(&text))
514 })?;
515 regex_table.set("match", match_fn)?;
516
517 let find_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
518 let re = regex_lite::Regex::new(&pattern)
519 .map_err(|e| mlua::Error::runtime(format!("regex.find: invalid pattern: {e}")))?;
520 match re.captures(&text) {
521 Some(caps) => {
522 let result = lua.create_table()?;
523 let full_match = caps.get(0).map(|m| m.as_str()).unwrap_or("");
524 result.set("match", full_match.to_string())?;
525 let groups = lua.create_table()?;
526 for i in 1..caps.len() {
527 if let Some(m) = caps.get(i) {
528 groups.set(i, m.as_str().to_string())?;
529 }
530 }
531 result.set("groups", groups)?;
532 Ok(Value::Table(result))
533 }
534 None => Ok(Value::Nil),
535 }
536 })?;
537 regex_table.set("find", find_fn)?;
538
539 let find_all_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
540 let re = regex_lite::Regex::new(&pattern)
541 .map_err(|e| mlua::Error::runtime(format!("regex.find_all: invalid pattern: {e}")))?;
542 let results = lua.create_table()?;
543 for (i, m) in re.find_iter(&text).enumerate() {
544 results.set(i + 1, m.as_str().to_string())?;
545 }
546 Ok(results)
547 })?;
548 regex_table.set("find_all", find_all_fn)?;
549
550 let replace_fn = lua.create_function(
551 |_, (text, pattern, replacement): (String, String, String)| {
552 let re = regex_lite::Regex::new(&pattern).map_err(|e| {
553 mlua::Error::runtime(format!("regex.replace: invalid pattern: {e}"))
554 })?;
555 Ok(re.replace_all(&text, replacement.as_str()).into_owned())
556 },
557 )?;
558 regex_table.set("replace", replace_fn)?;
559
560 lua.globals().set("regex", regex_table)?;
561 Ok(())
562}
563
564pub fn register_async(lua: &Lua) -> mlua::Result<()> {
565 let async_table = lua.create_table()?;
566
567 let spawn_fn = lua.create_async_function(|lua, func: mlua::Function| async move {
568 let thread = lua.create_thread(func)?;
569 let async_thread = thread.into_async::<mlua::MultiValue>(())?;
570 let join_handle: tokio::task::JoinHandle<Result<Vec<Value>, String>> =
571 tokio::task::spawn_local(async move {
572 let values = async_thread.await.map_err(|e| e.to_string())?;
573 Ok(values.into_vec())
574 });
575
576 let handle = lua.create_table()?;
577 let cell = std::rc::Rc::new(std::cell::RefCell::new(Some(join_handle)));
578 let cell_clone = cell.clone();
579
580 let await_fn = lua.create_async_function(move |lua, ()| {
581 let cell = cell_clone.clone();
582 async move {
583 let join_handle = cell
584 .borrow_mut()
585 .take()
586 .ok_or_else(|| mlua::Error::runtime("async handle already awaited"))?;
587 let result = join_handle.await.map_err(|e| {
588 mlua::Error::runtime(format!("async.spawn: task panicked: {e}"))
589 })?;
590 match result {
591 Ok(values) => {
592 let tbl = lua.create_table()?;
593 for (i, v) in values.into_iter().enumerate() {
594 tbl.set(i + 1, v)?;
595 }
596 Ok(Value::Table(tbl))
597 }
598 Err(msg) => Err(mlua::Error::runtime(msg)),
599 }
600 }
601 })?;
602 handle.set("await", await_fn)?;
603
604 Ok(handle)
605 })?;
606 async_table.set("spawn", spawn_fn)?;
607
608 let spawn_interval_fn =
609 lua.create_async_function(|lua, (seconds, func): (f64, mlua::Function)| async move {
610 if seconds <= 0.0 {
611 return Err(mlua::Error::runtime(
612 "async.spawn_interval: interval must be positive",
613 ));
614 }
615
616 let cancel = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
617 let cancel_clone = cancel.clone();
618
619 tokio::task::spawn_local({
620 let cancel = cancel_clone.clone();
621 async move {
622 let mut interval =
623 tokio::time::interval(std::time::Duration::from_secs_f64(seconds));
624 interval.tick().await;
625 loop {
626 interval.tick().await;
627 if cancel.load(std::sync::atomic::Ordering::Relaxed) {
628 break;
629 }
630 if let Err(e) = func.call_async::<()>(()).await {
631 error!("async.spawn_interval: callback error: {e}");
632 break;
633 }
634 }
635 }
636 });
637
638 let handle = lua.create_table()?;
639 let cancel_fn = lua.create_function(move |_, ()| {
640 cancel.store(true, std::sync::atomic::Ordering::Relaxed);
641 Ok(())
642 })?;
643 handle.set("cancel", cancel_fn)?;
644
645 Ok(handle)
646 })?;
647 async_table.set("spawn_interval", spawn_interval_fn)?;
648
649 lua.globals().set("async", async_table)?;
650 Ok(())
651}
652
653#[cfg(test)]
654mod tests {
655 use data_encoding::BASE64;
656
657 #[test]
658 fn test_base64_roundtrip() {
659 let input = "hello world";
660 let encoded = BASE64.encode(input.as_bytes());
661 assert_eq!(encoded, "aGVsbG8gd29ybGQ=");
662 let decoded = BASE64.decode(encoded.as_bytes()).unwrap();
663 assert_eq!(String::from_utf8(decoded).unwrap(), input);
664 }
665
666 #[test]
667 fn test_base64_empty() {
668 let encoded = BASE64.encode(b"");
669 assert_eq!(encoded, "");
670 let decoded = BASE64.decode(b"").unwrap();
671 assert!(decoded.is_empty());
672 }
673}