Skip to main content

lux_lib/rockspec/
lua_dependency.rs

1use std::{collections::HashMap, convert::Infallible, fmt::Display, str::FromStr};
2
3use serde::{Deserialize, Deserializer};
4use thiserror::Error;
5
6use crate::{
7    lockfile::{OptState, PinnedState},
8    lua_rockspec::{
9        ExternalDependencySpec, PartialOverride, PerPlatform, PlatformOverridable, RockSourceSpec,
10    },
11    package::{PackageName, PackageReq, PackageReqParseError, PackageSpec, PackageVersionReq},
12};
13
14#[derive(Error, Debug)]
15pub enum LuaDependencySpecParseError {
16    #[error(transparent)]
17    PackageReq(#[from] PackageReqParseError),
18}
19
20#[derive(Debug, Clone, PartialEq)]
21pub struct LuaDependencySpec {
22    pub(crate) package_req: PackageReq,
23    pub(crate) pin: PinnedState,
24    pub(crate) opt: OptState,
25    pub(crate) source: Option<RockSourceSpec>,
26}
27
28impl LuaDependencySpec {
29    pub fn package_req(&self) -> &PackageReq {
30        &self.package_req
31    }
32    pub fn pin(&self) -> &PinnedState {
33        &self.pin
34    }
35    pub fn opt(&self) -> &OptState {
36        &self.opt
37    }
38    pub fn source(&self) -> &Option<RockSourceSpec> {
39        &self.source
40    }
41    pub fn into_package_req(self) -> PackageReq {
42        self.package_req
43    }
44    pub fn name(&self) -> &PackageName {
45        self.package_req.name()
46    }
47    pub fn version_req(&self) -> &PackageVersionReq {
48        self.package_req.version_req()
49    }
50    pub fn matches(&self, package: &PackageSpec) -> bool {
51        self.package_req.matches(package)
52    }
53}
54
55impl From<PackageName> for LuaDependencySpec {
56    fn from(name: PackageName) -> Self {
57        Self {
58            package_req: PackageReq::from(name),
59            pin: PinnedState::default(),
60            opt: OptState::default(),
61            source: None,
62        }
63    }
64}
65
66impl From<PackageReq> for LuaDependencySpec {
67    fn from(package_req: PackageReq) -> Self {
68        Self {
69            package_req,
70            pin: PinnedState::default(),
71            opt: OptState::default(),
72            source: None,
73        }
74    }
75}
76
77impl FromStr for LuaDependencySpec {
78    type Err = LuaDependencySpecParseError;
79
80    fn from_str(str: &str) -> Result<Self, LuaDependencySpecParseError> {
81        let package_req = PackageReq::from_str(str)?;
82        Ok(Self {
83            package_req,
84            pin: PinnedState::default(),
85            opt: OptState::default(),
86            source: None,
87        })
88    }
89}
90
91impl Display for LuaDependencySpec {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        if self.version_req().is_any() {
94            self.name().fmt(f)
95        } else {
96            f.write_str(format!("{}{}", self.name(), self.version_req()).as_str())
97        }
98    }
99}
100
101/// Override `base_deps` with `override_deps`
102/// - Adds missing dependencies
103/// - Replaces dependencies with the same name
104impl PartialOverride for Vec<LuaDependencySpec> {
105    type Err = Infallible;
106
107    fn apply_overrides(&self, override_vec: &Self) -> Result<Self, Self::Err> {
108        let mut result_map: HashMap<String, LuaDependencySpec> = self
109            .iter()
110            .map(|dep| (dep.name().clone().to_string(), dep.clone()))
111            .collect();
112        for override_dep in override_vec {
113            result_map.insert(
114                override_dep.name().clone().to_string(),
115                override_dep.clone(),
116            );
117        }
118        Ok(result_map.into_values().collect())
119    }
120}
121
122impl PlatformOverridable for Vec<LuaDependencySpec> {
123    type Err = Infallible;
124
125    fn on_nil<T>() -> Result<super::PerPlatform<T>, <Self as PlatformOverridable>::Err>
126    where
127        T: PlatformOverridable,
128        T: Default,
129    {
130        Ok(PerPlatform::default())
131    }
132}
133
134impl<'de> Deserialize<'de> for LuaDependencySpec {
135    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
136    where
137        D: Deserializer<'de>,
138    {
139        let package_req = PackageReq::deserialize(deserializer)?;
140        Ok(Self {
141            package_req,
142            pin: PinnedState::default(),
143            opt: OptState::default(),
144            source: None,
145        })
146    }
147}
148
149#[derive(Debug, Deserialize)]
150#[serde(rename_all = "lowercase")]
151pub enum DependencyType<T> {
152    Regular(Vec<T>),
153    Build(Vec<T>),
154    Test(Vec<T>),
155    External(HashMap<String, ExternalDependencySpec>),
156}
157
158#[derive(Debug, Deserialize)]
159#[serde(rename_all = "lowercase")]
160pub enum LuaDependencyType<T> {
161    Regular(Vec<T>),
162    Build(Vec<T>),
163    Test(Vec<T>),
164}
165
166#[cfg(test)]
167mod test {
168
169    use ottavino::{Closure, Executor, Fuel, Lua, Value};
170    use ottavino_util::serde::from_value;
171    use path_slash::PathBufExt;
172
173    use super::*;
174
175    fn eval_lua<T: serde::de::DeserializeOwned>(code: &str) -> Result<T, ottavino::ExternError> {
176        Lua::core().try_enter(|ctx| {
177            let closure = Closure::load(ctx, None, code.as_bytes())?;
178            let executor = Executor::start(ctx, closure.into(), ());
179            executor.step(ctx, &mut Fuel::with(i32::MAX))?;
180            from_value(executor.take_result::<Value<'_>>(ctx)??).map_err(ottavino::Error::from)
181        })
182    }
183
184    #[tokio::test]
185    async fn test_override_lua_dependency_spec() {
186        let neorg_a: LuaDependencySpec = "neorg 1.0.0".parse().unwrap();
187        let neorg_b: LuaDependencySpec = "neorg 2.0.0".parse().unwrap();
188        let foo: LuaDependencySpec = "foo 1.0.0".parse().unwrap();
189        let bar: LuaDependencySpec = "bar 1.0.0".parse().unwrap();
190        let base_vec = vec![neorg_a, foo.clone()];
191        let override_vec = vec![neorg_b.clone(), bar.clone()];
192        let result = base_vec.apply_overrides(&override_vec).unwrap();
193        assert_eq!(result.clone().len(), 3);
194        assert_eq!(
195            result
196                .into_iter()
197                .filter(|dep| *dep == neorg_b || *dep == foo || *dep == bar)
198                .count(),
199            3
200        );
201    }
202
203    #[test]
204    fn test_dependency_type_from_lua() {
205        let regular_deps: DependencyType<LuaDependencySpec> =
206            eval_lua(r#"return { regular = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
207        let build_deps: DependencyType<LuaDependencySpec> =
208            eval_lua(r#"return { build = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
209        let test_deps: DependencyType<LuaDependencySpec> =
210            eval_lua(r#"return { test = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
211        let external_deps: DependencyType<ExternalDependencySpec> = eval_lua(
212            r#"return { external = { foo = { header = "foo.h", library = "libfoo.so" }, bar = { header = "bar.h" } } }"#,
213        )
214        .unwrap();
215
216        match regular_deps {
217            DependencyType::Regular(deps) => {
218                assert_eq!(deps.len(), 2);
219                assert_eq!(deps[0].to_string(), "neorg==1.0.0");
220                assert_eq!(deps[1].to_string(), "foo==1.0.0");
221            }
222            _ => panic!("Expected regular dependencies"),
223        }
224
225        match build_deps {
226            DependencyType::Build(deps) => {
227                assert_eq!(deps.len(), 2);
228                assert_eq!(deps[0].to_string(), "neorg==1.0.0");
229                assert_eq!(deps[1].to_string(), "foo==1.0.0");
230            }
231            _ => panic!("Expected build dependencies"),
232        }
233
234        match test_deps {
235            DependencyType::Test(deps) => {
236                assert_eq!(deps.len(), 2);
237                assert_eq!(deps[0].to_string(), "neorg==1.0.0");
238                assert_eq!(deps[1].to_string(), "foo==1.0.0");
239            }
240            _ => panic!("Expected test dependencies"),
241        }
242
243        match external_deps {
244            DependencyType::External(deps) => {
245                assert_eq!(deps.len(), 2);
246                assert_eq!(
247                    deps["foo"].header.as_ref().unwrap().to_slash_lossy(),
248                    "foo.h"
249                );
250                assert_eq!(
251                    deps["foo"].library.as_ref().unwrap().to_slash_lossy(),
252                    "libfoo.so"
253                );
254
255                assert_eq!(
256                    deps["bar"].header.as_ref().unwrap().to_slash_lossy(),
257                    "bar.h"
258                );
259                assert!(deps["bar"].library.is_none());
260            }
261            _ => panic!("Expected external dependencies"),
262        }
263
264        let _err: ottavino::ExternError =
265            eval_lua::<DependencyType<ExternalDependencySpec>>("return {}").unwrap_err();
266    }
267
268    #[test]
269    fn test_lua_dependency_type_from_lua() {
270        let regular_deps: LuaDependencyType<LuaDependencySpec> =
271            eval_lua(r#"return { regular = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
272        let build_deps: LuaDependencyType<LuaDependencySpec> =
273            eval_lua(r#"return { build = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
274        let test_deps: LuaDependencyType<LuaDependencySpec> =
275            eval_lua(r#"return { test = {"neorg 1.0.0", "foo 1.0.0"} }"#).unwrap();
276
277        match regular_deps {
278            LuaDependencyType::Regular(deps) => {
279                assert_eq!(deps.len(), 2);
280                assert_eq!(deps[0].to_string(), "neorg==1.0.0");
281                assert_eq!(deps[1].to_string(), "foo==1.0.0");
282            }
283            _ => panic!("Expected regular dependencies"),
284        }
285
286        match build_deps {
287            LuaDependencyType::Build(deps) => {
288                assert_eq!(deps.len(), 2);
289                assert_eq!(deps[0].to_string(), "neorg==1.0.0");
290                assert_eq!(deps[1].to_string(), "foo==1.0.0");
291            }
292            _ => panic!("Expected build dependencies"),
293        }
294
295        match test_deps {
296            LuaDependencyType::Test(deps) => {
297                assert_eq!(deps.len(), 2);
298                assert_eq!(deps[0].to_string(), "neorg==1.0.0");
299                assert_eq!(deps[1].to_string(), "foo==1.0.0");
300            }
301            _ => panic!("Expected test dependencies"),
302        }
303
304        eval_lua::<LuaDependencyType<LuaDependencySpec>>("return {}").unwrap_err();
305    }
306}