rigz_lua/
lib.rs

1mod args;
2
3use std::cmp::max;
4use crate::args::{to_args, Arg, Definition};
5use anyhow::anyhow;
6use log::{debug, info, warn};
7use mlua::{Function, Lua, Value, Variadic};
8use rigz_core::{Argument, InitializationArgs, Module, RuntimeStatus};
9use serde::Deserialize;
10use std::collections::HashMap;
11use std::fs::File;
12use std::io::Read;
13use std::path::PathBuf;
14
15#[derive(Copy, Clone, Debug, Default, Deserialize)]
16pub enum FunctionFormat {
17    #[default]
18    StructFunction,        // { name, args, context, prior }
19                           // Dynamic https://gitlab.com/inapinch/rigz/rigz/-/issues/2
20}
21
22impl From<String> for FunctionFormat {
23    fn from(value: String) -> Self {
24        match value.as_str() {
25            "StructFunction" => FunctionFormat::StructFunction,
26            _ => {
27                warn!("Unsupported Format: {}, defaulting to Args", value);
28                FunctionFormat::StructFunction
29            }
30        }
31    }
32}
33
34pub struct LuaModule {
35    pub(crate) name: String,
36    pub(crate) function_format: FunctionFormat,
37    pub(crate) module_root: PathBuf,
38    pub(crate) lua: Lua,
39    pub(crate) source_files: Vec<PathBuf>,
40    pub(crate) input_files: HashMap<String, Vec<File>>,
41}
42
43fn inspect(lua: &Lua, value: Value) -> mlua::Result<String> {
44    let result = match value {
45        Value::Table(t) => {
46            let mut result = String::new();
47            let len = t.clone().pairs::<Value, Value>().count();
48            let pairs = t.pairs::<Value, Value>();
49            let mut index = 1;
50            let mut is_array = false;
51            for pair in pairs {
52                let (key, value) = pair?;
53                result.push(' ');
54                let key_str = inspect(lua, key)?;
55
56                if index == 1 {
57                    if key_str.as_str() == "1" {
58                        is_array = true;
59                        result.push('[');
60                    } else {
61                        result.push('{');
62                    }
63                }
64                if !is_array {
65                    result.push_str(key_str.as_str());
66                    result.push_str(" = ");
67                }
68                result.push_str(inspect(lua, value)?.as_str());
69                if index < len {
70                    result.push(',');
71                }
72                index += 1;
73            }
74
75            if is_array {
76                result.push(']');
77            } else {
78                result.push('}');
79            }
80            result
81        }
82        _ => value.to_string()?
83    };
84    Ok(result)
85}
86
87impl LuaModule {
88    pub fn new(
89        name: String,
90        module_root: PathBuf,
91        source_files: Vec<PathBuf>,
92        input_files: HashMap<String, Vec<File>>,
93        config: Option<serde_value::Value>,
94    ) -> Box<dyn Module> {
95        let function_format = match config {
96            None => FunctionFormat::default(),
97            Some(f) => {
98                if let serde_value::Value::Map(mut map) = f {
99                    let f = map.remove(&serde_value::Value::String("function_format".into()));
100                    if f.is_none() {
101                        FunctionFormat::default()
102                    } else {
103                        f.unwrap()
104                            .deserialize_into()
105                            .unwrap_or(FunctionFormat::default())
106                    }
107                } else {
108                    FunctionFormat::default()
109                }
110            }
111        };
112        Box::new(LuaModule {
113            name,
114            function_format,
115            module_root,
116            input_files,
117            lua: Lua::new(),
118            source_files,
119        })
120    }
121
122    pub(crate) fn invoke_function(
123        &self,
124        name: &str,
125        args: Vec<Arg>,
126        context: Definition,
127        previous_value: Arg,
128    ) -> RuntimeStatus<Arg> {
129        let lua = &self.lua;
130        let table = lua.globals();
131        
132        lua
133            .scope(|_| {
134                let function: Function = match table.get::<_, Function>(name) {
135                    Ok(f) => f,
136                    Err(e) => {
137                        warn!("Function Not Found: {} - {}", name, e);
138                        return Ok(RuntimeStatus::NotFound);
139                    }
140                };
141
142                let status = match self.function_format {
143                    FunctionFormat::StructFunction => {
144                        let table = lua.create_table()?;
145                        table.set("name", name)?;
146                        table.set("args", args)?;
147                        table.set("previous_value", previous_value)?;
148                        table.set("context", context)?;
149                        RuntimeStatus::Ok(function.call(table)?)
150                    }
151                };
152                Ok(status)
153            })
154            .unwrap_or(RuntimeStatus::Err("Lua Execution Failed".to_string()))
155    }
156
157    fn load_source_files(&self) -> anyhow::Result<()> {
158        if self.source_files.is_empty() {
159            warn!("No source files configured for module {}", self.name);
160        }
161
162        for file in &self.source_files {
163            let ext = file
164                .extension()
165                .map(|o| o.to_str().unwrap_or("<invalid>"))
166                .unwrap_or("<none>");
167            if ext != "lua" {
168                continue;
169            }
170            let current_file = file.to_str().unwrap_or("<unknown>");
171            info!("{} loading {}", self.name, current_file);
172            let contents = load_file(file)?;
173            match self.lua.scope(|_| {
174                let global = self.lua.globals();
175                let chunk = self.lua.load(contents);
176                chunk.exec()?;
177                global.set("__module_name", self.name.as_str())?;
178                Ok(())
179            }) {
180                Ok(_) => continue,
181                Err(e) => {
182                    return Err(anyhow!(
183                        "Failed to load file: {} - {} {}",
184                        self.name,
185                        current_file,
186                        e
187                    ))
188                }
189            }
190        }
191
192        Ok(())
193    }
194}
195
196fn load_file(path_buf: &PathBuf) -> anyhow::Result<String> {
197    let mut contents = String::new();
198    let mut file = File::open(path_buf)?;
199    file.read_to_string(&mut contents)?;
200    Ok(contents)
201}
202
203impl Module for LuaModule {
204    fn name(&self) -> &str {
205        self.name.as_str()
206    }
207
208    fn root(&self) -> PathBuf {
209        self.module_root.clone()
210    }
211
212    fn function_call(
213        &self,
214        name: &str,
215        arguments: Vec<Argument>,
216        definition: rigz_core::Definition,
217        prior_result: Argument,
218    ) -> RuntimeStatus<Argument> {
219        match self.invoke_function(
220            name,
221            to_args(arguments),
222            definition.into(),
223            prior_result.into(),
224        ) {
225            RuntimeStatus::Ok(a) => RuntimeStatus::Ok(a.into()),
226            RuntimeStatus::NotFound => RuntimeStatus::NotFound,
227            RuntimeStatus::Err(e) => RuntimeStatus::Err(e),
228        }
229    }
230
231    fn initialize(&self, args: InitializationArgs) -> RuntimeStatus<()> {
232        match self.load_source_files() {
233            Ok(_) => {}
234            Err(e) => return RuntimeStatus::Err(format!("Failed to load source files - {}", e)),
235        };
236
237
238        if self.input_files.is_empty() {
239            debug!("No input files passed into module {}", self.name)
240        }
241
242        match self.lua.scope(|_| {
243            let global = self.lua.globals();
244            global.set("inspect", self.lua.create_function(inspect)?)?;
245            global.set("__module_name", self.name.as_str())?;
246            Ok(())
247        }) {
248            Ok(_) => RuntimeStatus::Ok(()),
249            Err(e) => RuntimeStatus::Err(format!("Initialization Failed: {} - {}", self.name, e)),
250        }
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn it_works() {
260        let module = LuaModule::new(
261            "hello_world".to_string(),
262            Default::default(),
263            vec![],
264            Default::default(),
265            None,
266        );
267
268        let result = module.function_call(
269            "print",
270            vec![Argument::String("Hello World".into())],
271            rigz_core::Definition::None,
272            Argument::None,
273        );
274        assert_eq!(result, RuntimeStatus::Ok(Argument::None));
275    }
276}