Skip to main content

lux_lib/lua_rockspec/build/
builtin.rs

1use itertools::Itertools;
2use serde::{de, Deserialize, Deserializer};
3use std::{collections::HashMap, convert::Infallible, fmt::Display, path::PathBuf, str::FromStr};
4use thiserror::Error;
5
6use crate::{
7    build::utils::c_dylib_extension,
8    lua_rockspec::{
9        deserialize_vec_from_lua_array_or_string, normalize_lua_value, DisplayAsLuaValue,
10        PartialOverride, PerPlatform, PlatformOverridable,
11    },
12};
13
14use super::{DisplayLuaKV, DisplayLuaValue};
15
16#[derive(Debug, PartialEq, Deserialize, Default, Clone)]
17pub struct BuiltinBuildSpec {
18    /// Keys are module names in the format normally used by the `require()` function
19    pub modules: HashMap<LuaModule, ModuleSpec>,
20}
21
22#[derive(Debug, PartialEq, Eq, Deserialize, Default, Clone, Hash)]
23pub struct LuaModule(String);
24
25impl LuaModule {
26    pub fn to_lua_path(&self) -> PathBuf {
27        self.to_file_path(".lua")
28    }
29
30    pub fn to_lua_init_path(&self) -> PathBuf {
31        self.to_path_buf().join("init.lua")
32    }
33
34    pub fn to_lib_path(&self) -> PathBuf {
35        self.to_file_path(&format!(".{}", c_dylib_extension()))
36    }
37
38    fn to_path_buf(&self) -> PathBuf {
39        PathBuf::from(self.0.replace('.', std::path::MAIN_SEPARATOR_STR))
40    }
41
42    fn to_file_path(&self, extension: &str) -> PathBuf {
43        PathBuf::from(self.0.replace('.', std::path::MAIN_SEPARATOR_STR) + extension)
44    }
45
46    pub fn from_pathbuf(path: PathBuf) -> Self {
47        let extension = path
48            .extension()
49            .map(|ext| ext.to_string_lossy().to_string())
50            .unwrap_or("".into());
51        let module = path
52            .to_string_lossy()
53            .trim_end_matches(format!("init.{extension}").as_str())
54            .trim_end_matches(format!(".{extension}").as_str())
55            .trim_end_matches(std::path::MAIN_SEPARATOR_STR)
56            .replace(std::path::MAIN_SEPARATOR_STR, ".");
57        LuaModule(module)
58    }
59
60    pub fn join(&self, other: &LuaModule) -> LuaModule {
61        LuaModule(format!("{}.{}", self.0, other.0))
62    }
63
64    pub fn as_str(&self) -> &str {
65        self.0.as_str()
66    }
67}
68
69#[derive(Error, Debug)]
70#[error("could not parse lua module from {0}.")]
71pub struct ParseLuaModuleError(String);
72
73impl FromStr for LuaModule {
74    type Err = ParseLuaModuleError;
75
76    // NOTE: We may want to add some validations
77    fn from_str(s: &str) -> Result<Self, Self::Err> {
78        Ok(LuaModule(s.into()))
79    }
80}
81
82impl Display for LuaModule {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        self.0.fmt(f)
85    }
86}
87
88impl DisplayAsLuaValue for LuaModule {
89    fn display_lua_value(&self) -> DisplayLuaValue {
90        DisplayLuaValue::String(self.to_string())
91    }
92}
93
94impl DisplayAsLuaValue for HashMap<LuaModule, PathBuf> {
95    fn display_lua_value(&self) -> DisplayLuaValue {
96        use path_slash::PathBufExt as _;
97        DisplayLuaValue::Table(
98            self.iter()
99                .map(|(k, v)| DisplayLuaKV {
100                    key: k.to_string(),
101                    value: DisplayLuaValue::String(v.to_slash_lossy().into_owned()),
102                })
103                .collect_vec(),
104        )
105    }
106}
107
108#[derive(Debug, PartialEq, Clone)]
109pub enum ModuleSpec {
110    /// Pathnames of Lua files or C sources, for modules based on a single source file.
111    SourcePath(PathBuf),
112    /// Pathnames of C sources of a simple module written in C composed of multiple files.
113    SourcePaths(Vec<PathBuf>),
114    ModulePaths(ModulePaths),
115}
116
117impl ModuleSpec {
118    pub fn from_internal(
119        internal: ModuleSpecInternal,
120    ) -> Result<ModuleSpec, ModulePathsMissingSources> {
121        match internal {
122            ModuleSpecInternal::SourcePath(path) => Ok(ModuleSpec::SourcePath(path)),
123            ModuleSpecInternal::SourcePaths(paths) => Ok(ModuleSpec::SourcePaths(paths)),
124            ModuleSpecInternal::ModulePaths(module_paths) => Ok(ModuleSpec::ModulePaths(
125                ModulePaths::from_internal(module_paths)?,
126            )),
127        }
128    }
129}
130
131impl<'de> Deserialize<'de> for ModuleSpec {
132    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
133    where
134        D: Deserializer<'de>,
135    {
136        Self::from_internal(ModuleSpecInternal::deserialize(deserializer)?)
137            .map_err(de::Error::custom)
138    }
139}
140
141impl TryFrom<ModuleSpecInternal> for ModuleSpec {
142    type Error = ModulePathsMissingSources;
143
144    fn try_from(internal: ModuleSpecInternal) -> Result<Self, Self::Error> {
145        Self::from_internal(internal)
146    }
147}
148
149#[derive(Debug, PartialEq, Clone)]
150pub enum ModuleSpecInternal {
151    SourcePath(PathBuf),
152    SourcePaths(Vec<PathBuf>),
153    ModulePaths(ModulePathsInternal),
154}
155
156impl<'de> Deserialize<'de> for ModuleSpecInternal {
157    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
158    where
159        D: Deserializer<'de>,
160    {
161        let value = normalize_lua_value(serde_value::Value::deserialize(deserializer)?);
162        match value {
163            serde_value::Value::String(s) => Ok(Self::SourcePath(PathBuf::from(s))),
164            serde_value::Value::Seq(_) => {
165                let src_paths: Vec<PathBuf> =
166                    value.deserialize_into().map_err(de::Error::custom)?;
167                Ok(Self::SourcePaths(src_paths))
168            }
169            serde_value::Value::Map(_) => {
170                let module_paths: ModulePathsInternal =
171                    value.deserialize_into().map_err(de::Error::custom)?;
172                Ok(Self::ModulePaths(module_paths))
173            }
174            _ => Err(de::Error::custom(format!(
175                "expected a string, list, or table for module spec, got: {value:?}"
176            ))),
177        }
178    }
179}
180
181impl DisplayAsLuaValue for ModuleSpecInternal {
182    fn display_lua_value(&self) -> DisplayLuaValue {
183        match self {
184            ModuleSpecInternal::SourcePath(path) => {
185                DisplayLuaValue::String(path.to_string_lossy().into())
186            }
187            ModuleSpecInternal::SourcePaths(paths) => DisplayLuaValue::List(
188                paths
189                    .iter()
190                    .map(|p| DisplayLuaValue::String(p.to_string_lossy().into()))
191                    .collect(),
192            ),
193            ModuleSpecInternal::ModulePaths(module_paths) => module_paths.display_lua_value(),
194        }
195    }
196}
197
198fn deserialize_definitions<'de, D>(
199    deserializer: D,
200) -> Result<Vec<(String, Option<String>)>, D::Error>
201where
202    D: Deserializer<'de>,
203{
204    let values: Vec<String> = deserialize_vec_from_lua_array_or_string(deserializer)?;
205    values
206        .iter()
207        .map(|val| {
208            if let Some((key, value)) = val.split_once('=') {
209                Ok((key.into(), Some(value.into())))
210            } else {
211                Ok((val.into(), None))
212            }
213        })
214        .try_collect()
215}
216
217#[derive(Error, Debug)]
218#[error("cannot resolve ambiguous platform override for `build.modules`.")]
219pub struct ModuleSpecAmbiguousPlatformOverride;
220
221impl PartialOverride for ModuleSpecInternal {
222    type Err = ModuleSpecAmbiguousPlatformOverride;
223
224    fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
225        match (override_spec, self) {
226            (ModuleSpecInternal::SourcePath(_), b @ ModuleSpecInternal::SourcePath(_)) => {
227                Ok(b.to_owned())
228            }
229            (ModuleSpecInternal::SourcePaths(_), b @ ModuleSpecInternal::SourcePaths(_)) => {
230                Ok(b.to_owned())
231            }
232            (ModuleSpecInternal::ModulePaths(a), ModuleSpecInternal::ModulePaths(b)) => Ok(
233                ModuleSpecInternal::ModulePaths(a.apply_overrides(b).unwrap()),
234            ),
235            _ => Err(ModuleSpecAmbiguousPlatformOverride),
236        }
237    }
238}
239
240#[derive(Error, Debug)]
241#[error("could not resolve platform override for `build.modules`. THIS IS A BUG!")]
242pub struct BuildModulesPlatformOverride;
243
244impl PlatformOverridable for ModuleSpecInternal {
245    type Err = BuildModulesPlatformOverride;
246
247    fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
248    where
249        T: PlatformOverridable,
250    {
251        Err(BuildModulesPlatformOverride)
252    }
253}
254
255#[derive(Error, Debug)]
256#[error("missing or empty field `sources`")]
257pub struct ModulePathsMissingSources;
258
259#[derive(Debug, PartialEq, Clone)]
260pub struct ModulePaths {
261    /// Path names of C sources, mandatory field
262    pub sources: Vec<PathBuf>,
263    /// External libraries to be linked
264    pub libraries: Vec<PathBuf>,
265    /// C defines, e.g. { "FOO=bar", "USE_BLA" }
266    pub defines: Vec<(String, Option<String>)>,
267    /// Directories to be added to the compiler's headers lookup directory list.
268    pub incdirs: Vec<PathBuf>,
269    /// Directories to be added to the linker's library lookup directory list.
270    pub libdirs: Vec<PathBuf>,
271}
272
273impl ModulePaths {
274    fn from_internal(
275        internal: ModulePathsInternal,
276    ) -> Result<ModulePaths, ModulePathsMissingSources> {
277        if internal.sources.is_empty() {
278            Err(ModulePathsMissingSources)
279        } else {
280            Ok(ModulePaths {
281                sources: internal.sources,
282                libraries: internal.libraries,
283                defines: internal.defines,
284                incdirs: internal.incdirs,
285                libdirs: internal.libdirs,
286            })
287        }
288    }
289}
290
291impl TryFrom<ModulePathsInternal> for ModulePaths {
292    type Error = ModulePathsMissingSources;
293
294    fn try_from(internal: ModulePathsInternal) -> Result<Self, Self::Error> {
295        Self::from_internal(internal)
296    }
297}
298
299impl<'de> Deserialize<'de> for ModulePaths {
300    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
301    where
302        D: Deserializer<'de>,
303    {
304        Self::from_internal(ModulePathsInternal::deserialize(deserializer)?)
305            .map_err(de::Error::custom)
306    }
307}
308
309#[derive(Debug, PartialEq, Deserialize, Clone, Default)]
310pub struct ModulePathsInternal {
311    #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
312    pub sources: Vec<PathBuf>,
313    #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
314    pub libraries: Vec<PathBuf>,
315    #[serde(default, deserialize_with = "deserialize_definitions")]
316    pub defines: Vec<(String, Option<String>)>,
317    #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
318    pub incdirs: Vec<PathBuf>,
319    #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
320    pub libdirs: Vec<PathBuf>,
321}
322
323impl DisplayAsLuaValue for ModulePathsInternal {
324    fn display_lua_value(&self) -> DisplayLuaValue {
325        DisplayLuaValue::Table(vec![
326            DisplayLuaKV {
327                key: "sources".into(),
328                value: DisplayLuaValue::List(
329                    self.sources
330                        .iter()
331                        .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
332                        .collect(),
333                ),
334            },
335            DisplayLuaKV {
336                key: "libraries".into(),
337                value: DisplayLuaValue::List(
338                    self.libraries
339                        .iter()
340                        .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
341                        .collect(),
342                ),
343            },
344            DisplayLuaKV {
345                key: "defines".into(),
346                value: DisplayLuaValue::List(
347                    self.defines
348                        .iter()
349                        .map(|(k, v)| {
350                            if let Some(v) = v {
351                                DisplayLuaValue::String(format!("{k}={v}"))
352                            } else {
353                                DisplayLuaValue::String(k.clone())
354                            }
355                        })
356                        .collect(),
357                ),
358            },
359            DisplayLuaKV {
360                key: "incdirs".into(),
361                value: DisplayLuaValue::List(
362                    self.incdirs
363                        .iter()
364                        .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
365                        .collect(),
366                ),
367            },
368            DisplayLuaKV {
369                key: "libdirs".into(),
370                value: DisplayLuaValue::List(
371                    self.libdirs
372                        .iter()
373                        .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
374                        .collect(),
375                ),
376            },
377        ])
378    }
379}
380
381impl PartialOverride for ModulePathsInternal {
382    type Err = Infallible;
383
384    fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
385        Ok(Self {
386            sources: override_vec(override_spec.sources.as_ref(), self.sources.as_ref()),
387            libraries: override_vec(override_spec.libraries.as_ref(), self.libraries.as_ref()),
388            defines: override_vec(override_spec.defines.as_ref(), self.defines.as_ref()),
389            incdirs: override_vec(override_spec.incdirs.as_ref(), self.incdirs.as_ref()),
390            libdirs: override_vec(override_spec.libdirs.as_ref(), self.libdirs.as_ref()),
391        })
392    }
393}
394
395impl PlatformOverridable for ModulePathsInternal {
396    type Err = Infallible;
397
398    fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
399    where
400        T: PlatformOverridable,
401        T: Default,
402    {
403        Ok(PerPlatform::default())
404    }
405}
406
407fn override_vec<T: Clone>(override_vec: &[T], base: &[T]) -> Vec<T> {
408    if override_vec.is_empty() {
409        return base.to_owned();
410    }
411    override_vec.to_owned()
412}
413
414#[cfg(test)]
415mod tests {
416    use ottavino::{Closure, Executor, Fuel, Lua};
417    use ottavino_util::serde::from_value;
418
419    use super::*;
420
421    fn exec_lua<T: serde::de::DeserializeOwned>(
422        code: &str,
423        key: &'static str,
424    ) -> Result<T, ottavino::ExternError> {
425        Lua::core().try_enter(|ctx| {
426            let closure = Closure::load(ctx, None, code.as_bytes())?;
427            let executor = Executor::start(ctx, closure.into(), ());
428            executor.step(ctx, &mut Fuel::with(i32::MAX))?;
429            from_value(ctx.globals().get_value(ctx, key)).map_err(ottavino::Error::from)
430        })
431    }
432
433    #[tokio::test]
434    pub async fn parse_lua_module_from_path() {
435        let lua_module = LuaModule::from_pathbuf("foo/init.lua".into());
436        assert_eq!(&lua_module.0, "foo");
437        let lua_module = LuaModule::from_pathbuf("foo/bar.lua".into());
438        assert_eq!(&lua_module.0, "foo.bar");
439        let lua_module = LuaModule::from_pathbuf("foo/bar/init.lua".into());
440        assert_eq!(&lua_module.0, "foo.bar");
441        let lua_module = LuaModule::from_pathbuf("foo/bar/baz.lua".into());
442        assert_eq!(&lua_module.0, "foo.bar.baz");
443    }
444
445    #[tokio::test]
446    pub async fn modules_spec_from_lua() {
447        let lua_content = "
448        build = {\n
449            modules = {\n
450                foo = 'lua/foo/init.lua',\n
451                bar = {\n
452                  'lua/bar.lua',\n
453                  'lua/bar/internal.lua',\n
454                },\n
455                baz = {\n
456                    sources = {\n
457                        'lua/baz.lua',\n
458                    },\n
459                    defines = { 'USE_BAZ' },\n
460                },\n
461                foo = 'lua/foo/init.lua',
462            },\n
463        }\n
464        ";
465        let build_spec: BuiltinBuildSpec = exec_lua(lua_content, "build").unwrap();
466        let foo = build_spec
467            .modules
468            .get(&LuaModule::from_str("foo").unwrap())
469            .unwrap();
470        assert_eq!(*foo, ModuleSpec::SourcePath("lua/foo/init.lua".into()));
471        let bar = build_spec
472            .modules
473            .get(&LuaModule::from_str("bar").unwrap())
474            .unwrap();
475        assert_eq!(
476            *bar,
477            ModuleSpec::SourcePaths(vec!["lua/bar.lua".into(), "lua/bar/internal.lua".into()])
478        );
479        let baz = build_spec
480            .modules
481            .get(&LuaModule::from_str("baz").unwrap())
482            .unwrap();
483        assert!(matches!(baz, ModuleSpec::ModulePaths { .. }));
484        let lua_content_no_sources = "
485        build = {\n
486            modules = {\n
487                baz = {\n
488                    defines = { 'USE_BAZ' },\n
489                },\n
490            },\n
491        }\n
492        ";
493        let result: Result<BuiltinBuildSpec, _> = exec_lua(lua_content_no_sources, "build");
494        let _err = result.unwrap_err();
495        let lua_content_complex_defines = "
496        build = {\n
497            modules = {\n
498                baz = {\n
499                    sources = {\n
500                        'lua/baz.lua',\n
501                    },\n
502                    defines = { 'USE_BAZ=1', 'ENABLE_LOGGING=true', 'LINK_STATIC' },\n
503                },\n
504            },\n
505        }\n
506        ";
507        let build_spec: BuiltinBuildSpec = exec_lua(lua_content_complex_defines, "build").unwrap();
508        let baz = build_spec
509            .modules
510            .get(&LuaModule::from_str("baz").unwrap())
511            .unwrap();
512        match baz {
513            ModuleSpec::ModulePaths(paths) => assert_eq!(
514                paths.defines,
515                vec![
516                    ("USE_BAZ".into(), Some("1".into())),
517                    ("ENABLE_LOGGING".into(), Some("true".into())),
518                    ("LINK_STATIC".into(), None)
519                ]
520            ),
521            _ => panic!(),
522        }
523    }
524}