1use itertools::Itertools;
2use mlua::{IntoLua, UserData};
3use serde::{de, Deserialize, Deserializer};
4use std::{collections::HashMap, convert::Infallible, fmt::Display, path::PathBuf, str::FromStr};
5use thiserror::Error;
6
7use crate::{
8 build::utils::c_dylib_extension,
9 lua_rockspec::{
10 deserialize_vec_from_lua_array_or_string, DisplayAsLuaValue, FromPlatformOverridable,
11 PartialOverride, PerPlatform, PlatformOverridable,
12 },
13};
14
15use super::{DisplayLuaKV, DisplayLuaValue};
16
17#[derive(Debug, PartialEq, Deserialize, Default, Clone)]
18pub struct BuiltinBuildSpec {
19 pub modules: HashMap<LuaModule, ModuleSpec>,
21}
22
23impl IntoLua for BuiltinBuildSpec {
24 fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
25 self.modules.into_lua(lua)
26 }
27}
28
29#[derive(Debug, PartialEq, Eq, Deserialize, Default, Clone, Hash)]
30pub struct LuaModule(String);
31
32impl IntoLua for LuaModule {
33 fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
34 self.0.into_lua(lua)
35 }
36}
37
38impl LuaModule {
39 pub fn to_lua_path(&self) -> PathBuf {
40 self.to_file_path(".lua")
41 }
42
43 pub fn to_lua_init_path(&self) -> PathBuf {
44 self.to_path_buf().join("init.lua")
45 }
46
47 pub fn to_lib_path(&self) -> PathBuf {
48 self.to_file_path(&format!(".{}", c_dylib_extension()))
49 }
50
51 fn to_path_buf(&self) -> PathBuf {
52 PathBuf::from(self.0.replace('.', std::path::MAIN_SEPARATOR_STR))
53 }
54
55 fn to_file_path(&self, extension: &str) -> PathBuf {
56 PathBuf::from(self.0.replace('.', std::path::MAIN_SEPARATOR_STR) + extension)
57 }
58
59 pub fn from_pathbuf(path: PathBuf) -> Self {
60 let extension = path
61 .extension()
62 .map(|ext| ext.to_string_lossy().to_string())
63 .unwrap_or("".into());
64 let module = path
65 .to_string_lossy()
66 .trim_end_matches(format!("init.{}", extension).as_str())
67 .trim_end_matches(format!(".{}", extension).as_str())
68 .trim_end_matches(std::path::MAIN_SEPARATOR_STR)
69 .replace(std::path::MAIN_SEPARATOR_STR, ".");
70 LuaModule(module)
71 }
72
73 pub fn join(&self, other: &LuaModule) -> LuaModule {
74 LuaModule(format!("{}.{}", self.0, other.0))
75 }
76
77 pub fn as_str(&self) -> &str {
78 self.0.as_str()
79 }
80}
81
82#[derive(Error, Debug)]
83#[error("could not parse lua module from {0}.")]
84pub struct ParseLuaModuleError(String);
85
86impl FromStr for LuaModule {
87 type Err = ParseLuaModuleError;
88
89 fn from_str(s: &str) -> Result<Self, Self::Err> {
91 Ok(LuaModule(s.into()))
92 }
93}
94
95impl Display for LuaModule {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 self.0.fmt(f)
98 }
99}
100
101#[derive(Debug, PartialEq, Clone)]
102pub enum ModuleSpec {
103 SourcePath(PathBuf),
105 SourcePaths(Vec<PathBuf>),
107 ModulePaths(ModulePaths),
108}
109
110impl IntoLua for ModuleSpec {
111 fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
112 let table = lua.create_table()?;
113
114 match self {
115 ModuleSpec::SourcePath(path_buf) => table.set("source", path_buf)?,
116 ModuleSpec::SourcePaths(path_bufs) => table.set("sources", path_bufs)?,
117 ModuleSpec::ModulePaths(module_paths) => table.set("modules", module_paths)?,
118 }
119
120 Ok(mlua::Value::Table(table))
121 }
122}
123
124impl ModuleSpec {
125 pub fn from_internal(
126 internal: ModuleSpecInternal,
127 ) -> Result<ModuleSpec, ModulePathsMissingSources> {
128 match internal {
129 ModuleSpecInternal::SourcePath(path) => Ok(ModuleSpec::SourcePath(path)),
130 ModuleSpecInternal::SourcePaths(paths) => Ok(ModuleSpec::SourcePaths(paths)),
131 ModuleSpecInternal::ModulePaths(module_paths) => Ok(ModuleSpec::ModulePaths(
132 ModulePaths::from_internal(module_paths)?,
133 )),
134 }
135 }
136}
137
138impl<'de> Deserialize<'de> for ModuleSpec {
139 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
140 where
141 D: Deserializer<'de>,
142 {
143 Self::from_internal(ModuleSpecInternal::deserialize(deserializer)?)
144 .map_err(de::Error::custom)
145 }
146}
147
148impl FromPlatformOverridable<ModuleSpecInternal, Self> for ModuleSpec {
149 type Err = ModulePathsMissingSources;
150
151 fn from_platform_overridable(internal: ModuleSpecInternal) -> Result<Self, Self::Err> {
152 Self::from_internal(internal)
153 }
154}
155
156#[derive(Debug, PartialEq, Clone)]
157pub enum ModuleSpecInternal {
158 SourcePath(PathBuf),
159 SourcePaths(Vec<PathBuf>),
160 ModulePaths(ModulePathsInternal),
161}
162
163impl<'de> Deserialize<'de> for ModuleSpecInternal {
164 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
165 where
166 D: Deserializer<'de>,
167 {
168 let value = serde_json::Value::deserialize(deserializer)?;
169 if value.is_string() {
170 let src_path = serde_json::from_value(value).map_err(de::Error::custom)?;
171 Ok(Self::SourcePath(src_path))
172 } else if value.is_array() {
173 let src_paths = serde_json::from_value(value).map_err(de::Error::custom)?;
174 Ok(Self::SourcePaths(src_paths))
175 } else {
176 let module_paths = serde_json::from_value(value).map_err(de::Error::custom)?;
177 Ok(Self::ModulePaths(module_paths))
178 }
179 }
180}
181
182impl DisplayAsLuaValue for ModuleSpecInternal {
183 fn display_lua_value(&self) -> DisplayLuaValue {
184 match self {
185 ModuleSpecInternal::SourcePath(path) => {
186 DisplayLuaValue::String(path.to_string_lossy().into())
187 }
188 ModuleSpecInternal::SourcePaths(paths) => DisplayLuaValue::List(
189 paths
190 .iter()
191 .map(|p| DisplayLuaValue::String(p.to_string_lossy().into()))
192 .collect(),
193 ),
194 ModuleSpecInternal::ModulePaths(module_paths) => module_paths.display_lua_value(),
195 }
196 }
197}
198
199fn deserialize_definitions<'de, D>(
200 deserializer: D,
201) -> Result<Vec<(String, Option<String>)>, D::Error>
202where
203 D: Deserializer<'de>,
204{
205 let values: Vec<String> = deserialize_vec_from_lua_array_or_string(deserializer)?;
206 values
207 .iter()
208 .map(|val| {
209 if let Some((key, value)) = val.split_once('=') {
210 Ok((key.into(), Some(value.into())))
211 } else {
212 Ok((val.into(), None))
213 }
214 })
215 .try_collect()
216}
217
218#[derive(Error, Debug)]
219#[error("cannot resolve ambiguous platform override for `build.modules`.")]
220pub struct ModuleSpecAmbiguousPlatformOverride;
221
222impl PartialOverride for ModuleSpecInternal {
223 type Err = ModuleSpecAmbiguousPlatformOverride;
224
225 fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
226 match (override_spec, self) {
227 (ModuleSpecInternal::SourcePath(_), b @ ModuleSpecInternal::SourcePath(_)) => {
228 Ok(b.to_owned())
229 }
230 (ModuleSpecInternal::SourcePaths(_), b @ ModuleSpecInternal::SourcePaths(_)) => {
231 Ok(b.to_owned())
232 }
233 (ModuleSpecInternal::ModulePaths(a), ModuleSpecInternal::ModulePaths(b)) => Ok(
234 ModuleSpecInternal::ModulePaths(a.apply_overrides(b).unwrap()),
235 ),
236 _ => Err(ModuleSpecAmbiguousPlatformOverride),
237 }
238 }
239}
240
241#[derive(Error, Debug)]
242#[error("could not resolve platform override for `build.modules`. This is a bug!")]
243pub struct BuildModulesPlatformOverride;
244
245impl PlatformOverridable for ModuleSpecInternal {
246 type Err = BuildModulesPlatformOverride;
247
248 fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
249 where
250 T: PlatformOverridable,
251 {
252 Err(BuildModulesPlatformOverride)
253 }
254}
255
256#[derive(Error, Debug)]
257#[error("missing or empty field `sources`")]
258pub struct ModulePathsMissingSources;
259
260#[derive(Debug, PartialEq, Clone)]
261pub struct ModulePaths {
262 pub sources: Vec<PathBuf>,
264 pub libraries: Vec<PathBuf>,
266 pub defines: Vec<(String, Option<String>)>,
268 pub incdirs: Vec<PathBuf>,
270 pub libdirs: Vec<PathBuf>,
272}
273
274impl UserData for ModulePaths {
275 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
276 methods.add_method("sources", |_, this, _: ()| Ok(this.sources.clone()));
277 methods.add_method("libraries", |_, this, _: ()| Ok(this.libraries.clone()));
278 methods.add_method("defines", |_, this, _: ()| {
279 Ok(this
280 .defines
281 .iter()
282 .cloned()
283 .collect::<HashMap<_, Option<_>>>())
284 });
285 methods.add_method("incdirs", |_, this, _: ()| Ok(this.incdirs.clone()));
286 methods.add_method("libdirs", |_, this, _: ()| Ok(this.libdirs.clone()));
287 }
288}
289
290impl ModulePaths {
291 fn from_internal(
292 internal: ModulePathsInternal,
293 ) -> Result<ModulePaths, ModulePathsMissingSources> {
294 if internal.sources.is_empty() {
295 Err(ModulePathsMissingSources)
296 } else {
297 Ok(ModulePaths {
298 sources: internal.sources,
299 libraries: internal.libraries,
300 defines: internal.defines,
301 incdirs: internal.incdirs,
302 libdirs: internal.libdirs,
303 })
304 }
305 }
306}
307
308impl<'de> Deserialize<'de> for ModulePaths {
309 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
310 where
311 D: Deserializer<'de>,
312 {
313 let internal = ModulePathsInternal::deserialize(deserializer)?;
314 Self::from_internal(internal).map_err(de::Error::custom)
315 }
316}
317
318#[derive(Debug, PartialEq, Deserialize, Clone, Default)]
319pub struct ModulePathsInternal {
320 #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
321 pub sources: Vec<PathBuf>,
322 #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
323 pub libraries: Vec<PathBuf>,
324 #[serde(default, deserialize_with = "deserialize_definitions")]
325 pub defines: Vec<(String, Option<String>)>,
326 #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
327 pub incdirs: Vec<PathBuf>,
328 #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
329 pub libdirs: Vec<PathBuf>,
330}
331
332impl DisplayAsLuaValue for ModulePathsInternal {
333 fn display_lua_value(&self) -> DisplayLuaValue {
334 DisplayLuaValue::Table(vec![
335 DisplayLuaKV {
336 key: "sources".into(),
337 value: DisplayLuaValue::List(
338 self.sources
339 .iter()
340 .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
341 .collect(),
342 ),
343 },
344 DisplayLuaKV {
345 key: "libraries".into(),
346 value: DisplayLuaValue::List(
347 self.libraries
348 .iter()
349 .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
350 .collect(),
351 ),
352 },
353 DisplayLuaKV {
354 key: "defines".into(),
355 value: DisplayLuaValue::List(
356 self.defines
357 .iter()
358 .map(|(k, v)| {
359 if let Some(v) = v {
360 DisplayLuaValue::String(format!("{}={}", k, v))
361 } else {
362 DisplayLuaValue::String(k.clone())
363 }
364 })
365 .collect(),
366 ),
367 },
368 DisplayLuaKV {
369 key: "incdirs".into(),
370 value: DisplayLuaValue::List(
371 self.incdirs
372 .iter()
373 .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
374 .collect(),
375 ),
376 },
377 DisplayLuaKV {
378 key: "libdirs".into(),
379 value: DisplayLuaValue::List(
380 self.libdirs
381 .iter()
382 .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
383 .collect(),
384 ),
385 },
386 ])
387 }
388}
389
390impl PartialOverride for ModulePathsInternal {
391 type Err = Infallible;
392
393 fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
394 Ok(Self {
395 sources: override_vec(override_spec.sources.as_ref(), self.sources.as_ref()),
396 libraries: override_vec(override_spec.libraries.as_ref(), self.libraries.as_ref()),
397 defines: override_vec(override_spec.defines.as_ref(), self.defines.as_ref()),
398 incdirs: override_vec(override_spec.incdirs.as_ref(), self.incdirs.as_ref()),
399 libdirs: override_vec(override_spec.libdirs.as_ref(), self.libdirs.as_ref()),
400 })
401 }
402}
403
404impl PlatformOverridable for ModulePathsInternal {
405 type Err = Infallible;
406
407 fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
408 where
409 T: PlatformOverridable,
410 T: Default,
411 {
412 Ok(PerPlatform::default())
413 }
414}
415
416fn override_vec<T: Clone>(override_vec: &[T], base: &[T]) -> Vec<T> {
417 if override_vec.is_empty() {
418 return base.to_owned();
419 }
420 override_vec.to_owned()
421}
422
423#[cfg(test)]
424mod tests {
425 use mlua::{Lua, LuaSerdeExt};
426
427 use super::*;
428
429 #[tokio::test]
430 pub async fn parse_lua_module_from_path() {
431 let lua_module = LuaModule::from_pathbuf("foo/init.lua".into());
432 assert_eq!(&lua_module.0, "foo");
433 let lua_module = LuaModule::from_pathbuf("foo/bar.lua".into());
434 assert_eq!(&lua_module.0, "foo.bar");
435 let lua_module = LuaModule::from_pathbuf("foo/bar/init.lua".into());
436 assert_eq!(&lua_module.0, "foo.bar");
437 let lua_module = LuaModule::from_pathbuf("foo/bar/baz.lua".into());
438 assert_eq!(&lua_module.0, "foo.bar.baz");
439 }
440
441 #[tokio::test]
442 pub async fn modules_spec_from_lua() {
443 let lua_content = "
444 build = {\n
445 modules = {\n
446 foo = 'lua/foo/init.lua',\n
447 bar = {\n
448 'lua/bar.lua',\n
449 'lua/bar/internal.lua',\n
450 },\n
451 baz = {\n
452 sources = {\n
453 'lua/baz.lua',\n
454 },\n
455 defines = { 'USE_BAZ' },\n
456 },\n
457 foo = 'lua/foo/init.lua',
458 },\n
459 }\n
460 ";
461 let lua = Lua::new();
462 lua.load(lua_content).exec().unwrap();
463 let build_spec: BuiltinBuildSpec =
464 lua.from_value(lua.globals().get("build").unwrap()).unwrap();
465 let foo = build_spec
466 .modules
467 .get(&LuaModule::from_str("foo").unwrap())
468 .unwrap();
469 assert_eq!(*foo, ModuleSpec::SourcePath("lua/foo/init.lua".into()));
470 let bar = build_spec
471 .modules
472 .get(&LuaModule::from_str("bar").unwrap())
473 .unwrap();
474 assert_eq!(
475 *bar,
476 ModuleSpec::SourcePaths(vec!["lua/bar.lua".into(), "lua/bar/internal.lua".into()])
477 );
478 let baz = build_spec
479 .modules
480 .get(&LuaModule::from_str("baz").unwrap())
481 .unwrap();
482 assert!(matches!(baz, ModuleSpec::ModulePaths { .. }));
483 let lua_content_no_sources = "
484 build = {\n
485 modules = {\n
486 baz = {\n
487 defines = { 'USE_BAZ' },\n
488 },\n
489 },\n
490 }\n
491 ";
492 lua.load(lua_content_no_sources).exec().unwrap();
493 let result: mlua::Result<BuiltinBuildSpec> =
494 lua.from_value(lua.globals().get("build").unwrap());
495 let _err = result.unwrap_err();
496 let lua_content_complex_defines = "
497 build = {\n
498 modules = {\n
499 baz = {\n
500 sources = {\n
501 'lua/baz.lua',\n
502 },\n
503 defines = { 'USE_BAZ=1', 'ENABLE_LOGGING=true', 'LINK_STATIC' },\n
504 },\n
505 },\n
506 }\n
507 ";
508 lua.load(lua_content_complex_defines).exec().unwrap();
509 let build_spec: BuiltinBuildSpec =
510 lua.from_value(lua.globals().get("build").unwrap()).unwrap();
511 let baz = build_spec
512 .modules
513 .get(&LuaModule::from_str("baz").unwrap())
514 .unwrap();
515 match baz {
516 ModuleSpec::ModulePaths(paths) => assert_eq!(
517 paths.defines,
518 vec![
519 ("USE_BAZ".into(), Some("1".into())),
520 ("ENABLE_LOGGING".into(), Some("true".into())),
521 ("LINK_STATIC".into(), None)
522 ]
523 ),
524 _ => panic!(),
525 }
526 }
527}