Skip to main content

lux_lib/lua_rockspec/build/
builtin.rs

1use itertools::Itertools;
2use serde::{de, Deserialize, Deserializer};
3use std::{collections::HashMap, convert::Infallible, fmt::Display, path::PathBuf, str::FromStr};
4use thiserror::Error;
5
6use crate::{
7    build::utils::c_dylib_extension,
8    lua_rockspec::{
9        deserialize_vec_from_lua_array_or_string, normalize_lua_value, DisplayAsLuaValue,
10        PartialOverride, PerPlatform, PlatformOverridable,
11    },
12};
13
14use super::{DisplayLuaKV, DisplayLuaValue};
15
16#[derive(Debug, PartialEq, Deserialize, Default, Clone)]
17pub struct BuiltinBuildSpec {
18    /// Keys are module names in the format normally used by the `require()` function
19    pub modules: HashMap<LuaModule, ModuleSpec>,
20}
21
22#[derive(Debug, PartialEq, Eq, Deserialize, Default, Clone, Hash)]
23pub struct LuaModule(String);
24
25impl LuaModule {
26    pub fn to_lua_path(&self) -> PathBuf {
27        self.to_file_path(".lua")
28    }
29
30    pub fn to_lua_init_path(&self) -> PathBuf {
31        self.to_path_buf().join("init.lua")
32    }
33
34    pub fn to_lib_path(&self) -> PathBuf {
35        self.to_file_path(&format!(".{}", c_dylib_extension()))
36    }
37
38    fn to_path_buf(&self) -> PathBuf {
39        PathBuf::from(self.0.replace('.', std::path::MAIN_SEPARATOR_STR))
40    }
41
42    fn to_file_path(&self, extension: &str) -> PathBuf {
43        PathBuf::from(self.0.replace('.', std::path::MAIN_SEPARATOR_STR) + extension)
44    }
45
46    pub fn from_pathbuf(path: PathBuf) -> Self {
47        let extension = path
48            .extension()
49            .map(|ext| ext.to_string_lossy().to_string())
50            .unwrap_or("".into());
51        let module = path
52            .to_string_lossy()
53            .trim_end_matches(format!("init.{extension}").as_str())
54            .trim_end_matches(format!(".{extension}").as_str())
55            .trim_end_matches(std::path::MAIN_SEPARATOR_STR)
56            .replace(std::path::MAIN_SEPARATOR_STR, ".");
57        LuaModule(module)
58    }
59
60    pub fn join(&self, other: &LuaModule) -> LuaModule {
61        LuaModule(format!("{}.{}", self.0, other.0))
62    }
63
64    pub fn as_str(&self) -> &str {
65        self.0.as_str()
66    }
67}
68
69#[derive(Error, Debug)]
70#[error("could not parse lua module from {0}.")]
71pub struct ParseLuaModuleError(String);
72
73impl FromStr for LuaModule {
74    type Err = ParseLuaModuleError;
75
76    // NOTE: We may want to add some validations
77    fn from_str(s: &str) -> Result<Self, Self::Err> {
78        Ok(LuaModule(s.into()))
79    }
80}
81
82impl Display for LuaModule {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        self.0.fmt(f)
85    }
86}
87
88#[derive(Debug, PartialEq, Clone)]
89pub enum ModuleSpec {
90    /// Pathnames of Lua files or C sources, for modules based on a single source file.
91    SourcePath(PathBuf),
92    /// Pathnames of C sources of a simple module written in C composed of multiple files.
93    SourcePaths(Vec<PathBuf>),
94    ModulePaths(ModulePaths),
95}
96
97impl ModuleSpec {
98    pub fn from_internal(
99        internal: ModuleSpecInternal,
100    ) -> Result<ModuleSpec, ModulePathsMissingSources> {
101        match internal {
102            ModuleSpecInternal::SourcePath(path) => Ok(ModuleSpec::SourcePath(path)),
103            ModuleSpecInternal::SourcePaths(paths) => Ok(ModuleSpec::SourcePaths(paths)),
104            ModuleSpecInternal::ModulePaths(module_paths) => Ok(ModuleSpec::ModulePaths(
105                ModulePaths::from_internal(module_paths)?,
106            )),
107        }
108    }
109}
110
111impl<'de> Deserialize<'de> for ModuleSpec {
112    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
113    where
114        D: Deserializer<'de>,
115    {
116        Self::from_internal(ModuleSpecInternal::deserialize(deserializer)?)
117            .map_err(de::Error::custom)
118    }
119}
120
121impl TryFrom<ModuleSpecInternal> for ModuleSpec {
122    type Error = ModulePathsMissingSources;
123
124    fn try_from(internal: ModuleSpecInternal) -> Result<Self, Self::Error> {
125        Self::from_internal(internal)
126    }
127}
128
129#[derive(Debug, PartialEq, Clone)]
130pub enum ModuleSpecInternal {
131    SourcePath(PathBuf),
132    SourcePaths(Vec<PathBuf>),
133    ModulePaths(ModulePathsInternal),
134}
135
136impl<'de> Deserialize<'de> for ModuleSpecInternal {
137    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
138    where
139        D: Deserializer<'de>,
140    {
141        let value = normalize_lua_value(serde_value::Value::deserialize(deserializer)?);
142        match value {
143            serde_value::Value::String(s) => Ok(Self::SourcePath(PathBuf::from(s))),
144            serde_value::Value::Seq(_) => {
145                let src_paths: Vec<PathBuf> =
146                    value.deserialize_into().map_err(de::Error::custom)?;
147                Ok(Self::SourcePaths(src_paths))
148            }
149            serde_value::Value::Map(_) => {
150                let module_paths: ModulePathsInternal =
151                    value.deserialize_into().map_err(de::Error::custom)?;
152                Ok(Self::ModulePaths(module_paths))
153            }
154            _ => Err(de::Error::custom(format!(
155                "expected a string, list, or table for module spec, got: {value:?}"
156            ))),
157        }
158    }
159}
160
161impl DisplayAsLuaValue for ModuleSpecInternal {
162    fn display_lua_value(&self) -> DisplayLuaValue {
163        match self {
164            ModuleSpecInternal::SourcePath(path) => {
165                DisplayLuaValue::String(path.to_string_lossy().into())
166            }
167            ModuleSpecInternal::SourcePaths(paths) => DisplayLuaValue::List(
168                paths
169                    .iter()
170                    .map(|p| DisplayLuaValue::String(p.to_string_lossy().into()))
171                    .collect(),
172            ),
173            ModuleSpecInternal::ModulePaths(module_paths) => module_paths.display_lua_value(),
174        }
175    }
176}
177
178fn deserialize_definitions<'de, D>(
179    deserializer: D,
180) -> Result<Vec<(String, Option<String>)>, D::Error>
181where
182    D: Deserializer<'de>,
183{
184    let values: Vec<String> = deserialize_vec_from_lua_array_or_string(deserializer)?;
185    values
186        .iter()
187        .map(|val| {
188            if let Some((key, value)) = val.split_once('=') {
189                Ok((key.into(), Some(value.into())))
190            } else {
191                Ok((val.into(), None))
192            }
193        })
194        .try_collect()
195}
196
197#[derive(Error, Debug)]
198#[error("cannot resolve ambiguous platform override for `build.modules`.")]
199pub struct ModuleSpecAmbiguousPlatformOverride;
200
201impl PartialOverride for ModuleSpecInternal {
202    type Err = ModuleSpecAmbiguousPlatformOverride;
203
204    fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
205        match (override_spec, self) {
206            (ModuleSpecInternal::SourcePath(_), b @ ModuleSpecInternal::SourcePath(_)) => {
207                Ok(b.to_owned())
208            }
209            (ModuleSpecInternal::SourcePaths(_), b @ ModuleSpecInternal::SourcePaths(_)) => {
210                Ok(b.to_owned())
211            }
212            (ModuleSpecInternal::ModulePaths(a), ModuleSpecInternal::ModulePaths(b)) => Ok(
213                ModuleSpecInternal::ModulePaths(a.apply_overrides(b).unwrap()),
214            ),
215            _ => Err(ModuleSpecAmbiguousPlatformOverride),
216        }
217    }
218}
219
220#[derive(Error, Debug)]
221#[error("could not resolve platform override for `build.modules`. THIS IS A BUG!")]
222pub struct BuildModulesPlatformOverride;
223
224impl PlatformOverridable for ModuleSpecInternal {
225    type Err = BuildModulesPlatformOverride;
226
227    fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
228    where
229        T: PlatformOverridable,
230    {
231        Err(BuildModulesPlatformOverride)
232    }
233}
234
235#[derive(Error, Debug)]
236#[error("missing or empty field `sources`")]
237pub struct ModulePathsMissingSources;
238
239#[derive(Debug, PartialEq, Clone)]
240pub struct ModulePaths {
241    /// Path names of C sources, mandatory field
242    pub sources: Vec<PathBuf>,
243    /// External libraries to be linked
244    pub libraries: Vec<PathBuf>,
245    /// C defines, e.g. { "FOO=bar", "USE_BLA" }
246    pub defines: Vec<(String, Option<String>)>,
247    /// Directories to be added to the compiler's headers lookup directory list.
248    pub incdirs: Vec<PathBuf>,
249    /// Directories to be added to the linker's library lookup directory list.
250    pub libdirs: Vec<PathBuf>,
251}
252
253impl ModulePaths {
254    fn from_internal(
255        internal: ModulePathsInternal,
256    ) -> Result<ModulePaths, ModulePathsMissingSources> {
257        if internal.sources.is_empty() {
258            Err(ModulePathsMissingSources)
259        } else {
260            Ok(ModulePaths {
261                sources: internal.sources,
262                libraries: internal.libraries,
263                defines: internal.defines,
264                incdirs: internal.incdirs,
265                libdirs: internal.libdirs,
266            })
267        }
268    }
269}
270
271impl TryFrom<ModulePathsInternal> for ModulePaths {
272    type Error = ModulePathsMissingSources;
273
274    fn try_from(internal: ModulePathsInternal) -> Result<Self, Self::Error> {
275        Self::from_internal(internal)
276    }
277}
278
279impl<'de> Deserialize<'de> for ModulePaths {
280    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
281    where
282        D: Deserializer<'de>,
283    {
284        Self::from_internal(ModulePathsInternal::deserialize(deserializer)?)
285            .map_err(de::Error::custom)
286    }
287}
288
289#[derive(Debug, PartialEq, Deserialize, Clone, Default)]
290pub struct ModulePathsInternal {
291    #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
292    pub sources: Vec<PathBuf>,
293    #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
294    pub libraries: Vec<PathBuf>,
295    #[serde(default, deserialize_with = "deserialize_definitions")]
296    pub defines: Vec<(String, Option<String>)>,
297    #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
298    pub incdirs: Vec<PathBuf>,
299    #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
300    pub libdirs: Vec<PathBuf>,
301}
302
303impl DisplayAsLuaValue for ModulePathsInternal {
304    fn display_lua_value(&self) -> DisplayLuaValue {
305        DisplayLuaValue::Table(vec![
306            DisplayLuaKV {
307                key: "sources".into(),
308                value: DisplayLuaValue::List(
309                    self.sources
310                        .iter()
311                        .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
312                        .collect(),
313                ),
314            },
315            DisplayLuaKV {
316                key: "libraries".into(),
317                value: DisplayLuaValue::List(
318                    self.libraries
319                        .iter()
320                        .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
321                        .collect(),
322                ),
323            },
324            DisplayLuaKV {
325                key: "defines".into(),
326                value: DisplayLuaValue::List(
327                    self.defines
328                        .iter()
329                        .map(|(k, v)| {
330                            if let Some(v) = v {
331                                DisplayLuaValue::String(format!("{k}={v}"))
332                            } else {
333                                DisplayLuaValue::String(k.clone())
334                            }
335                        })
336                        .collect(),
337                ),
338            },
339            DisplayLuaKV {
340                key: "incdirs".into(),
341                value: DisplayLuaValue::List(
342                    self.incdirs
343                        .iter()
344                        .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
345                        .collect(),
346                ),
347            },
348            DisplayLuaKV {
349                key: "libdirs".into(),
350                value: DisplayLuaValue::List(
351                    self.libdirs
352                        .iter()
353                        .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
354                        .collect(),
355                ),
356            },
357        ])
358    }
359}
360
361impl PartialOverride for ModulePathsInternal {
362    type Err = Infallible;
363
364    fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
365        Ok(Self {
366            sources: override_vec(override_spec.sources.as_ref(), self.sources.as_ref()),
367            libraries: override_vec(override_spec.libraries.as_ref(), self.libraries.as_ref()),
368            defines: override_vec(override_spec.defines.as_ref(), self.defines.as_ref()),
369            incdirs: override_vec(override_spec.incdirs.as_ref(), self.incdirs.as_ref()),
370            libdirs: override_vec(override_spec.libdirs.as_ref(), self.libdirs.as_ref()),
371        })
372    }
373}
374
375impl PlatformOverridable for ModulePathsInternal {
376    type Err = Infallible;
377
378    fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
379    where
380        T: PlatformOverridable,
381        T: Default,
382    {
383        Ok(PerPlatform::default())
384    }
385}
386
387fn override_vec<T: Clone>(override_vec: &[T], base: &[T]) -> Vec<T> {
388    if override_vec.is_empty() {
389        return base.to_owned();
390    }
391    override_vec.to_owned()
392}
393
394#[cfg(test)]
395mod tests {
396    use ottavino::{Closure, Executor, Fuel, Lua};
397    use ottavino_util::serde::from_value;
398
399    use super::*;
400
401    fn exec_lua<T: serde::de::DeserializeOwned>(
402        code: &str,
403        key: &'static str,
404    ) -> Result<T, ottavino::ExternError> {
405        Lua::core().try_enter(|ctx| {
406            let closure = Closure::load(ctx, None, code.as_bytes())?;
407            let executor = Executor::start(ctx, closure.into(), ());
408            executor.step(ctx, &mut Fuel::with(i32::MAX))?;
409            from_value(ctx.globals().get_value(ctx, key)).map_err(ottavino::Error::from)
410        })
411    }
412
413    #[tokio::test]
414    pub async fn parse_lua_module_from_path() {
415        let lua_module = LuaModule::from_pathbuf("foo/init.lua".into());
416        assert_eq!(&lua_module.0, "foo");
417        let lua_module = LuaModule::from_pathbuf("foo/bar.lua".into());
418        assert_eq!(&lua_module.0, "foo.bar");
419        let lua_module = LuaModule::from_pathbuf("foo/bar/init.lua".into());
420        assert_eq!(&lua_module.0, "foo.bar");
421        let lua_module = LuaModule::from_pathbuf("foo/bar/baz.lua".into());
422        assert_eq!(&lua_module.0, "foo.bar.baz");
423    }
424
425    #[tokio::test]
426    pub async fn modules_spec_from_lua() {
427        let lua_content = "
428        build = {\n
429            modules = {\n
430                foo = 'lua/foo/init.lua',\n
431                bar = {\n
432                  'lua/bar.lua',\n
433                  'lua/bar/internal.lua',\n
434                },\n
435                baz = {\n
436                    sources = {\n
437                        'lua/baz.lua',\n
438                    },\n
439                    defines = { 'USE_BAZ' },\n
440                },\n
441                foo = 'lua/foo/init.lua',
442            },\n
443        }\n
444        ";
445        let build_spec: BuiltinBuildSpec = exec_lua(lua_content, "build").unwrap();
446        let foo = build_spec
447            .modules
448            .get(&LuaModule::from_str("foo").unwrap())
449            .unwrap();
450        assert_eq!(*foo, ModuleSpec::SourcePath("lua/foo/init.lua".into()));
451        let bar = build_spec
452            .modules
453            .get(&LuaModule::from_str("bar").unwrap())
454            .unwrap();
455        assert_eq!(
456            *bar,
457            ModuleSpec::SourcePaths(vec!["lua/bar.lua".into(), "lua/bar/internal.lua".into()])
458        );
459        let baz = build_spec
460            .modules
461            .get(&LuaModule::from_str("baz").unwrap())
462            .unwrap();
463        assert!(matches!(baz, ModuleSpec::ModulePaths { .. }));
464        let lua_content_no_sources = "
465        build = {\n
466            modules = {\n
467                baz = {\n
468                    defines = { 'USE_BAZ' },\n
469                },\n
470            },\n
471        }\n
472        ";
473        let result: Result<BuiltinBuildSpec, _> = exec_lua(lua_content_no_sources, "build");
474        let _err = result.unwrap_err();
475        let lua_content_complex_defines = "
476        build = {\n
477            modules = {\n
478                baz = {\n
479                    sources = {\n
480                        'lua/baz.lua',\n
481                    },\n
482                    defines = { 'USE_BAZ=1', 'ENABLE_LOGGING=true', 'LINK_STATIC' },\n
483                },\n
484            },\n
485        }\n
486        ";
487        let build_spec: BuiltinBuildSpec = exec_lua(lua_content_complex_defines, "build").unwrap();
488        let baz = build_spec
489            .modules
490            .get(&LuaModule::from_str("baz").unwrap())
491            .unwrap();
492        match baz {
493            ModuleSpec::ModulePaths(paths) => assert_eq!(
494                paths.defines,
495                vec![
496                    ("USE_BAZ".into(), Some("1".into())),
497                    ("ENABLE_LOGGING".into(), Some("true".into())),
498                    ("LINK_STATIC".into(), None)
499                ]
500            ),
501            _ => panic!(),
502        }
503    }
504}