lux_lib/lua_rockspec/build/
builtin.rs

1use itertools::Itertools;
2use mlua::{IntoLua, UserData};
3use serde::{de, Deserialize, Deserializer};
4use std::{collections::HashMap, convert::Infallible, fmt::Display, path::PathBuf, str::FromStr};
5use thiserror::Error;
6
7use crate::{
8    build::utils::c_dylib_extension,
9    lua_rockspec::{
10        deserialize_vec_from_lua_array_or_string, DisplayAsLuaValue, FromPlatformOverridable,
11        PartialOverride, PerPlatform, PlatformOverridable,
12    },
13};
14
15use super::{DisplayLuaKV, DisplayLuaValue};
16
17#[derive(Debug, PartialEq, Deserialize, Default, Clone)]
18pub struct BuiltinBuildSpec {
19    /// Keys are module names in the format normally used by the `require()` function
20    pub modules: HashMap<LuaModule, ModuleSpec>,
21}
22
23impl IntoLua for BuiltinBuildSpec {
24    fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
25        self.modules.into_lua(lua)
26    }
27}
28
29#[derive(Debug, PartialEq, Eq, Deserialize, Default, Clone, Hash)]
30pub struct LuaModule(String);
31
32impl IntoLua for LuaModule {
33    fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
34        self.0.into_lua(lua)
35    }
36}
37
38impl LuaModule {
39    pub fn to_lua_path(&self) -> PathBuf {
40        self.to_file_path(".lua")
41    }
42
43    pub fn to_lua_init_path(&self) -> PathBuf {
44        self.to_path_buf().join("init.lua")
45    }
46
47    pub fn to_lib_path(&self) -> PathBuf {
48        self.to_file_path(&format!(".{}", c_dylib_extension()))
49    }
50
51    fn to_path_buf(&self) -> PathBuf {
52        PathBuf::from(self.0.replace('.', std::path::MAIN_SEPARATOR_STR))
53    }
54
55    fn to_file_path(&self, extension: &str) -> PathBuf {
56        PathBuf::from(self.0.replace('.', std::path::MAIN_SEPARATOR_STR) + extension)
57    }
58
59    pub fn from_pathbuf(path: PathBuf) -> Self {
60        let extension = path
61            .extension()
62            .map(|ext| ext.to_string_lossy().to_string())
63            .unwrap_or("".into());
64        let module = path
65            .to_string_lossy()
66            .trim_end_matches(format!("init.{}", extension).as_str())
67            .trim_end_matches(format!(".{}", extension).as_str())
68            .trim_end_matches(std::path::MAIN_SEPARATOR_STR)
69            .replace(std::path::MAIN_SEPARATOR_STR, ".");
70        LuaModule(module)
71    }
72
73    pub fn join(&self, other: &LuaModule) -> LuaModule {
74        LuaModule(format!("{}.{}", self.0, other.0))
75    }
76
77    pub fn as_str(&self) -> &str {
78        self.0.as_str()
79    }
80}
81
82#[derive(Error, Debug)]
83#[error("could not parse lua module from {0}.")]
84pub struct ParseLuaModuleError(String);
85
86impl FromStr for LuaModule {
87    type Err = ParseLuaModuleError;
88
89    // NOTE: We may want to add some validations
90    fn from_str(s: &str) -> Result<Self, Self::Err> {
91        Ok(LuaModule(s.into()))
92    }
93}
94
95impl Display for LuaModule {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        self.0.fmt(f)
98    }
99}
100
101#[derive(Debug, PartialEq, Clone)]
102pub enum ModuleSpec {
103    /// Pathnames of Lua files or C sources, for modules based on a single source file.
104    SourcePath(PathBuf),
105    /// Pathnames of C sources of a simple module written in C composed of multiple files.
106    SourcePaths(Vec<PathBuf>),
107    ModulePaths(ModulePaths),
108}
109
110impl IntoLua for ModuleSpec {
111    fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
112        let table = lua.create_table()?;
113
114        match self {
115            ModuleSpec::SourcePath(path_buf) => table.set("source", path_buf)?,
116            ModuleSpec::SourcePaths(path_bufs) => table.set("sources", path_bufs)?,
117            ModuleSpec::ModulePaths(module_paths) => table.set("modules", module_paths)?,
118        }
119
120        Ok(mlua::Value::Table(table))
121    }
122}
123
124impl ModuleSpec {
125    pub fn from_internal(
126        internal: ModuleSpecInternal,
127    ) -> Result<ModuleSpec, ModulePathsMissingSources> {
128        match internal {
129            ModuleSpecInternal::SourcePath(path) => Ok(ModuleSpec::SourcePath(path)),
130            ModuleSpecInternal::SourcePaths(paths) => Ok(ModuleSpec::SourcePaths(paths)),
131            ModuleSpecInternal::ModulePaths(module_paths) => Ok(ModuleSpec::ModulePaths(
132                ModulePaths::from_internal(module_paths)?,
133            )),
134        }
135    }
136}
137
138impl<'de> Deserialize<'de> for ModuleSpec {
139    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
140    where
141        D: Deserializer<'de>,
142    {
143        Self::from_internal(ModuleSpecInternal::deserialize(deserializer)?)
144            .map_err(de::Error::custom)
145    }
146}
147
148impl FromPlatformOverridable<ModuleSpecInternal, Self> for ModuleSpec {
149    type Err = ModulePathsMissingSources;
150
151    fn from_platform_overridable(internal: ModuleSpecInternal) -> Result<Self, Self::Err> {
152        Self::from_internal(internal)
153    }
154}
155
156#[derive(Debug, PartialEq, Clone)]
157pub enum ModuleSpecInternal {
158    SourcePath(PathBuf),
159    SourcePaths(Vec<PathBuf>),
160    ModulePaths(ModulePathsInternal),
161}
162
163impl<'de> Deserialize<'de> for ModuleSpecInternal {
164    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
165    where
166        D: Deserializer<'de>,
167    {
168        let value = serde_json::Value::deserialize(deserializer)?;
169        if value.is_string() {
170            let src_path = serde_json::from_value(value).map_err(de::Error::custom)?;
171            Ok(Self::SourcePath(src_path))
172        } else if value.is_array() {
173            let src_paths = serde_json::from_value(value).map_err(de::Error::custom)?;
174            Ok(Self::SourcePaths(src_paths))
175        } else {
176            let module_paths = serde_json::from_value(value).map_err(de::Error::custom)?;
177            Ok(Self::ModulePaths(module_paths))
178        }
179    }
180}
181
182impl DisplayAsLuaValue for ModuleSpecInternal {
183    fn display_lua_value(&self) -> DisplayLuaValue {
184        match self {
185            ModuleSpecInternal::SourcePath(path) => {
186                DisplayLuaValue::String(path.to_string_lossy().into())
187            }
188            ModuleSpecInternal::SourcePaths(paths) => DisplayLuaValue::List(
189                paths
190                    .iter()
191                    .map(|p| DisplayLuaValue::String(p.to_string_lossy().into()))
192                    .collect(),
193            ),
194            ModuleSpecInternal::ModulePaths(module_paths) => module_paths.display_lua_value(),
195        }
196    }
197}
198
199fn deserialize_definitions<'de, D>(
200    deserializer: D,
201) -> Result<Vec<(String, Option<String>)>, D::Error>
202where
203    D: Deserializer<'de>,
204{
205    let values: Vec<String> = deserialize_vec_from_lua_array_or_string(deserializer)?;
206    values
207        .iter()
208        .map(|val| {
209            if let Some((key, value)) = val.split_once('=') {
210                Ok((key.into(), Some(value.into())))
211            } else {
212                Ok((val.into(), None))
213            }
214        })
215        .try_collect()
216}
217
218#[derive(Error, Debug)]
219#[error("cannot resolve ambiguous platform override for `build.modules`.")]
220pub struct ModuleSpecAmbiguousPlatformOverride;
221
222impl PartialOverride for ModuleSpecInternal {
223    type Err = ModuleSpecAmbiguousPlatformOverride;
224
225    fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
226        match (override_spec, self) {
227            (ModuleSpecInternal::SourcePath(_), b @ ModuleSpecInternal::SourcePath(_)) => {
228                Ok(b.to_owned())
229            }
230            (ModuleSpecInternal::SourcePaths(_), b @ ModuleSpecInternal::SourcePaths(_)) => {
231                Ok(b.to_owned())
232            }
233            (ModuleSpecInternal::ModulePaths(a), ModuleSpecInternal::ModulePaths(b)) => Ok(
234                ModuleSpecInternal::ModulePaths(a.apply_overrides(b).unwrap()),
235            ),
236            _ => Err(ModuleSpecAmbiguousPlatformOverride),
237        }
238    }
239}
240
241#[derive(Error, Debug)]
242#[error("could not resolve platform override for `build.modules`. This is a bug!")]
243pub struct BuildModulesPlatformOverride;
244
245impl PlatformOverridable for ModuleSpecInternal {
246    type Err = BuildModulesPlatformOverride;
247
248    fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
249    where
250        T: PlatformOverridable,
251    {
252        Err(BuildModulesPlatformOverride)
253    }
254}
255
256#[derive(Error, Debug)]
257#[error("missing or empty field `sources`")]
258pub struct ModulePathsMissingSources;
259
260#[derive(Debug, PartialEq, Clone)]
261pub struct ModulePaths {
262    /// Path names of C sources, mandatory field
263    pub sources: Vec<PathBuf>,
264    /// External libraries to be linked
265    pub libraries: Vec<PathBuf>,
266    /// C defines, e.g. { "FOO=bar", "USE_BLA" }
267    pub defines: Vec<(String, Option<String>)>,
268    /// Directories to be added to the compiler's headers lookup directory list.
269    pub incdirs: Vec<PathBuf>,
270    /// Directories to be added to the linker's library lookup directory list.
271    pub libdirs: Vec<PathBuf>,
272}
273
274impl UserData for ModulePaths {
275    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
276        methods.add_method("sources", |_, this, _: ()| Ok(this.sources.clone()));
277        methods.add_method("libraries", |_, this, _: ()| Ok(this.libraries.clone()));
278        methods.add_method("defines", |_, this, _: ()| {
279            Ok(this
280                .defines
281                .iter()
282                .cloned()
283                .collect::<HashMap<_, Option<_>>>())
284        });
285        methods.add_method("incdirs", |_, this, _: ()| Ok(this.incdirs.clone()));
286        methods.add_method("libdirs", |_, this, _: ()| Ok(this.libdirs.clone()));
287    }
288}
289
290impl ModulePaths {
291    fn from_internal(
292        internal: ModulePathsInternal,
293    ) -> Result<ModulePaths, ModulePathsMissingSources> {
294        if internal.sources.is_empty() {
295            Err(ModulePathsMissingSources)
296        } else {
297            Ok(ModulePaths {
298                sources: internal.sources,
299                libraries: internal.libraries,
300                defines: internal.defines,
301                incdirs: internal.incdirs,
302                libdirs: internal.libdirs,
303            })
304        }
305    }
306}
307
308impl<'de> Deserialize<'de> for ModulePaths {
309    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
310    where
311        D: Deserializer<'de>,
312    {
313        let internal = ModulePathsInternal::deserialize(deserializer)?;
314        Self::from_internal(internal).map_err(de::Error::custom)
315    }
316}
317
318#[derive(Debug, PartialEq, Deserialize, Clone, Default)]
319pub struct ModulePathsInternal {
320    #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
321    pub sources: Vec<PathBuf>,
322    #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
323    pub libraries: Vec<PathBuf>,
324    #[serde(default, deserialize_with = "deserialize_definitions")]
325    pub defines: Vec<(String, Option<String>)>,
326    #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
327    pub incdirs: Vec<PathBuf>,
328    #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
329    pub libdirs: Vec<PathBuf>,
330}
331
332impl DisplayAsLuaValue for ModulePathsInternal {
333    fn display_lua_value(&self) -> DisplayLuaValue {
334        DisplayLuaValue::Table(vec![
335            DisplayLuaKV {
336                key: "sources".into(),
337                value: DisplayLuaValue::List(
338                    self.sources
339                        .iter()
340                        .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
341                        .collect(),
342                ),
343            },
344            DisplayLuaKV {
345                key: "libraries".into(),
346                value: DisplayLuaValue::List(
347                    self.libraries
348                        .iter()
349                        .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
350                        .collect(),
351                ),
352            },
353            DisplayLuaKV {
354                key: "defines".into(),
355                value: DisplayLuaValue::List(
356                    self.defines
357                        .iter()
358                        .map(|(k, v)| {
359                            if let Some(v) = v {
360                                DisplayLuaValue::String(format!("{}={}", k, v))
361                            } else {
362                                DisplayLuaValue::String(k.clone())
363                            }
364                        })
365                        .collect(),
366                ),
367            },
368            DisplayLuaKV {
369                key: "incdirs".into(),
370                value: DisplayLuaValue::List(
371                    self.incdirs
372                        .iter()
373                        .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
374                        .collect(),
375                ),
376            },
377            DisplayLuaKV {
378                key: "libdirs".into(),
379                value: DisplayLuaValue::List(
380                    self.libdirs
381                        .iter()
382                        .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
383                        .collect(),
384                ),
385            },
386        ])
387    }
388}
389
390impl PartialOverride for ModulePathsInternal {
391    type Err = Infallible;
392
393    fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
394        Ok(Self {
395            sources: override_vec(override_spec.sources.as_ref(), self.sources.as_ref()),
396            libraries: override_vec(override_spec.libraries.as_ref(), self.libraries.as_ref()),
397            defines: override_vec(override_spec.defines.as_ref(), self.defines.as_ref()),
398            incdirs: override_vec(override_spec.incdirs.as_ref(), self.incdirs.as_ref()),
399            libdirs: override_vec(override_spec.libdirs.as_ref(), self.libdirs.as_ref()),
400        })
401    }
402}
403
404impl PlatformOverridable for ModulePathsInternal {
405    type Err = Infallible;
406
407    fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
408    where
409        T: PlatformOverridable,
410        T: Default,
411    {
412        Ok(PerPlatform::default())
413    }
414}
415
416fn override_vec<T: Clone>(override_vec: &[T], base: &[T]) -> Vec<T> {
417    if override_vec.is_empty() {
418        return base.to_owned();
419    }
420    override_vec.to_owned()
421}
422
423#[cfg(test)]
424mod tests {
425    use mlua::{Lua, LuaSerdeExt};
426
427    use super::*;
428
429    #[tokio::test]
430    pub async fn parse_lua_module_from_path() {
431        let lua_module = LuaModule::from_pathbuf("foo/init.lua".into());
432        assert_eq!(&lua_module.0, "foo");
433        let lua_module = LuaModule::from_pathbuf("foo/bar.lua".into());
434        assert_eq!(&lua_module.0, "foo.bar");
435        let lua_module = LuaModule::from_pathbuf("foo/bar/init.lua".into());
436        assert_eq!(&lua_module.0, "foo.bar");
437        let lua_module = LuaModule::from_pathbuf("foo/bar/baz.lua".into());
438        assert_eq!(&lua_module.0, "foo.bar.baz");
439    }
440
441    #[tokio::test]
442    pub async fn modules_spec_from_lua() {
443        let lua_content = "
444        build = {\n
445            modules = {\n
446                foo = 'lua/foo/init.lua',\n
447                bar = {\n
448                  'lua/bar.lua',\n
449                  'lua/bar/internal.lua',\n
450                },\n
451                baz = {\n
452                    sources = {\n
453                        'lua/baz.lua',\n
454                    },\n
455                    defines = { 'USE_BAZ' },\n
456                },\n
457                foo = 'lua/foo/init.lua',
458            },\n
459        }\n
460        ";
461        let lua = Lua::new();
462        lua.load(lua_content).exec().unwrap();
463        let build_spec: BuiltinBuildSpec =
464            lua.from_value(lua.globals().get("build").unwrap()).unwrap();
465        let foo = build_spec
466            .modules
467            .get(&LuaModule::from_str("foo").unwrap())
468            .unwrap();
469        assert_eq!(*foo, ModuleSpec::SourcePath("lua/foo/init.lua".into()));
470        let bar = build_spec
471            .modules
472            .get(&LuaModule::from_str("bar").unwrap())
473            .unwrap();
474        assert_eq!(
475            *bar,
476            ModuleSpec::SourcePaths(vec!["lua/bar.lua".into(), "lua/bar/internal.lua".into()])
477        );
478        let baz = build_spec
479            .modules
480            .get(&LuaModule::from_str("baz").unwrap())
481            .unwrap();
482        assert!(matches!(baz, ModuleSpec::ModulePaths { .. }));
483        let lua_content_no_sources = "
484        build = {\n
485            modules = {\n
486                baz = {\n
487                    defines = { 'USE_BAZ' },\n
488                },\n
489            },\n
490        }\n
491        ";
492        lua.load(lua_content_no_sources).exec().unwrap();
493        let result: mlua::Result<BuiltinBuildSpec> =
494            lua.from_value(lua.globals().get("build").unwrap());
495        let _err = result.unwrap_err();
496        let lua_content_complex_defines = "
497        build = {\n
498            modules = {\n
499                baz = {\n
500                    sources = {\n
501                        'lua/baz.lua',\n
502                    },\n
503                    defines = { 'USE_BAZ=1', 'ENABLE_LOGGING=true', 'LINK_STATIC' },\n
504                },\n
505            },\n
506        }\n
507        ";
508        lua.load(lua_content_complex_defines).exec().unwrap();
509        let build_spec: BuiltinBuildSpec =
510            lua.from_value(lua.globals().get("build").unwrap()).unwrap();
511        let baz = build_spec
512            .modules
513            .get(&LuaModule::from_str("baz").unwrap())
514            .unwrap();
515        match baz {
516            ModuleSpec::ModulePaths(paths) => assert_eq!(
517                paths.defines,
518                vec![
519                    ("USE_BAZ".into(), Some("1".into())),
520                    ("ENABLE_LOGGING".into(), Some("true".into())),
521                    ("LINK_STATIC".into(), None)
522                ]
523            ),
524            _ => panic!(),
525        }
526    }
527}