1mod args;
2
3use std::cmp::max;
4use crate::args::{to_args, Arg, Definition};
5use anyhow::anyhow;
6use log::{debug, info, warn};
7use mlua::{Function, Lua, Value, Variadic};
8use rigz_core::{Argument, InitializationArgs, Module, RuntimeStatus};
9use serde::Deserialize;
10use std::collections::HashMap;
11use std::fs::File;
12use std::io::Read;
13use std::path::PathBuf;
14
15#[derive(Copy, Clone, Debug, Default, Deserialize)]
16pub enum FunctionFormat {
17 #[default]
18 StructFunction, }
21
22impl From<String> for FunctionFormat {
23 fn from(value: String) -> Self {
24 match value.as_str() {
25 "StructFunction" => FunctionFormat::StructFunction,
26 _ => {
27 warn!("Unsupported Format: {}, defaulting to Args", value);
28 FunctionFormat::StructFunction
29 }
30 }
31 }
32}
33
34pub struct LuaModule {
35 pub(crate) name: String,
36 pub(crate) function_format: FunctionFormat,
37 pub(crate) module_root: PathBuf,
38 pub(crate) lua: Lua,
39 pub(crate) source_files: Vec<PathBuf>,
40 pub(crate) input_files: HashMap<String, Vec<File>>,
41}
42
43fn inspect(lua: &Lua, value: Value) -> mlua::Result<String> {
44 let result = match value {
45 Value::Table(t) => {
46 let mut result = String::new();
47 let len = t.clone().pairs::<Value, Value>().count();
48 let pairs = t.pairs::<Value, Value>();
49 let mut index = 1;
50 let mut is_array = false;
51 for pair in pairs {
52 let (key, value) = pair?;
53 result.push(' ');
54 let key_str = inspect(lua, key)?;
55
56 if index == 1 {
57 if key_str.as_str() == "1" {
58 is_array = true;
59 result.push('[');
60 } else {
61 result.push('{');
62 }
63 }
64 if !is_array {
65 result.push_str(key_str.as_str());
66 result.push_str(" = ");
67 }
68 result.push_str(inspect(lua, value)?.as_str());
69 if index < len {
70 result.push(',');
71 }
72 index += 1;
73 }
74
75 if is_array {
76 result.push(']');
77 } else {
78 result.push('}');
79 }
80 result
81 }
82 _ => value.to_string()?
83 };
84 Ok(result)
85}
86
87impl LuaModule {
88 pub fn new(
89 name: String,
90 module_root: PathBuf,
91 source_files: Vec<PathBuf>,
92 input_files: HashMap<String, Vec<File>>,
93 config: Option<serde_value::Value>,
94 ) -> Box<dyn Module> {
95 let function_format = match config {
96 None => FunctionFormat::default(),
97 Some(f) => {
98 if let serde_value::Value::Map(mut map) = f {
99 let f = map.remove(&serde_value::Value::String("function_format".into()));
100 if f.is_none() {
101 FunctionFormat::default()
102 } else {
103 f.unwrap()
104 .deserialize_into()
105 .unwrap_or(FunctionFormat::default())
106 }
107 } else {
108 FunctionFormat::default()
109 }
110 }
111 };
112 Box::new(LuaModule {
113 name,
114 function_format,
115 module_root,
116 input_files,
117 lua: Lua::new(),
118 source_files,
119 })
120 }
121
122 pub(crate) fn invoke_function(
123 &self,
124 name: &str,
125 args: Vec<Arg>,
126 context: Definition,
127 previous_value: Arg,
128 ) -> RuntimeStatus<Arg> {
129 let lua = &self.lua;
130 let table = lua.globals();
131
132 lua
133 .scope(|_| {
134 let function: Function = match table.get::<_, Function>(name) {
135 Ok(f) => f,
136 Err(e) => {
137 warn!("Function Not Found: {} - {}", name, e);
138 return Ok(RuntimeStatus::NotFound);
139 }
140 };
141
142 let status = match self.function_format {
143 FunctionFormat::StructFunction => {
144 let table = lua.create_table()?;
145 table.set("name", name)?;
146 table.set("args", args)?;
147 table.set("previous_value", previous_value)?;
148 table.set("context", context)?;
149 RuntimeStatus::Ok(function.call(table)?)
150 }
151 };
152 Ok(status)
153 })
154 .unwrap_or(RuntimeStatus::Err("Lua Execution Failed".to_string()))
155 }
156
157 fn load_source_files(&self) -> anyhow::Result<()> {
158 if self.source_files.is_empty() {
159 warn!("No source files configured for module {}", self.name);
160 }
161
162 for file in &self.source_files {
163 let ext = file
164 .extension()
165 .map(|o| o.to_str().unwrap_or("<invalid>"))
166 .unwrap_or("<none>");
167 if ext != "lua" {
168 continue;
169 }
170 let current_file = file.to_str().unwrap_or("<unknown>");
171 info!("{} loading {}", self.name, current_file);
172 let contents = load_file(file)?;
173 match self.lua.scope(|_| {
174 let global = self.lua.globals();
175 let chunk = self.lua.load(contents);
176 chunk.exec()?;
177 global.set("__module_name", self.name.as_str())?;
178 Ok(())
179 }) {
180 Ok(_) => continue,
181 Err(e) => {
182 return Err(anyhow!(
183 "Failed to load file: {} - {} {}",
184 self.name,
185 current_file,
186 e
187 ))
188 }
189 }
190 }
191
192 Ok(())
193 }
194}
195
196fn load_file(path_buf: &PathBuf) -> anyhow::Result<String> {
197 let mut contents = String::new();
198 let mut file = File::open(path_buf)?;
199 file.read_to_string(&mut contents)?;
200 Ok(contents)
201}
202
203impl Module for LuaModule {
204 fn name(&self) -> &str {
205 self.name.as_str()
206 }
207
208 fn root(&self) -> PathBuf {
209 self.module_root.clone()
210 }
211
212 fn function_call(
213 &self,
214 name: &str,
215 arguments: Vec<Argument>,
216 definition: rigz_core::Definition,
217 prior_result: Argument,
218 ) -> RuntimeStatus<Argument> {
219 match self.invoke_function(
220 name,
221 to_args(arguments),
222 definition.into(),
223 prior_result.into(),
224 ) {
225 RuntimeStatus::Ok(a) => RuntimeStatus::Ok(a.into()),
226 RuntimeStatus::NotFound => RuntimeStatus::NotFound,
227 RuntimeStatus::Err(e) => RuntimeStatus::Err(e),
228 }
229 }
230
231 fn initialize(&self, args: InitializationArgs) -> RuntimeStatus<()> {
232 match self.load_source_files() {
233 Ok(_) => {}
234 Err(e) => return RuntimeStatus::Err(format!("Failed to load source files - {}", e)),
235 };
236
237
238 if self.input_files.is_empty() {
239 debug!("No input files passed into module {}", self.name)
240 }
241
242 match self.lua.scope(|_| {
243 let global = self.lua.globals();
244 global.set("inspect", self.lua.create_function(inspect)?)?;
245 global.set("__module_name", self.name.as_str())?;
246 Ok(())
247 }) {
248 Ok(_) => RuntimeStatus::Ok(()),
249 Err(e) => RuntimeStatus::Err(format!("Initialization Failed: {} - {}", self.name, e)),
250 }
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn it_works() {
260 let module = LuaModule::new(
261 "hello_world".to_string(),
262 Default::default(),
263 vec![],
264 Default::default(),
265 None,
266 );
267
268 let result = module.function_call(
269 "print",
270 vec![Argument::String("Hello World".into())],
271 rigz_core::Definition::None,
272 Argument::None,
273 );
274 assert_eq!(result, RuntimeStatus::Ok(Argument::None));
275 }
276}