1use std::{collections::HashMap, convert::Infallible, fmt::Display, str::FromStr};
2
3use serde::{Deserialize, Deserializer};
4use thiserror::Error;
5
6use crate::{
7 lockfile::{OptState, PinnedState},
8 lua_rockspec::{
9 ExternalDependencySpec, PartialOverride, PerPlatform, PlatformOverridable, RockSourceSpec,
10 },
11 package::{PackageName, PackageReq, PackageReqParseError, PackageSpec, PackageVersionReq},
12};
13
14#[derive(Error, Debug)]
15pub enum LuaDependencySpecParseError {
16 #[error(transparent)]
17 PackageReq(#[from] PackageReqParseError),
18}
19
20#[derive(Debug, Clone, PartialEq)]
21pub struct LuaDependencySpec {
22 pub(crate) package_req: PackageReq,
23 pub(crate) pin: PinnedState,
24 pub(crate) opt: OptState,
25 pub(crate) source: Option<RockSourceSpec>,
26}
27
28impl LuaDependencySpec {
29 pub fn package_req(&self) -> &PackageReq {
30 &self.package_req
31 }
32 pub fn pin(&self) -> &PinnedState {
33 &self.pin
34 }
35 pub fn opt(&self) -> &OptState {
36 &self.opt
37 }
38 pub fn source(&self) -> &Option<RockSourceSpec> {
39 &self.source
40 }
41 pub fn into_package_req(self) -> PackageReq {
42 self.package_req
43 }
44 pub fn name(&self) -> &PackageName {
45 self.package_req.name()
46 }
47 pub fn version_req(&self) -> &PackageVersionReq {
48 self.package_req.version_req()
49 }
50 pub fn matches(&self, package: &PackageSpec) -> bool {
51 self.package_req.matches(package)
52 }
53}
54
55impl From<PackageName> for LuaDependencySpec {
56 fn from(name: PackageName) -> Self {
57 Self {
58 package_req: PackageReq::from(name),
59 pin: PinnedState::default(),
60 opt: OptState::default(),
61 source: None,
62 }
63 }
64}
65
66impl From<PackageReq> for LuaDependencySpec {
67 fn from(package_req: PackageReq) -> Self {
68 Self {
69 package_req,
70 pin: PinnedState::default(),
71 opt: OptState::default(),
72 source: None,
73 }
74 }
75}
76
77impl FromStr for LuaDependencySpec {
78 type Err = LuaDependencySpecParseError;
79
80 fn from_str(str: &str) -> Result<Self, LuaDependencySpecParseError> {
81 let package_req = PackageReq::from_str(str)?;
82 Ok(Self {
83 package_req,
84 pin: PinnedState::default(),
85 opt: OptState::default(),
86 source: None,
87 })
88 }
89}
90
91impl Display for LuaDependencySpec {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 if self.version_req().is_any() {
94 self.name().fmt(f)
95 } else {
96 f.write_str(format!("{}{}", self.name(), self.version_req()).as_str())
97 }
98 }
99}
100
101impl PartialOverride for Vec<LuaDependencySpec> {
105 type Err = Infallible;
106
107 fn apply_overrides(&self, override_vec: &Self) -> Result<Self, Self::Err> {
108 let mut result_map: HashMap<String, LuaDependencySpec> = self
109 .iter()
110 .map(|dep| (dep.name().clone().to_string(), dep.clone()))
111 .collect();
112 for override_dep in override_vec {
113 result_map.insert(
114 override_dep.name().clone().to_string(),
115 override_dep.clone(),
116 );
117 }
118 Ok(result_map.into_values().collect())
119 }
120}
121
122impl PlatformOverridable for Vec<LuaDependencySpec> {
123 type Err = Infallible;
124
125 fn on_nil<T>() -> Result<super::PerPlatform<T>, <Self as PlatformOverridable>::Err>
126 where
127 T: PlatformOverridable,
128 T: Default,
129 {
130 Ok(PerPlatform::default())
131 }
132}
133
134impl<'de> Deserialize<'de> for LuaDependencySpec {
135 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
136 where
137 D: Deserializer<'de>,
138 {
139 let package_req = PackageReq::deserialize(deserializer)?;
140 Ok(Self {
141 package_req,
142 pin: PinnedState::default(),
143 opt: OptState::default(),
144 source: None,
145 })
146 }
147}
148
149#[derive(Debug, Deserialize)]
150#[serde(rename_all = "lowercase")]
151pub enum DependencyType<T> {
152 Regular(Vec<T>),
153 Build(Vec<T>),
154 Test(Vec<T>),
155 External(HashMap<String, ExternalDependencySpec>),
156}
157
158#[derive(Debug, Deserialize)]
159#[serde(rename_all = "lowercase")]
160pub enum LuaDependencyType<T> {
161 Regular(Vec<T>),
162 Build(Vec<T>),
163 Test(Vec<T>),
164}
165
166#[cfg(test)]
167mod test {
168
169 use ottavino::{Closure, Executor, Fuel, Lua, Value};
170 use ottavino_util::serde::from_value;
171 use path_slash::PathBufExt;
172
173 use super::*;
174
175 fn eval_lua<T: serde::de::DeserializeOwned>(code: &str) -> Result<T, ottavino::ExternError> {
176 Lua::core().try_enter(|ctx| {
177 let closure = Closure::load(ctx, None, code.as_bytes())?;
178 let executor = Executor::start(ctx, closure.into(), ());
179 executor.step(ctx, &mut Fuel::with(i32::MAX))?;
180 from_value(executor.take_result::<Value<'_>>(ctx)??).map_err(ottavino::Error::from)
181 })
182 }
183
184 #[tokio::test]
185 async fn test_override_lua_dependency_spec() {
186 let neorg_a: LuaDependencySpec = "neorg 1.0.0".parse().unwrap();
187 let neorg_b: LuaDependencySpec = "neorg 2.0.0".parse().unwrap();
188 let foo: LuaDependencySpec = "foo 1.0.0".parse().unwrap();
189 let bar: LuaDependencySpec = "bar 1.0.0".parse().unwrap();
190 let base_vec = vec![neorg_a, foo.clone()];
191 let override_vec = vec![neorg_b.clone(), bar.clone()];
192 let result = base_vec.apply_overrides(&override_vec).unwrap();
193 assert_eq!(result.clone().len(), 3);
194 assert_eq!(
195 result
196 .into_iter()
197 .filter(|dep| *dep == neorg_b || *dep == foo || *dep == bar)
198 .count(),
199 3
200 );
201 }
202
203 #[test]
204 fn test_dependency_type_from_lua() {
205 let regular_deps: DependencyType<LuaDependencySpec> =
206 eval_lua(r#"return { regular = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
207 let build_deps: DependencyType<LuaDependencySpec> =
208 eval_lua(r#"return { build = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
209 let test_deps: DependencyType<LuaDependencySpec> =
210 eval_lua(r#"return { test = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
211 let external_deps: DependencyType<ExternalDependencySpec> = eval_lua(
212 r#"return { external = { foo = { header = "foo.h", library = "libfoo.so" }, bar = { header = "bar.h" } } }"#,
213 )
214 .unwrap();
215
216 match regular_deps {
217 DependencyType::Regular(deps) => {
218 assert_eq!(deps.len(), 2);
219 assert_eq!(deps[0].to_string(), "neorg==1.0.0");
220 assert_eq!(deps[1].to_string(), "foo==1.0.0");
221 }
222 _ => panic!("Expected regular dependencies"),
223 }
224
225 match build_deps {
226 DependencyType::Build(deps) => {
227 assert_eq!(deps.len(), 2);
228 assert_eq!(deps[0].to_string(), "neorg==1.0.0");
229 assert_eq!(deps[1].to_string(), "foo==1.0.0");
230 }
231 _ => panic!("Expected build dependencies"),
232 }
233
234 match test_deps {
235 DependencyType::Test(deps) => {
236 assert_eq!(deps.len(), 2);
237 assert_eq!(deps[0].to_string(), "neorg==1.0.0");
238 assert_eq!(deps[1].to_string(), "foo==1.0.0");
239 }
240 _ => panic!("Expected test dependencies"),
241 }
242
243 match external_deps {
244 DependencyType::External(deps) => {
245 assert_eq!(deps.len(), 2);
246 assert_eq!(
247 deps["foo"].header.as_ref().unwrap().to_slash_lossy(),
248 "foo.h"
249 );
250 assert_eq!(
251 deps["foo"].library.as_ref().unwrap().to_slash_lossy(),
252 "libfoo.so"
253 );
254
255 assert_eq!(
256 deps["bar"].header.as_ref().unwrap().to_slash_lossy(),
257 "bar.h"
258 );
259 assert!(deps["bar"].library.is_none());
260 }
261 _ => panic!("Expected external dependencies"),
262 }
263
264 let _err: ottavino::ExternError =
265 eval_lua::<DependencyType<ExternalDependencySpec>>("return {}").unwrap_err();
266 }
267
268 #[test]
269 fn test_lua_dependency_type_from_lua() {
270 let regular_deps: LuaDependencyType<LuaDependencySpec> =
271 eval_lua(r#"return { regular = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
272 let build_deps: LuaDependencyType<LuaDependencySpec> =
273 eval_lua(r#"return { build = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
274 let test_deps: LuaDependencyType<LuaDependencySpec> =
275 eval_lua(r#"return { test = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
276
277 match regular_deps {
278 LuaDependencyType::Regular(deps) => {
279 assert_eq!(deps.len(), 2);
280 assert_eq!(deps[0].to_string(), "neorg==1.0.0");
281 assert_eq!(deps[1].to_string(), "foo==1.0.0");
282 }
283 _ => panic!("Expected regular dependencies"),
284 }
285
286 match build_deps {
287 LuaDependencyType::Build(deps) => {
288 assert_eq!(deps.len(), 2);
289 assert_eq!(deps[0].to_string(), "neorg==1.0.0");
290 assert_eq!(deps[1].to_string(), "foo==1.0.0");
291 }
292 _ => panic!("Expected build dependencies"),
293 }
294
295 match test_deps {
296 LuaDependencyType::Test(deps) => {
297 assert_eq!(deps.len(), 2);
298 assert_eq!(deps[0].to_string(), "neorg==1.0.0");
299 assert_eq!(deps[1].to_string(), "foo==1.0.0");
300 }
301 _ => panic!("Expected test dependencies"),
302 }
303
304 eval_lua::<LuaDependencyType<LuaDependencySpec>>("return {}").unwrap_err();
305 }
306}