lux_lib/which/
mod.rs

1use std::{io, path::PathBuf};
2
3use bon::{builder, Builder};
4use itertools::Itertools;
5use thiserror::Error;
6
7use crate::{
8    config::{Config, LuaVersion, LuaVersionUnset},
9    lua_rockspec::LuaModule,
10    package::PackageReq,
11    tree::TreeError,
12};
13
14/// A rocks module finder.
15#[derive(Builder)]
16#[builder(start_fn = new, finish_fn(name = _build, vis = ""))]
17pub struct Which<'a> {
18    #[builder(start_fn)]
19    module: LuaModule,
20    #[builder(start_fn)]
21    config: &'a Config,
22    #[builder(field)]
23    packages: Vec<PackageReq>,
24}
25
26impl<State> WhichBuilder<'_, State>
27where
28    State: which_builder::State,
29{
30    pub fn package(mut self, package: PackageReq) -> Self {
31        self.packages.push(package);
32        self
33    }
34
35    pub fn packages(mut self, packages: impl IntoIterator<Item = PackageReq>) -> Self {
36        self.packages.extend(packages);
37        self
38    }
39
40    pub fn search(self) -> Result<PathBuf, WhichError>
41    where
42        State: which_builder::IsComplete,
43    {
44        do_search(self._build())
45    }
46}
47
48#[derive(Error, Debug)]
49pub enum WhichError {
50    #[error(transparent)]
51    Io(#[from] io::Error),
52    #[error(transparent)]
53    Tree(#[from] TreeError),
54    #[error(transparent)]
55    LuaVersionUnset(#[from] LuaVersionUnset),
56    #[error("lua module {0} not found.")]
57    ModuleNotFound(LuaModule),
58}
59
60fn do_search(which: Which<'_>) -> Result<PathBuf, WhichError> {
61    let config = which.config;
62    let lua_version = LuaVersion::from(config)?;
63    let tree = config.user_tree(lua_version.clone())?;
64    let lockfile = tree.lockfile()?;
65    let local_packages = if which.packages.is_empty() {
66        lockfile
67            .list()
68            .into_iter()
69            .flat_map(|(_, pkgs)| pkgs)
70            .collect_vec()
71    } else {
72        which
73            .packages
74            .iter()
75            .flat_map(|req| {
76                lockfile
77                    .find_rocks(req)
78                    .into_iter()
79                    .map(|id| lockfile.get(&id).unwrap())
80                    .cloned()
81                    .collect_vec()
82            })
83            .collect_vec()
84    };
85    local_packages
86        .into_iter()
87        .filter_map(|pkg| {
88            let rock_layout = tree.installed_rock_layout(&pkg).ok()?;
89            let lib_path = rock_layout.lib.join(which.module.to_lib_path());
90            if lib_path.is_file() {
91                return Some(lib_path);
92            }
93            let lua_path = rock_layout.src.join(which.module.to_lua_path());
94            if lua_path.is_file() {
95                return Some(lua_path);
96            }
97            let lua_path = rock_layout.src.join(which.module.to_lua_init_path());
98            if lua_path.is_file() {
99                return Some(lua_path);
100            }
101            None
102        })
103        .next()
104        .ok_or(WhichError::ModuleNotFound(which.module))
105}
106
107#[cfg(test)]
108mod test {
109    use super::*;
110    use crate::config::{ConfigBuilder, LuaVersion};
111    use assert_fs::prelude::PathCopy;
112    use std::{path::PathBuf, str::FromStr};
113
114    #[tokio::test]
115    async fn test_which() {
116        let tree_path =
117            PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("resources/test/sample-tree");
118        let temp = assert_fs::TempDir::new().unwrap();
119        temp.copy_from(&tree_path, &["**"]).unwrap();
120        let tree_path = temp.to_path_buf();
121        let config = ConfigBuilder::new()
122            .unwrap()
123            .user_tree(Some(tree_path.clone()))
124            .lua_version(Some(LuaVersion::Lua51))
125            .build()
126            .unwrap();
127
128        let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
129            .search()
130            .unwrap();
131        assert_eq!(result.file_name().unwrap().to_string_lossy(), "bar.lua");
132        assert_eq!(
133            result
134                .parent()
135                .unwrap()
136                .file_name()
137                .unwrap()
138                .to_string_lossy(),
139            "foo"
140        );
141        let result = Which::new(LuaModule::from_str("bat.baz").unwrap(), &config)
142            .search()
143            .unwrap();
144        assert_eq!(result.file_name().unwrap().to_string_lossy(), "baz.so");
145        assert_eq!(
146            result
147                .parent()
148                .unwrap()
149                .file_name()
150                .unwrap()
151                .to_string_lossy(),
152            "bat"
153        );
154        let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
155            .package("lua-cjson".parse().unwrap())
156            .search();
157        assert!(matches!(result, Err(WhichError::ModuleNotFound(_))));
158        let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
159            .package("neorg@8.1.1-1".parse().unwrap())
160            .search();
161        assert!(matches!(result, Err(WhichError::ModuleNotFound(_))));
162    }
163}