Skip to main content

lux_lib/lua_rockspec/
test_spec.rs

1use itertools::Itertools;
2use mlua::{FromLua, IntoLua, UserData};
3use path_slash::PathExt;
4use serde_enum_str::Serialize_enum_str;
5use std::{convert::Infallible, path::PathBuf};
6use thiserror::Error;
7
8use serde::{Deserialize, Deserializer};
9
10use crate::{
11    config::{Config, ConfigBuilder, ConfigError, LuaVersion},
12    package::PackageReq,
13    project::{project_toml::LocalProjectTomlValidationError, Project},
14    rockspec::Rockspec,
15};
16
17use super::{
18    FromPlatformOverridable, PartialOverride, PerPlatform, PerPlatformWrapper, PlatformOverridable,
19};
20
21#[cfg(target_family = "unix")]
22const NLUA_EXE: &str = "nlua";
23#[cfg(target_family = "windows")]
24const NLUA_EXE: &str = "nlua.bat";
25
26#[derive(Error, Debug)]
27pub enum TestSpecDecodeError {
28    #[error("the 'command' test type must specify either a 'command' or 'script' field")]
29    NoCommandOrScript,
30    #[error("the 'command' test type cannot have both 'command' and 'script' fields")]
31    CommandAndScript,
32}
33
34#[derive(Error, Debug)]
35pub enum TestSpecError {
36    #[error("could not auto-detect the test spec. Please add one to your lux.toml")]
37    NoTestSpecDetected,
38    #[error("project validation failed:\n{0}")]
39    LocalProjectTomlValidation(#[from] LocalProjectTomlValidationError),
40}
41
42#[derive(Clone, Debug, PartialEq, Default)]
43pub enum TestSpec {
44    #[default]
45    AutoDetect,
46    Busted(BustedTestSpec),
47    BustedNlua(BustedTestSpec),
48    Command(CommandTestSpec),
49    Script(LuaScriptTestSpec),
50}
51
52#[derive(Clone, Debug, PartialEq)]
53pub(crate) enum ValidatedTestSpec {
54    Busted(BustedTestSpec),
55    BustedNlua(BustedTestSpec),
56    Command(CommandTestSpec),
57    LuaScript(LuaScriptTestSpec),
58}
59
60impl TestSpec {
61    pub(crate) fn test_dependencies(&self, project: &Project) -> Vec<PackageReq> {
62        self.to_validated(project)
63            .ok()
64            .iter()
65            .flat_map(|spec| spec.test_dependencies())
66            .collect_vec()
67    }
68
69    pub(crate) fn to_validated(
70        &self,
71        project: &Project,
72    ) -> Result<ValidatedTestSpec, TestSpecError> {
73        let project_root = project.root();
74        let toml = project.toml().into_local()?;
75        let test_dependencies = toml.test_dependencies().current_platform();
76        let is_busted = project_root.join(".busted").is_file()
77            || test_dependencies
78                .iter()
79                .any(|dep| dep.name().to_string() == "busted");
80        match self {
81            Self::AutoDetect if is_busted => {
82                if test_dependencies
83                    .iter()
84                    .any(|dep| dep.name().to_string() == "nlua")
85                {
86                    Ok(ValidatedTestSpec::BustedNlua(BustedTestSpec::default()))
87                } else {
88                    Ok(ValidatedTestSpec::Busted(BustedTestSpec::default()))
89                }
90            }
91            Self::Busted(spec) => Ok(ValidatedTestSpec::Busted(spec.clone())),
92            Self::BustedNlua(spec) => Ok(ValidatedTestSpec::BustedNlua(spec.clone())),
93            Self::Command(spec) => Ok(ValidatedTestSpec::Command(spec.clone())),
94            Self::Script(spec) => Ok(ValidatedTestSpec::LuaScript(spec.clone())),
95            Self::AutoDetect => Err(TestSpecError::NoTestSpecDetected),
96        }
97    }
98}
99
100impl ValidatedTestSpec {
101    pub fn args(&self) -> Vec<String> {
102        match self {
103            Self::Busted(spec) => spec.flags.clone(),
104            Self::BustedNlua(spec) => {
105                let mut flags = spec.flags.clone();
106                // If there's a .busted config file which has lua set to "nlua",
107                // we tell busted to ignore it, because we set the correct
108                // platform-dependent "nlua" executable in the wrapper script.
109                flags.push("--ignore-lua".into());
110                flags
111            }
112            Self::Command(spec) => spec.flags.clone(),
113            Self::LuaScript(spec) => std::iter::once(spec.script.to_slash_lossy().to_string())
114                .chain(spec.flags.clone())
115                .collect_vec(),
116        }
117    }
118
119    pub(crate) fn test_config(&self, config: &Config) -> Result<Config, ConfigError> {
120        match self {
121            Self::BustedNlua(_) => {
122                let config_builder: ConfigBuilder = config.clone().into();
123
124                // XXX: On macos and msvc, Neovim and LuaJIT segfault when
125                // requiring luafilesystem's `lfs` module
126                // if it is built with Lua 5.1 headers.
127                #[cfg(not(any(target_os = "macos", target_env = "msvc")))]
128                let lua_version = LuaVersion::Lua51;
129
130                #[cfg(any(target_os = "macos", target_env = "msvc"))]
131                let lua_version = LuaVersion::LuaJIT;
132
133                Ok(config_builder
134                    .lua_version(Some(lua_version))
135                    .variables(Some(
136                        vec![("LUA".to_string(), NLUA_EXE.to_string())]
137                            .into_iter()
138                            .collect(),
139                    ))
140                    .build()?)
141            }
142            _ => Ok(config.clone()),
143        }
144    }
145
146    fn test_dependencies(&self) -> Vec<PackageReq> {
147        match self {
148            Self::Busted(_) => unsafe { vec![PackageReq::new_unchecked("busted".into(), None)] },
149            Self::BustedNlua(_) => unsafe {
150                vec![
151                    PackageReq::new_unchecked("busted".into(), None),
152                    PackageReq::new_unchecked("nlua".into(), None),
153                ]
154            },
155            Self::Command(_) => Vec::new(),
156            Self::LuaScript(_) => Vec::new(),
157        }
158    }
159}
160
161impl IntoLua for TestSpec {
162    fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
163        let table = lua.create_table()?;
164        match self {
165            TestSpec::AutoDetect => table.set("auto_detect", true)?,
166            TestSpec::Busted(busted_test_spec) => table.set("busted", busted_test_spec)?,
167            TestSpec::BustedNlua(busted_test_spec) => table.set("busted-nlua", busted_test_spec)?,
168            TestSpec::Command(command_test_spec) => table.set("command", command_test_spec)?,
169            TestSpec::Script(script_test_spec) => table.set("script", script_test_spec)?,
170        }
171        Ok(mlua::Value::Table(table))
172    }
173}
174
175impl FromPlatformOverridable<TestSpecInternal, Self> for TestSpec {
176    type Err = TestSpecDecodeError;
177
178    fn from_platform_overridable(internal: TestSpecInternal) -> Result<Self, Self::Err> {
179        let test_spec = match internal.test_type {
180            Some(TestType::Busted) => Ok(Self::Busted(BustedTestSpec {
181                flags: internal.flags.unwrap_or_default(),
182            })),
183            Some(TestType::BustedNlua) => Ok(Self::BustedNlua(BustedTestSpec {
184                flags: internal.flags.unwrap_or_default(),
185            })),
186            Some(TestType::Command) => match (internal.command, internal.lua_script) {
187                (None, None) => Err(TestSpecDecodeError::NoCommandOrScript),
188                (None, Some(script)) => Ok(Self::Script(LuaScriptTestSpec {
189                    script,
190                    flags: internal.flags.unwrap_or_default(),
191                })),
192                (Some(command), None) => Ok(Self::Command(CommandTestSpec {
193                    command,
194                    flags: internal.flags.unwrap_or_default(),
195                })),
196                (Some(_), Some(_)) => Err(TestSpecDecodeError::CommandAndScript),
197            },
198            None => Ok(Self::default()),
199        }?;
200        Ok(test_spec)
201    }
202}
203
204impl FromLua for PerPlatform<TestSpec> {
205    fn from_lua(
206        value: mlua::prelude::LuaValue,
207        lua: &mlua::prelude::Lua,
208    ) -> mlua::prelude::LuaResult<Self> {
209        let wrapper = PerPlatformWrapper::from_lua(value, lua)?;
210        Ok(wrapper.un_per_platform)
211    }
212}
213
214impl<'de> Deserialize<'de> for TestSpec {
215    fn deserialize<D>(deserializer: D) -> Result<TestSpec, D::Error>
216    where
217        D: Deserializer<'de>,
218    {
219        let internal = TestSpecInternal::deserialize(deserializer)?;
220        let test_spec =
221            TestSpec::from_platform_overridable(internal).map_err(serde::de::Error::custom)?;
222        Ok(test_spec)
223    }
224}
225
226#[derive(Clone, Debug, PartialEq, Default)]
227pub struct BustedTestSpec {
228    pub(crate) flags: Vec<String>,
229}
230
231impl UserData for BustedTestSpec {
232    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
233        methods.add_method("flags", |_, this, _: ()| Ok(this.flags.clone()));
234    }
235}
236
237#[derive(Clone, Debug, PartialEq)]
238pub struct CommandTestSpec {
239    pub(crate) command: String,
240    pub(crate) flags: Vec<String>,
241}
242
243impl UserData for CommandTestSpec {
244    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
245        methods.add_method("command", |_, this, _: ()| Ok(this.command.clone()));
246        methods.add_method("flags", |_, this, _: ()| Ok(this.flags.clone()));
247    }
248}
249
250#[derive(Clone, Debug, PartialEq)]
251pub struct LuaScriptTestSpec {
252    pub(crate) script: PathBuf,
253    pub(crate) flags: Vec<String>,
254}
255
256impl UserData for LuaScriptTestSpec {
257    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
258        methods.add_method("script", |_, this, _: ()| Ok(this.script.clone()));
259        methods.add_method("flags", |_, this, _: ()| Ok(this.flags.clone()));
260    }
261}
262
263#[derive(Debug, Deserialize, Serialize_enum_str, PartialEq, Clone)]
264#[serde(rename_all = "kebab-case")]
265pub(crate) enum TestType {
266    Busted,
267    BustedNlua,
268    Command,
269}
270
271#[derive(Debug, PartialEq, Deserialize, Default, Clone)]
272pub(crate) struct TestSpecInternal {
273    #[serde(default, rename = "type")]
274    pub(crate) test_type: Option<TestType>,
275    #[serde(default)]
276    pub(crate) flags: Option<Vec<String>>,
277    #[serde(default)]
278    pub(crate) command: Option<String>,
279    #[serde(default, rename = "script", alias = "lua_script")]
280    pub(crate) lua_script: Option<PathBuf>,
281}
282
283impl PartialOverride for TestSpecInternal {
284    type Err = Infallible;
285
286    fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
287        Ok(TestSpecInternal {
288            test_type: override_opt(&override_spec.test_type, &self.test_type),
289            flags: match (override_spec.flags.clone(), self.flags.clone()) {
290                (Some(override_vec), Some(base_vec)) => {
291                    let merged: Vec<String> =
292                        base_vec.into_iter().chain(override_vec).unique().collect();
293                    Some(merged)
294                }
295                (None, base_vec @ Some(_)) => base_vec,
296                (override_vec @ Some(_), None) => override_vec,
297                _ => None,
298            },
299            command: match override_spec.lua_script.clone() {
300                Some(_) => None,
301                None => override_opt(&override_spec.command, &self.command),
302            },
303            lua_script: match override_spec.command.clone() {
304                Some(_) => None,
305                None => override_opt(&override_spec.lua_script, &self.lua_script),
306            },
307        })
308    }
309}
310
311impl PlatformOverridable for TestSpecInternal {
312    type Err = Infallible;
313
314    fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
315    where
316        T: PlatformOverridable,
317        T: Default,
318    {
319        Ok(PerPlatform::default())
320    }
321}
322
323fn override_opt<T: Clone>(override_opt: &Option<T>, base: &Option<T>) -> Option<T> {
324    match override_opt.clone() {
325        override_val @ Some(_) => override_val,
326        None => base.clone(),
327    }
328}
329
330#[cfg(test)]
331mod tests {
332
333    use mlua::{Error, FromLua, Lua};
334
335    use crate::lua_rockspec::PlatformIdentifier;
336
337    use super::*;
338
339    #[tokio::test]
340    pub async fn test_spec_from_lua() {
341        let lua_content = "
342        test = {\n
343        }\n
344        ";
345        let lua = Lua::new();
346        lua.load(lua_content).exec().unwrap();
347        let test_spec = PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
348        assert!(matches!(test_spec.default, TestSpec::AutoDetect));
349        let lua_content = "
350        test = {\n
351            type = 'busted',\n
352        }\n
353        ";
354        let lua = Lua::new();
355        lua.load(lua_content).exec().unwrap();
356        let test_spec: PerPlatform<TestSpec> =
357            PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
358        assert_eq!(
359            test_spec.default,
360            TestSpec::Busted(BustedTestSpec::default())
361        );
362        let lua_content = "
363        test = {\n
364            type = 'busted',\n
365            flags = { 'foo', 'bar' },\n
366        }\n
367        ";
368        let lua = Lua::new();
369        lua.load(lua_content).exec().unwrap();
370        let test_spec: PerPlatform<TestSpec> =
371            PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
372        assert_eq!(
373            test_spec.default,
374            TestSpec::Busted(BustedTestSpec {
375                flags: vec!["foo".into(), "bar".into()],
376            })
377        );
378        let lua_content = "
379        test = {\n
380            type = 'command',\n
381        }\n
382        ";
383        let lua = Lua::new();
384        lua.load(lua_content).exec().unwrap();
385        let result: Result<PerPlatform<TestSpec>, Error> =
386            PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua);
387        let _err = result.unwrap_err();
388        let lua_content = "
389        test = {\n
390            type = 'command',\n
391            command = 'foo',\n
392            script = 'bar',\n
393        }\n
394        ";
395        let lua = Lua::new();
396        lua.load(lua_content).exec().unwrap();
397        let result: Result<PerPlatform<TestSpec>, Error> =
398            PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua);
399        let _err = result.unwrap_err();
400        let lua_content = "
401        test = {\n
402            type = 'command',\n
403            command = 'baz',\n
404            flags = { 'foo', 'bar' },\n
405        }\n
406        ";
407        let lua = Lua::new();
408        lua.load(lua_content).exec().unwrap();
409        let test_spec: PerPlatform<TestSpec> =
410            PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
411        assert_eq!(
412            test_spec.default,
413            TestSpec::Command(CommandTestSpec {
414                command: "baz".into(),
415                flags: vec!["foo".into(), "bar".into()],
416            })
417        );
418        let lua_content = "
419        test = {\n
420            type = 'command',\n
421            script = 'test.lua',\n
422            flags = { 'foo', 'bar' },\n
423        }\n
424        ";
425        let lua = Lua::new();
426        lua.load(lua_content).exec().unwrap();
427        let test_spec: PerPlatform<TestSpec> =
428            PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
429        assert_eq!(
430            test_spec.default,
431            TestSpec::Script(LuaScriptTestSpec {
432                script: PathBuf::from("test.lua"),
433                flags: vec!["foo".into(), "bar".into()],
434            })
435        );
436        let lua_content = "
437        test = {\n
438            type = 'command',\n
439            command = 'baz',\n
440            flags = { 'foo', 'bar' },\n
441            platforms = {\n
442                unix = { flags = { 'baz' }, },\n
443                macosx = {\n
444                    script = 'bat.lua',\n
445                    flags = { 'bat' },\n
446                },\n
447                linux = { type = 'busted' },\n
448            },\n
449        }\n
450        ";
451        let lua = Lua::new();
452        lua.load(lua_content).exec().unwrap();
453        let test_spec: PerPlatform<TestSpec> =
454            PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
455        assert_eq!(
456            test_spec.default,
457            TestSpec::Command(CommandTestSpec {
458                command: "baz".into(),
459                flags: vec!["foo".into(), "bar".into()],
460            })
461        );
462        let unix = test_spec
463            .per_platform
464            .get(&PlatformIdentifier::Unix)
465            .unwrap();
466        assert_eq!(
467            *unix,
468            TestSpec::Command(CommandTestSpec {
469                command: "baz".into(),
470                flags: vec!["foo".into(), "bar".into(), "baz".into()],
471            })
472        );
473        let macosx = test_spec
474            .per_platform
475            .get(&PlatformIdentifier::MacOSX)
476            .unwrap();
477        assert_eq!(
478            *macosx,
479            TestSpec::Script(LuaScriptTestSpec {
480                script: "bat.lua".into(),
481                flags: vec!["foo".into(), "bar".into(), "bat".into(), "baz".into()],
482            })
483        );
484        let linux = test_spec
485            .per_platform
486            .get(&PlatformIdentifier::Linux)
487            .unwrap();
488        assert_eq!(
489            *linux,
490            TestSpec::Busted(BustedTestSpec {
491                flags: vec!["foo".into(), "bar".into(), "baz".into()],
492            })
493        );
494        let lua_content = "
495        test = {\n
496            type = 'busted-nlua',\n
497        }";
498        let lua = Lua::new();
499        lua.load(lua_content).exec().unwrap();
500        let test_spec: PerPlatform<TestSpec> =
501            PerPlatform::from_lua(lua.globals().get("test").unwrap(), &lua).unwrap();
502        assert_eq!(
503            test_spec.default,
504            TestSpec::BustedNlua(BustedTestSpec { flags: Vec::new() })
505        );
506    }
507}