lua_config/
lib.rs

1use std::error::Error;
2use std::str::from_utf8;
3
4#[derive(Debug, Clone)]
5pub enum LuaType {
6    Nil,
7    Boolean(bool),
8    Integer(i64),
9    Number(f64),
10    String(String),
11    Table(LuaTable),
12}
13
14pub type LuaTable = std::collections::HashMap<String, LuaType>;
15
16impl LuaType {
17    #[cfg(not(feature = "crash_on_none"))]
18    pub fn to<T>(&self) -> Option<T>
19    where
20        T: LuaConvert,
21    {
22        T::from_lua_type(self)
23    }
24
25    #[cfg(feature = "crash_on_none")]
26    pub fn to<T>(&self) -> T
27    where
28        T: LuaConvert,
29    {
30        match T::from_lua_type(self) {
31            Some(value) => value,
32            None => panic!(
33                "Failed to convert LuaType to {}",
34                std::any::type_name::<T>()
35            ),
36        }
37    }
38
39    #[cfg(not(feature = "crash_on_none"))]
40    pub fn get(&self, key: &str) -> Option<&LuaType> {
41        if let LuaType::Table(table) = self {
42            table.get(key)
43        } else {
44            None
45        }
46    }
47
48    #[cfg(feature = "crash_on_none")]
49    pub fn get(&self, key: &str) -> &LuaType {
50        if let LuaType::Table(table) = self {
51            match table.get(key) {
52                Some(value) => value,
53                None => panic!("Key {} not found in table", key),
54            }
55        } else {
56            panic!("Value is not a table");
57        }
58    }
59}
60
61pub trait LuaConvert: Sized {
62    fn from_lua_type(lua_type: &LuaType) -> Option<Self>;
63}
64
65impl LuaConvert for i32 {
66    fn from_lua_type(lua_type: &LuaType) -> Option<Self> {
67        match lua_type {
68            LuaType::Integer(i) => i32::try_from(*i).ok(),
69            LuaType::Number(n) => Some(*n as i32),
70            _ => None,
71        }
72    }
73}
74
75impl LuaConvert for f32 {
76    fn from_lua_type(lua_type: &LuaType) -> Option<Self> {
77        match lua_type {
78            LuaType::Integer(i) => Some(*i as f32),
79            LuaType::Number(n) => Some(*n as f32),
80            _ => None,
81        }
82    }
83}
84
85impl LuaConvert for i64 {
86    fn from_lua_type(lua_type: &LuaType) -> Option<Self> {
87        match lua_type {
88            LuaType::Integer(i) => Some(*i),
89            LuaType::Number(n) => Some(*n as i64),
90            _ => None,
91        }
92    }
93}
94
95impl LuaConvert for f64 {
96    fn from_lua_type(lua_type: &LuaType) -> Option<Self> {
97        match lua_type {
98            LuaType::Number(n) => Some(*n),
99            LuaType::Integer(i) => Some(*i as f64),
100            _ => None,
101        }
102    }
103}
104
105impl LuaConvert for bool {
106    fn from_lua_type(lua_type: &LuaType) -> Option<Self> {
107        if let LuaType::Boolean(b) = lua_type {
108            Some(*b)
109        } else {
110            None
111        }
112    }
113}
114
115impl LuaConvert for String {
116    fn from_lua_type(lua_type: &LuaType) -> Option<Self> {
117        if let LuaType::String(s) = lua_type {
118            Some(s.clone())
119        } else {
120            None
121        }
122    }
123}
124
125impl LuaConvert for LuaTable {
126    fn from_lua_type(lua_type: &LuaType) -> Option<Self> {
127        if let LuaType::Table(table) = lua_type {
128            Some(table.clone())
129        } else {
130            None
131        }
132    }
133}
134
135fn print_lua_type(value: LuaType, f: &mut std::fmt::Formatter, depth: usize) -> std::fmt::Result {
136    match value {
137        LuaType::Nil => write!(f, "nil"),
138        LuaType::Boolean(b) => write!(f, "Boolean({})", b),
139        LuaType::Integer(n) => write!(f, "Integer({})", n),
140        LuaType::Number(n) => write!(f, "Number({})", n),
141        LuaType::String(s) => write!(f, "String(\"{}\")", s),
142        LuaType::Table(map) => {
143            write!(f, "{{")?;
144            for (key, value) in map.iter() {
145                write!(f, "\n{}{} = ", " ".repeat(depth * 4), key)?;
146                print_lua_type(value.clone(), f, depth + 1)?;
147            }
148            write!(f, "\n{}}}", " ".repeat((depth - 1) * 4))
149        }
150    }
151}
152
153impl std::fmt::Display for LuaType {
154    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
155        print_lua_type(self.clone(), f, 1)
156    }
157}
158
159#[derive(Debug, Clone)]
160pub struct LuaConfig {
161    pub data: LuaTable,
162    config: String,
163    default: Option<String>,
164    app_version: String,
165}
166
167impl LuaConfig {
168    pub fn from_string(file: String) -> Self {
169        LuaConfig {
170            data: std::collections::HashMap::new(),
171            config: file,
172            default: None,
173            app_version: "0.0.0".to_string(),
174        }
175    }
176
177    pub fn with_version(mut self, version: &str) -> Self {
178        self.app_version = version.to_string();
179        self
180    }
181
182    pub fn from_file(path: &str) -> Result<Self, Box<dyn Error>> {
183        let file = std::fs::read_to_string(path)?;
184        Ok(LuaConfig::from_string(file))
185    }
186
187    pub fn with_default(mut self, default: &[u8]) -> Result<Self, Box<dyn Error>> {
188        self.default = Some(from_utf8(default)?.to_string());
189        Ok(self)
190    }
191
192    pub fn execute(mut self) -> Result<Self, Box<dyn Error>> {
193        let lua = rlua::Lua::new();
194        let config_values = self.get_hashmap_by_function(&lua, &self.config, "Config")?;
195
196        if self.default.is_some() {
197            let default_values =
198                self.get_hashmap_by_function(&lua, &self.default.clone().unwrap(), "Default")?;
199
200            // Recursivly return error if any value in the config_values is not present in the default_values
201            fn check_config_table(
202                config: &std::collections::HashMap<String, LuaType>,
203                default: &std::collections::HashMap<String, LuaType>,
204            ) -> Result<(), Box<dyn Error>> {
205                for (key, value) in config.iter() {
206                    if let Some(default_value) = default.get(key) {
207                        match value {
208                            LuaType::Table(table) => {
209                                if let LuaType::Table(default_table) = default_value {
210                                    check_config_table(table, default_table)?;
211                                } else {
212                                    return Err(format!("Key {} is not a table", key).into());
213                                }
214                            }
215                            _ => {}
216                        }
217                    } else {
218                        return Err(format!("Key {} not found in default config", key).into());
219                    }
220                }
221                Ok(())
222            }
223            check_config_table(&config_values, &default_values)?;
224
225            // Recursivly merge the default_values into the config_values
226            fn merge_tables(
227                config: std::collections::HashMap<String, LuaType>,
228                default: std::collections::HashMap<String, LuaType>,
229            ) -> std::collections::HashMap<String, LuaType> {
230                let mut result = default.clone();
231                for (key, value) in config {
232                    match value {
233                        LuaType::Table(table) => {
234                            if let LuaType::Table(default_table) = default.get(&key).unwrap() {
235                                result.insert(
236                                    key,
237                                    LuaType::Table(merge_tables(table, default_table.clone())),
238                                );
239                            }
240                        }
241                        _ => {
242                            result.insert(key, value);
243                        }
244                    }
245                }
246                result
247            }
248
249            self.data = merge_tables(config_values, default_values);
250        } else {
251            self.data = config_values;
252        }
253
254        Ok(self)
255    }
256
257    #[cfg(not(feature = "crash_on_none"))]
258    pub fn get(&self, key: &str) -> Option<&LuaType> {
259        self.get_lua_type(key)
260    }
261
262    #[cfg(feature = "crash_on_none")]
263    pub fn get(&self, key: &str) -> &LuaType {
264        match self.get_lua_type(key) {
265            Some(value) => value,
266            None => panic!("Key {} not found in config", key),
267        }
268    }
269
270    pub fn get_lua_type(&self, key: &str) -> Option<&LuaType> {
271        self.data.get(key)
272    }
273
274    fn declare_lua_functions(&self, ctx: &rlua::Context) -> Result<(), rlua::Error> {
275        let globals = ctx.globals();
276
277        let fetch_data = ctx.create_function(|lua_ctx, url: String| {
278            let response = reqwest::blocking::get(url).expect("Failed to fetch data");
279            let table = LuaConfig::lua_table_from_json(lua_ctx, &response.text().unwrap())
280                .expect("Failed to convert JSON to Lua table");
281            Ok(table)
282        })?;
283        globals.set("fetch_data", fetch_data)?;
284
285        let app_version = self.app_version.clone();
286        let version = ctx.create_function(move |_, ()| Ok(app_version.clone()))?;
287        globals.set("version", version)?;
288
289        let build = ctx.create_function(|_, ()| {
290            if cfg!(debug_assertions) {
291                Ok("debug")
292            } else {
293                Ok("release")
294            }
295        })?;
296        globals.set("build", build)?;
297
298        Ok(())
299    }
300
301    fn lua_table_from_json<'lua>(
302        lua: &'lua rlua::Lua,
303        json: &str,
304    ) -> Result<rlua::Table<'lua>, Box<dyn Error>> {
305        let json = json::parse(json)?;
306
307        fn convert_json_to_lua<'lua>(
308            lua: &'lua rlua::Lua,
309            json_value: &json::JsonValue,
310        ) -> Result<rlua::Value<'lua>, Box<dyn Error>> {
311            match json_value {
312                json::JsonValue::Null => Ok(rlua::Value::Nil),
313                json::JsonValue::String(s) => Ok(rlua::Value::String(lua.create_string(s)?)),
314                json::JsonValue::Number(n) => Ok(rlua::Value::Number(
315                    n.as_fixed_point_i64(0).unwrap_or_default() as f64,
316                )),
317                json::JsonValue::Boolean(b) => Ok(rlua::Value::Boolean(*b)),
318                json::JsonValue::Object(obj) => {
319                    let table = lua.create_table()?;
320                    for (key, value) in obj.iter() {
321                        table.set(key, convert_json_to_lua(lua, value)?)?;
322                    }
323                    Ok(rlua::Value::Table(table))
324                }
325                json::JsonValue::Array(arr) => {
326                    let table = lua.create_table()?;
327                    for (i, value) in arr.iter().enumerate() {
328                        table.set(i + 1, convert_json_to_lua(lua, value)?)?;
329                    }
330                    Ok(rlua::Value::Table(table))
331                }
332                _ => unimplemented!("This datatype is not implemented yet"),
333            }
334        }
335
336        let lua_value = convert_json_to_lua(lua, &json)?;
337
338        if let rlua::Value::Table(table) = lua_value {
339            Ok(table)
340        } else {
341            Err("Root element is not a table".into())
342        }
343    }
344
345    fn get_hashmap_by_function<'lua>(
346        &self,
347        lua: &'lua rlua::Lua,
348        code: &str,
349        function_name: &str,
350    ) -> Result<std::collections::HashMap<String, LuaType>, Box<dyn Error>> {
351        let ctx = lua.load(code);
352        self.declare_lua_functions(&lua).unwrap();
353
354        ctx.exec()?;
355        let globals = lua.globals();
356        let func = match globals.get::<_, rlua::Function>(function_name) {
357            Ok(f) => f,
358            Err(e) => {
359                return Err(format!("Error getting function {}: {}", function_name, e).into());
360            }
361        };
362        let table = match func.call::<_, rlua::Table>(()) {
363            Ok(t) => t,
364            Err(e) => {
365                return Err(format!("Error calling function {}: {}", function_name, e).into());
366            }
367        };
368
369        if table.is_empty() {
370            return Err(format!("Function {} returned an empty table", function_name).into());
371        }
372
373        let mut values = std::collections::HashMap::new();
374        for pair in table.pairs::<String, rlua::Value>() {
375            let (key, value) = pair?;
376            let value = LuaConfig::value_to_lua_type(&value);
377            values.insert(key, value);
378        }
379
380        Ok(values)
381    }
382
383    fn value_to_lua_type(value: &rlua::Value) -> LuaType {
384        match value {
385            rlua::Value::Nil => LuaType::Nil,
386            rlua::Value::Boolean(b) => LuaType::Boolean(*b),
387            rlua::Value::Integer(n) => LuaType::Integer(*n),
388            rlua::Value::Number(n) => LuaType::Number(*n),
389            rlua::Value::String(s) => LuaType::String(s.to_str().unwrap_or_default().to_owned()),
390            rlua::Value::Table(table) => {
391                let mut map = std::collections::HashMap::new();
392                for pair in table.clone().pairs::<String, rlua::Value>() {
393                    if let Ok((key, value)) = pair {
394                        map.insert(key, LuaConfig::value_to_lua_type(&value));
395                    }
396                }
397                LuaType::Table(map)
398            }
399            _ => unimplemented!("Conversion for this Lua type is not implemented yet"),
400        }
401    }
402}
403
404impl std::fmt::Display for LuaConfig {
405    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
406        for (key, value) in self.data.iter() {
407            write!(f, "{} = {}\n", key, value)?;
408        }
409        Ok(())
410    }
411}