Skip to main content

lux_lib/lua_rockspec/
test_spec.rs

1use itertools::Itertools;
2
3use path_slash::PathExt;
4use serde_enum_str::Serialize_enum_str;
5use std::{convert::Infallible, path::PathBuf};
6use thiserror::Error;
7
8use serde::{Deserialize, Deserializer};
9
10use crate::lua_version::LuaVersion;
11
12use crate::{
13    config::{Config, ConfigBuilder, ConfigError},
14    lua_rockspec::per_platform_from_intermediate,
15    package::PackageReq,
16    project::{project_toml::LocalProjectTomlValidationError, Project},
17    rockspec::Rockspec,
18};
19
20use super::{PartialOverride, PerPlatform, PlatformOverridable};
21
22#[cfg(target_family = "unix")]
23const NLUA_EXE: &str = "nlua";
24#[cfg(target_family = "windows")]
25const NLUA_EXE: &str = "nlua.bat";
26
27#[derive(Error, Debug)]
28pub enum TestSpecDecodeError {
29    #[error("the 'command' test type must specify either a 'command' or 'script' field")]
30    NoCommandOrScript,
31    #[error("the 'command' test type cannot have both 'command' and 'script' fields")]
32    CommandAndScript,
33}
34
35#[derive(Error, Debug)]
36pub enum TestSpecError {
37    #[error("could not auto-detect the test spec. Please add one to your lux.toml")]
38    NoTestSpecDetected,
39    #[error("project validation failed:\n{0}")]
40    LocalProjectTomlValidation(#[from] LocalProjectTomlValidationError),
41}
42
43#[derive(Clone, Debug, PartialEq, Default)]
44pub enum TestSpec {
45    #[default]
46    AutoDetect,
47    Busted(BustedTestSpec),
48    BustedNlua(BustedTestSpec),
49    Command(CommandTestSpec),
50    Script(LuaScriptTestSpec),
51}
52
53#[derive(Clone, Debug, PartialEq)]
54pub(crate) enum ValidatedTestSpec {
55    Busted {
56        spec: BustedTestSpec,
57        dependencies: Vec<PackageReq>,
58    },
59    BustedNlua {
60        spec: BustedTestSpec,
61        dependencies: Vec<PackageReq>,
62    },
63    Command(CommandTestSpec),
64    LuaScript(LuaScriptTestSpec),
65}
66
67impl TestSpec {
68    pub(crate) fn test_dependencies(&self, project: &Project) -> Vec<PackageReq> {
69        self.to_validated(project)
70            .ok()
71            .iter()
72            .flat_map(|spec| spec.test_dependencies())
73            .collect_vec()
74    }
75
76    pub(crate) fn to_validated(
77        &self,
78        project: &Project,
79    ) -> Result<ValidatedTestSpec, TestSpecError> {
80        let project_root = project.root();
81        let toml = project.toml().into_local()?;
82        let test_dependencies = toml.test_dependencies().current_platform();
83        let busted_dependency = test_dependencies
84            .iter()
85            .find(|dep| dep.name().to_string() == "busted");
86        let busted = busted_dependency
87            .map(|dep| dep.package_req.clone())
88            .unwrap_or_else(|| unsafe { PackageReq::new_unchecked("busted".into(), None) });
89        let nlua_dependency = test_dependencies
90            .iter()
91            .find(|dep| dep.name().to_string() == "nlua");
92        let nlua = nlua_dependency
93            .map(|dep| dep.package_req.clone())
94            .unwrap_or_else(|| unsafe { PackageReq::new_unchecked("nlua".into(), None) });
95        let is_busted = project_root.join(".busted").is_file() || busted_dependency.is_some();
96        match self {
97            Self::AutoDetect if is_busted => {
98                if nlua_dependency.is_some() {
99                    Ok(ValidatedTestSpec::BustedNlua {
100                        spec: BustedTestSpec::default(),
101                        dependencies: vec![busted, nlua],
102                    })
103                } else {
104                    Ok(ValidatedTestSpec::Busted {
105                        spec: BustedTestSpec::default(),
106                        dependencies: vec![busted],
107                    })
108                }
109            }
110            Self::Busted(spec) => Ok(ValidatedTestSpec::Busted {
111                spec: spec.clone(),
112                dependencies: vec![busted],
113            }),
114            Self::BustedNlua(spec) => Ok(ValidatedTestSpec::BustedNlua {
115                spec: spec.clone(),
116                dependencies: vec![busted, nlua],
117            }),
118            Self::Command(spec) => Ok(ValidatedTestSpec::Command(spec.clone())),
119            Self::Script(spec) => Ok(ValidatedTestSpec::LuaScript(spec.clone())),
120            Self::AutoDetect => Err(TestSpecError::NoTestSpecDetected),
121        }
122    }
123}
124
125impl ValidatedTestSpec {
126    pub fn args(&self) -> Vec<String> {
127        match self {
128            Self::Busted {
129                spec,
130                dependencies: _,
131            } => spec.flags.clone(),
132            Self::BustedNlua {
133                spec,
134                dependencies: _,
135            } => {
136                let mut flags = spec.flags.clone();
137                // If there's a .busted config file which has lua set to "nlua",
138                // we tell busted to ignore it, because we set the correct
139                // platform-dependent "nlua" executable in the wrapper script.
140                flags.push("--ignore-lua".into());
141                flags
142            }
143            Self::Command(spec) => spec.flags.clone(),
144            Self::LuaScript(spec) => std::iter::once(spec.script.to_slash_lossy().to_string())
145                .chain(spec.flags.clone())
146                .collect_vec(),
147        }
148    }
149
150    pub(crate) fn test_config(&self, config: &Config) -> Result<Config, ConfigError> {
151        match self {
152            Self::BustedNlua { .. } => {
153                let config_builder: ConfigBuilder = config.clone().into();
154
155                // XXX: On macos and msvc, Neovim and LuaJIT segfault when
156                // requiring luafilesystem's `lfs` module
157                // if it is built with Lua 5.1 headers.
158                #[cfg(not(any(target_os = "macos", target_env = "msvc")))]
159                let lua_version = LuaVersion::Lua51;
160
161                #[cfg(any(target_os = "macos", target_env = "msvc"))]
162                let lua_version = LuaVersion::LuaJIT;
163
164                Ok(config_builder
165                    .lua_version(Some(lua_version))
166                    .variables(Some(
167                        vec![("LUA".to_string(), NLUA_EXE.to_string())]
168                            .into_iter()
169                            .collect(),
170                    ))
171                    .build()?)
172            }
173            _ => Ok(config.clone()),
174        }
175    }
176
177    fn test_dependencies(&self) -> Vec<PackageReq> {
178        match self {
179            Self::Busted {
180                spec: _,
181                dependencies,
182            } => dependencies.clone(),
183            Self::BustedNlua {
184                spec: _,
185                dependencies,
186            } => dependencies.clone(),
187            Self::Command(_) => Vec::new(),
188            Self::LuaScript(_) => Vec::new(),
189        }
190    }
191}
192
193impl TryFrom<TestSpecInternal> for TestSpec {
194    type Error = TestSpecDecodeError;
195
196    fn try_from(internal: TestSpecInternal) -> Result<Self, Self::Error> {
197        let test_spec = match internal.test_type {
198            Some(TestType::Busted) => Ok(Self::Busted(BustedTestSpec {
199                flags: internal.flags.unwrap_or_default(),
200            })),
201            Some(TestType::BustedNlua) => Ok(Self::BustedNlua(BustedTestSpec {
202                flags: internal.flags.unwrap_or_default(),
203            })),
204            Some(TestType::Command) => match (internal.command, internal.lua_script) {
205                (None, None) => Err(TestSpecDecodeError::NoCommandOrScript),
206                (None, Some(script)) => Ok(Self::Script(LuaScriptTestSpec {
207                    script,
208                    flags: internal.flags.unwrap_or_default(),
209                })),
210                (Some(command), None) => Ok(Self::Command(CommandTestSpec {
211                    command,
212                    flags: internal.flags.unwrap_or_default(),
213                })),
214                (Some(_), Some(_)) => Err(TestSpecDecodeError::CommandAndScript),
215            },
216            None => Ok(Self::default()),
217        }?;
218        Ok(test_spec)
219    }
220}
221
222impl<'de> Deserialize<'de> for PerPlatform<TestSpec> {
223    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
224    where
225        D: Deserializer<'de>,
226    {
227        per_platform_from_intermediate::<_, TestSpecInternal, _>(deserializer)
228    }
229}
230
231impl<'de> Deserialize<'de> for TestSpec {
232    fn deserialize<D>(deserializer: D) -> Result<TestSpec, D::Error>
233    where
234        D: Deserializer<'de>,
235    {
236        let internal = TestSpecInternal::deserialize(deserializer)?;
237        let test_spec = TestSpec::try_from(internal).map_err(serde::de::Error::custom)?;
238        Ok(test_spec)
239    }
240}
241
242#[derive(Clone, Debug, PartialEq, Default)]
243pub struct BustedTestSpec {
244    pub(crate) flags: Vec<String>,
245}
246
247impl BustedTestSpec {
248    pub fn flags(&self) -> &Vec<String> {
249        &self.flags
250    }
251}
252
253#[derive(Clone, Debug, PartialEq)]
254pub struct CommandTestSpec {
255    pub(crate) command: String,
256    pub(crate) flags: Vec<String>,
257}
258
259impl CommandTestSpec {
260    pub fn flags(&self) -> &Vec<String> {
261        &self.flags
262    }
263
264    pub fn command(&self) -> &str {
265        &self.command
266    }
267}
268
269#[derive(Clone, Debug, PartialEq)]
270pub struct LuaScriptTestSpec {
271    pub(crate) script: PathBuf,
272    pub(crate) flags: Vec<String>,
273}
274
275impl LuaScriptTestSpec {
276    pub fn flags(&self) -> &Vec<String> {
277        &self.flags
278    }
279
280    pub fn script(&self) -> &PathBuf {
281        &self.script
282    }
283}
284
285#[derive(Debug, Deserialize, Serialize_enum_str, PartialEq, Clone)]
286#[serde(rename_all = "kebab-case")]
287pub(crate) enum TestType {
288    Busted,
289    BustedNlua,
290    Command,
291}
292
293#[derive(Debug, PartialEq, Deserialize, Default, Clone)]
294pub(crate) struct TestSpecInternal {
295    #[serde(default, rename = "type")]
296    pub(crate) test_type: Option<TestType>,
297    #[serde(default)]
298    pub(crate) flags: Option<Vec<String>>,
299    #[serde(default)]
300    pub(crate) command: Option<String>,
301    #[serde(default, rename = "script", alias = "lua_script")]
302    pub(crate) lua_script: Option<PathBuf>,
303}
304
305impl PartialOverride for TestSpecInternal {
306    type Err = Infallible;
307
308    fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
309        Ok(TestSpecInternal {
310            test_type: override_opt(&override_spec.test_type, &self.test_type),
311            flags: match (override_spec.flags.clone(), self.flags.clone()) {
312                (Some(override_vec), Some(base_vec)) => {
313                    let merged: Vec<String> =
314                        base_vec.into_iter().chain(override_vec).unique().collect();
315                    Some(merged)
316                }
317                (None, base_vec @ Some(_)) => base_vec,
318                (override_vec @ Some(_), None) => override_vec,
319                _ => None,
320            },
321            command: match override_spec.lua_script.clone() {
322                Some(_) => None,
323                None => override_opt(&override_spec.command, &self.command),
324            },
325            lua_script: match override_spec.command.clone() {
326                Some(_) => None,
327                None => override_opt(&override_spec.lua_script, &self.lua_script),
328            },
329        })
330    }
331}
332
333impl PlatformOverridable for TestSpecInternal {
334    type Err = Infallible;
335
336    fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
337    where
338        T: PlatformOverridable,
339        T: Default,
340    {
341        Ok(PerPlatform::default())
342    }
343}
344
345fn override_opt<T: Clone>(override_opt: &Option<T>, base: &Option<T>) -> Option<T> {
346    match override_opt.clone() {
347        override_val @ Some(_) => override_val,
348        None => base.clone(),
349    }
350}
351
352#[cfg(test)]
353mod tests {
354
355    use ottavino::{Closure, Executor, Fuel, Lua};
356    use ottavino_util::serde::from_value;
357
358    use crate::lua_rockspec::PlatformIdentifier;
359
360    use super::*;
361
362    fn exec_lua<T: serde::de::DeserializeOwned>(
363        code: &str,
364        key: &'static str,
365    ) -> Result<T, ottavino::ExternError> {
366        Lua::core().try_enter(|ctx| {
367            let closure = Closure::load(ctx, None, code.as_bytes())?;
368            let executor = Executor::start(ctx, closure.into(), ());
369            executor.step(ctx, &mut Fuel::with(i32::MAX))?;
370            from_value(ctx.globals().get_value(ctx, key)).map_err(ottavino::Error::from)
371        })
372    }
373
374    #[tokio::test]
375    pub async fn test_spec_from_lua() {
376        let lua_content = "
377        test = {\n
378        }\n
379        ";
380        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
381        assert!(matches!(test_spec.default, TestSpec::AutoDetect));
382        let lua_content = "
383        test = {\n
384            type = 'busted',\n
385        }\n
386        ";
387        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
388        assert_eq!(
389            test_spec.default,
390            TestSpec::Busted(BustedTestSpec::default())
391        );
392        let lua_content = "
393        test = {\n
394            type = 'busted',\n
395            flags = { 'foo', 'bar' },\n
396        }\n
397        ";
398        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
399        assert_eq!(
400            test_spec.default,
401            TestSpec::Busted(BustedTestSpec {
402                flags: vec!["foo".into(), "bar".into()],
403            })
404        );
405        let lua_content = "
406        test = {\n
407            type = 'command',\n
408        }\n
409        ";
410        let result: Result<PerPlatform<TestSpec>, _> = exec_lua(lua_content, "test");
411        let _err = result.unwrap_err();
412        let lua_content = "
413        test = {\n
414            type = 'command',\n
415            command = 'foo',\n
416            script = 'bar',\n
417        }\n
418        ";
419        let result: Result<PerPlatform<TestSpec>, _> = exec_lua(lua_content, "test");
420        let _err = result.unwrap_err();
421        let lua_content = "
422        test = {\n
423            type = 'command',\n
424            command = 'baz',\n
425            flags = { 'foo', 'bar' },\n
426        }\n
427        ";
428        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
429        assert_eq!(
430            test_spec.default,
431            TestSpec::Command(CommandTestSpec {
432                command: "baz".into(),
433                flags: vec!["foo".into(), "bar".into()],
434            })
435        );
436        let lua_content = "
437        test = {\n
438            type = 'command',\n
439            script = 'test.lua',\n
440            flags = { 'foo', 'bar' },\n
441        }\n
442        ";
443        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
444        assert_eq!(
445            test_spec.default,
446            TestSpec::Script(LuaScriptTestSpec {
447                script: PathBuf::from("test.lua"),
448                flags: vec!["foo".into(), "bar".into()],
449            })
450        );
451        let lua_content = "
452        test = {\n
453            type = 'command',\n
454            command = 'baz',\n
455            flags = { 'foo', 'bar' },\n
456            platforms = {\n
457                unix = { flags = { 'baz' }, },\n
458                macosx = {\n
459                    script = 'bat.lua',\n
460                    flags = { 'bat' },\n
461                },\n
462                linux = { type = 'busted' },\n
463            },\n
464        }\n
465        ";
466        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
467        assert_eq!(
468            test_spec.default,
469            TestSpec::Command(CommandTestSpec {
470                command: "baz".into(),
471                flags: vec!["foo".into(), "bar".into()],
472            })
473        );
474        let unix = test_spec
475            .per_platform
476            .get(&PlatformIdentifier::Unix)
477            .unwrap();
478        assert_eq!(
479            *unix,
480            TestSpec::Command(CommandTestSpec {
481                command: "baz".into(),
482                flags: vec!["foo".into(), "bar".into(), "baz".into()],
483            })
484        );
485        let macosx = test_spec
486            .per_platform
487            .get(&PlatformIdentifier::MacOSX)
488            .unwrap();
489        assert_eq!(
490            *macosx,
491            TestSpec::Script(LuaScriptTestSpec {
492                script: "bat.lua".into(),
493                flags: vec!["foo".into(), "bar".into(), "bat".into(), "baz".into()],
494            })
495        );
496        let linux = test_spec
497            .per_platform
498            .get(&PlatformIdentifier::Linux)
499            .unwrap();
500        assert_eq!(
501            *linux,
502            TestSpec::Busted(BustedTestSpec {
503                flags: vec!["foo".into(), "bar".into(), "baz".into()],
504            })
505        );
506        let lua_content = "
507        test = {\n
508            type = 'busted-nlua',\n
509        }";
510        let test_spec: PerPlatform<TestSpec> = exec_lua(lua_content, "test").unwrap();
511        assert_eq!(
512            test_spec.default,
513            TestSpec::BustedNlua(BustedTestSpec { flags: Vec::new() })
514        );
515    }
516}