Skip to main content

lux_lib/which/
mod.rs

1use std::{io, path::PathBuf};
2
3use bon::Builder;
4use itertools::Itertools;
5use thiserror::Error;
6
7use crate::{
8    config::Config,
9    lua_rockspec::LuaModule,
10    lua_version::{LuaVersion, LuaVersionUnset},
11    package::PackageReq,
12    tree::TreeError,
13};
14
15/// A rocks module finder.
16#[derive(Builder)]
17#[builder(start_fn = new, finish_fn(name = _build, vis = ""))]
18pub struct Which<'a> {
19    #[builder(start_fn)]
20    module: LuaModule,
21    #[builder(start_fn)]
22    config: &'a Config,
23    #[builder(field)]
24    packages: Vec<PackageReq>,
25}
26
27impl<State> WhichBuilder<'_, State>
28where
29    State: which_builder::State,
30{
31    pub fn package(mut self, package: PackageReq) -> Self {
32        self.packages.push(package);
33        self
34    }
35
36    pub fn packages(mut self, packages: impl IntoIterator<Item = PackageReq>) -> Self {
37        self.packages.extend(packages);
38        self
39    }
40
41    pub fn search(self) -> Result<PathBuf, WhichError>
42    where
43        State: which_builder::IsComplete,
44    {
45        do_search(self._build())
46    }
47}
48
49#[derive(Error, Debug)]
50pub enum WhichError {
51    #[error(transparent)]
52    Io(#[from] io::Error),
53    #[error(transparent)]
54    Tree(#[from] TreeError),
55    #[error(transparent)]
56    LuaVersionUnset(#[from] LuaVersionUnset),
57    #[error("lua module {0} not found.")]
58    ModuleNotFound(LuaModule),
59}
60
61fn do_search(which: Which<'_>) -> Result<PathBuf, WhichError> {
62    let config = which.config;
63    let lua_version = LuaVersion::from(config)?;
64    let tree = config.user_tree(lua_version.clone())?;
65    let lockfile = tree.lockfile()?;
66    let local_packages = if which.packages.is_empty() {
67        lockfile.list().into_values().flatten().collect_vec()
68    } else {
69        which
70            .packages
71            .iter()
72            .flat_map(|req| {
73                lockfile
74                    .find_rocks(req)
75                    .into_iter()
76                    .filter_map(|id| lockfile.get(&id))
77                    .cloned()
78                    .collect_vec()
79            })
80            .collect_vec()
81    };
82    local_packages
83        .into_iter()
84        .filter_map(|pkg| {
85            let rock_layout = tree.installed_rock_layout(&pkg).ok()?;
86            let lib_path = rock_layout.lib.join(which.module.to_lib_path());
87            if lib_path.is_file() {
88                return Some(lib_path);
89            }
90            let lua_path = rock_layout.src.join(which.module.to_lua_path());
91            if lua_path.is_file() {
92                return Some(lua_path);
93            }
94            let lua_path = rock_layout.src.join(which.module.to_lua_init_path());
95            if lua_path.is_file() {
96                return Some(lua_path);
97            }
98            None
99        })
100        .next()
101        .ok_or(WhichError::ModuleNotFound(which.module))
102}
103
104#[cfg(test)]
105mod test {
106    use super::*;
107    use crate::config::ConfigBuilder;
108    use assert_fs::prelude::PathCopy;
109    use std::{path::PathBuf, str::FromStr};
110
111    #[tokio::test]
112    async fn test_which() {
113        let tree_path =
114            PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("resources/test/sample-tree");
115        let temp = assert_fs::TempDir::new().unwrap();
116        temp.copy_from(&tree_path, &["**"]).unwrap();
117        let tree_path = temp.to_path_buf();
118        let config = ConfigBuilder::new()
119            .unwrap()
120            .user_tree(Some(tree_path.clone()))
121            .lua_version(Some(LuaVersion::Lua51))
122            .build()
123            .unwrap();
124
125        let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
126            .search()
127            .unwrap();
128        assert_eq!(result.file_name().unwrap().to_string_lossy(), "bar.lua");
129        assert_eq!(
130            result
131                .parent()
132                .unwrap()
133                .file_name()
134                .unwrap()
135                .to_string_lossy(),
136            "foo"
137        );
138        let result = Which::new(LuaModule::from_str("bat.baz").unwrap(), &config)
139            .search()
140            .unwrap();
141        assert_eq!(result.file_name().unwrap().to_string_lossy(), "baz.so");
142        assert_eq!(
143            result
144                .parent()
145                .unwrap()
146                .file_name()
147                .unwrap()
148                .to_string_lossy(),
149            "bat"
150        );
151        let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
152            .package("lua-cjson".parse().unwrap())
153            .search();
154        assert!(matches!(result, Err(WhichError::ModuleNotFound(_))));
155        let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
156            .package("neorg@8.1.1-1".parse().unwrap())
157            .search();
158        assert!(matches!(result, Err(WhichError::ModuleNotFound(_))));
159    }
160}