1use std::{io, path::PathBuf};
2
3use bon::{builder, Builder};
4use itertools::Itertools;
5use thiserror::Error;
6
7use crate::{
8 config::{Config, LuaVersion, LuaVersionUnset},
9 lua_rockspec::LuaModule,
10 package::PackageReq,
11 tree::TreeError,
12};
13
14#[derive(Builder)]
16#[builder(start_fn = new, finish_fn(name = _build, vis = ""))]
17pub struct Which<'a> {
18 #[builder(start_fn)]
19 module: LuaModule,
20 #[builder(start_fn)]
21 config: &'a Config,
22 #[builder(field)]
23 packages: Vec<PackageReq>,
24}
25
26impl<State> WhichBuilder<'_, State>
27where
28 State: which_builder::State,
29{
30 pub fn package(mut self, package: PackageReq) -> Self {
31 self.packages.push(package);
32 self
33 }
34
35 pub fn packages(mut self, packages: impl IntoIterator<Item = PackageReq>) -> Self {
36 self.packages.extend(packages);
37 self
38 }
39
40 pub fn search(self) -> Result<PathBuf, WhichError>
41 where
42 State: which_builder::IsComplete,
43 {
44 do_search(self._build())
45 }
46}
47
48#[derive(Error, Debug)]
49pub enum WhichError {
50 #[error(transparent)]
51 Io(#[from] io::Error),
52 #[error(transparent)]
53 Tree(#[from] TreeError),
54 #[error(transparent)]
55 LuaVersionUnset(#[from] LuaVersionUnset),
56 #[error("lua module {0} not found.")]
57 ModuleNotFound(LuaModule),
58}
59
60fn do_search(which: Which<'_>) -> Result<PathBuf, WhichError> {
61 let config = which.config;
62 let lua_version = LuaVersion::from(config)?;
63 let tree = config.user_tree(lua_version.clone())?;
64 let lockfile = tree.lockfile()?;
65 let local_packages = if which.packages.is_empty() {
66 lockfile
67 .list()
68 .into_iter()
69 .flat_map(|(_, pkgs)| pkgs)
70 .collect_vec()
71 } else {
72 which
73 .packages
74 .iter()
75 .flat_map(|req| {
76 lockfile
77 .find_rocks(req)
78 .into_iter()
79 .map(|id| lockfile.get(&id).unwrap())
80 .cloned()
81 .collect_vec()
82 })
83 .collect_vec()
84 };
85 local_packages
86 .into_iter()
87 .filter_map(|pkg| {
88 let rock_layout = tree.installed_rock_layout(&pkg).ok()?;
89 let lib_path = rock_layout.lib.join(which.module.to_lib_path());
90 if lib_path.is_file() {
91 return Some(lib_path);
92 }
93 let lua_path = rock_layout.src.join(which.module.to_lua_path());
94 if lua_path.is_file() {
95 return Some(lua_path);
96 }
97 let lua_path = rock_layout.src.join(which.module.to_lua_init_path());
98 if lua_path.is_file() {
99 return Some(lua_path);
100 }
101 None
102 })
103 .next()
104 .ok_or(WhichError::ModuleNotFound(which.module))
105}
106
107#[cfg(test)]
108mod test {
109 use super::*;
110 use crate::config::{ConfigBuilder, LuaVersion};
111 use assert_fs::prelude::PathCopy;
112 use std::{path::PathBuf, str::FromStr};
113
114 #[tokio::test]
115 async fn test_which() {
116 let tree_path =
117 PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("resources/test/sample-tree");
118 let temp = assert_fs::TempDir::new().unwrap();
119 temp.copy_from(&tree_path, &["**"]).unwrap();
120 let tree_path = temp.to_path_buf();
121 let config = ConfigBuilder::new()
122 .unwrap()
123 .user_tree(Some(tree_path.clone()))
124 .lua_version(Some(LuaVersion::Lua51))
125 .build()
126 .unwrap();
127
128 let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
129 .search()
130 .unwrap();
131 assert_eq!(result.file_name().unwrap().to_string_lossy(), "bar.lua");
132 assert_eq!(
133 result
134 .parent()
135 .unwrap()
136 .file_name()
137 .unwrap()
138 .to_string_lossy(),
139 "foo"
140 );
141 let result = Which::new(LuaModule::from_str("bat.baz").unwrap(), &config)
142 .search()
143 .unwrap();
144 assert_eq!(result.file_name().unwrap().to_string_lossy(), "baz.so");
145 assert_eq!(
146 result
147 .parent()
148 .unwrap()
149 .file_name()
150 .unwrap()
151 .to_string_lossy(),
152 "bat"
153 );
154 let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
155 .package("lua-cjson".parse().unwrap())
156 .search();
157 assert!(matches!(result, Err(WhichError::ModuleNotFound(_))));
158 let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
159 .package("neorg@8.1.1-1".parse().unwrap())
160 .search();
161 assert!(matches!(result, Err(WhichError::ModuleNotFound(_))));
162 }
163}