Skip to main content

nil_lua/
lib.rs

1// Copyright (C) Call of Nil contributors
2// SPDX-License-Identifier: AGPL-3.0-only
3
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![doc(html_favicon_url = "https://nil.dev.br/favicon.png")]
6#![feature(iterator_try_collect)]
7
8pub mod client;
9pub mod error;
10pub mod io;
11pub mod script;
12
13use client::ClientUserData;
14use error::Result;
15use io::{Stdio, StdioMessage};
16use mlua::{LuaOptions, StdLib, Value, Variadic};
17use nil_client::Client;
18use script::ScriptOutput;
19use std::mem;
20use std::sync::Arc;
21use tokio::sync::RwLock;
22use tokio::sync::mpsc::{self, UnboundedReceiver};
23
24pub struct Lua {
25  inner: mlua::Lua,
26  stdout: Stdio,
27  stderr: Stdio,
28}
29
30#[bon::bon]
31impl Lua {
32  #[builder]
33  pub fn new(
34    #[builder(start_fn)] client: &Arc<RwLock<Client>>,
35    #[builder(default = StdLib::ALL_SAFE)] libs: StdLib,
36  ) -> Result<Self> {
37    let lua = mlua::Lua::new_with(libs, LuaOptions::default())?;
38
39    let globals = lua.globals();
40    let client_data = ClientUserData::new(Arc::clone(client));
41    globals.set("client", lua.create_userdata(client_data)?)?;
42
43    let stdout_rx = pipe(&lua, "print", "println")?;
44    let stderr_rx = pipe(&lua, "eprint", "eprintln")?;
45
46    Ok(Self {
47      inner: lua,
48      stdout: Stdio::new(stdout_rx),
49      stderr: Stdio::new(stderr_rx),
50    })
51  }
52
53  pub async fn execute(&mut self, chunk: &str) -> Result<ScriptOutput> {
54    self.flush();
55    self.clear();
56
57    self.inner.load(chunk).exec_async().await?;
58
59    Ok(self.output())
60  }
61
62  fn output(&mut self) -> ScriptOutput {
63    self.flush();
64    self.stdout.buffer.sort();
65    self.stderr.buffer.sort();
66
67    ScriptOutput {
68      stdout: mem::take(&mut self.stdout.buffer),
69      stderr: mem::take(&mut self.stderr.buffer),
70    }
71  }
72
73  fn flush(&mut self) {
74    self.stdout.flush();
75    self.stderr.flush();
76  }
77
78  fn clear(&mut self) {
79    self.stdout.buffer.clear();
80    self.stderr.buffer.clear();
81  }
82}
83
84fn pipe(lua: &mlua::Lua, name: &str, name_ln: &str) -> Result<UnboundedReceiver<StdioMessage>> {
85  let (tx, rx) = mpsc::unbounded_channel();
86  let create_fn = |line_break: bool| {
87    let tx = tx.clone();
88    lua.create_function(move |_, values: Variadic<Value>| {
89      let mut string = values
90        .into_iter()
91        .map(|it| it.to_string())
92        .try_collect::<String>()?;
93
94      if line_break {
95        string.push('\n');
96      }
97
98      let _ = tx.send(StdioMessage::new(string));
99
100      Ok(())
101    })
102  };
103
104  let globals = lua.globals();
105  globals.set(name, create_fn(false)?)?;
106  globals.set(name_ln, create_fn(true)?)?;
107
108  Ok(rx)
109}