1use itertools::Itertools;
2use mlua::{FromLua, IntoLua, UserData};
3use path_slash::PathExt;
4use serde_enum_str::Serialize_enum_str;
5use std::{convert::Infallible, path::PathBuf};
6use thiserror::Error;
7
8use serde::{Deserialize, Deserializer};
9
10use crate::{
11 config::{Config, ConfigBuilder, ConfigError, LuaVersion},
12 package::PackageReq,
13 project::{project_toml::LocalProjectTomlValidationError, Project},
14 rockspec::Rockspec,
15};
16
17use super::{
18 DisplayAsLuaKV, DisplayLuaKV, DisplayLuaValue, FromPlatformOverridable, PartialOverride,
19 PerPlatform, PerPlatformWrapper, PlatformOverridable,
20};
21
22#[cfg(target_family = "unix")]
23const NLUA_EXE: &str = "nlua";
24#[cfg(target_family = "windows")]
25const NLUA_EXE: &str = "nlua.bat";
26
27#[derive(Error, Debug)]
28pub enum TestSpecDecodeError {
29 #[error("'command' test type must specify 'command' or 'script' field")]
30 NoCommandOrScript,
31 #[error("'command' test type cannot have both 'command' and 'script' fields")]
32 CommandAndScript,
33}
34
35#[derive(Error, Debug)]
36pub enum TestSpecError {
37 #[error("could not auto-detect test spec. Please add one to your lux.toml")]
38 NoTestSpecDetected,
39 #[error(transparent)]
40 LocalProjectTomlValidation(#[from] LocalProjectTomlValidationError),
41}
42
43#[derive(Clone, Debug, PartialEq)]
44pub enum TestSpec {
45 AutoDetect,
46 Busted(BustedTestSpec),
47 BustedNlua(BustedTestSpec),
48 Command(CommandTestSpec),
49 Script(LuaScriptTestSpec),
50}
51
52#[derive(Clone, Debug, PartialEq)]
53pub(crate) enum ValidatedTestSpec {
54 Busted(BustedTestSpec),
55 BustedNlua(BustedTestSpec),
56 Command(CommandTestSpec),
57 LuaScript(LuaScriptTestSpec),
58}
59
60impl TestSpec {
61 pub(crate) fn test_dependencies(&self, project: &Project) -> Vec<PackageReq> {
62 self.to_validated(project)
63 .ok()
64 .iter()
65 .flat_map(|spec| spec.test_dependencies())
66 .collect_vec()
67 }
68
69 pub(crate) fn to_validated(
70 &self,
71 project: &Project,
72 ) -> Result<ValidatedTestSpec, TestSpecError> {
73 let project_root = project.root();
74 let toml = project.toml().into_local()?;
75 let test_dependencies = toml.test_dependencies().current_platform();
76 let is_busted = project_root.join(".busted").is_file()
77 || test_dependencies
78 .iter()
79 .any(|dep| dep.name().to_string() == "busted");
80 match self {
81 Self::AutoDetect if is_busted => {
82 if test_dependencies
83 .iter()
84 .any(|dep| dep.name().to_string() == "nlua")
85 {
86 Ok(ValidatedTestSpec::BustedNlua(BustedTestSpec::default()))
87 } else {
88 Ok(ValidatedTestSpec::Busted(BustedTestSpec::default()))
89 }
90 }
91 Self::Busted(spec) => Ok(ValidatedTestSpec::Busted(spec.clone())),
92 Self::BustedNlua(spec) => Ok(ValidatedTestSpec::BustedNlua(spec.clone())),
93 Self::Command(spec) => Ok(ValidatedTestSpec::Command(spec.clone())),
94 Self::Script(spec) => Ok(ValidatedTestSpec::LuaScript(spec.clone())),
95 Self::AutoDetect => Err(TestSpecError::NoTestSpecDetected),
96 }
97 }
98}
99
100impl ValidatedTestSpec {
101 pub fn args(&self) -> Vec<String> {
102 match self {
103 Self::Busted(spec) => spec.flags.clone(),
104 Self::BustedNlua(spec) => spec.flags.clone(),
105 Self::Command(spec) => spec.flags.clone(),
106 Self::LuaScript(spec) => std::iter::once(spec.script.to_slash_lossy().to_string())
107 .chain(spec.flags.clone())
108 .collect_vec(),
109 }
110 }
111
112 pub(crate) fn test_config(&self, config: &Config) -> Result<Config, ConfigError> {
113 match self {
114 Self::BustedNlua(_) => {
115 let config_builder: ConfigBuilder = config.clone().into();
116 Ok(config_builder
117 .lua_version(Some(LuaVersion::Lua51))
118 .variables(Some(
119 vec![("LUA".to_string(), NLUA_EXE.to_string())]
120 .into_iter()
121 .collect(),
122 ))
123 .build()?)
124 }
125 _ => Ok(config.clone()),
126 }
127 }
128
129 fn test_dependencies(&self) -> Vec<PackageReq> {
130 match self {
131 Self::Busted(_) => vec![PackageReq::new("busted".into(), None).unwrap()],
132 Self::BustedNlua(_) => vec![
133 PackageReq::new("busted".into(), None).unwrap(),
134 PackageReq::new("nlua".into(), None).unwrap(),
135 ],
136 Self::Command(_) => Vec::new(),
137 Self::LuaScript(_) => Vec::new(),
138 }
139 }
140}
141
142impl Default for TestSpec {
143 fn default() -> Self {
144 Self::AutoDetect
145 }
146}
147
148impl IntoLua for TestSpec {
149 fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
150 let table = lua.create_table()?;
151 match self {
152 TestSpec::AutoDetect => table.set("auto_detect", true)?,
153 TestSpec::Busted(busted_test_spec) => table.set("busted", busted_test_spec)?,
154 TestSpec::BustedNlua(busted_test_spec) => table.set("busted-nlua", busted_test_spec)?,
155 TestSpec::Command(command_test_spec) => table.set("command", command_test_spec)?,
156 TestSpec::Script(script_test_spec) => table.set("script", script_test_spec)?,
157 }
158 Ok(mlua::Value::Table(table))
159 }
160}
161
162impl FromPlatformOverridable<TestSpecInternal, Self> for TestSpec {
163 type Err = TestSpecDecodeError;
164
165 fn from_platform_overridable(internal: TestSpecInternal) -> Result<Self, Self::Err> {
166 let test_spec = match internal.test_type {
167 Some(TestType::Busted) => Ok(Self::Busted(BustedTestSpec {
168 flags: internal.flags.unwrap_or_default(),
169 })),
170 Some(TestType::Command) => match (internal.command, internal.lua_script) {
171 (None, None) => Err(TestSpecDecodeError::NoCommandOrScript),
172 (None, Some(script)) => Ok(Self::Script(LuaScriptTestSpec {
173 script,
174 flags: internal.flags.unwrap_or_default(),
175 })),
176 (Some(command), None) => Ok(Self::Command(CommandTestSpec {
177 command,
178 flags: internal.flags.unwrap_or_default(),
179 })),
180 (Some(_), Some(_)) => Err(TestSpecDecodeError::CommandAndScript),
181 },
182 None => Ok(Self::default()),
183 }?;
184 Ok(test_spec)
185 }
186}
187
188impl FromLua for PerPlatform<TestSpec> {
189 fn from_lua(
190 value: mlua::prelude::LuaValue,
191 lua: &mlua::prelude::Lua,
192 ) -> mlua::prelude::LuaResult<Self> {
193 let wrapper = PerPlatformWrapper::from_lua(value, lua)?;
194 Ok(wrapper.un_per_platform)
195 }
196}
197
198impl<'de> Deserialize<'de> for TestSpec {
199 fn deserialize<D>(deserializer: D) -> Result<TestSpec, D::Error>
200 where
201 D: Deserializer<'de>,
202 {
203 let internal = TestSpecInternal::deserialize(deserializer)?;
204 let test_spec =
205 TestSpec::from_platform_overridable(internal).map_err(serde::de::Error::custom)?;
206 Ok(test_spec)
207 }
208}
209
210#[derive(Clone, Debug, PartialEq, Default)]
211pub struct BustedTestSpec {
212 pub(crate) flags: Vec<String>,
213}
214
215impl UserData for BustedTestSpec {
216 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
217 methods.add_method("flags", |_, this, _: ()| Ok(this.flags.clone()));
218 }
219}
220
221#[derive(Clone, Debug, PartialEq)]
222pub struct CommandTestSpec {
223 pub(crate) command: String,
224 pub(crate) flags: Vec<String>,
225}
226
227impl UserData for CommandTestSpec {
228 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
229 methods.add_method("command", |_, this, _: ()| Ok(this.command.clone()));
230 methods.add_method("flags", |_, this, _: ()| Ok(this.flags.clone()));
231 }
232}
233
234#[derive(Clone, Debug, PartialEq)]
235pub struct LuaScriptTestSpec {
236 pub(crate) script: PathBuf,
237 pub(crate) flags: Vec<String>,
238}
239
240impl UserData for LuaScriptTestSpec {
241 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
242 methods.add_method("script", |_, this, _: ()| Ok(this.script.clone()));
243 methods.add_method("flags", |_, this, _: ()| Ok(this.flags.clone()));
244 }
245}
246
247#[derive(Debug, Deserialize, Serialize_enum_str, PartialEq, Clone)]
248#[serde(rename_all = "lowercase")]
249pub(crate) enum TestType {
250 Busted,
251 Command,
252}
253
254#[derive(Debug, PartialEq, Deserialize, Default, Clone)]
255pub(crate) struct TestSpecInternal {
256 #[serde(default, rename = "type")]
257 pub(crate) test_type: Option<TestType>,
258 #[serde(default)]
259 pub(crate) flags: Option<Vec<String>>,
260 #[serde(default)]
261 pub(crate) command: Option<String>,
262 #[serde(default, rename = "script", alias = "lua_script")]
263 pub(crate) lua_script: Option<PathBuf>,
264}
265
266impl PartialOverride for TestSpecInternal {
267 type Err = Infallible;
268
269 fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
270 Ok(TestSpecInternal {
271 test_type: override_opt(&override_spec.test_type, &self.test_type),
272 flags: match (override_spec.flags.clone(), self.flags.clone()) {
273 (Some(override_vec), Some(base_vec)) => {
274 let merged: Vec<String> =
275 base_vec.into_iter().chain(override_vec).unique().collect();
276 Some(merged)
277 }
278 (None, base_vec @ Some(_)) => base_vec,
279 (override_vec @ Some(_), None) => override_vec,
280 _ => None,
281 },
282 command: match override_spec.lua_script.clone() {
283 Some(_) => None,
284 None => override_opt(&override_spec.command, &self.command),
285 },
286 lua_script: match override_spec.command.clone() {
287 Some(_) => None,
288 None => override_opt(&override_spec.lua_script, &self.lua_script),
289 },
290 })
291 }
292}
293
294impl PlatformOverridable for TestSpecInternal {
295 type Err = Infallible;
296
297 fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
298 where
299 T: PlatformOverridable,
300 T: Default,
301 {
302 Ok(PerPlatform::default())
303 }
304}
305
306fn override_opt<T: Clone>(override_opt: &Option<T>, base: &Option<T>) -> Option<T> {
307 match override_opt.clone() {
308 override_val @ Some(_) => override_val,
309 None => base.clone(),
310 }
311}
312
313impl DisplayAsLuaKV for TestSpecInternal {
314 fn display_lua(&self) -> DisplayLuaKV {
315 let mut result = Vec::new();
316
317 if let Some(test_type) = &self.test_type {
318 result.push(DisplayLuaKV {
319 key: "type".to_string(),
320 value: DisplayLuaValue::String(test_type.to_string()),
321 });
322 }
323 if let Some(flags) = &self.flags {
324 result.push(DisplayLuaKV {
325 key: "flags".to_string(),
326 value: DisplayLuaValue::List(
327 flags
328 .iter()
329 .map(|flag| DisplayLuaValue::String(flag.clone()))
330 .collect(),
331 ),
332 });
333 }
334 if let Some(command) = &self.command {
335 result.push(DisplayLuaKV {
336 key: "command".to_string(),
337 value: DisplayLuaValue::String(command.clone()),
338 });
339 }
340 if let Some(script) = &self.lua_script {
341 result.push(DisplayLuaKV {
342 key: "script".to_string(),
343 value: DisplayLuaValue::String(script.to_string_lossy().to_string()),
344 });
345 }
346
347 DisplayLuaKV {
348 key: "test".to_string(),
349 value: DisplayLuaValue::Table(result),
350 }
351 }
352}
353
354#[cfg(test)]
355mod tests {
356
357 use mlua::{Error, FromLua, Lua};
358
359 use crate::lua_rockspec::PlatformIdentifier;
360
361 use super::*;
362
363 #[tokio::test]
364 pub async fn test_spec_from_lua() {
365 let lua_content = "
366 test = {\n
367 }\n
368 ";
369 let lua = Lua::new();
370 lua.load(lua_content).exec().unwrap();
371 let test_spec = PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
372 assert!(matches!(test_spec.default, TestSpec::AutoDetect));
373 let lua_content = "
374 test = {\n
375 type = 'busted',\n
376 }\n
377 ";
378 let lua = Lua::new();
379 lua.load(lua_content).exec().unwrap();
380 let test_spec: PerPlatform<TestSpec> =
381 PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
382 assert_eq!(
383 test_spec.default,
384 TestSpec::Busted(BustedTestSpec::default())
385 );
386 let lua_content = "
387 test = {\n
388 type = 'busted',\n
389 flags = { 'foo', 'bar' },\n
390 }\n
391 ";
392 let lua = Lua::new();
393 lua.load(lua_content).exec().unwrap();
394 let test_spec: PerPlatform<TestSpec> =
395 PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
396 assert_eq!(
397 test_spec.default,
398 TestSpec::Busted(BustedTestSpec {
399 flags: vec!["foo".into(), "bar".into()],
400 })
401 );
402 let lua_content = "
403 test = {\n
404 type = 'command',\n
405 }\n
406 ";
407 let lua = Lua::new();
408 lua.load(lua_content).exec().unwrap();
409 let result: Result<PerPlatform<TestSpec>, Error> =
410 PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua);
411 let _err = result.unwrap_err();
412 let lua_content = "
413 test = {\n
414 type = 'command',\n
415 command = 'foo',\n
416 script = 'bar',\n
417 }\n
418 ";
419 let lua = Lua::new();
420 lua.load(lua_content).exec().unwrap();
421 let result: Result<PerPlatform<TestSpec>, Error> =
422 PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua);
423 let _err = result.unwrap_err();
424 let lua_content = "
425 test = {\n
426 type = 'command',\n
427 command = 'baz',\n
428 flags = { 'foo', 'bar' },\n
429 }\n
430 ";
431 let lua = Lua::new();
432 lua.load(lua_content).exec().unwrap();
433 let test_spec: PerPlatform<TestSpec> =
434 PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
435 assert_eq!(
436 test_spec.default,
437 TestSpec::Command(CommandTestSpec {
438 command: "baz".into(),
439 flags: vec!["foo".into(), "bar".into()],
440 })
441 );
442 let lua_content = "
443 test = {\n
444 type = 'command',\n
445 script = 'test.lua',\n
446 flags = { 'foo', 'bar' },\n
447 }\n
448 ";
449 let lua = Lua::new();
450 lua.load(lua_content).exec().unwrap();
451 let test_spec: PerPlatform<TestSpec> =
452 PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
453 assert_eq!(
454 test_spec.default,
455 TestSpec::Script(LuaScriptTestSpec {
456 script: PathBuf::from("test.lua"),
457 flags: vec!["foo".into(), "bar".into()],
458 })
459 );
460 let lua_content = "
461 test = {\n
462 type = 'command',\n
463 command = 'baz',\n
464 flags = { 'foo', 'bar' },\n
465 platforms = {\n
466 unix = { flags = { 'baz' }, },\n
467 macosx = {\n
468 script = 'bat.lua',\n
469 flags = { 'bat' },\n
470 },\n
471 linux = { type = 'busted' },\n
472 },\n
473 }\n
474 ";
475 let lua = Lua::new();
476 lua.load(lua_content).exec().unwrap();
477 let test_spec: PerPlatform<TestSpec> =
478 PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
479 assert_eq!(
480 test_spec.default,
481 TestSpec::Command(CommandTestSpec {
482 command: "baz".into(),
483 flags: vec!["foo".into(), "bar".into()],
484 })
485 );
486 let unix = test_spec
487 .per_platform
488 .get(&PlatformIdentifier::Unix)
489 .unwrap();
490 assert_eq!(
491 *unix,
492 TestSpec::Command(CommandTestSpec {
493 command: "baz".into(),
494 flags: vec!["foo".into(), "bar".into(), "baz".into()],
495 })
496 );
497 let macosx = test_spec
498 .per_platform
499 .get(&PlatformIdentifier::MacOSX)
500 .unwrap();
501 assert_eq!(
502 *macosx,
503 TestSpec::Script(LuaScriptTestSpec {
504 script: "bat.lua".into(),
505 flags: vec!["foo".into(), "bar".into(), "bat".into(), "baz".into()],
506 })
507 );
508 let linux = test_spec
509 .per_platform
510 .get(&PlatformIdentifier::Linux)
511 .unwrap();
512 assert_eq!(
513 *linux,
514 TestSpec::Busted(BustedTestSpec {
515 flags: vec!["foo".into(), "bar".into(), "baz".into()],
516 })
517 );
518 }
519}