Skip to main content

lux_lib/which/
mod.rs

1use std::{io, path::PathBuf};
2
3use bon::Builder;
4use itertools::Itertools;
5use thiserror::Error;
6
7use crate::{
8    config::Config,
9    lua_rockspec::LuaModule,
10    lua_version::{LuaVersion, LuaVersionUnset},
11    package::PackageReq,
12    tree::TreeError,
13};
14
15/// A rocks module finder.
16#[derive(Builder)]
17#[builder(start_fn = new, finish_fn(name = _build, vis = ""))]
18pub struct Which<'a> {
19    #[builder(start_fn)]
20    module: LuaModule,
21    #[builder(start_fn)]
22    config: &'a Config,
23    #[builder(field)]
24    packages: Vec<PackageReq>,
25}
26
27impl<State> WhichBuilder<'_, State>
28where
29    State: which_builder::State,
30{
31    pub fn package(mut self, package: PackageReq) -> Self {
32        self.packages.push(package);
33        self
34    }
35
36    pub fn packages(mut self, packages: impl IntoIterator<Item = PackageReq>) -> Self {
37        self.packages.extend(packages);
38        self
39    }
40
41    pub fn search(self) -> Result<PathBuf, WhichError>
42    where
43        State: which_builder::IsComplete,
44    {
45        do_search(self._build())
46    }
47}
48
49#[derive(Error, Debug)]
50pub enum WhichError {
51    #[error(transparent)]
52    Io(#[from] io::Error),
53    #[error(transparent)]
54    Tree(#[from] TreeError),
55    #[error(transparent)]
56    LuaVersionUnset(#[from] LuaVersionUnset),
57    #[error("lua module {0} not found.")]
58    ModuleNotFound(LuaModule),
59}
60
61fn do_search(which: Which<'_>) -> Result<PathBuf, WhichError> {
62    let config = which.config;
63    let lua_version = LuaVersion::from(config)?;
64    let tree = config.user_tree(lua_version.clone())?;
65    let lockfile = tree.lockfile()?;
66    let local_packages = if which.packages.is_empty() {
67        lockfile
68            .list()
69            .into_iter()
70            .flat_map(|(_, pkgs)| pkgs)
71            .collect_vec()
72    } else {
73        which
74            .packages
75            .iter()
76            .flat_map(|req| {
77                lockfile
78                    .find_rocks(req)
79                    .into_iter()
80                    .filter_map(|id| lockfile.get(&id))
81                    .cloned()
82                    .collect_vec()
83            })
84            .collect_vec()
85    };
86    local_packages
87        .into_iter()
88        .filter_map(|pkg| {
89            let rock_layout = tree.installed_rock_layout(&pkg).ok()?;
90            let lib_path = rock_layout.lib.join(which.module.to_lib_path());
91            if lib_path.is_file() {
92                return Some(lib_path);
93            }
94            let lua_path = rock_layout.src.join(which.module.to_lua_path());
95            if lua_path.is_file() {
96                return Some(lua_path);
97            }
98            let lua_path = rock_layout.src.join(which.module.to_lua_init_path());
99            if lua_path.is_file() {
100                return Some(lua_path);
101            }
102            None
103        })
104        .next()
105        .ok_or(WhichError::ModuleNotFound(which.module))
106}
107
108#[cfg(test)]
109mod test {
110    use super::*;
111    use crate::config::ConfigBuilder;
112    use assert_fs::prelude::PathCopy;
113    use std::{path::PathBuf, str::FromStr};
114
115    #[tokio::test]
116    async fn test_which() {
117        let tree_path =
118            PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("resources/test/sample-tree");
119        let temp = assert_fs::TempDir::new().unwrap();
120        temp.copy_from(&tree_path, &["**"]).unwrap();
121        let tree_path = temp.to_path_buf();
122        let config = ConfigBuilder::new()
123            .unwrap()
124            .user_tree(Some(tree_path.clone()))
125            .lua_version(Some(LuaVersion::Lua51))
126            .build()
127            .unwrap();
128
129        let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
130            .search()
131            .unwrap();
132        assert_eq!(result.file_name().unwrap().to_string_lossy(), "bar.lua");
133        assert_eq!(
134            result
135                .parent()
136                .unwrap()
137                .file_name()
138                .unwrap()
139                .to_string_lossy(),
140            "foo"
141        );
142        let result = Which::new(LuaModule::from_str("bat.baz").unwrap(), &config)
143            .search()
144            .unwrap();
145        assert_eq!(result.file_name().unwrap().to_string_lossy(), "baz.so");
146        assert_eq!(
147            result
148                .parent()
149                .unwrap()
150                .file_name()
151                .unwrap()
152                .to_string_lossy(),
153            "bat"
154        );
155        let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
156            .package("lua-cjson".parse().unwrap())
157            .search();
158        assert!(matches!(result, Err(WhichError::ModuleNotFound(_))));
159        let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
160            .package("neorg@8.1.1-1".parse().unwrap())
161            .search();
162        assert!(matches!(result, Err(WhichError::ModuleNotFound(_))));
163    }
164}