1use itertools::Itertools;
2use serde::{de, Deserialize, Deserializer};
3use std::{collections::HashMap, convert::Infallible, fmt::Display, path::PathBuf, str::FromStr};
4use thiserror::Error;
5
6use crate::{
7 build::utils::c_dylib_extension,
8 lua_rockspec::{
9 deserialize_vec_from_lua_array_or_string, normalize_lua_value, DisplayAsLuaValue,
10 PartialOverride, PerPlatform, PlatformOverridable,
11 },
12};
13
14use super::{DisplayLuaKV, DisplayLuaValue};
15
16#[derive(Debug, PartialEq, Deserialize, Default, Clone)]
17pub struct BuiltinBuildSpec {
18 pub modules: HashMap<LuaModule, ModuleSpec>,
20}
21
22#[derive(Debug, PartialEq, Eq, Deserialize, Default, Clone, Hash)]
23pub struct LuaModule(String);
24
25impl LuaModule {
26 pub fn to_lua_path(&self) -> PathBuf {
27 self.to_file_path(".lua")
28 }
29
30 pub fn to_lua_init_path(&self) -> PathBuf {
31 self.to_path_buf().join("init.lua")
32 }
33
34 pub fn to_lib_path(&self) -> PathBuf {
35 self.to_file_path(&format!(".{}", c_dylib_extension()))
36 }
37
38 fn to_path_buf(&self) -> PathBuf {
39 PathBuf::from(self.0.replace('.', std::path::MAIN_SEPARATOR_STR))
40 }
41
42 fn to_file_path(&self, extension: &str) -> PathBuf {
43 PathBuf::from(self.0.replace('.', std::path::MAIN_SEPARATOR_STR) + extension)
44 }
45
46 pub fn from_pathbuf(path: PathBuf) -> Self {
47 let extension = path
48 .extension()
49 .map(|ext| ext.to_string_lossy().to_string())
50 .unwrap_or("".into());
51 let module = path
52 .to_string_lossy()
53 .trim_end_matches(format!("init.{extension}").as_str())
54 .trim_end_matches(format!(".{extension}").as_str())
55 .trim_end_matches(std::path::MAIN_SEPARATOR_STR)
56 .replace(std::path::MAIN_SEPARATOR_STR, ".");
57 LuaModule(module)
58 }
59
60 pub fn join(&self, other: &LuaModule) -> LuaModule {
61 LuaModule(format!("{}.{}", self.0, other.0))
62 }
63
64 pub fn as_str(&self) -> &str {
65 self.0.as_str()
66 }
67}
68
69#[derive(Error, Debug)]
70#[error("could not parse lua module from {0}.")]
71pub struct ParseLuaModuleError(String);
72
73impl FromStr for LuaModule {
74 type Err = ParseLuaModuleError;
75
76 fn from_str(s: &str) -> Result<Self, Self::Err> {
78 Ok(LuaModule(s.into()))
79 }
80}
81
82impl Display for LuaModule {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 self.0.fmt(f)
85 }
86}
87
88#[derive(Debug, PartialEq, Clone)]
89pub enum ModuleSpec {
90 SourcePath(PathBuf),
92 SourcePaths(Vec<PathBuf>),
94 ModulePaths(ModulePaths),
95}
96
97impl ModuleSpec {
98 pub fn from_internal(
99 internal: ModuleSpecInternal,
100 ) -> Result<ModuleSpec, ModulePathsMissingSources> {
101 match internal {
102 ModuleSpecInternal::SourcePath(path) => Ok(ModuleSpec::SourcePath(path)),
103 ModuleSpecInternal::SourcePaths(paths) => Ok(ModuleSpec::SourcePaths(paths)),
104 ModuleSpecInternal::ModulePaths(module_paths) => Ok(ModuleSpec::ModulePaths(
105 ModulePaths::from_internal(module_paths)?,
106 )),
107 }
108 }
109}
110
111impl<'de> Deserialize<'de> for ModuleSpec {
112 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
113 where
114 D: Deserializer<'de>,
115 {
116 Self::from_internal(ModuleSpecInternal::deserialize(deserializer)?)
117 .map_err(de::Error::custom)
118 }
119}
120
121impl TryFrom<ModuleSpecInternal> for ModuleSpec {
122 type Error = ModulePathsMissingSources;
123
124 fn try_from(internal: ModuleSpecInternal) -> Result<Self, Self::Error> {
125 Self::from_internal(internal)
126 }
127}
128
129#[derive(Debug, PartialEq, Clone)]
130pub enum ModuleSpecInternal {
131 SourcePath(PathBuf),
132 SourcePaths(Vec<PathBuf>),
133 ModulePaths(ModulePathsInternal),
134}
135
136impl<'de> Deserialize<'de> for ModuleSpecInternal {
137 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
138 where
139 D: Deserializer<'de>,
140 {
141 let value = normalize_lua_value(serde_value::Value::deserialize(deserializer)?);
142 match value {
143 serde_value::Value::String(s) => Ok(Self::SourcePath(PathBuf::from(s))),
144 serde_value::Value::Seq(_) => {
145 let src_paths: Vec<PathBuf> =
146 value.deserialize_into().map_err(de::Error::custom)?;
147 Ok(Self::SourcePaths(src_paths))
148 }
149 serde_value::Value::Map(_) => {
150 let module_paths: ModulePathsInternal =
151 value.deserialize_into().map_err(de::Error::custom)?;
152 Ok(Self::ModulePaths(module_paths))
153 }
154 _ => Err(de::Error::custom(format!(
155 "expected a string, list, or table for module spec, got: {value:?}"
156 ))),
157 }
158 }
159}
160
161impl DisplayAsLuaValue for ModuleSpecInternal {
162 fn display_lua_value(&self) -> DisplayLuaValue {
163 match self {
164 ModuleSpecInternal::SourcePath(path) => {
165 DisplayLuaValue::String(path.to_string_lossy().into())
166 }
167 ModuleSpecInternal::SourcePaths(paths) => DisplayLuaValue::List(
168 paths
169 .iter()
170 .map(|p| DisplayLuaValue::String(p.to_string_lossy().into()))
171 .collect(),
172 ),
173 ModuleSpecInternal::ModulePaths(module_paths) => module_paths.display_lua_value(),
174 }
175 }
176}
177
178fn deserialize_definitions<'de, D>(
179 deserializer: D,
180) -> Result<Vec<(String, Option<String>)>, D::Error>
181where
182 D: Deserializer<'de>,
183{
184 let values: Vec<String> = deserialize_vec_from_lua_array_or_string(deserializer)?;
185 values
186 .iter()
187 .map(|val| {
188 if let Some((key, value)) = val.split_once('=') {
189 Ok((key.into(), Some(value.into())))
190 } else {
191 Ok((val.into(), None))
192 }
193 })
194 .try_collect()
195}
196
197#[derive(Error, Debug)]
198#[error("cannot resolve ambiguous platform override for `build.modules`.")]
199pub struct ModuleSpecAmbiguousPlatformOverride;
200
201impl PartialOverride for ModuleSpecInternal {
202 type Err = ModuleSpecAmbiguousPlatformOverride;
203
204 fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
205 match (override_spec, self) {
206 (ModuleSpecInternal::SourcePath(_), b @ ModuleSpecInternal::SourcePath(_)) => {
207 Ok(b.to_owned())
208 }
209 (ModuleSpecInternal::SourcePaths(_), b @ ModuleSpecInternal::SourcePaths(_)) => {
210 Ok(b.to_owned())
211 }
212 (ModuleSpecInternal::ModulePaths(a), ModuleSpecInternal::ModulePaths(b)) => Ok(
213 ModuleSpecInternal::ModulePaths(a.apply_overrides(b).unwrap()),
214 ),
215 _ => Err(ModuleSpecAmbiguousPlatformOverride),
216 }
217 }
218}
219
220#[derive(Error, Debug)]
221#[error("could not resolve platform override for `build.modules`. THIS IS A BUG!")]
222pub struct BuildModulesPlatformOverride;
223
224impl PlatformOverridable for ModuleSpecInternal {
225 type Err = BuildModulesPlatformOverride;
226
227 fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
228 where
229 T: PlatformOverridable,
230 {
231 Err(BuildModulesPlatformOverride)
232 }
233}
234
235#[derive(Error, Debug)]
236#[error("missing or empty field `sources`")]
237pub struct ModulePathsMissingSources;
238
239#[derive(Debug, PartialEq, Clone)]
240pub struct ModulePaths {
241 pub sources: Vec<PathBuf>,
243 pub libraries: Vec<PathBuf>,
245 pub defines: Vec<(String, Option<String>)>,
247 pub incdirs: Vec<PathBuf>,
249 pub libdirs: Vec<PathBuf>,
251}
252
253impl ModulePaths {
254 fn from_internal(
255 internal: ModulePathsInternal,
256 ) -> Result<ModulePaths, ModulePathsMissingSources> {
257 if internal.sources.is_empty() {
258 Err(ModulePathsMissingSources)
259 } else {
260 Ok(ModulePaths {
261 sources: internal.sources,
262 libraries: internal.libraries,
263 defines: internal.defines,
264 incdirs: internal.incdirs,
265 libdirs: internal.libdirs,
266 })
267 }
268 }
269}
270
271impl TryFrom<ModulePathsInternal> for ModulePaths {
272 type Error = ModulePathsMissingSources;
273
274 fn try_from(internal: ModulePathsInternal) -> Result<Self, Self::Error> {
275 Self::from_internal(internal)
276 }
277}
278
279impl<'de> Deserialize<'de> for ModulePaths {
280 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
281 where
282 D: Deserializer<'de>,
283 {
284 Self::from_internal(ModulePathsInternal::deserialize(deserializer)?)
285 .map_err(de::Error::custom)
286 }
287}
288
289#[derive(Debug, PartialEq, Deserialize, Clone, Default)]
290pub struct ModulePathsInternal {
291 #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
292 pub sources: Vec<PathBuf>,
293 #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
294 pub libraries: Vec<PathBuf>,
295 #[serde(default, deserialize_with = "deserialize_definitions")]
296 pub defines: Vec<(String, Option<String>)>,
297 #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
298 pub incdirs: Vec<PathBuf>,
299 #[serde(default, deserialize_with = "deserialize_vec_from_lua_array_or_string")]
300 pub libdirs: Vec<PathBuf>,
301}
302
303impl DisplayAsLuaValue for ModulePathsInternal {
304 fn display_lua_value(&self) -> DisplayLuaValue {
305 DisplayLuaValue::Table(vec![
306 DisplayLuaKV {
307 key: "sources".into(),
308 value: DisplayLuaValue::List(
309 self.sources
310 .iter()
311 .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
312 .collect(),
313 ),
314 },
315 DisplayLuaKV {
316 key: "libraries".into(),
317 value: DisplayLuaValue::List(
318 self.libraries
319 .iter()
320 .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
321 .collect(),
322 ),
323 },
324 DisplayLuaKV {
325 key: "defines".into(),
326 value: DisplayLuaValue::List(
327 self.defines
328 .iter()
329 .map(|(k, v)| {
330 if let Some(v) = v {
331 DisplayLuaValue::String(format!("{k}={v}"))
332 } else {
333 DisplayLuaValue::String(k.clone())
334 }
335 })
336 .collect(),
337 ),
338 },
339 DisplayLuaKV {
340 key: "incdirs".into(),
341 value: DisplayLuaValue::List(
342 self.incdirs
343 .iter()
344 .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
345 .collect(),
346 ),
347 },
348 DisplayLuaKV {
349 key: "libdirs".into(),
350 value: DisplayLuaValue::List(
351 self.libdirs
352 .iter()
353 .map(|s| DisplayLuaValue::String(s.to_string_lossy().into()))
354 .collect(),
355 ),
356 },
357 ])
358 }
359}
360
361impl PartialOverride for ModulePathsInternal {
362 type Err = Infallible;
363
364 fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
365 Ok(Self {
366 sources: override_vec(override_spec.sources.as_ref(), self.sources.as_ref()),
367 libraries: override_vec(override_spec.libraries.as_ref(), self.libraries.as_ref()),
368 defines: override_vec(override_spec.defines.as_ref(), self.defines.as_ref()),
369 incdirs: override_vec(override_spec.incdirs.as_ref(), self.incdirs.as_ref()),
370 libdirs: override_vec(override_spec.libdirs.as_ref(), self.libdirs.as_ref()),
371 })
372 }
373}
374
375impl PlatformOverridable for ModulePathsInternal {
376 type Err = Infallible;
377
378 fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
379 where
380 T: PlatformOverridable,
381 T: Default,
382 {
383 Ok(PerPlatform::default())
384 }
385}
386
387fn override_vec<T: Clone>(override_vec: &[T], base: &[T]) -> Vec<T> {
388 if override_vec.is_empty() {
389 return base.to_owned();
390 }
391 override_vec.to_owned()
392}
393
394#[cfg(test)]
395mod tests {
396 use ottavino::{Closure, Executor, Fuel, Lua};
397 use ottavino_util::serde::from_value;
398
399 use super::*;
400
401 fn exec_lua<T: serde::de::DeserializeOwned>(
402 code: &str,
403 key: &'static str,
404 ) -> Result<T, ottavino::ExternError> {
405 Lua::core().try_enter(|ctx| {
406 let closure = Closure::load(ctx, None, code.as_bytes())?;
407 let executor = Executor::start(ctx, closure.into(), ());
408 executor.step(ctx, &mut Fuel::with(i32::MAX))?;
409 from_value(ctx.globals().get_value(ctx, key)).map_err(ottavino::Error::from)
410 })
411 }
412
413 #[tokio::test]
414 pub async fn parse_lua_module_from_path() {
415 let lua_module = LuaModule::from_pathbuf("foo/init.lua".into());
416 assert_eq!(&lua_module.0, "foo");
417 let lua_module = LuaModule::from_pathbuf("foo/bar.lua".into());
418 assert_eq!(&lua_module.0, "foo.bar");
419 let lua_module = LuaModule::from_pathbuf("foo/bar/init.lua".into());
420 assert_eq!(&lua_module.0, "foo.bar");
421 let lua_module = LuaModule::from_pathbuf("foo/bar/baz.lua".into());
422 assert_eq!(&lua_module.0, "foo.bar.baz");
423 }
424
425 #[tokio::test]
426 pub async fn modules_spec_from_lua() {
427 let lua_content = "
428 build = {\n
429 modules = {\n
430 foo = 'lua/foo/init.lua',\n
431 bar = {\n
432 'lua/bar.lua',\n
433 'lua/bar/internal.lua',\n
434 },\n
435 baz = {\n
436 sources = {\n
437 'lua/baz.lua',\n
438 },\n
439 defines = { 'USE_BAZ' },\n
440 },\n
441 foo = 'lua/foo/init.lua',
442 },\n
443 }\n
444 ";
445 let build_spec: BuiltinBuildSpec = exec_lua(lua_content, "build").unwrap();
446 let foo = build_spec
447 .modules
448 .get(&LuaModule::from_str("foo").unwrap())
449 .unwrap();
450 assert_eq!(*foo, ModuleSpec::SourcePath("lua/foo/init.lua".into()));
451 let bar = build_spec
452 .modules
453 .get(&LuaModule::from_str("bar").unwrap())
454 .unwrap();
455 assert_eq!(
456 *bar,
457 ModuleSpec::SourcePaths(vec!["lua/bar.lua".into(), "lua/bar/internal.lua".into()])
458 );
459 let baz = build_spec
460 .modules
461 .get(&LuaModule::from_str("baz").unwrap())
462 .unwrap();
463 assert!(matches!(baz, ModuleSpec::ModulePaths { .. }));
464 let lua_content_no_sources = "
465 build = {\n
466 modules = {\n
467 baz = {\n
468 defines = { 'USE_BAZ' },\n
469 },\n
470 },\n
471 }\n
472 ";
473 let result: Result<BuiltinBuildSpec, _> = exec_lua(lua_content_no_sources, "build");
474 let _err = result.unwrap_err();
475 let lua_content_complex_defines = "
476 build = {\n
477 modules = {\n
478 baz = {\n
479 sources = {\n
480 'lua/baz.lua',\n
481 },\n
482 defines = { 'USE_BAZ=1', 'ENABLE_LOGGING=true', 'LINK_STATIC' },\n
483 },\n
484 },\n
485 }\n
486 ";
487 let build_spec: BuiltinBuildSpec = exec_lua(lua_content_complex_defines, "build").unwrap();
488 let baz = build_spec
489 .modules
490 .get(&LuaModule::from_str("baz").unwrap())
491 .unwrap();
492 match baz {
493 ModuleSpec::ModulePaths(paths) => assert_eq!(
494 paths.defines,
495 vec![
496 ("USE_BAZ".into(), Some("1".into())),
497 ("ENABLE_LOGGING".into(), Some("true".into())),
498 ("LINK_STATIC".into(), None)
499 ]
500 ),
501 _ => panic!(),
502 }
503 }
504}