1use std::{io, path::PathBuf};
2
3use bon::Builder;
4use itertools::Itertools;
5use thiserror::Error;
6
7use crate::{
8 config::Config,
9 lua_rockspec::LuaModule,
10 lua_version::{LuaVersion, LuaVersionUnset},
11 package::PackageReq,
12 tree::TreeError,
13};
14
15#[derive(Builder)]
17#[builder(start_fn = new, finish_fn(name = _build, vis = ""))]
18pub struct Which<'a> {
19 #[builder(start_fn)]
20 module: LuaModule,
21 #[builder(start_fn)]
22 config: &'a Config,
23 #[builder(field)]
24 packages: Vec<PackageReq>,
25}
26
27impl<State> WhichBuilder<'_, State>
28where
29 State: which_builder::State,
30{
31 pub fn package(mut self, package: PackageReq) -> Self {
32 self.packages.push(package);
33 self
34 }
35
36 pub fn packages(mut self, packages: impl IntoIterator<Item = PackageReq>) -> Self {
37 self.packages.extend(packages);
38 self
39 }
40
41 pub fn search(self) -> Result<PathBuf, WhichError>
42 where
43 State: which_builder::IsComplete,
44 {
45 do_search(self._build())
46 }
47}
48
49#[derive(Error, Debug)]
50pub enum WhichError {
51 #[error(transparent)]
52 Io(#[from] io::Error),
53 #[error(transparent)]
54 Tree(#[from] TreeError),
55 #[error(transparent)]
56 LuaVersionUnset(#[from] LuaVersionUnset),
57 #[error("lua module {0} not found.")]
58 ModuleNotFound(LuaModule),
59}
60
61fn do_search(which: Which<'_>) -> Result<PathBuf, WhichError> {
62 let config = which.config;
63 let lua_version = LuaVersion::from(config)?;
64 let tree = config.user_tree(lua_version.clone())?;
65 let lockfile = tree.lockfile()?;
66 let local_packages = if which.packages.is_empty() {
67 lockfile
68 .list()
69 .into_iter()
70 .flat_map(|(_, pkgs)| pkgs)
71 .collect_vec()
72 } else {
73 which
74 .packages
75 .iter()
76 .flat_map(|req| {
77 lockfile
78 .find_rocks(req)
79 .into_iter()
80 .filter_map(|id| lockfile.get(&id))
81 .cloned()
82 .collect_vec()
83 })
84 .collect_vec()
85 };
86 local_packages
87 .into_iter()
88 .filter_map(|pkg| {
89 let rock_layout = tree.installed_rock_layout(&pkg).ok()?;
90 let lib_path = rock_layout.lib.join(which.module.to_lib_path());
91 if lib_path.is_file() {
92 return Some(lib_path);
93 }
94 let lua_path = rock_layout.src.join(which.module.to_lua_path());
95 if lua_path.is_file() {
96 return Some(lua_path);
97 }
98 let lua_path = rock_layout.src.join(which.module.to_lua_init_path());
99 if lua_path.is_file() {
100 return Some(lua_path);
101 }
102 None
103 })
104 .next()
105 .ok_or(WhichError::ModuleNotFound(which.module))
106}
107
108#[cfg(test)]
109mod test {
110 use super::*;
111 use crate::config::ConfigBuilder;
112 use assert_fs::prelude::PathCopy;
113 use std::{path::PathBuf, str::FromStr};
114
115 #[tokio::test]
116 async fn test_which() {
117 let tree_path =
118 PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("resources/test/sample-tree");
119 let temp = assert_fs::TempDir::new().unwrap();
120 temp.copy_from(&tree_path, &["**"]).unwrap();
121 let tree_path = temp.to_path_buf();
122 let config = ConfigBuilder::new()
123 .unwrap()
124 .user_tree(Some(tree_path.clone()))
125 .lua_version(Some(LuaVersion::Lua51))
126 .build()
127 .unwrap();
128
129 let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
130 .search()
131 .unwrap();
132 assert_eq!(result.file_name().unwrap().to_string_lossy(), "bar.lua");
133 assert_eq!(
134 result
135 .parent()
136 .unwrap()
137 .file_name()
138 .unwrap()
139 .to_string_lossy(),
140 "foo"
141 );
142 let result = Which::new(LuaModule::from_str("bat.baz").unwrap(), &config)
143 .search()
144 .unwrap();
145 assert_eq!(result.file_name().unwrap().to_string_lossy(), "baz.so");
146 assert_eq!(
147 result
148 .parent()
149 .unwrap()
150 .file_name()
151 .unwrap()
152 .to_string_lossy(),
153 "bat"
154 );
155 let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
156 .package("lua-cjson".parse().unwrap())
157 .search();
158 assert!(matches!(result, Err(WhichError::ModuleNotFound(_))));
159 let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
160 .package("neorg@8.1.1-1".parse().unwrap())
161 .search();
162 assert!(matches!(result, Err(WhichError::ModuleNotFound(_))));
163 }
164}