lux_lib/lua_rockspec/
test_spec.rs

1use itertools::Itertools;
2use mlua::{FromLua, IntoLua, UserData};
3use path_slash::PathExt;
4use serde_enum_str::Serialize_enum_str;
5use std::{convert::Infallible, path::PathBuf};
6use thiserror::Error;
7
8use serde::{Deserialize, Deserializer};
9
10use crate::{
11    config::{Config, ConfigBuilder, ConfigError, LuaVersion},
12    package::PackageReq,
13    project::{project_toml::LocalProjectTomlValidationError, Project},
14    rockspec::Rockspec,
15};
16
17use super::{
18    DisplayAsLuaKV, DisplayLuaKV, DisplayLuaValue, FromPlatformOverridable, PartialOverride,
19    PerPlatform, PerPlatformWrapper, PlatformOverridable,
20};
21
22#[cfg(target_family = "unix")]
23const NLUA_EXE: &str = "nlua";
24#[cfg(target_family = "windows")]
25const NLUA_EXE: &str = "nlua.bat";
26
27#[derive(Error, Debug)]
28pub enum TestSpecDecodeError {
29    #[error("'command' test type must specify 'command' or 'script' field")]
30    NoCommandOrScript,
31    #[error("'command' test type cannot have both 'command' and 'script' fields")]
32    CommandAndScript,
33}
34
35#[derive(Error, Debug)]
36pub enum TestSpecError {
37    #[error("could not auto-detect test spec. Please add one to your lux.toml")]
38    NoTestSpecDetected,
39    #[error(transparent)]
40    LocalProjectTomlValidation(#[from] LocalProjectTomlValidationError),
41}
42
43#[derive(Clone, Debug, PartialEq)]
44pub enum TestSpec {
45    AutoDetect,
46    Busted(BustedTestSpec),
47    BustedNlua(BustedTestSpec),
48    Command(CommandTestSpec),
49    Script(LuaScriptTestSpec),
50}
51
52#[derive(Clone, Debug, PartialEq)]
53pub(crate) enum ValidatedTestSpec {
54    Busted(BustedTestSpec),
55    BustedNlua(BustedTestSpec),
56    Command(CommandTestSpec),
57    LuaScript(LuaScriptTestSpec),
58}
59
60impl TestSpec {
61    pub(crate) fn test_dependencies(&self, project: &Project) -> Vec<PackageReq> {
62        self.to_validated(project)
63            .ok()
64            .iter()
65            .flat_map(|spec| spec.test_dependencies())
66            .collect_vec()
67    }
68
69    pub(crate) fn to_validated(
70        &self,
71        project: &Project,
72    ) -> Result<ValidatedTestSpec, TestSpecError> {
73        let project_root = project.root();
74        let toml = project.toml().into_local()?;
75        let test_dependencies = toml.test_dependencies().current_platform();
76        let is_busted = project_root.join(".busted").is_file()
77            || test_dependencies
78                .iter()
79                .any(|dep| dep.name().to_string() == "busted");
80        match self {
81            Self::AutoDetect if is_busted => {
82                if test_dependencies
83                    .iter()
84                    .any(|dep| dep.name().to_string() == "nlua")
85                {
86                    Ok(ValidatedTestSpec::BustedNlua(BustedTestSpec::default()))
87                } else {
88                    Ok(ValidatedTestSpec::Busted(BustedTestSpec::default()))
89                }
90            }
91            Self::Busted(spec) => Ok(ValidatedTestSpec::Busted(spec.clone())),
92            Self::BustedNlua(spec) => Ok(ValidatedTestSpec::BustedNlua(spec.clone())),
93            Self::Command(spec) => Ok(ValidatedTestSpec::Command(spec.clone())),
94            Self::Script(spec) => Ok(ValidatedTestSpec::LuaScript(spec.clone())),
95            Self::AutoDetect => Err(TestSpecError::NoTestSpecDetected),
96        }
97    }
98}
99
100impl ValidatedTestSpec {
101    pub fn args(&self) -> Vec<String> {
102        match self {
103            Self::Busted(spec) => spec.flags.clone(),
104            Self::BustedNlua(spec) => spec.flags.clone(),
105            Self::Command(spec) => spec.flags.clone(),
106            Self::LuaScript(spec) => std::iter::once(spec.script.to_slash_lossy().to_string())
107                .chain(spec.flags.clone())
108                .collect_vec(),
109        }
110    }
111
112    pub(crate) fn test_config(&self, config: &Config) -> Result<Config, ConfigError> {
113        match self {
114            Self::BustedNlua(_) => {
115                let config_builder: ConfigBuilder = config.clone().into();
116                Ok(config_builder
117                    .lua_version(Some(LuaVersion::Lua51))
118                    .variables(Some(
119                        vec![("LUA".to_string(), NLUA_EXE.to_string())]
120                            .into_iter()
121                            .collect(),
122                    ))
123                    .build()?)
124            }
125            _ => Ok(config.clone()),
126        }
127    }
128
129    fn test_dependencies(&self) -> Vec<PackageReq> {
130        match self {
131            Self::Busted(_) => vec![PackageReq::new("busted".into(), None).unwrap()],
132            Self::BustedNlua(_) => vec![
133                PackageReq::new("busted".into(), None).unwrap(),
134                PackageReq::new("nlua".into(), None).unwrap(),
135            ],
136            Self::Command(_) => Vec::new(),
137            Self::LuaScript(_) => Vec::new(),
138        }
139    }
140}
141
142impl Default for TestSpec {
143    fn default() -> Self {
144        Self::AutoDetect
145    }
146}
147
148impl IntoLua for TestSpec {
149    fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
150        let table = lua.create_table()?;
151        match self {
152            TestSpec::AutoDetect => table.set("auto_detect", true)?,
153            TestSpec::Busted(busted_test_spec) => table.set("busted", busted_test_spec)?,
154            TestSpec::BustedNlua(busted_test_spec) => table.set("busted-nlua", busted_test_spec)?,
155            TestSpec::Command(command_test_spec) => table.set("command", command_test_spec)?,
156            TestSpec::Script(script_test_spec) => table.set("script", script_test_spec)?,
157        }
158        Ok(mlua::Value::Table(table))
159    }
160}
161
162impl FromPlatformOverridable<TestSpecInternal, Self> for TestSpec {
163    type Err = TestSpecDecodeError;
164
165    fn from_platform_overridable(internal: TestSpecInternal) -> Result<Self, Self::Err> {
166        let test_spec = match internal.test_type {
167            Some(TestType::Busted) => Ok(Self::Busted(BustedTestSpec {
168                flags: internal.flags.unwrap_or_default(),
169            })),
170            Some(TestType::Command) => match (internal.command, internal.lua_script) {
171                (None, None) => Err(TestSpecDecodeError::NoCommandOrScript),
172                (None, Some(script)) => Ok(Self::Script(LuaScriptTestSpec {
173                    script,
174                    flags: internal.flags.unwrap_or_default(),
175                })),
176                (Some(command), None) => Ok(Self::Command(CommandTestSpec {
177                    command,
178                    flags: internal.flags.unwrap_or_default(),
179                })),
180                (Some(_), Some(_)) => Err(TestSpecDecodeError::CommandAndScript),
181            },
182            None => Ok(Self::default()),
183        }?;
184        Ok(test_spec)
185    }
186}
187
188impl FromLua for PerPlatform<TestSpec> {
189    fn from_lua(
190        value: mlua::prelude::LuaValue,
191        lua: &mlua::prelude::Lua,
192    ) -> mlua::prelude::LuaResult<Self> {
193        let wrapper = PerPlatformWrapper::from_lua(value, lua)?;
194        Ok(wrapper.un_per_platform)
195    }
196}
197
198impl<'de> Deserialize<'de> for TestSpec {
199    fn deserialize<D>(deserializer: D) -> Result<TestSpec, D::Error>
200    where
201        D: Deserializer<'de>,
202    {
203        let internal = TestSpecInternal::deserialize(deserializer)?;
204        let test_spec =
205            TestSpec::from_platform_overridable(internal).map_err(serde::de::Error::custom)?;
206        Ok(test_spec)
207    }
208}
209
210#[derive(Clone, Debug, PartialEq, Default)]
211pub struct BustedTestSpec {
212    pub(crate) flags: Vec<String>,
213}
214
215impl UserData for BustedTestSpec {
216    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
217        methods.add_method("flags", |_, this, _: ()| Ok(this.flags.clone()));
218    }
219}
220
221#[derive(Clone, Debug, PartialEq)]
222pub struct CommandTestSpec {
223    pub(crate) command: String,
224    pub(crate) flags: Vec<String>,
225}
226
227impl UserData for CommandTestSpec {
228    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
229        methods.add_method("command", |_, this, _: ()| Ok(this.command.clone()));
230        methods.add_method("flags", |_, this, _: ()| Ok(this.flags.clone()));
231    }
232}
233
234#[derive(Clone, Debug, PartialEq)]
235pub struct LuaScriptTestSpec {
236    pub(crate) script: PathBuf,
237    pub(crate) flags: Vec<String>,
238}
239
240impl UserData for LuaScriptTestSpec {
241    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
242        methods.add_method("script", |_, this, _: ()| Ok(this.script.clone()));
243        methods.add_method("flags", |_, this, _: ()| Ok(this.flags.clone()));
244    }
245}
246
247#[derive(Debug, Deserialize, Serialize_enum_str, PartialEq, Clone)]
248#[serde(rename_all = "lowercase")]
249pub(crate) enum TestType {
250    Busted,
251    Command,
252}
253
254#[derive(Debug, PartialEq, Deserialize, Default, Clone)]
255pub(crate) struct TestSpecInternal {
256    #[serde(default, rename = "type")]
257    pub(crate) test_type: Option<TestType>,
258    #[serde(default)]
259    pub(crate) flags: Option<Vec<String>>,
260    #[serde(default)]
261    pub(crate) command: Option<String>,
262    #[serde(default, rename = "script", alias = "lua_script")]
263    pub(crate) lua_script: Option<PathBuf>,
264}
265
266impl PartialOverride for TestSpecInternal {
267    type Err = Infallible;
268
269    fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
270        Ok(TestSpecInternal {
271            test_type: override_opt(&override_spec.test_type, &self.test_type),
272            flags: match (override_spec.flags.clone(), self.flags.clone()) {
273                (Some(override_vec), Some(base_vec)) => {
274                    let merged: Vec<String> =
275                        base_vec.into_iter().chain(override_vec).unique().collect();
276                    Some(merged)
277                }
278                (None, base_vec @ Some(_)) => base_vec,
279                (override_vec @ Some(_), None) => override_vec,
280                _ => None,
281            },
282            command: match override_spec.lua_script.clone() {
283                Some(_) => None,
284                None => override_opt(&override_spec.command, &self.command),
285            },
286            lua_script: match override_spec.command.clone() {
287                Some(_) => None,
288                None => override_opt(&override_spec.lua_script, &self.lua_script),
289            },
290        })
291    }
292}
293
294impl PlatformOverridable for TestSpecInternal {
295    type Err = Infallible;
296
297    fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
298    where
299        T: PlatformOverridable,
300        T: Default,
301    {
302        Ok(PerPlatform::default())
303    }
304}
305
306fn override_opt<T: Clone>(override_opt: &Option<T>, base: &Option<T>) -> Option<T> {
307    match override_opt.clone() {
308        override_val @ Some(_) => override_val,
309        None => base.clone(),
310    }
311}
312
313impl DisplayAsLuaKV for TestSpecInternal {
314    fn display_lua(&self) -> DisplayLuaKV {
315        let mut result = Vec::new();
316
317        if let Some(test_type) = &self.test_type {
318            result.push(DisplayLuaKV {
319                key: "type".to_string(),
320                value: DisplayLuaValue::String(test_type.to_string()),
321            });
322        }
323        if let Some(flags) = &self.flags {
324            result.push(DisplayLuaKV {
325                key: "flags".to_string(),
326                value: DisplayLuaValue::List(
327                    flags
328                        .iter()
329                        .map(|flag| DisplayLuaValue::String(flag.clone()))
330                        .collect(),
331                ),
332            });
333        }
334        if let Some(command) = &self.command {
335            result.push(DisplayLuaKV {
336                key: "command".to_string(),
337                value: DisplayLuaValue::String(command.clone()),
338            });
339        }
340        if let Some(script) = &self.lua_script {
341            result.push(DisplayLuaKV {
342                key: "script".to_string(),
343                value: DisplayLuaValue::String(script.to_string_lossy().to_string()),
344            });
345        }
346
347        DisplayLuaKV {
348            key: "test".to_string(),
349            value: DisplayLuaValue::Table(result),
350        }
351    }
352}
353
354#[cfg(test)]
355mod tests {
356
357    use mlua::{Error, FromLua, Lua};
358
359    use crate::lua_rockspec::PlatformIdentifier;
360
361    use super::*;
362
363    #[tokio::test]
364    pub async fn test_spec_from_lua() {
365        let lua_content = "
366        test = {\n
367        }\n
368        ";
369        let lua = Lua::new();
370        lua.load(lua_content).exec().unwrap();
371        let test_spec = PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
372        assert!(matches!(test_spec.default, TestSpec::AutoDetect));
373        let lua_content = "
374        test = {\n
375            type = 'busted',\n
376        }\n
377        ";
378        let lua = Lua::new();
379        lua.load(lua_content).exec().unwrap();
380        let test_spec: PerPlatform<TestSpec> =
381            PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
382        assert_eq!(
383            test_spec.default,
384            TestSpec::Busted(BustedTestSpec::default())
385        );
386        let lua_content = "
387        test = {\n
388            type = 'busted',\n
389            flags = { 'foo', 'bar' },\n
390        }\n
391        ";
392        let lua = Lua::new();
393        lua.load(lua_content).exec().unwrap();
394        let test_spec: PerPlatform<TestSpec> =
395            PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
396        assert_eq!(
397            test_spec.default,
398            TestSpec::Busted(BustedTestSpec {
399                flags: vec!["foo".into(), "bar".into()],
400            })
401        );
402        let lua_content = "
403        test = {\n
404            type = 'command',\n
405        }\n
406        ";
407        let lua = Lua::new();
408        lua.load(lua_content).exec().unwrap();
409        let result: Result<PerPlatform<TestSpec>, Error> =
410            PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua);
411        let _err = result.unwrap_err();
412        let lua_content = "
413        test = {\n
414            type = 'command',\n
415            command = 'foo',\n
416            script = 'bar',\n
417        }\n
418        ";
419        let lua = Lua::new();
420        lua.load(lua_content).exec().unwrap();
421        let result: Result<PerPlatform<TestSpec>, Error> =
422            PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua);
423        let _err = result.unwrap_err();
424        let lua_content = "
425        test = {\n
426            type = 'command',\n
427            command = 'baz',\n
428            flags = { 'foo', 'bar' },\n
429        }\n
430        ";
431        let lua = Lua::new();
432        lua.load(lua_content).exec().unwrap();
433        let test_spec: PerPlatform<TestSpec> =
434            PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
435        assert_eq!(
436            test_spec.default,
437            TestSpec::Command(CommandTestSpec {
438                command: "baz".into(),
439                flags: vec!["foo".into(), "bar".into()],
440            })
441        );
442        let lua_content = "
443        test = {\n
444            type = 'command',\n
445            script = 'test.lua',\n
446            flags = { 'foo', 'bar' },\n
447        }\n
448        ";
449        let lua = Lua::new();
450        lua.load(lua_content).exec().unwrap();
451        let test_spec: PerPlatform<TestSpec> =
452            PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
453        assert_eq!(
454            test_spec.default,
455            TestSpec::Script(LuaScriptTestSpec {
456                script: PathBuf::from("test.lua"),
457                flags: vec!["foo".into(), "bar".into()],
458            })
459        );
460        let lua_content = "
461        test = {\n
462            type = 'command',\n
463            command = 'baz',\n
464            flags = { 'foo', 'bar' },\n
465            platforms = {\n
466                unix = { flags = { 'baz' }, },\n
467                macosx = {\n
468                    script = 'bat.lua',\n
469                    flags = { 'bat' },\n
470                },\n
471                linux = { type = 'busted' },\n
472            },\n
473        }\n
474        ";
475        let lua = Lua::new();
476        lua.load(lua_content).exec().unwrap();
477        let test_spec: PerPlatform<TestSpec> =
478            PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
479        assert_eq!(
480            test_spec.default,
481            TestSpec::Command(CommandTestSpec {
482                command: "baz".into(),
483                flags: vec!["foo".into(), "bar".into()],
484            })
485        );
486        let unix = test_spec
487            .per_platform
488            .get(&PlatformIdentifier::Unix)
489            .unwrap();
490        assert_eq!(
491            *unix,
492            TestSpec::Command(CommandTestSpec {
493                command: "baz".into(),
494                flags: vec!["foo".into(), "bar".into(), "baz".into()],
495            })
496        );
497        let macosx = test_spec
498            .per_platform
499            .get(&PlatformIdentifier::MacOSX)
500            .unwrap();
501        assert_eq!(
502            *macosx,
503            TestSpec::Script(LuaScriptTestSpec {
504                script: "bat.lua".into(),
505                flags: vec!["foo".into(), "bar".into(), "bat".into(), "baz".into()],
506            })
507        );
508        let linux = test_spec
509            .per_platform
510            .get(&PlatformIdentifier::Linux)
511            .unwrap();
512        assert_eq!(
513            *linux,
514            TestSpec::Busted(BustedTestSpec {
515                flags: vec!["foo".into(), "bar".into(), "baz".into()],
516            })
517        );
518    }
519}