Skip to main content

lux_lib/lua_rockspec/
test_spec.rs

1use itertools::Itertools;
2
3use path_slash::PathExt;
4use serde_enum_str::Serialize_enum_str;
5use std::{convert::Infallible, path::PathBuf};
6use thiserror::Error;
7
8use serde::{Deserialize, Deserializer};
9
10use crate::lua_version::LuaVersion;
11
12use crate::{
13    config::{Config, ConfigBuilder, ConfigError},
14    lua_rockspec::per_platform_from_intermediate,
15    package::PackageReq,
16    project::{project_toml::LocalProjectTomlValidationError, Project},
17    rockspec::Rockspec,
18};
19
20use super::{PartialOverride, PerPlatform, PlatformOverridable};
21
22#[cfg(target_family = "unix")]
23const NLUA_EXE: &str = "nlua";
24#[cfg(target_family = "windows")]
25const NLUA_EXE: &str = "nlua.bat";
26
27#[derive(Error, Debug)]
28pub enum TestSpecDecodeError {
29    #[error("the 'command' test type must specify either a 'command' or 'script' field")]
30    NoCommandOrScript,
31    #[error("the 'command' test type cannot have both 'command' and 'script' fields")]
32    CommandAndScript,
33}
34
35#[derive(Error, Debug)]
36pub enum TestSpecError {
37    #[error("could not auto-detect the test spec. Please add one to your lux.toml")]
38    NoTestSpecDetected,
39    #[error("project validation failed:\n{0}")]
40    LocalProjectTomlValidation(#[from] LocalProjectTomlValidationError),
41}
42
43#[derive(Clone, Debug, PartialEq, Default)]
44pub enum TestSpec {
45    #[default]
46    AutoDetect,
47    Busted(BustedTestSpec),
48    BustedNlua(BustedTestSpec),
49    Command(CommandTestSpec),
50    Script(LuaScriptTestSpec),
51}
52
53#[derive(Clone, Debug, PartialEq)]
54pub(crate) enum ValidatedTestSpec {
55    Busted(BustedTestSpec),
56    BustedNlua(BustedTestSpec),
57    Command(CommandTestSpec),
58    LuaScript(LuaScriptTestSpec),
59}
60
61impl TestSpec {
62    pub(crate) fn test_dependencies(&self, project: &Project) -> Vec<PackageReq> {
63        self.to_validated(project)
64            .ok()
65            .iter()
66            .flat_map(|spec| spec.test_dependencies())
67            .collect_vec()
68    }
69
70    pub(crate) fn to_validated(
71        &self,
72        project: &Project,
73    ) -> Result<ValidatedTestSpec, TestSpecError> {
74        let project_root = project.root();
75        let toml = project.toml().into_local()?;
76        let test_dependencies = toml.test_dependencies().current_platform();
77        let is_busted = project_root.join(".busted").is_file()
78            || test_dependencies
79                .iter()
80                .any(|dep| dep.name().to_string() == "busted");
81        match self {
82            Self::AutoDetect if is_busted => {
83                if test_dependencies
84                    .iter()
85                    .any(|dep| dep.name().to_string() == "nlua")
86                {
87                    Ok(ValidatedTestSpec::BustedNlua(BustedTestSpec::default()))
88                } else {
89                    Ok(ValidatedTestSpec::Busted(BustedTestSpec::default()))
90                }
91            }
92            Self::Busted(spec) => Ok(ValidatedTestSpec::Busted(spec.clone())),
93            Self::BustedNlua(spec) => Ok(ValidatedTestSpec::BustedNlua(spec.clone())),
94            Self::Command(spec) => Ok(ValidatedTestSpec::Command(spec.clone())),
95            Self::Script(spec) => Ok(ValidatedTestSpec::LuaScript(spec.clone())),
96            Self::AutoDetect => Err(TestSpecError::NoTestSpecDetected),
97        }
98    }
99}
100
101impl ValidatedTestSpec {
102    pub fn args(&self) -> Vec<String> {
103        match self {
104            Self::Busted(spec) => spec.flags.clone(),
105            Self::BustedNlua(spec) => {
106                let mut flags = spec.flags.clone();
107                // If there's a .busted config file which has lua set to "nlua",
108                // we tell busted to ignore it, because we set the correct
109                // platform-dependent "nlua" executable in the wrapper script.
110                flags.push("--ignore-lua".into());
111                flags
112            }
113            Self::Command(spec) => spec.flags.clone(),
114            Self::LuaScript(spec) => std::iter::once(spec.script.to_slash_lossy().to_string())
115                .chain(spec.flags.clone())
116                .collect_vec(),
117        }
118    }
119
120    pub(crate) fn test_config(&self, config: &Config) -> Result<Config, ConfigError> {
121        match self {
122            Self::BustedNlua(_) => {
123                let config_builder: ConfigBuilder = config.clone().into();
124
125                // XXX: On macos and msvc, Neovim and LuaJIT segfault when
126                // requiring luafilesystem's `lfs` module
127                // if it is built with Lua 5.1 headers.
128                #[cfg(not(any(target_os = "macos", target_env = "msvc")))]
129                let lua_version = LuaVersion::Lua51;
130
131                #[cfg(any(target_os = "macos", target_env = "msvc"))]
132                let lua_version = LuaVersion::LuaJIT;
133
134                Ok(config_builder
135                    .lua_version(Some(lua_version))
136                    .variables(Some(
137                        vec![("LUA".to_string(), NLUA_EXE.to_string())]
138                            .into_iter()
139                            .collect(),
140                    ))
141                    .build()?)
142            }
143            _ => Ok(config.clone()),
144        }
145    }
146
147    fn test_dependencies(&self) -> Vec<PackageReq> {
148        match self {
149            Self::Busted(_) => unsafe { vec![PackageReq::new_unchecked("busted".into(), None)] },
150            Self::BustedNlua(_) => unsafe {
151                vec![
152                    PackageReq::new_unchecked("busted".into(), None),
153                    PackageReq::new_unchecked("nlua".into(), None),
154                ]
155            },
156            Self::Command(_) => Vec::new(),
157            Self::LuaScript(_) => Vec::new(),
158        }
159    }
160}
161
162impl TryFrom<TestSpecInternal> for TestSpec {
163    type Error = TestSpecDecodeError;
164
165    fn try_from(internal: TestSpecInternal) -> Result<Self, Self::Error> {
166        let test_spec = match internal.test_type {
167            Some(TestType::Busted) => Ok(Self::Busted(BustedTestSpec {
168                flags: internal.flags.unwrap_or_default(),
169            })),
170            Some(TestType::BustedNlua) => Ok(Self::BustedNlua(BustedTestSpec {
171                flags: internal.flags.unwrap_or_default(),
172            })),
173            Some(TestType::Command) => match (internal.command, internal.lua_script) {
174                (None, None) => Err(TestSpecDecodeError::NoCommandOrScript),
175                (None, Some(script)) => Ok(Self::Script(LuaScriptTestSpec {
176                    script,
177                    flags: internal.flags.unwrap_or_default(),
178                })),
179                (Some(command), None) => Ok(Self::Command(CommandTestSpec {
180                    command,
181                    flags: internal.flags.unwrap_or_default(),
182                })),
183                (Some(_), Some(_)) => Err(TestSpecDecodeError::CommandAndScript),
184            },
185            None => Ok(Self::default()),
186        }?;
187        Ok(test_spec)
188    }
189}
190
191impl<'de> Deserialize<'de> for PerPlatform<TestSpec> {
192    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
193    where
194        D: Deserializer<'de>,
195    {
196        per_platform_from_intermediate::<_, TestSpecInternal, _>(deserializer)
197    }
198}
199
200impl<'de> Deserialize<'de> for TestSpec {
201    fn deserialize<D>(deserializer: D) -> Result<TestSpec, D::Error>
202    where
203        D: Deserializer<'de>,
204    {
205        let internal = TestSpecInternal::deserialize(deserializer)?;
206        let test_spec = TestSpec::try_from(internal).map_err(serde::de::Error::custom)?;
207        Ok(test_spec)
208    }
209}
210
211#[derive(Clone, Debug, PartialEq, Default)]
212pub struct BustedTestSpec {
213    pub(crate) flags: Vec<String>,
214}
215
216impl BustedTestSpec {
217    pub fn flags(&self) -> &Vec<String> {
218        &self.flags
219    }
220}
221
222#[derive(Clone, Debug, PartialEq)]
223pub struct CommandTestSpec {
224    pub(crate) command: String,
225    pub(crate) flags: Vec<String>,
226}
227
228impl CommandTestSpec {
229    pub fn flags(&self) -> &Vec<String> {
230        &self.flags
231    }
232
233    pub fn command(&self) -> &str {
234        &self.command
235    }
236}
237
238#[derive(Clone, Debug, PartialEq)]
239pub struct LuaScriptTestSpec {
240    pub(crate) script: PathBuf,
241    pub(crate) flags: Vec<String>,
242}
243
244impl LuaScriptTestSpec {
245    pub fn flags(&self) -> &Vec<String> {
246        &self.flags
247    }
248
249    pub fn script(&self) -> &PathBuf {
250        &self.script
251    }
252}
253
254#[derive(Debug, Deserialize, Serialize_enum_str, PartialEq, Clone)]
255#[serde(rename_all = "kebab-case")]
256pub(crate) enum TestType {
257    Busted,
258    BustedNlua,
259    Command,
260}
261
262#[derive(Debug, PartialEq, Deserialize, Default, Clone)]
263pub(crate) struct TestSpecInternal {
264    #[serde(default, rename = "type")]
265    pub(crate) test_type: Option<TestType>,
266    #[serde(default)]
267    pub(crate) flags: Option<Vec<String>>,
268    #[serde(default)]
269    pub(crate) command: Option<String>,
270    #[serde(default, rename = "script", alias = "lua_script")]
271    pub(crate) lua_script: Option<PathBuf>,
272}
273
274impl PartialOverride for TestSpecInternal {
275    type Err = Infallible;
276
277    fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
278        Ok(TestSpecInternal {
279            test_type: override_opt(&override_spec.test_type, &self.test_type),
280            flags: match (override_spec.flags.clone(), self.flags.clone()) {
281                (Some(override_vec), Some(base_vec)) => {
282                    let merged: Vec<String> =
283                        base_vec.into_iter().chain(override_vec).unique().collect();
284                    Some(merged)
285                }
286                (None, base_vec @ Some(_)) => base_vec,
287                (override_vec @ Some(_), None) => override_vec,
288                _ => None,
289            },
290            command: match override_spec.lua_script.clone() {
291                Some(_) => None,
292                None => override_opt(&override_spec.command, &self.command),
293            },
294            lua_script: match override_spec.command.clone() {
295                Some(_) => None,
296                None => override_opt(&override_spec.lua_script, &self.lua_script),
297            },
298        })
299    }
300}
301
302impl PlatformOverridable for TestSpecInternal {
303    type Err = Infallible;
304
305    fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
306    where
307        T: PlatformOverridable,
308        T: Default,
309    {
310        Ok(PerPlatform::default())
311    }
312}
313
314fn override_opt<T: Clone>(override_opt: &Option<T>, base: &Option<T>) -> Option<T> {
315    match override_opt.clone() {
316        override_val @ Some(_) => override_val,
317        None => base.clone(),
318    }
319}
320
321#[cfg(test)]
322mod tests {
323
324    use ottavino::{Closure, Executor, Fuel, Lua};
325    use ottavino_util::serde::from_value;
326
327    use crate::lua_rockspec::PlatformIdentifier;
328
329    use super::*;
330
331    fn exec_lua<T: serde::de::DeserializeOwned>(
332        code: &str,
333        key: &'static str,
334    ) -> Result<T, ottavino::ExternError> {
335        Lua::core().try_enter(|ctx| {
336            let closure = Closure::load(ctx, None, code.as_bytes())?;
337            let executor = Executor::start(ctx, closure.into(), ());
338            executor.step(ctx, &mut Fuel::with(i32::MAX))?;
339            from_value(ctx.globals().get_value(ctx, key)).map_err(ottavino::Error::from)
340        })
341    }
342
343    #[tokio::test]
344    pub async fn test_spec_from_lua() {
345        let lua_content = "
346        test = {\n
347        }\n
348        ";
349        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
350        assert!(matches!(test_spec.default, TestSpec::AutoDetect));
351        let lua_content = "
352        test = {\n
353            type = 'busted',\n
354        }\n
355        ";
356        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
357        assert_eq!(
358            test_spec.default,
359            TestSpec::Busted(BustedTestSpec::default())
360        );
361        let lua_content = "
362        test = {\n
363            type = 'busted',\n
364            flags = { 'foo', 'bar' },\n
365        }\n
366        ";
367        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
368        assert_eq!(
369            test_spec.default,
370            TestSpec::Busted(BustedTestSpec {
371                flags: vec!["foo".into(), "bar".into()],
372            })
373        );
374        let lua_content = "
375        test = {\n
376            type = 'command',\n
377        }\n
378        ";
379        let result: Result<PerPlatform<TestSpec>, _> = exec_lua(lua_content, "test");
380        let _err = result.unwrap_err();
381        let lua_content = "
382        test = {\n
383            type = 'command',\n
384            command = 'foo',\n
385            script = 'bar',\n
386        }\n
387        ";
388        let result: Result<PerPlatform<TestSpec>, _> = exec_lua(lua_content, "test");
389        let _err = result.unwrap_err();
390        let lua_content = "
391        test = {\n
392            type = 'command',\n
393            command = 'baz',\n
394            flags = { 'foo', 'bar' },\n
395        }\n
396        ";
397        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
398        assert_eq!(
399            test_spec.default,
400            TestSpec::Command(CommandTestSpec {
401                command: "baz".into(),
402                flags: vec!["foo".into(), "bar".into()],
403            })
404        );
405        let lua_content = "
406        test = {\n
407            type = 'command',\n
408            script = 'test.lua',\n
409            flags = { 'foo', 'bar' },\n
410        }\n
411        ";
412        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
413        assert_eq!(
414            test_spec.default,
415            TestSpec::Script(LuaScriptTestSpec {
416                script: PathBuf::from("test.lua"),
417                flags: vec!["foo".into(), "bar".into()],
418            })
419        );
420        let lua_content = "
421        test = {\n
422            type = 'command',\n
423            command = 'baz',\n
424            flags = { 'foo', 'bar' },\n
425            platforms = {\n
426                unix = { flags = { 'baz' }, },\n
427                macosx = {\n
428                    script = 'bat.lua',\n
429                    flags = { 'bat' },\n
430                },\n
431                linux = { type = 'busted' },\n
432            },\n
433        }\n
434        ";
435        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
436        assert_eq!(
437            test_spec.default,
438            TestSpec::Command(CommandTestSpec {
439                command: "baz".into(),
440                flags: vec!["foo".into(), "bar".into()],
441            })
442        );
443        let unix = test_spec
444            .per_platform
445            .get(&PlatformIdentifier::Unix)
446            .unwrap();
447        assert_eq!(
448            *unix,
449            TestSpec::Command(CommandTestSpec {
450                command: "baz".into(),
451                flags: vec!["foo".into(), "bar".into(), "baz".into()],
452            })
453        );
454        let macosx = test_spec
455            .per_platform
456            .get(&PlatformIdentifier::MacOSX)
457            .unwrap();
458        assert_eq!(
459            *macosx,
460            TestSpec::Script(LuaScriptTestSpec {
461                script: "bat.lua".into(),
462                flags: vec!["foo".into(), "bar".into(), "bat".into(), "baz".into()],
463            })
464        );
465        let linux = test_spec
466            .per_platform
467            .get(&PlatformIdentifier::Linux)
468            .unwrap();
469        assert_eq!(
470            *linux,
471            TestSpec::Busted(BustedTestSpec {
472                flags: vec!["foo".into(), "bar".into(), "baz".into()],
473            })
474        );
475        let lua_content = "
476        test = {\n
477            type = 'busted-nlua',\n
478        }";
479        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
480        assert_eq!(
481            test_spec.default,
482            TestSpec::BustedNlua(BustedTestSpec { flags: Vec::new() })
483        );
484    }
485}