1use itertools::Itertools;
2use mlua::{FromLua, IntoLua, UserData};
3use path_slash::PathExt;
4use serde_enum_str::Serialize_enum_str;
5use std::{convert::Infallible, path::PathBuf};
6use thiserror::Error;
7
8use serde::{Deserialize, Deserializer};
9
10use crate::{
11 config::{Config, ConfigBuilder, ConfigError, LuaVersion},
12 package::PackageReq,
13 project::{project_toml::LocalProjectTomlValidationError, Project},
14 rockspec::Rockspec,
15};
16
17use super::{
18 FromPlatformOverridable, PartialOverride, PerPlatform, PerPlatformWrapper, PlatformOverridable,
19};
20
21#[cfg(target_family = "unix")]
22const NLUA_EXE: &str = "nlua";
23#[cfg(target_family = "windows")]
24const NLUA_EXE: &str = "nlua.bat";
25
26#[derive(Error, Debug)]
27pub enum TestSpecDecodeError {
28 #[error("the 'command' test type must specify either a 'command' or 'script' field")]
29 NoCommandOrScript,
30 #[error("the 'command' test type cannot have both 'command' and 'script' fields")]
31 CommandAndScript,
32}
33
34#[derive(Error, Debug)]
35pub enum TestSpecError {
36 #[error("could not auto-detect the test spec. Please add one to your lux.toml")]
37 NoTestSpecDetected,
38 #[error("project validation failed:\n{0}")]
39 LocalProjectTomlValidation(#[from] LocalProjectTomlValidationError),
40}
41
42#[derive(Clone, Debug, PartialEq, Default)]
43pub enum TestSpec {
44 #[default]
45 AutoDetect,
46 Busted(BustedTestSpec),
47 BustedNlua(BustedTestSpec),
48 Command(CommandTestSpec),
49 Script(LuaScriptTestSpec),
50}
51
52#[derive(Clone, Debug, PartialEq)]
53pub(crate) enum ValidatedTestSpec {
54 Busted(BustedTestSpec),
55 BustedNlua(BustedTestSpec),
56 Command(CommandTestSpec),
57 LuaScript(LuaScriptTestSpec),
58}
59
60impl TestSpec {
61 pub(crate) fn test_dependencies(&self, project: &Project) -> Vec<PackageReq> {
62 self.to_validated(project)
63 .ok()
64 .iter()
65 .flat_map(|spec| spec.test_dependencies())
66 .collect_vec()
67 }
68
69 pub(crate) fn to_validated(
70 &self,
71 project: &Project,
72 ) -> Result<ValidatedTestSpec, TestSpecError> {
73 let project_root = project.root();
74 let toml = project.toml().into_local()?;
75 let test_dependencies = toml.test_dependencies().current_platform();
76 let is_busted = project_root.join(".busted").is_file()
77 || test_dependencies
78 .iter()
79 .any(|dep| dep.name().to_string() == "busted");
80 match self {
81 Self::AutoDetect if is_busted => {
82 if test_dependencies
83 .iter()
84 .any(|dep| dep.name().to_string() == "nlua")
85 {
86 Ok(ValidatedTestSpec::BustedNlua(BustedTestSpec::default()))
87 } else {
88 Ok(ValidatedTestSpec::Busted(BustedTestSpec::default()))
89 }
90 }
91 Self::Busted(spec) => Ok(ValidatedTestSpec::Busted(spec.clone())),
92 Self::BustedNlua(spec) => Ok(ValidatedTestSpec::BustedNlua(spec.clone())),
93 Self::Command(spec) => Ok(ValidatedTestSpec::Command(spec.clone())),
94 Self::Script(spec) => Ok(ValidatedTestSpec::LuaScript(spec.clone())),
95 Self::AutoDetect => Err(TestSpecError::NoTestSpecDetected),
96 }
97 }
98}
99
100impl ValidatedTestSpec {
101 pub fn args(&self) -> Vec<String> {
102 match self {
103 Self::Busted(spec) => spec.flags.clone(),
104 Self::BustedNlua(spec) => {
105 let mut flags = spec.flags.clone();
106 flags.push("--ignore-lua".into());
110 flags
111 }
112 Self::Command(spec) => spec.flags.clone(),
113 Self::LuaScript(spec) => std::iter::once(spec.script.to_slash_lossy().to_string())
114 .chain(spec.flags.clone())
115 .collect_vec(),
116 }
117 }
118
119 pub(crate) fn test_config(&self, config: &Config) -> Result<Config, ConfigError> {
120 match self {
121 Self::BustedNlua(_) => {
122 let config_builder: ConfigBuilder = config.clone().into();
123
124 #[cfg(not(any(target_os = "macos", target_env = "msvc")))]
128 let lua_version = LuaVersion::Lua51;
129
130 #[cfg(any(target_os = "macos", target_env = "msvc"))]
131 let lua_version = LuaVersion::LuaJIT;
132
133 Ok(config_builder
134 .lua_version(Some(lua_version))
135 .variables(Some(
136 vec![("LUA".to_string(), NLUA_EXE.to_string())]
137 .into_iter()
138 .collect(),
139 ))
140 .build()?)
141 }
142 _ => Ok(config.clone()),
143 }
144 }
145
146 fn test_dependencies(&self) -> Vec<PackageReq> {
147 match self {
148 Self::Busted(_) => unsafe { vec![PackageReq::new_unchecked("busted".into(), None)] },
149 Self::BustedNlua(_) => unsafe {
150 vec![
151 PackageReq::new_unchecked("busted".into(), None),
152 PackageReq::new_unchecked("nlua".into(), None),
153 ]
154 },
155 Self::Command(_) => Vec::new(),
156 Self::LuaScript(_) => Vec::new(),
157 }
158 }
159}
160
161impl IntoLua for TestSpec {
162 fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
163 let table = lua.create_table()?;
164 match self {
165 TestSpec::AutoDetect => table.set("auto_detect", true)?,
166 TestSpec::Busted(busted_test_spec) => table.set("busted", busted_test_spec)?,
167 TestSpec::BustedNlua(busted_test_spec) => table.set("busted-nlua", busted_test_spec)?,
168 TestSpec::Command(command_test_spec) => table.set("command", command_test_spec)?,
169 TestSpec::Script(script_test_spec) => table.set("script", script_test_spec)?,
170 }
171 Ok(mlua::Value::Table(table))
172 }
173}
174
175impl FromPlatformOverridable<TestSpecInternal, Self> for TestSpec {
176 type Err = TestSpecDecodeError;
177
178 fn from_platform_overridable(internal: TestSpecInternal) -> Result<Self, Self::Err> {
179 let test_spec = match internal.test_type {
180 Some(TestType::Busted) => Ok(Self::Busted(BustedTestSpec {
181 flags: internal.flags.unwrap_or_default(),
182 })),
183 Some(TestType::BustedNlua) => Ok(Self::BustedNlua(BustedTestSpec {
184 flags: internal.flags.unwrap_or_default(),
185 })),
186 Some(TestType::Command) => match (internal.command, internal.lua_script) {
187 (None, None) => Err(TestSpecDecodeError::NoCommandOrScript),
188 (None, Some(script)) => Ok(Self::Script(LuaScriptTestSpec {
189 script,
190 flags: internal.flags.unwrap_or_default(),
191 })),
192 (Some(command), None) => Ok(Self::Command(CommandTestSpec {
193 command,
194 flags: internal.flags.unwrap_or_default(),
195 })),
196 (Some(_), Some(_)) => Err(TestSpecDecodeError::CommandAndScript),
197 },
198 None => Ok(Self::default()),
199 }?;
200 Ok(test_spec)
201 }
202}
203
204impl FromLua for PerPlatform<TestSpec> {
205 fn from_lua(
206 value: mlua::prelude::LuaValue,
207 lua: &mlua::prelude::Lua,
208 ) -> mlua::prelude::LuaResult<Self> {
209 let wrapper = PerPlatformWrapper::from_lua(value, lua)?;
210 Ok(wrapper.un_per_platform)
211 }
212}
213
214impl<'de> Deserialize<'de> for TestSpec {
215 fn deserialize<D>(deserializer: D) -> Result<TestSpec, D::Error>
216 where
217 D: Deserializer<'de>,
218 {
219 let internal = TestSpecInternal::deserialize(deserializer)?;
220 let test_spec =
221 TestSpec::from_platform_overridable(internal).map_err(serde::de::Error::custom)?;
222 Ok(test_spec)
223 }
224}
225
226#[derive(Clone, Debug, PartialEq, Default)]
227pub struct BustedTestSpec {
228 pub(crate) flags: Vec<String>,
229}
230
231impl UserData for BustedTestSpec {
232 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
233 methods.add_method("flags", |_, this, _: ()| Ok(this.flags.clone()));
234 }
235}
236
237#[derive(Clone, Debug, PartialEq)]
238pub struct CommandTestSpec {
239 pub(crate) command: String,
240 pub(crate) flags: Vec<String>,
241}
242
243impl UserData for CommandTestSpec {
244 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
245 methods.add_method("command", |_, this, _: ()| Ok(this.command.clone()));
246 methods.add_method("flags", |_, this, _: ()| Ok(this.flags.clone()));
247 }
248}
249
250#[derive(Clone, Debug, PartialEq)]
251pub struct LuaScriptTestSpec {
252 pub(crate) script: PathBuf,
253 pub(crate) flags: Vec<String>,
254}
255
256impl UserData for LuaScriptTestSpec {
257 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
258 methods.add_method("script", |_, this, _: ()| Ok(this.script.clone()));
259 methods.add_method("flags", |_, this, _: ()| Ok(this.flags.clone()));
260 }
261}
262
263#[derive(Debug, Deserialize, Serialize_enum_str, PartialEq, Clone)]
264#[serde(rename_all = "kebab-case")]
265pub(crate) enum TestType {
266 Busted,
267 BustedNlua,
268 Command,
269}
270
271#[derive(Debug, PartialEq, Deserialize, Default, Clone)]
272pub(crate) struct TestSpecInternal {
273 #[serde(default, rename = "type")]
274 pub(crate) test_type: Option<TestType>,
275 #[serde(default)]
276 pub(crate) flags: Option<Vec<String>>,
277 #[serde(default)]
278 pub(crate) command: Option<String>,
279 #[serde(default, rename = "script", alias = "lua_script")]
280 pub(crate) lua_script: Option<PathBuf>,
281}
282
283impl PartialOverride for TestSpecInternal {
284 type Err = Infallible;
285
286 fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
287 Ok(TestSpecInternal {
288 test_type: override_opt(&override_spec.test_type, &self.test_type),
289 flags: match (override_spec.flags.clone(), self.flags.clone()) {
290 (Some(override_vec), Some(base_vec)) => {
291 let merged: Vec<String> =
292 base_vec.into_iter().chain(override_vec).unique().collect();
293 Some(merged)
294 }
295 (None, base_vec @ Some(_)) => base_vec,
296 (override_vec @ Some(_), None) => override_vec,
297 _ => None,
298 },
299 command: match override_spec.lua_script.clone() {
300 Some(_) => None,
301 None => override_opt(&override_spec.command, &self.command),
302 },
303 lua_script: match override_spec.command.clone() {
304 Some(_) => None,
305 None => override_opt(&override_spec.lua_script, &self.lua_script),
306 },
307 })
308 }
309}
310
311impl PlatformOverridable for TestSpecInternal {
312 type Err = Infallible;
313
314 fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
315 where
316 T: PlatformOverridable,
317 T: Default,
318 {
319 Ok(PerPlatform::default())
320 }
321}
322
323fn override_opt<T: Clone>(override_opt: &Option<T>, base: &Option<T>) -> Option<T> {
324 match override_opt.clone() {
325 override_val @ Some(_) => override_val,
326 None => base.clone(),
327 }
328}
329
330#[cfg(test)]
331mod tests {
332
333 use mlua::{Error, FromLua, Lua};
334
335 use crate::lua_rockspec::PlatformIdentifier;
336
337 use super::*;
338
339 #[tokio::test]
340 pub async fn test_spec_from_lua() {
341 let lua_content = "
342 test = {\n
343 }\n
344 ";
345 let lua = Lua::new();
346 lua.load(lua_content).exec().unwrap();
347 let test_spec = PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
348 assert!(matches!(test_spec.default, TestSpec::AutoDetect));
349 let lua_content = "
350 test = {\n
351 type = 'busted',\n
352 }\n
353 ";
354 let lua = Lua::new();
355 lua.load(lua_content).exec().unwrap();
356 let test_spec: PerPlatform<TestSpec> =
357 PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
358 assert_eq!(
359 test_spec.default,
360 TestSpec::Busted(BustedTestSpec::default())
361 );
362 let lua_content = "
363 test = {\n
364 type = 'busted',\n
365 flags = { 'foo', 'bar' },\n
366 }\n
367 ";
368 let lua = Lua::new();
369 lua.load(lua_content).exec().unwrap();
370 let test_spec: PerPlatform<TestSpec> =
371 PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
372 assert_eq!(
373 test_spec.default,
374 TestSpec::Busted(BustedTestSpec {
375 flags: vec!["foo".into(), "bar".into()],
376 })
377 );
378 let lua_content = "
379 test = {\n
380 type = 'command',\n
381 }\n
382 ";
383 let lua = Lua::new();
384 lua.load(lua_content).exec().unwrap();
385 let result: Result<PerPlatform<TestSpec>, Error> =
386 PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua);
387 let _err = result.unwrap_err();
388 let lua_content = "
389 test = {\n
390 type = 'command',\n
391 command = 'foo',\n
392 script = 'bar',\n
393 }\n
394 ";
395 let lua = Lua::new();
396 lua.load(lua_content).exec().unwrap();
397 let result: Result<PerPlatform<TestSpec>, Error> =
398 PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua);
399 let _err = result.unwrap_err();
400 let lua_content = "
401 test = {\n
402 type = 'command',\n
403 command = 'baz',\n
404 flags = { 'foo', 'bar' },\n
405 }\n
406 ";
407 let lua = Lua::new();
408 lua.load(lua_content).exec().unwrap();
409 let test_spec: PerPlatform<TestSpec> =
410 PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
411 assert_eq!(
412 test_spec.default,
413 TestSpec::Command(CommandTestSpec {
414 command: "baz".into(),
415 flags: vec!["foo".into(), "bar".into()],
416 })
417 );
418 let lua_content = "
419 test = {\n
420 type = 'command',\n
421 script = 'test.lua',\n
422 flags = { 'foo', 'bar' },\n
423 }\n
424 ";
425 let lua = Lua::new();
426 lua.load(lua_content).exec().unwrap();
427 let test_spec: PerPlatform<TestSpec> =
428 PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
429 assert_eq!(
430 test_spec.default,
431 TestSpec::Script(LuaScriptTestSpec {
432 script: PathBuf::from("test.lua"),
433 flags: vec!["foo".into(), "bar".into()],
434 })
435 );
436 let lua_content = "
437 test = {\n
438 type = 'command',\n
439 command = 'baz',\n
440 flags = { 'foo', 'bar' },\n
441 platforms = {\n
442 unix = { flags = { 'baz' }, },\n
443 macosx = {\n
444 script = 'bat.lua',\n
445 flags = { 'bat' },\n
446 },\n
447 linux = { type = 'busted' },\n
448 },\n
449 }\n
450 ";
451 let lua = Lua::new();
452 lua.load(lua_content).exec().unwrap();
453 let test_spec: PerPlatform<TestSpec> =
454 PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
455 assert_eq!(
456 test_spec.default,
457 TestSpec::Command(CommandTestSpec {
458 command: "baz".into(),
459 flags: vec!["foo".into(), "bar".into()],
460 })
461 );
462 let unix = test_spec
463 .per_platform
464 .get(&PlatformIdentifier::Unix)
465 .unwrap();
466 assert_eq!(
467 *unix,
468 TestSpec::Command(CommandTestSpec {
469 command: "baz".into(),
470 flags: vec!["foo".into(), "bar".into(), "baz".into()],
471 })
472 );
473 let macosx = test_spec
474 .per_platform
475 .get(&PlatformIdentifier::MacOSX)
476 .unwrap();
477 assert_eq!(
478 *macosx,
479 TestSpec::Script(LuaScriptTestSpec {
480 script: "bat.lua".into(),
481 flags: vec!["foo".into(), "bar".into(), "bat".into(), "baz".into()],
482 })
483 );
484 let linux = test_spec
485 .per_platform
486 .get(&PlatformIdentifier::Linux)
487 .unwrap();
488 assert_eq!(
489 *linux,
490 TestSpec::Busted(BustedTestSpec {
491 flags: vec!["foo".into(), "bar".into(), "baz".into()],
492 })
493 );
494 let lua_content = "
495 test = {\n
496 type = 'busted-nlua',\n
497 }";
498 let lua = Lua::new();
499 lua.load(lua_content).exec().unwrap();
500 let test_spec: PerPlatform<TestSpec> =
501 PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
502 assert_eq!(
503 test_spec.default,
504 TestSpec::BustedNlua(BustedTestSpec { flags: Vec::new() })
505 );
506 }
507}