1use itertools::Itertools;
2use path_slash::PathBufExt;
3use serde::{de, Deserialize, Deserializer};
4use std::{collections::HashMap, convert::Infallible, fmt::Display, path::PathBuf, str::FromStr};
5use thiserror::Error;
6
7use crate::{
8 build::utils::c_dylib_extension,
9 lua_rockspec::{
10 deserialize_vec_from_lua_array_or_string, normalize_lua_value, DisplayAsLuaValue,
11 PartialOverride, PerPlatform, PlatformOverridable,
12 },
13};
14
15use super::{DisplayLuaKV, DisplayLuaValue};
16
17#[derive(Debug, PartialEq, Deserialize, Default, Clone)]
18pub struct BuiltinBuildSpec {
19 pub modules: HashMap<LuaModule, ModuleSpec>,
21}
22
23#[derive(Debug, PartialEq, Eq, Deserialize, Default, Clone, Hash)]
24pub struct LuaModule(String);
25
26impl LuaModule {
27 pub fn to_lua_path(&self) -> PathBuf {
28 self.to_file_path(".lua")
29 }
30
31 pub fn to_lua_init_path(&self) -> PathBuf {
32 self.to_path_buf().join("init.lua")
33 }
34
35 pub fn to_lib_path(&self) -> PathBuf {
36 self.to_file_path(&format!(".{}", c_dylib_extension()))
37 }
38
39 fn to_path_buf(&self) -> PathBuf {
40 PathBuf::from(self.0.replace('.', std::path::MAIN_SEPARATOR_STR))
41 }
42
43 pub(crate) fn to_file_path(&self, extension: &str) -> PathBuf {
44 PathBuf::from(self.0.replace('.', std::path::MAIN_SEPARATOR_STR) + extension)
45 }
46
47 pub fn from_pathbuf(path: PathBuf) -> Result<Self, ParseLuaModuleError> {
48 let extension = path
49 .extension()
50 .map(|ext| ext.to_string_lossy().to_string())
51 .unwrap_or("".into());
52 let module = path
53 .to_string_lossy()
54 .trim_end_matches(format!("init.{extension}").as_str())
55 .trim_end_matches(format!(".{extension}").as_str())
56 .trim_end_matches(std::path::MAIN_SEPARATOR_STR)
57 .replace(std::path::MAIN_SEPARATOR_STR, ".");
58 if module.is_empty() {
59 Err(ParseLuaModuleError::EmptyModule(
60 path.to_slash_lossy().to_string(),
61 ))
62 } else {
63 Ok(LuaModule(module))
64 }
65 }
66
67 pub fn join(&self, other: &LuaModule) -> LuaModule {
68 LuaModule(format!("{}.{}", self.0, other.0))
69 }
70
71 pub fn as_str(&self) -> &str {
72 self.0.as_str()
73 }
74}
75
76#[derive(Error, Debug)]
77pub enum ParseLuaModuleError {
78 #[error("could not parse Lua module from '{0}'.")]
79 FromString(String),
80 #[error("path '{0}' resulted in an empty Lua module.")]
81 EmptyModule(String),
82}
83
84impl FromStr for LuaModule {
85 type Err = ParseLuaModuleError;
86
87 fn from_str(s: &str) -> Result<Self, Self::Err> {
89 if s.is_empty() {
90 Err(ParseLuaModuleError::EmptyModule(s.into()))
91 } else {
92 Ok(LuaModule(s.into()))
93 }
94 }
95}
96
97impl Display for LuaModule {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 self.0.fmt(f)
100 }
101}
102
103impl DisplayAsLuaValue for LuaModule {
104 fn display_lua_value(&self) -> DisplayLuaValue {
105 DisplayLuaValue::String(self.to_string())
106 }
107}
108
109impl DisplayAsLuaValue for HashMap<LuaModule, PathBuf> {
110 fn display_lua_value(&self) -> DisplayLuaValue {
111 use path_slash::PathBufExt as _;
112 DisplayLuaValue::Table(
113 self.iter()
114 .map(|(k, v)| DisplayLuaKV {
115 key: k.to_string(),
116 value: DisplayLuaValue::String(v.to_slash_lossy().into_owned()),
117 })
118 .collect_vec(),
119 )
120 }
121}
122
123#[derive(Debug, PartialEq, Clone)]
124pub enum ModuleSpec {
125 SourcePath(PathBuf),
127 SourcePaths(Vec<PathBuf>),
129 ModulePaths(ModulePaths),
130}
131
132impl ModuleSpec {
133 pub fn from_internal(
134 internal: ModuleSpecInternal,
135 ) -> Result<ModuleSpec, ModulePathsMissingSources> {
136 match internal {
137 ModuleSpecInternal::SourcePath(path) => Ok(ModuleSpec::SourcePath(path)),
138 ModuleSpecInternal::SourcePaths(paths) => Ok(ModuleSpec::SourcePaths(paths)),
139 ModuleSpecInternal::ModulePaths(module_paths) => Ok(ModuleSpec::ModulePaths(
140 ModulePaths::from_internal(module_paths)?,
141 )),
142 }
143 }
144}
145
146impl<'de> Deserialize<'de> for ModuleSpec {
147 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
148 where
149 D: Deserializer<'de>,
150 {
151 Self::from_internal(ModuleSpecInternal::deserialize(deserializer)?)
152 .map_err(de::Error::custom)
153 }
154}
155
156impl TryFrom<ModuleSpecInternal> for ModuleSpec {
157 type Error = ModulePathsMissingSources;
158
159 fn try_from(internal: ModuleSpecInternal) -> Result<Self, Self::Error> {
160 Self::from_internal(internal)
161 }
162}
163
164#[derive(Debug, PartialEq, Clone)]
165pub enum ModuleSpecInternal {
166 SourcePath(PathBuf),
167 SourcePaths(Vec<PathBuf>),
168 ModulePaths(ModulePathsInternal),
169}
170
171impl<'de> Deserialize<'de> for ModuleSpecInternal {
172 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
173 where
174 D: Deserializer<'de>,
175 {
176 let value = normalize_lua_value(serde_value::Value::deserialize(deserializer)?);
177 match value {
178 serde_value::Value::String(s) => Ok(Self::SourcePath(PathBuf::from(s))),
179 serde_value::Value::Seq(_) => {
180 let src_paths: Vec<PathBuf> =
181 value.deserialize_into().map_err(de::Error::custom)?;
182 Ok(Self::SourcePaths(src_paths))
183 }
184 serde_value::Value::Map(_) => {
185 let module_paths: ModulePathsInternal =
186 value.deserialize_into().map_err(de::Error::custom)?;
187 Ok(Self::ModulePaths(module_paths))
188 }
189 _ => Err(de::Error::custom(format!(
190 "expected a string, list, or table for module spec, got: {value:?}"
191 ))),
192 }
193 }
194}
195
196impl DisplayAsLuaValue for ModuleSpecInternal {
197 fn display_lua_value(&self) -> DisplayLuaValue {
198 match self {
199 ModuleSpecInternal::SourcePath(path) => {
200 DisplayLuaValue::String(path.to_string_lossy().into())
201 }
202 ModuleSpecInternal::SourcePaths(paths) => DisplayLuaValue::List(
203 paths
204 .iter()
205 .map(|p| DisplayLuaValue::String(p.to_string_lossy().into()))
206 .collect(),
207 ),
208 ModuleSpecInternal::ModulePaths(module_paths) => module_paths.display_lua_value(),
209 }
210 }
211}
212
213fn deserialize_definitions<'de, D>(
214 deserializer: D,
215) -> Result<Vec<(String, Option<String>)>, D::Error>
216where
217 D: Deserializer<'de>,
218{
219 let values: Vec<String> = deserialize_vec_from_lua_array_or_string(deserializer)?;
220 values
221 .iter()
222 .map(|val| {
223 if let Some((key, value)) = val.split_once('=') {
224 Ok((key.into(), Some(value.into())))
225 } else {
226 Ok((val.into(), None))
227 }
228 })
229 .try_collect()
230}
231
232#[derive(Error, Debug)]
233#[error("cannot resolve ambiguous platform override for `build.modules`.")]
234pub struct ModuleSpecAmbiguousPlatformOverride;
235
236impl PartialOverride for ModuleSpecInternal {
237 type Err = ModuleSpecAmbiguousPlatformOverride;
238
239 fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
240 match (override_spec, self) {
241 (ModuleSpecInternal::SourcePath(_), b @ ModuleSpecInternal::SourcePath(_)) => {
242 Ok(b.to_owned())
243 }
244 (ModuleSpecInternal::SourcePaths(_), b @ ModuleSpecInternal::SourcePaths(_)) => {
245 Ok(b.to_owned())
246 }
247 (ModuleSpecInternal::ModulePaths(a), ModuleSpecInternal::ModulePaths(b)) => Ok(
248 ModuleSpecInternal::ModulePaths(a.apply_overrides(b).unwrap()),
249 ),
250 _ => Err(ModuleSpecAmbiguousPlatformOverride),
251 }
252 }
253}
254
255#[derive(Error, Debug)]
256#[error("could not resolve platform override for `build.modules`. THIS IS A BUG!")]
257pub struct BuildModulesPlatformOverride;
258
259impl PlatformOverridable for ModuleSpecInternal {
260 type Err = BuildModulesPlatformOverride;
261
262 fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
263 where
264 T: PlatformOverridable,
265 {
266 Err(BuildModulesPlatformOverride)
267 }
268}
269
270#[derive(Error, Debug)]
271#[error("missing or empty field `sources`")]
272pub struct ModulePathsMissingSources;
273
274#[derive(Debug, PartialEq, Clone)]
275pub struct ModulePaths {
276 pub sources: Vec<PathBuf>,
278 pub libraries: Vec<PathBuf>,
280 pub defines: Vec<(String, Option<String>)>,
282 pub incdirs: Vec<PathBuf>,
284 pub libdirs: Vec<PathBuf>,
286}
287
288impl ModulePaths {
289 fn from_internal(
290 internal: ModulePathsInternal,
291 ) -> Result<ModulePaths, ModulePathsMissingSources> {
292 if internal.sources.is_empty() {
293 Err(ModulePathsMissingSources)
294 } else {
295 Ok(ModulePaths {
296 sources: internal.sources,
297 libraries: internal.libraries,
298 defines: internal.defines,
299 incdirs: internal.incdirs,
300 libdirs: internal.libdirs,
301 })
302 }
303 }
304}
305
306impl TryFrom<ModulePathsInternal> for ModulePaths {
307 type Error = ModulePathsMissingSources;
308
309 fn try_from(internal: ModulePathsInternal) -> Result<Self, Self::Error> {
310 Self::from_internal(internal)
311 }
312}
313
314impl<'de> Deserialize<'de> for ModulePaths {
315 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
316 where
317 D: Deserializer<'de>,
318 {
319 Self::from_internal(ModulePathsInternal::deserialize(deserializer)?)
320 .map_err(de::Error::custom)
321 }
322}
323
324#[derive(Debug, PartialEq, Deserialize, Clone, Default)]
325pub struct ModulePathsInternal {
326 #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
327 pub sources: Vec<PathBuf>,
328 #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
329 pub libraries: Vec<PathBuf>,
330 #[serde(default, deserialize_with = "deserialize_definitions")]
331 pub defines: Vec<(String, Option<String>)>,
332 #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
333 pub incdirs: Vec<PathBuf>,
334 #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
335 pub libdirs: Vec<PathBuf>,
336}
337
338impl DisplayAsLuaValue for ModulePathsInternal {
339 fn display_lua_value(&self) -> DisplayLuaValue {
340 DisplayLuaValue::Table(vec![
341 DisplayLuaKV {
342 key: "sources".into(),
343 value: DisplayLuaValue::List(
344 self.sources
345 .iter()
346 .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
347 .collect(),
348 ),
349 },
350 DisplayLuaKV {
351 key: "libraries".into(),
352 value: DisplayLuaValue::List(
353 self.libraries
354 .iter()
355 .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
356 .collect(),
357 ),
358 },
359 DisplayLuaKV {
360 key: "defines".into(),
361 value: DisplayLuaValue::List(
362 self.defines
363 .iter()
364 .map(|(k, v)| {
365 if let Some(v) = v {
366 DisplayLuaValue::String(format!("{k}={v}"))
367 } else {
368 DisplayLuaValue::String(k.clone())
369 }
370 })
371 .collect(),
372 ),
373 },
374 DisplayLuaKV {
375 key: "incdirs".into(),
376 value: DisplayLuaValue::List(
377 self.incdirs
378 .iter()
379 .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
380 .collect(),
381 ),
382 },
383 DisplayLuaKV {
384 key: "libdirs".into(),
385 value: DisplayLuaValue::List(
386 self.libdirs
387 .iter()
388 .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
389 .collect(),
390 ),
391 },
392 ])
393 }
394}
395
396impl PartialOverride for ModulePathsInternal {
397 type Err = Infallible;
398
399 fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
400 Ok(Self {
401 sources: override_vec(override_spec.sources.as_ref(), self.sources.as_ref()),
402 libraries: override_vec(override_spec.libraries.as_ref(), self.libraries.as_ref()),
403 defines: override_vec(override_spec.defines.as_ref(), self.defines.as_ref()),
404 incdirs: override_vec(override_spec.incdirs.as_ref(), self.incdirs.as_ref()),
405 libdirs: override_vec(override_spec.libdirs.as_ref(), self.libdirs.as_ref()),
406 })
407 }
408}
409
410impl PlatformOverridable for ModulePathsInternal {
411 type Err = Infallible;
412
413 fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
414 where
415 T: PlatformOverridable,
416 T: Default,
417 {
418 Ok(PerPlatform::default())
419 }
420}
421
422fn override_vec<T: Clone>(override_vec: &[T], base: &[T]) -> Vec<T> {
423 if override_vec.is_empty() {
424 return base.to_owned();
425 }
426 override_vec.to_owned()
427}
428
429#[cfg(test)]
430mod tests {
431 use ottavino::{Closure, Executor, Fuel, Lua};
432 use ottavino_util::serde::from_value;
433
434 use super::*;
435
436 fn exec_lua<T: serde::de::DeserializeOwned>(
437 code: &str,
438 key: &'static str,
439 ) -> Result<T, ottavino::ExternError> {
440 Lua::core().try_enter(|ctx| {
441 let closure = Closure::load(ctx, None, code.as_bytes())?;
442 let executor = Executor::start(ctx, closure.into(), ());
443 executor.step(ctx, &mut Fuel::with(i32::MAX))?;
444 from_value(ctx.globals().get_value(ctx, key)).map_err(ottavino::Error::from)
445 })
446 }
447
448 #[tokio::test]
449 pub async fn parse_lua_module_from_path() {
450 let lua_module = LuaModule::from_pathbuf("foo/init.lua".into()).unwrap();
451 assert_eq!(&lua_module.0, "foo");
452 let lua_module = LuaModule::from_pathbuf("foo/bar.lua".into()).unwrap();
453 assert_eq!(&lua_module.0, "foo.bar");
454 let lua_module = LuaModule::from_pathbuf("foo/bar/init.lua".into()).unwrap();
455 assert_eq!(&lua_module.0, "foo.bar");
456 let lua_module = LuaModule::from_pathbuf("foo/bar/baz.lua".into()).unwrap();
457 assert_eq!(&lua_module.0, "foo.bar.baz");
458 }
459
460 #[tokio::test]
461 pub async fn modules_spec_from_lua() {
462 let lua_content = "
463 build = {\n
464 modules = {\n
465 foo = 'lua/foo/init.lua',\n
466 bar = {\n
467 'lua/bar.lua',\n
468 'lua/bar/internal.lua',\n
469 },\n
470 baz = {\n
471 sources = {\n
472 'lua/baz.lua',\n
473 },\n
474 defines = { 'USE_BAZ' },\n
475 },\n
476 foo = 'lua/foo/init.lua',
477 },\n
478 }\n
479 ";
480 let build_spec: BuiltinBuildSpec = exec_lua(lua_content, "build").unwrap();
481 let foo = build_spec
482 .modules
483 .get(&LuaModule::from_str("foo").unwrap())
484 .unwrap();
485 assert_eq!(*foo, ModuleSpec::SourcePath("lua/foo/init.lua".into()));
486 let bar = build_spec
487 .modules
488 .get(&LuaModule::from_str("bar").unwrap())
489 .unwrap();
490 assert_eq!(
491 *bar,
492 ModuleSpec::SourcePaths(vec!["lua/bar.lua".into(), "lua/bar/internal.lua".into()])
493 );
494 let baz = build_spec
495 .modules
496 .get(&LuaModule::from_str("baz").unwrap())
497 .unwrap();
498 assert!(matches!(baz, ModuleSpec::ModulePaths { .. }));
499 let lua_content_no_sources = "
500 build = {\n
501 modules = {\n
502 baz = {\n
503 defines = { 'USE_BAZ' },\n
504 },\n
505 },\n
506 }\n
507 ";
508 let result: Result<BuiltinBuildSpec, _> = exec_lua(lua_content_no_sources, "build");
509 let _err = result.unwrap_err();
510 let lua_content_complex_defines = "
511 build = {\n
512 modules = {\n
513 baz = {\n
514 sources = {\n
515 'lua/baz.lua',\n
516 },\n
517 defines = { 'USE_BAZ=1', 'ENABLE_LOGGING=true', 'LINK_STATIC' },\n
518 },\n
519 },\n
520 }\n
521 ";
522 let build_spec: BuiltinBuildSpec = exec_lua(lua_content_complex_defines, "build").unwrap();
523 let baz = build_spec
524 .modules
525 .get(&LuaModule::from_str("baz").unwrap())
526 .unwrap();
527 match baz {
528 ModuleSpec::ModulePaths(paths) => assert_eq!(
529 paths.defines,
530 vec![
531 ("USE_BAZ".into(), Some("1".into())),
532 ("ENABLE_LOGGING".into(), Some("true".into())),
533 ("LINK_STATIC".into(), None)
534 ]
535 ),
536 _ => panic!(),
537 }
538 }
539}