1use std::{collections::HashMap, convert::Infallible, fmt::Display, str::FromStr};
2
3use serde::{Deserialize, Deserializer};
4use thiserror::Error;
5
6use crate::{
7 lockfile::{OptState, PinnedState},
8 lua_rockspec::{
9 ExternalDependencySpec, PartialOverride, PerPlatform, PlatformOverridable, RockSourceSpec,
10 },
11 package::{PackageName, PackageReq, PackageReqParseError, PackageSpec, PackageVersionReq},
12};
13
14#[derive(Error, Debug)]
15pub enum LuaDependencySpecParseError {
16 #[error(transparent)]
17 PackageReq(#[from] PackageReqParseError),
18}
19
20#[derive(Debug, Clone, PartialEq)]
21pub struct LuaDependencySpec {
22 pub(crate) package_req: PackageReq,
23 pub(crate) pin: PinnedState,
24 pub(crate) opt: OptState,
25 pub(crate) source: Option<RockSourceSpec>,
26}
27
28impl LuaDependencySpec {
29 pub fn package_req(&self) -> &PackageReq {
30 &self.package_req
31 }
32 pub fn pin(&self) -> &PinnedState {
33 &self.pin
34 }
35 pub fn opt(&self) -> &OptState {
36 &self.opt
37 }
38 pub fn source(&self) -> &Option<RockSourceSpec> {
39 &self.source
40 }
41 pub fn into_package_req(self) -> PackageReq {
42 self.package_req
43 }
44 pub fn name(&self) -> &PackageName {
45 self.package_req.name()
46 }
47 pub fn version_req(&self) -> &PackageVersionReq {
48 self.package_req.version_req()
49 }
50 pub fn matches(&self, package: &PackageSpec) -> bool {
51 self.package_req.matches(package)
52 }
53}
54
55impl From<PackageName> for LuaDependencySpec {
56 fn from(name: PackageName) -> Self {
57 Self {
58 package_req: PackageReq::from(name),
59 pin: PinnedState::default(),
60 opt: OptState::default(),
61 source: None,
62 }
63 }
64}
65
66impl From<PackageReq> for LuaDependencySpec {
67 fn from(package_req: PackageReq) -> Self {
68 Self {
69 package_req,
70 pin: PinnedState::default(),
71 opt: OptState::default(),
72 source: None,
73 }
74 }
75}
76
77impl FromStr for LuaDependencySpec {
78 type Err = LuaDependencySpecParseError;
79
80 fn from_str(str: &str) -> Result<Self, LuaDependencySpecParseError> {
81 let package_req = PackageReq::from_str(str)?;
82 Ok(Self {
83 package_req,
84 pin: PinnedState::default(),
85 opt: OptState::default(),
86 source: None,
87 })
88 }
89}
90
91impl Display for LuaDependencySpec {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 if self.version_req().is_any() {
94 self.name().fmt(f)
95 } else {
96 f.write_str(format!("{}{}", self.name(), self.version_req()).as_str())
97 }
98 }
99}
100
101impl PartialOverride for Vec<LuaDependencySpec> {
105 type Err = Infallible;
106
107 fn apply_overrides(&self, override_vec: &Self) -> Result<Self, Self::Err> {
108 let mut result_map: HashMap<String, LuaDependencySpec> = self
109 .iter()
110 .map(|dep| (dep.name().clone().to_string(), dep.clone()))
111 .collect();
112 for override_dep in override_vec {
113 result_map.insert(
114 override_dep.name().clone().to_string(),
115 override_dep.clone(),
116 );
117 }
118 Ok(result_map.into_values().collect())
119 }
120}
121
122impl PlatformOverridable for Vec<LuaDependencySpec> {
123 type Err = Infallible;
124
125 fn on_nil<T>() -> Result<super::PerPlatform<T>, <Self as PlatformOverridable>::Err>
126 where
127 T: PlatformOverridable,
128 T: Default,
129 {
130 Ok(PerPlatform::default())
131 }
132}
133
134impl<'de> Deserialize<'de> for LuaDependencySpec {
135 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
136 where
137 D: Deserializer<'de>,
138 {
139 let package_req = PackageReq::deserialize(deserializer)?;
140 Ok(Self {
141 package_req,
142 pin: PinnedState::default(),
143 opt: OptState::default(),
144 source: None,
145 })
146 }
147}
148
149#[derive(Debug, Deserialize)]
150#[serde(rename_all = "lowercase")]
151pub enum DependencyType<T> {
152 Regular(Vec<T>),
153 Build(Vec<T>),
154 Test(Vec<T>),
155 External(HashMap<String, ExternalDependencySpec>),
156}
157
158impl<T> DependencyType<T> {
159 pub fn as_ref(&self) -> DependencyType<&T> {
160 match *self {
161 Self::Regular(ref x) => DependencyType::Regular(x.iter().collect()),
162 Self::Build(ref x) => DependencyType::Build(x.iter().collect()),
163 Self::Test(ref x) => DependencyType::Test(x.iter().collect()),
164 Self::External(ref x) => DependencyType::External(x.clone()),
165 }
166 }
167}
168
169#[derive(Debug, Deserialize)]
170#[serde(rename_all = "lowercase")]
171pub enum LuaDependencyType<T> {
172 Regular(Vec<T>),
173 Build(Vec<T>),
174 Test(Vec<T>),
175}
176
177#[cfg(test)]
178mod test {
179
180 use ottavino::{Closure, Executor, Fuel, Lua, Value};
181 use ottavino_util::serde::from_value;
182 use path_slash::PathBufExt;
183
184 use super::*;
185
186 fn eval_lua<T: serde::de::DeserializeOwned>(code: &str) -> Result<T, ottavino::ExternError> {
187 Lua::core().try_enter(|ctx| {
188 let closure = Closure::load(ctx, None, code.as_bytes())?;
189 let executor = Executor::start(ctx, closure.into(), ());
190 executor.step(ctx, &mut Fuel::with(i32::MAX))?;
191 from_value(executor.take_result::<Value<'_>>(ctx)??).map_err(ottavino::Error::from)
192 })
193 }
194
195 #[tokio::test]
196 async fn test_override_lua_dependency_spec() {
197 let neorg_a: LuaDependencySpec = "neorg 1.0.0".parse().unwrap();
198 let neorg_b: LuaDependencySpec = "neorg 2.0.0".parse().unwrap();
199 let foo: LuaDependencySpec = "foo 1.0.0".parse().unwrap();
200 let bar: LuaDependencySpec = "bar 1.0.0".parse().unwrap();
201 let base_vec = vec![neorg_a, foo.clone()];
202 let override_vec = vec![neorg_b.clone(), bar.clone()];
203 let result = base_vec.apply_overrides(&override_vec).unwrap();
204 assert_eq!(result.clone().len(), 3);
205 assert_eq!(
206 result
207 .into_iter()
208 .filter(|dep| *dep == neorg_b || *dep == foo || *dep == bar)
209 .count(),
210 3
211 );
212 }
213
214 #[test]
215 fn test_dependency_type_from_lua() {
216 let regular_deps: DependencyType<LuaDependencySpec> =
217 eval_lua(r#"return { regular = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
218 let build_deps: DependencyType<LuaDependencySpec> =
219 eval_lua(r#"return { build = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
220 let test_deps: DependencyType<LuaDependencySpec> =
221 eval_lua(r#"return { test = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
222 let external_deps: DependencyType<ExternalDependencySpec> = eval_lua(
223 r#"return { external = { foo = { header = "foo.h", library = "libfoo.so" }, bar = { header = "bar.h" } } }"#,
224 )
225 .unwrap();
226
227 match regular_deps {
228 DependencyType::Regular(deps) => {
229 assert_eq!(deps.len(), 2);
230 assert_eq!(deps[0].to_string(), "neorg==1.0.0");
231 assert_eq!(deps[1].to_string(), "foo==1.0.0");
232 }
233 _ => panic!("Expected regular dependencies"),
234 }
235
236 match build_deps {
237 DependencyType::Build(deps) => {
238 assert_eq!(deps.len(), 2);
239 assert_eq!(deps[0].to_string(), "neorg==1.0.0");
240 assert_eq!(deps[1].to_string(), "foo==1.0.0");
241 }
242 _ => panic!("Expected build dependencies"),
243 }
244
245 match test_deps {
246 DependencyType::Test(deps) => {
247 assert_eq!(deps.len(), 2);
248 assert_eq!(deps[0].to_string(), "neorg==1.0.0");
249 assert_eq!(deps[1].to_string(), "foo==1.0.0");
250 }
251 _ => panic!("Expected test dependencies"),
252 }
253
254 match external_deps {
255 DependencyType::External(deps) => {
256 assert_eq!(deps.len(), 2);
257 assert_eq!(
258 deps["foo"].header.as_ref().unwrap().to_slash_lossy(),
259 "foo.h"
260 );
261 assert_eq!(
262 deps["foo"].library.as_ref().unwrap().to_slash_lossy(),
263 "libfoo.so"
264 );
265
266 assert_eq!(
267 deps["bar"].header.as_ref().unwrap().to_slash_lossy(),
268 "bar.h"
269 );
270 assert!(deps["bar"].library.is_none());
271 }
272 _ => panic!("Expected external dependencies"),
273 }
274
275 let _err: ottavino::ExternError =
276 eval_lua::<DependencyType<ExternalDependencySpec>>("return {}").unwrap_err();
277 }
278
279 #[test]
280 fn test_lua_dependency_type_from_lua() {
281 let regular_deps: LuaDependencyType<LuaDependencySpec> =
282 eval_lua(r#"return { regular = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
283 let build_deps: LuaDependencyType<LuaDependencySpec> =
284 eval_lua(r#"return { build = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
285 let test_deps: LuaDependencyType<LuaDependencySpec> =
286 eval_lua(r#"return { test = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
287
288 match regular_deps {
289 LuaDependencyType::Regular(deps) => {
290 assert_eq!(deps.len(), 2);
291 assert_eq!(deps[0].to_string(), "neorg==1.0.0");
292 assert_eq!(deps[1].to_string(), "foo==1.0.0");
293 }
294 _ => panic!("Expected regular dependencies"),
295 }
296
297 match build_deps {
298 LuaDependencyType::Build(deps) => {
299 assert_eq!(deps.len(), 2);
300 assert_eq!(deps[0].to_string(), "neorg==1.0.0");
301 assert_eq!(deps[1].to_string(), "foo==1.0.0");
302 }
303 _ => panic!("Expected build dependencies"),
304 }
305
306 match test_deps {
307 LuaDependencyType::Test(deps) => {
308 assert_eq!(deps.len(), 2);
309 assert_eq!(deps[0].to_string(), "neorg==1.0.0");
310 assert_eq!(deps[1].to_string(), "foo==1.0.0");
311 }
312 _ => panic!("Expected test dependencies"),
313 }
314
315 eval_lua::<LuaDependencyType<LuaDependencySpec>>("return {}").unwrap_err();
316 }
317}