1use itertools::Itertools;
2use serde::{de, Deserialize, Deserializer};
3use std::{collections::HashMap, convert::Infallible, fmt::Display, path::PathBuf, str::FromStr};
4use thiserror::Error;
5
6use crate::{
7 build::utils::c_dylib_extension,
8 lua_rockspec::{
9 deserialize_vec_from_lua_array_or_string, normalize_lua_value, DisplayAsLuaValue,
10 PartialOverride, PerPlatform, PlatformOverridable,
11 },
12};
13
14use super::{DisplayLuaKV, DisplayLuaValue};
15
16#[derive(Debug, PartialEq, Deserialize, Default, Clone)]
17pub struct BuiltinBuildSpec {
18 pub modules: HashMap<LuaModule, ModuleSpec>,
20}
21
22#[derive(Debug, PartialEq, Eq, Deserialize, Default, Clone, Hash)]
23pub struct LuaModule(String);
24
25impl LuaModule {
26 pub fn to_lua_path(&self) -> PathBuf {
27 self.to_file_path(".lua")
28 }
29
30 pub fn to_lua_init_path(&self) -> PathBuf {
31 self.to_path_buf().join("init.lua")
32 }
33
34 pub fn to_lib_path(&self) -> PathBuf {
35 self.to_file_path(&format!(".{}", c_dylib_extension()))
36 }
37
38 fn to_path_buf(&self) -> PathBuf {
39 PathBuf::from(self.0.replace('.', std::path::MAIN_SEPARATOR_STR))
40 }
41
42 fn to_file_path(&self, extension: &str) -> PathBuf {
43 PathBuf::from(self.0.replace('.', std::path::MAIN_SEPARATOR_STR) + extension)
44 }
45
46 pub fn from_pathbuf(path: PathBuf) -> Self {
47 let extension = path
48 .extension()
49 .map(|ext| ext.to_string_lossy().to_string())
50 .unwrap_or("".into());
51 let module = path
52 .to_string_lossy()
53 .trim_end_matches(format!("init.{extension}").as_str())
54 .trim_end_matches(format!(".{extension}").as_str())
55 .trim_end_matches(std::path::MAIN_SEPARATOR_STR)
56 .replace(std::path::MAIN_SEPARATOR_STR, ".");
57 LuaModule(module)
58 }
59
60 pub fn join(&self, other: &LuaModule) -> LuaModule {
61 LuaModule(format!("{}.{}", self.0, other.0))
62 }
63
64 pub fn as_str(&self) -> &str {
65 self.0.as_str()
66 }
67}
68
69#[derive(Error, Debug)]
70#[error("could not parse lua module from {0}.")]
71pub struct ParseLuaModuleError(String);
72
73impl FromStr for LuaModule {
74 type Err = ParseLuaModuleError;
75
76 fn from_str(s: &str) -> Result<Self, Self::Err> {
78 Ok(LuaModule(s.into()))
79 }
80}
81
82impl Display for LuaModule {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 self.0.fmt(f)
85 }
86}
87
88impl DisplayAsLuaValue for LuaModule {
89 fn display_lua_value(&self) -> DisplayLuaValue {
90 DisplayLuaValue::String(self.to_string())
91 }
92}
93
94impl DisplayAsLuaValue for HashMap<LuaModule, PathBuf> {
95 fn display_lua_value(&self) -> DisplayLuaValue {
96 use path_slash::PathBufExt as _;
97 DisplayLuaValue::Table(
98 self.iter()
99 .map(|(k, v)| DisplayLuaKV {
100 key: k.to_string(),
101 value: DisplayLuaValue::String(v.to_slash_lossy().into_owned()),
102 })
103 .collect_vec(),
104 )
105 }
106}
107
108#[derive(Debug, PartialEq, Clone)]
109pub enum ModuleSpec {
110 SourcePath(PathBuf),
112 SourcePaths(Vec<PathBuf>),
114 ModulePaths(ModulePaths),
115}
116
117impl ModuleSpec {
118 pub fn from_internal(
119 internal: ModuleSpecInternal,
120 ) -> Result<ModuleSpec, ModulePathsMissingSources> {
121 match internal {
122 ModuleSpecInternal::SourcePath(path) => Ok(ModuleSpec::SourcePath(path)),
123 ModuleSpecInternal::SourcePaths(paths) => Ok(ModuleSpec::SourcePaths(paths)),
124 ModuleSpecInternal::ModulePaths(module_paths) => Ok(ModuleSpec::ModulePaths(
125 ModulePaths::from_internal(module_paths)?,
126 )),
127 }
128 }
129}
130
131impl<'de> Deserialize<'de> for ModuleSpec {
132 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
133 where
134 D: Deserializer<'de>,
135 {
136 Self::from_internal(ModuleSpecInternal::deserialize(deserializer)?)
137 .map_err(de::Error::custom)
138 }
139}
140
141impl TryFrom<ModuleSpecInternal> for ModuleSpec {
142 type Error = ModulePathsMissingSources;
143
144 fn try_from(internal: ModuleSpecInternal) -> Result<Self, Self::Error> {
145 Self::from_internal(internal)
146 }
147}
148
149#[derive(Debug, PartialEq, Clone)]
150pub enum ModuleSpecInternal {
151 SourcePath(PathBuf),
152 SourcePaths(Vec<PathBuf>),
153 ModulePaths(ModulePathsInternal),
154}
155
156impl<'de> Deserialize<'de> for ModuleSpecInternal {
157 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
158 where
159 D: Deserializer<'de>,
160 {
161 let value = normalize_lua_value(serde_value::Value::deserialize(deserializer)?);
162 match value {
163 serde_value::Value::String(s) => Ok(Self::SourcePath(PathBuf::from(s))),
164 serde_value::Value::Seq(_) => {
165 let src_paths: Vec<PathBuf> =
166 value.deserialize_into().map_err(de::Error::custom)?;
167 Ok(Self::SourcePaths(src_paths))
168 }
169 serde_value::Value::Map(_) => {
170 let module_paths: ModulePathsInternal =
171 value.deserialize_into().map_err(de::Error::custom)?;
172 Ok(Self::ModulePaths(module_paths))
173 }
174 _ => Err(de::Error::custom(format!(
175 "expected a string, list, or table for module spec, got: {value:?}"
176 ))),
177 }
178 }
179}
180
181impl DisplayAsLuaValue for ModuleSpecInternal {
182 fn display_lua_value(&self) -> DisplayLuaValue {
183 match self {
184 ModuleSpecInternal::SourcePath(path) => {
185 DisplayLuaValue::String(path.to_string_lossy().into())
186 }
187 ModuleSpecInternal::SourcePaths(paths) => DisplayLuaValue::List(
188 paths
189 .iter()
190 .map(|p| DisplayLuaValue::String(p.to_string_lossy().into()))
191 .collect(),
192 ),
193 ModuleSpecInternal::ModulePaths(module_paths) => module_paths.display_lua_value(),
194 }
195 }
196}
197
198fn deserialize_definitions<'de, D>(
199 deserializer: D,
200) -> Result<Vec<(String, Option<String>)>, D::Error>
201where
202 D: Deserializer<'de>,
203{
204 let values: Vec<String> = deserialize_vec_from_lua_array_or_string(deserializer)?;
205 values
206 .iter()
207 .map(|val| {
208 if let Some((key, value)) = val.split_once('=') {
209 Ok((key.into(), Some(value.into())))
210 } else {
211 Ok((val.into(), None))
212 }
213 })
214 .try_collect()
215}
216
217#[derive(Error, Debug)]
218#[error("cannot resolve ambiguous platform override for `build.modules`.")]
219pub struct ModuleSpecAmbiguousPlatformOverride;
220
221impl PartialOverride for ModuleSpecInternal {
222 type Err = ModuleSpecAmbiguousPlatformOverride;
223
224 fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
225 match (override_spec, self) {
226 (ModuleSpecInternal::SourcePath(_), b @ ModuleSpecInternal::SourcePath(_)) => {
227 Ok(b.to_owned())
228 }
229 (ModuleSpecInternal::SourcePaths(_), b @ ModuleSpecInternal::SourcePaths(_)) => {
230 Ok(b.to_owned())
231 }
232 (ModuleSpecInternal::ModulePaths(a), ModuleSpecInternal::ModulePaths(b)) => Ok(
233 ModuleSpecInternal::ModulePaths(a.apply_overrides(b).unwrap()),
234 ),
235 _ => Err(ModuleSpecAmbiguousPlatformOverride),
236 }
237 }
238}
239
240#[derive(Error, Debug)]
241#[error("could not resolve platform override for `build.modules`. THIS IS A BUG!")]
242pub struct BuildModulesPlatformOverride;
243
244impl PlatformOverridable for ModuleSpecInternal {
245 type Err = BuildModulesPlatformOverride;
246
247 fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
248 where
249 T: PlatformOverridable,
250 {
251 Err(BuildModulesPlatformOverride)
252 }
253}
254
255#[derive(Error, Debug)]
256#[error("missing or empty field `sources`")]
257pub struct ModulePathsMissingSources;
258
259#[derive(Debug, PartialEq, Clone)]
260pub struct ModulePaths {
261 pub sources: Vec<PathBuf>,
263 pub libraries: Vec<PathBuf>,
265 pub defines: Vec<(String, Option<String>)>,
267 pub incdirs: Vec<PathBuf>,
269 pub libdirs: Vec<PathBuf>,
271}
272
273impl ModulePaths {
274 fn from_internal(
275 internal: ModulePathsInternal,
276 ) -> Result<ModulePaths, ModulePathsMissingSources> {
277 if internal.sources.is_empty() {
278 Err(ModulePathsMissingSources)
279 } else {
280 Ok(ModulePaths {
281 sources: internal.sources,
282 libraries: internal.libraries,
283 defines: internal.defines,
284 incdirs: internal.incdirs,
285 libdirs: internal.libdirs,
286 })
287 }
288 }
289}
290
291impl TryFrom<ModulePathsInternal> for ModulePaths {
292 type Error = ModulePathsMissingSources;
293
294 fn try_from(internal: ModulePathsInternal) -> Result<Self, Self::Error> {
295 Self::from_internal(internal)
296 }
297}
298
299impl<'de> Deserialize<'de> for ModulePaths {
300 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
301 where
302 D: Deserializer<'de>,
303 {
304 Self::from_internal(ModulePathsInternal::deserialize(deserializer)?)
305 .map_err(de::Error::custom)
306 }
307}
308
309#[derive(Debug, PartialEq, Deserialize, Clone, Default)]
310pub struct ModulePathsInternal {
311 #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
312 pub sources: Vec<PathBuf>,
313 #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
314 pub libraries: Vec<PathBuf>,
315 #[serde(default, deserialize_with = "deserialize_definitions")]
316 pub defines: Vec<(String, Option<String>)>,
317 #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
318 pub incdirs: Vec<PathBuf>,
319 #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
320 pub libdirs: Vec<PathBuf>,
321}
322
323impl DisplayAsLuaValue for ModulePathsInternal {
324 fn display_lua_value(&self) -> DisplayLuaValue {
325 DisplayLuaValue::Table(vec![
326 DisplayLuaKV {
327 key: "sources".into(),
328 value: DisplayLuaValue::List(
329 self.sources
330 .iter()
331 .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
332 .collect(),
333 ),
334 },
335 DisplayLuaKV {
336 key: "libraries".into(),
337 value: DisplayLuaValue::List(
338 self.libraries
339 .iter()
340 .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
341 .collect(),
342 ),
343 },
344 DisplayLuaKV {
345 key: "defines".into(),
346 value: DisplayLuaValue::List(
347 self.defines
348 .iter()
349 .map(|(k, v)| {
350 if let Some(v) = v {
351 DisplayLuaValue::String(format!("{k}={v}"))
352 } else {
353 DisplayLuaValue::String(k.clone())
354 }
355 })
356 .collect(),
357 ),
358 },
359 DisplayLuaKV {
360 key: "incdirs".into(),
361 value: DisplayLuaValue::List(
362 self.incdirs
363 .iter()
364 .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
365 .collect(),
366 ),
367 },
368 DisplayLuaKV {
369 key: "libdirs".into(),
370 value: DisplayLuaValue::List(
371 self.libdirs
372 .iter()
373 .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
374 .collect(),
375 ),
376 },
377 ])
378 }
379}
380
381impl PartialOverride for ModulePathsInternal {
382 type Err = Infallible;
383
384 fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
385 Ok(Self {
386 sources: override_vec(override_spec.sources.as_ref(), self.sources.as_ref()),
387 libraries: override_vec(override_spec.libraries.as_ref(), self.libraries.as_ref()),
388 defines: override_vec(override_spec.defines.as_ref(), self.defines.as_ref()),
389 incdirs: override_vec(override_spec.incdirs.as_ref(), self.incdirs.as_ref()),
390 libdirs: override_vec(override_spec.libdirs.as_ref(), self.libdirs.as_ref()),
391 })
392 }
393}
394
395impl PlatformOverridable for ModulePathsInternal {
396 type Err = Infallible;
397
398 fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
399 where
400 T: PlatformOverridable,
401 T: Default,
402 {
403 Ok(PerPlatform::default())
404 }
405}
406
407fn override_vec<T: Clone>(override_vec: &[T], base: &[T]) -> Vec<T> {
408 if override_vec.is_empty() {
409 return base.to_owned();
410 }
411 override_vec.to_owned()
412}
413
414#[cfg(test)]
415mod tests {
416 use ottavino::{Closure, Executor, Fuel, Lua};
417 use ottavino_util::serde::from_value;
418
419 use super::*;
420
421 fn exec_lua<T: serde::de::DeserializeOwned>(
422 code: &str,
423 key: &'static str,
424 ) -> Result<T, ottavino::ExternError> {
425 Lua::core().try_enter(|ctx| {
426 let closure = Closure::load(ctx, None, code.as_bytes())?;
427 let executor = Executor::start(ctx, closure.into(), ());
428 executor.step(ctx, &mut Fuel::with(i32::MAX))?;
429 from_value(ctx.globals().get_value(ctx, key)).map_err(ottavino::Error::from)
430 })
431 }
432
433 #[tokio::test]
434 pub async fn parse_lua_module_from_path() {
435 let lua_module = LuaModule::from_pathbuf("foo/init.lua".into());
436 assert_eq!(&lua_module.0, "foo");
437 let lua_module = LuaModule::from_pathbuf("foo/bar.lua".into());
438 assert_eq!(&lua_module.0, "foo.bar");
439 let lua_module = LuaModule::from_pathbuf("foo/bar/init.lua".into());
440 assert_eq!(&lua_module.0, "foo.bar");
441 let lua_module = LuaModule::from_pathbuf("foo/bar/baz.lua".into());
442 assert_eq!(&lua_module.0, "foo.bar.baz");
443 }
444
445 #[tokio::test]
446 pub async fn modules_spec_from_lua() {
447 let lua_content = "
448 build = {\n
449 modules = {\n
450 foo = 'lua/foo/init.lua',\n
451 bar = {\n
452 'lua/bar.lua',\n
453 'lua/bar/internal.lua',\n
454 },\n
455 baz = {\n
456 sources = {\n
457 'lua/baz.lua',\n
458 },\n
459 defines = { 'USE_BAZ' },\n
460 },\n
461 foo = 'lua/foo/init.lua',
462 },\n
463 }\n
464 ";
465 let build_spec: BuiltinBuildSpec = exec_lua(lua_content, "build").unwrap();
466 let foo = build_spec
467 .modules
468 .get(&LuaModule::from_str("foo").unwrap())
469 .unwrap();
470 assert_eq!(*foo, ModuleSpec::SourcePath("lua/foo/init.lua".into()));
471 let bar = build_spec
472 .modules
473 .get(&LuaModule::from_str("bar").unwrap())
474 .unwrap();
475 assert_eq!(
476 *bar,
477 ModuleSpec::SourcePaths(vec!["lua/bar.lua".into(), "lua/bar/internal.lua".into()])
478 );
479 let baz = build_spec
480 .modules
481 .get(&LuaModule::from_str("baz").unwrap())
482 .unwrap();
483 assert!(matches!(baz, ModuleSpec::ModulePaths { .. }));
484 let lua_content_no_sources = "
485 build = {\n
486 modules = {\n
487 baz = {\n
488 defines = { 'USE_BAZ' },\n
489 },\n
490 },\n
491 }\n
492 ";
493 let result: Result<BuiltinBuildSpec, _> = exec_lua(lua_content_no_sources, "build");
494 let _err = result.unwrap_err();
495 let lua_content_complex_defines = "
496 build = {\n
497 modules = {\n
498 baz = {\n
499 sources = {\n
500 'lua/baz.lua',\n
501 },\n
502 defines = { 'USE_BAZ=1', 'ENABLE_LOGGING=true', 'LINK_STATIC' },\n
503 },\n
504 },\n
505 }\n
506 ";
507 let build_spec: BuiltinBuildSpec = exec_lua(lua_content_complex_defines, "build").unwrap();
508 let baz = build_spec
509 .modules
510 .get(&LuaModule::from_str("baz").unwrap())
511 .unwrap();
512 match baz {
513 ModuleSpec::ModulePaths(paths) => assert_eq!(
514 paths.defines,
515 vec![
516 ("USE_BAZ".into(), Some("1".into())),
517 ("ENABLE_LOGGING".into(), Some("true".into())),
518 ("LINK_STATIC".into(), None)
519 ]
520 ),
521 _ => panic!(),
522 }
523 }
524}