Skip to main content

lux_lib/lua_rockspec/
test_spec.rs

1use itertools::Itertools;
2
3use path_slash::PathExt;
4use serde_enum_str::Serialize_enum_str;
5use std::{convert::Infallible, path::PathBuf};
6use thiserror::Error;
7
8use serde::{Deserialize, Deserializer};
9
10use crate::{
11    config::{Config, ConfigBuilder, ConfigError, LuaVersion},
12    lua_rockspec::per_platform_from_intermediate,
13    package::PackageReq,
14    project::{project_toml::LocalProjectTomlValidationError, Project},
15    rockspec::Rockspec,
16};
17
18use super::{PartialOverride, PerPlatform, PlatformOverridable};
19
20#[cfg(target_family = "unix")]
21const NLUA_EXE: &str = "nlua";
22#[cfg(target_family = "windows")]
23const NLUA_EXE: &str = "nlua.bat";
24
25#[derive(Error, Debug)]
26pub enum TestSpecDecodeError {
27    #[error("the 'command' test type must specify either a 'command' or 'script' field")]
28    NoCommandOrScript,
29    #[error("the 'command' test type cannot have both 'command' and 'script' fields")]
30    CommandAndScript,
31}
32
33#[derive(Error, Debug)]
34pub enum TestSpecError {
35    #[error("could not auto-detect the test spec. Please add one to your lux.toml")]
36    NoTestSpecDetected,
37    #[error("project validation failed:\n{0}")]
38    LocalProjectTomlValidation(#[from] LocalProjectTomlValidationError),
39}
40
41#[derive(Clone, Debug, PartialEq, Default)]
42pub enum TestSpec {
43    #[default]
44    AutoDetect,
45    Busted(BustedTestSpec),
46    BustedNlua(BustedTestSpec),
47    Command(CommandTestSpec),
48    Script(LuaScriptTestSpec),
49}
50
51#[derive(Clone, Debug, PartialEq)]
52pub(crate) enum ValidatedTestSpec {
53    Busted(BustedTestSpec),
54    BustedNlua(BustedTestSpec),
55    Command(CommandTestSpec),
56    LuaScript(LuaScriptTestSpec),
57}
58
59impl TestSpec {
60    pub(crate) fn test_dependencies(&self, project: &Project) -> Vec<PackageReq> {
61        self.to_validated(project)
62            .ok()
63            .iter()
64            .flat_map(|spec| spec.test_dependencies())
65            .collect_vec()
66    }
67
68    pub(crate) fn to_validated(
69        &self,
70        project: &Project,
71    ) -> Result<ValidatedTestSpec, TestSpecError> {
72        let project_root = project.root();
73        let toml = project.toml().into_local()?;
74        let test_dependencies = toml.test_dependencies().current_platform();
75        let is_busted = project_root.join(".busted").is_file()
76            || test_dependencies
77                .iter()
78                .any(|dep| dep.name().to_string() == "busted");
79        match self {
80            Self::AutoDetect if is_busted => {
81                if test_dependencies
82                    .iter()
83                    .any(|dep| dep.name().to_string() == "nlua")
84                {
85                    Ok(ValidatedTestSpec::BustedNlua(BustedTestSpec::default()))
86                } else {
87                    Ok(ValidatedTestSpec::Busted(BustedTestSpec::default()))
88                }
89            }
90            Self::Busted(spec) => Ok(ValidatedTestSpec::Busted(spec.clone())),
91            Self::BustedNlua(spec) => Ok(ValidatedTestSpec::BustedNlua(spec.clone())),
92            Self::Command(spec) => Ok(ValidatedTestSpec::Command(spec.clone())),
93            Self::Script(spec) => Ok(ValidatedTestSpec::LuaScript(spec.clone())),
94            Self::AutoDetect => Err(TestSpecError::NoTestSpecDetected),
95        }
96    }
97}
98
99impl ValidatedTestSpec {
100    pub fn args(&self) -> Vec<String> {
101        match self {
102            Self::Busted(spec) => spec.flags.clone(),
103            Self::BustedNlua(spec) => {
104                let mut flags = spec.flags.clone();
105                // If there's a .busted config file which has lua set to "nlua",
106                // we tell busted to ignore it, because we set the correct
107                // platform-dependent "nlua" executable in the wrapper script.
108                flags.push("--ignore-lua".into());
109                flags
110            }
111            Self::Command(spec) => spec.flags.clone(),
112            Self::LuaScript(spec) => std::iter::once(spec.script.to_slash_lossy().to_string())
113                .chain(spec.flags.clone())
114                .collect_vec(),
115        }
116    }
117
118    pub(crate) fn test_config(&self, config: &Config) -> Result<Config, ConfigError> {
119        match self {
120            Self::BustedNlua(_) => {
121                let config_builder: ConfigBuilder = config.clone().into();
122
123                // XXX: On macos and msvc, Neovim and LuaJIT segfault when
124                // requiring luafilesystem's `lfs` module
125                // if it is built with Lua 5.1 headers.
126                #[cfg(not(any(target_os = "macos", target_env = "msvc")))]
127                let lua_version = LuaVersion::Lua51;
128
129                #[cfg(any(target_os = "macos", target_env = "msvc"))]
130                let lua_version = LuaVersion::LuaJIT;
131
132                Ok(config_builder
133                    .lua_version(Some(lua_version))
134                    .variables(Some(
135                        vec![("LUA".to_string(), NLUA_EXE.to_string())]
136                            .into_iter()
137                            .collect(),
138                    ))
139                    .build()?)
140            }
141            _ => Ok(config.clone()),
142        }
143    }
144
145    fn test_dependencies(&self) -> Vec<PackageReq> {
146        match self {
147            Self::Busted(_) => unsafe { vec![PackageReq::new_unchecked("busted".into(), None)] },
148            Self::BustedNlua(_) => unsafe {
149                vec![
150                    PackageReq::new_unchecked("busted".into(), None),
151                    PackageReq::new_unchecked("nlua".into(), None),
152                ]
153            },
154            Self::Command(_) => Vec::new(),
155            Self::LuaScript(_) => Vec::new(),
156        }
157    }
158}
159
160impl TryFrom<TestSpecInternal> for TestSpec {
161    type Error = TestSpecDecodeError;
162
163    fn try_from(internal: TestSpecInternal) -> Result<Self, Self::Error> {
164        let test_spec = match internal.test_type {
165            Some(TestType::Busted) => Ok(Self::Busted(BustedTestSpec {
166                flags: internal.flags.unwrap_or_default(),
167            })),
168            Some(TestType::BustedNlua) => Ok(Self::BustedNlua(BustedTestSpec {
169                flags: internal.flags.unwrap_or_default(),
170            })),
171            Some(TestType::Command) => match (internal.command, internal.lua_script) {
172                (None, None) => Err(TestSpecDecodeError::NoCommandOrScript),
173                (None, Some(script)) => Ok(Self::Script(LuaScriptTestSpec {
174                    script,
175                    flags: internal.flags.unwrap_or_default(),
176                })),
177                (Some(command), None) => Ok(Self::Command(CommandTestSpec {
178                    command,
179                    flags: internal.flags.unwrap_or_default(),
180                })),
181                (Some(_), Some(_)) => Err(TestSpecDecodeError::CommandAndScript),
182            },
183            None => Ok(Self::default()),
184        }?;
185        Ok(test_spec)
186    }
187}
188
189impl<'de> Deserialize<'de> for PerPlatform<TestSpec> {
190    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
191    where
192        D: Deserializer<'de>,
193    {
194        per_platform_from_intermediate::<_, TestSpecInternal, _>(deserializer)
195    }
196}
197
198impl<'de> Deserialize<'de> for TestSpec {
199    fn deserialize<D>(deserializer: D) -> Result<TestSpec, D::Error>
200    where
201        D: Deserializer<'de>,
202    {
203        let internal = TestSpecInternal::deserialize(deserializer)?;
204        let test_spec = TestSpec::try_from(internal).map_err(serde::de::Error::custom)?;
205        Ok(test_spec)
206    }
207}
208
209#[derive(Clone, Debug, PartialEq, Default)]
210pub struct BustedTestSpec {
211    pub(crate) flags: Vec<String>,
212}
213
214impl BustedTestSpec {
215    pub fn flags(&self) -> &Vec<String> {
216        &self.flags
217    }
218}
219
220#[derive(Clone, Debug, PartialEq)]
221pub struct CommandTestSpec {
222    pub(crate) command: String,
223    pub(crate) flags: Vec<String>,
224}
225
226impl CommandTestSpec {
227    pub fn flags(&self) -> &Vec<String> {
228        &self.flags
229    }
230
231    pub fn command(&self) -> &str {
232        &self.command
233    }
234}
235
236#[derive(Clone, Debug, PartialEq)]
237pub struct LuaScriptTestSpec {
238    pub(crate) script: PathBuf,
239    pub(crate) flags: Vec<String>,
240}
241
242impl LuaScriptTestSpec {
243    pub fn flags(&self) -> &Vec<String> {
244        &self.flags
245    }
246
247    pub fn script(&self) -> &PathBuf {
248        &self.script
249    }
250}
251
252#[derive(Debug, Deserialize, Serialize_enum_str, PartialEq, Clone)]
253#[serde(rename_all = "kebab-case")]
254pub(crate) enum TestType {
255    Busted,
256    BustedNlua,
257    Command,
258}
259
260#[derive(Debug, PartialEq, Deserialize, Default, Clone)]
261pub(crate) struct TestSpecInternal {
262    #[serde(default, rename = "type")]
263    pub(crate) test_type: Option<TestType>,
264    #[serde(default)]
265    pub(crate) flags: Option<Vec<String>>,
266    #[serde(default)]
267    pub(crate) command: Option<String>,
268    #[serde(default, rename = "script", alias = "lua_script")]
269    pub(crate) lua_script: Option<PathBuf>,
270}
271
272impl PartialOverride for TestSpecInternal {
273    type Err = Infallible;
274
275    fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
276        Ok(TestSpecInternal {
277            test_type: override_opt(&override_spec.test_type, &self.test_type),
278            flags: match (override_spec.flags.clone(), self.flags.clone()) {
279                (Some(override_vec), Some(base_vec)) => {
280                    let merged: Vec<String> =
281                        base_vec.into_iter().chain(override_vec).unique().collect();
282                    Some(merged)
283                }
284                (None, base_vec @ Some(_)) => base_vec,
285                (override_vec @ Some(_), None) => override_vec,
286                _ => None,
287            },
288            command: match override_spec.lua_script.clone() {
289                Some(_) => None,
290                None => override_opt(&override_spec.command, &self.command),
291            },
292            lua_script: match override_spec.command.clone() {
293                Some(_) => None,
294                None => override_opt(&override_spec.lua_script, &self.lua_script),
295            },
296        })
297    }
298}
299
300impl PlatformOverridable for TestSpecInternal {
301    type Err = Infallible;
302
303    fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
304    where
305        T: PlatformOverridable,
306        T: Default,
307    {
308        Ok(PerPlatform::default())
309    }
310}
311
312fn override_opt<T: Clone>(override_opt: &Option<T>, base: &Option<T>) -> Option<T> {
313    match override_opt.clone() {
314        override_val @ Some(_) => override_val,
315        None => base.clone(),
316    }
317}
318
319#[cfg(test)]
320mod tests {
321
322    use ottavino::{Closure, Executor, Fuel, Lua};
323    use ottavino_util::serde::from_value;
324
325    use crate::lua_rockspec::PlatformIdentifier;
326
327    use super::*;
328
329    fn exec_lua<T: serde::de::DeserializeOwned>(
330        code: &str,
331        key: &'static str,
332    ) -> Result<T, ottavino::ExternError> {
333        Lua::core().try_enter(|ctx| {
334            let closure = Closure::load(ctx, None, code.as_bytes())?;
335            let executor = Executor::start(ctx, closure.into(), ());
336            executor.step(ctx, &mut Fuel::with(i32::MAX))?;
337            from_value(ctx.globals().get_value(ctx, key)).map_err(ottavino::Error::from)
338        })
339    }
340
341    #[tokio::test]
342    pub async fn test_spec_from_lua() {
343        let lua_content = "
344        test = {\n
345        }\n
346        ";
347        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
348        assert!(matches!(test_spec.default, TestSpec::AutoDetect));
349        let lua_content = "
350        test = {\n
351            type = 'busted',\n
352        }\n
353        ";
354        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
355        assert_eq!(
356            test_spec.default,
357            TestSpec::Busted(BustedTestSpec::default())
358        );
359        let lua_content = "
360        test = {\n
361            type = 'busted',\n
362            flags = { 'foo', 'bar' },\n
363        }\n
364        ";
365        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
366        assert_eq!(
367            test_spec.default,
368            TestSpec::Busted(BustedTestSpec {
369                flags: vec!["foo".into(), "bar".into()],
370            })
371        );
372        let lua_content = "
373        test = {\n
374            type = 'command',\n
375        }\n
376        ";
377        let result: Result<PerPlatform<TestSpec>, _> = exec_lua(lua_content, "test");
378        let _err = result.unwrap_err();
379        let lua_content = "
380        test = {\n
381            type = 'command',\n
382            command = 'foo',\n
383            script = 'bar',\n
384        }\n
385        ";
386        let result: Result<PerPlatform<TestSpec>, _> = exec_lua(lua_content, "test");
387        let _err = result.unwrap_err();
388        let lua_content = "
389        test = {\n
390            type = 'command',\n
391            command = 'baz',\n
392            flags = { 'foo', 'bar' },\n
393        }\n
394        ";
395        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
396        assert_eq!(
397            test_spec.default,
398            TestSpec::Command(CommandTestSpec {
399                command: "baz".into(),
400                flags: vec!["foo".into(), "bar".into()],
401            })
402        );
403        let lua_content = "
404        test = {\n
405            type = 'command',\n
406            script = 'test.lua',\n
407            flags = { 'foo', 'bar' },\n
408        }\n
409        ";
410        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
411        assert_eq!(
412            test_spec.default,
413            TestSpec::Script(LuaScriptTestSpec {
414                script: PathBuf::from("test.lua"),
415                flags: vec!["foo".into(), "bar".into()],
416            })
417        );
418        let lua_content = "
419        test = {\n
420            type = 'command',\n
421            command = 'baz',\n
422            flags = { 'foo', 'bar' },\n
423            platforms = {\n
424                unix = { flags = { 'baz' }, },\n
425                macosx = {\n
426                    script = 'bat.lua',\n
427                    flags = { 'bat' },\n
428                },\n
429                linux = { type = 'busted' },\n
430            },\n
431        }\n
432        ";
433        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
434        assert_eq!(
435            test_spec.default,
436            TestSpec::Command(CommandTestSpec {
437                command: "baz".into(),
438                flags: vec!["foo".into(), "bar".into()],
439            })
440        );
441        let unix = test_spec
442            .per_platform
443            .get(&PlatformIdentifier::Unix)
444            .unwrap();
445        assert_eq!(
446            *unix,
447            TestSpec::Command(CommandTestSpec {
448                command: "baz".into(),
449                flags: vec!["foo".into(), "bar".into(), "baz".into()],
450            })
451        );
452        let macosx = test_spec
453            .per_platform
454            .get(&PlatformIdentifier::MacOSX)
455            .unwrap();
456        assert_eq!(
457            *macosx,
458            TestSpec::Script(LuaScriptTestSpec {
459                script: "bat.lua".into(),
460                flags: vec!["foo".into(), "bar".into(), "bat".into(), "baz".into()],
461            })
462        );
463        let linux = test_spec
464            .per_platform
465            .get(&PlatformIdentifier::Linux)
466            .unwrap();
467        assert_eq!(
468            *linux,
469            TestSpec::Busted(BustedTestSpec {
470                flags: vec!["foo".into(), "bar".into(), "baz".into()],
471            })
472        );
473        let lua_content = "
474        test = {\n
475            type = 'busted-nlua',\n
476        }";
477        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
478        assert_eq!(
479            test_spec.default,
480            TestSpec::BustedNlua(BustedTestSpec { flags: Vec::new() })
481        );
482    }
483}