1use std::{io, path::PathBuf};
2
3use bon::Builder;
4use itertools::Itertools;
5use thiserror::Error;
6
7use crate::{
8 config::Config,
9 lua_rockspec::LuaModule,
10 lua_version::{LuaVersion, LuaVersionUnset},
11 package::PackageReq,
12 tree::TreeError,
13};
14
15#[derive(Builder)]
17#[builder(start_fn = new, finish_fn(name = _build, vis = ""))]
18pub struct Which<'a> {
19 #[builder(start_fn)]
20 module: LuaModule,
21 #[builder(start_fn)]
22 config: &'a Config,
23 #[builder(field)]
24 packages: Vec<PackageReq>,
25}
26
27impl<State> WhichBuilder<'_, State>
28where
29 State: which_builder::State,
30{
31 pub fn package(mut self, package: PackageReq) -> Self {
32 self.packages.push(package);
33 self
34 }
35
36 pub fn packages(mut self, packages: impl IntoIterator<Item = PackageReq>) -> Self {
37 self.packages.extend(packages);
38 self
39 }
40
41 pub fn search(self) -> Result<PathBuf, WhichError>
42 where
43 State: which_builder::IsComplete,
44 {
45 do_search(self._build())
46 }
47}
48
49#[derive(Error, Debug)]
50pub enum WhichError {
51 #[error(transparent)]
52 Io(#[from] io::Error),
53 #[error(transparent)]
54 Tree(#[from] TreeError),
55 #[error(transparent)]
56 LuaVersionUnset(#[from] LuaVersionUnset),
57 #[error("lua module {0} not found.")]
58 ModuleNotFound(LuaModule),
59}
60
61fn do_search(which: Which<'_>) -> Result<PathBuf, WhichError> {
62 let config = which.config;
63 let lua_version = LuaVersion::from(config)?;
64 let tree = config.user_tree(lua_version.clone())?;
65 let lockfile = tree.lockfile()?;
66 let local_packages = if which.packages.is_empty() {
67 lockfile.list().into_values().flatten().collect_vec()
68 } else {
69 which
70 .packages
71 .iter()
72 .flat_map(|req| {
73 lockfile
74 .find_rocks(req)
75 .into_iter()
76 .filter_map(|id| lockfile.get(&id))
77 .cloned()
78 .collect_vec()
79 })
80 .collect_vec()
81 };
82 local_packages
83 .into_iter()
84 .filter_map(|pkg| {
85 let rock_layout = tree.installed_rock_layout(&pkg).ok()?;
86 let lib_path = rock_layout.lib.join(which.module.to_lib_path());
87 if lib_path.is_file() {
88 return Some(lib_path);
89 }
90 let lua_path = rock_layout.src.join(which.module.to_lua_path());
91 if lua_path.is_file() {
92 return Some(lua_path);
93 }
94 let lua_path = rock_layout.src.join(which.module.to_lua_init_path());
95 if lua_path.is_file() {
96 return Some(lua_path);
97 }
98 None
99 })
100 .next()
101 .ok_or(WhichError::ModuleNotFound(which.module))
102}
103
104#[cfg(test)]
105mod test {
106 use super::*;
107 use crate::config::ConfigBuilder;
108 use assert_fs::prelude::PathCopy;
109 use std::{path::PathBuf, str::FromStr};
110
111 #[tokio::test]
112 async fn test_which() {
113 let tree_path =
114 PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("resources/test/sample-tree");
115 let temp = assert_fs::TempDir::new().unwrap();
116 temp.copy_from(&tree_path, &["**"]).unwrap();
117 let tree_path = temp.to_path_buf();
118 let config = ConfigBuilder::new()
119 .unwrap()
120 .user_tree(Some(tree_path.clone()))
121 .lua_version(Some(LuaVersion::Lua51))
122 .build()
123 .unwrap();
124
125 let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
126 .search()
127 .unwrap();
128 assert_eq!(result.file_name().unwrap().to_string_lossy(), "bar.lua");
129 assert_eq!(
130 result
131 .parent()
132 .unwrap()
133 .file_name()
134 .unwrap()
135 .to_string_lossy(),
136 "foo"
137 );
138 let result = Which::new(LuaModule::from_str("bat.baz").unwrap(), &config)
139 .search()
140 .unwrap();
141 assert_eq!(result.file_name().unwrap().to_string_lossy(), "baz.so");
142 assert_eq!(
143 result
144 .parent()
145 .unwrap()
146 .file_name()
147 .unwrap()
148 .to_string_lossy(),
149 "bat"
150 );
151 let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
152 .package("lua-cjson".parse().unwrap())
153 .search();
154 assert!(matches!(result, Err(WhichError::ModuleNotFound(_))));
155 let result = Which::new(LuaModule::from_str("foo.bar").unwrap(), &config)
156 .package("neorg@8.1.1-1".parse().unwrap())
157 .search();
158 assert!(matches!(result, Err(WhichError::ModuleNotFound(_))));
159 }
160}