lux_lib/rockspec/
lua_dependency.rs

1use std::{collections::HashMap, convert::Infallible, fmt::Display, str::FromStr};
2
3use mlua::{FromLua, IntoLua, LuaSerdeExt};
4use serde::{Deserialize, Deserializer};
5use thiserror::Error;
6
7use crate::{
8    lockfile::{OptState, PinnedState},
9    lua_rockspec::{
10        ExternalDependencySpec, PartialOverride, PerPlatform, PlatformOverridable, RockSourceSpec,
11    },
12    package::{PackageName, PackageReq, PackageReqParseError, PackageSpec, PackageVersionReq},
13};
14
15#[derive(Error, Debug)]
16pub enum LuaDependencySpecParseError {
17    #[error(transparent)]
18    PackageReq(#[from] PackageReqParseError),
19}
20
21#[derive(Debug, Clone, PartialEq)]
22pub struct LuaDependencySpec {
23    pub(crate) package_req: PackageReq,
24    pub(crate) pin: PinnedState,
25    pub(crate) opt: OptState,
26    pub(crate) source: Option<RockSourceSpec>,
27}
28
29impl LuaDependencySpec {
30    pub fn package_req(&self) -> &PackageReq {
31        &self.package_req
32    }
33    pub fn pin(&self) -> &PinnedState {
34        &self.pin
35    }
36    pub fn opt(&self) -> &OptState {
37        &self.opt
38    }
39    pub fn source(&self) -> &Option<RockSourceSpec> {
40        &self.source
41    }
42    pub fn into_package_req(self) -> PackageReq {
43        self.package_req
44    }
45    pub fn name(&self) -> &PackageName {
46        self.package_req.name()
47    }
48    pub fn version_req(&self) -> &PackageVersionReq {
49        self.package_req.version_req()
50    }
51    pub fn matches(&self, package: &PackageSpec) -> bool {
52        self.package_req.matches(package)
53    }
54}
55
56impl From<PackageName> for LuaDependencySpec {
57    fn from(name: PackageName) -> Self {
58        Self {
59            package_req: PackageReq::from(name),
60            pin: PinnedState::default(),
61            opt: OptState::default(),
62            source: None,
63        }
64    }
65}
66
67impl From<PackageReq> for LuaDependencySpec {
68    fn from(package_req: PackageReq) -> Self {
69        Self {
70            package_req,
71            pin: PinnedState::default(),
72            opt: OptState::default(),
73            source: None,
74        }
75    }
76}
77
78impl FromStr for LuaDependencySpec {
79    type Err = LuaDependencySpecParseError;
80
81    fn from_str(str: &str) -> Result<Self, LuaDependencySpecParseError> {
82        let package_req = PackageReq::from_str(str)?;
83        Ok(Self {
84            package_req,
85            pin: PinnedState::default(),
86            opt: OptState::default(),
87            source: None,
88        })
89    }
90}
91
92impl Display for LuaDependencySpec {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        if self.version_req().is_any() {
95            self.name().fmt(f)
96        } else {
97            f.write_str(format!("{} {}", self.name(), self.version_req()).as_str())
98        }
99    }
100}
101
102/// Override `base_deps` with `override_deps`
103/// - Adds missing dependencies
104/// - Replaces dependencies with the same name
105impl PartialOverride for Vec<LuaDependencySpec> {
106    type Err = Infallible;
107
108    fn apply_overrides(&self, override_vec: &Self) -> Result<Self, Self::Err> {
109        let mut result_map: HashMap<String, LuaDependencySpec> = self
110            .iter()
111            .map(|dep| (dep.name().clone().to_string(), dep.clone()))
112            .collect();
113        for override_dep in override_vec {
114            result_map.insert(
115                override_dep.name().clone().to_string(),
116                override_dep.clone(),
117            );
118        }
119        Ok(result_map.into_values().collect())
120    }
121}
122
123impl PlatformOverridable for Vec<LuaDependencySpec> {
124    type Err = Infallible;
125
126    fn on_nil<T>() -> Result<super::PerPlatform<T>, <Self as PlatformOverridable>::Err>
127    where
128        T: PlatformOverridable,
129        T: Default,
130    {
131        Ok(PerPlatform::default())
132    }
133}
134
135impl FromLua for LuaDependencySpec {
136    fn from_lua(value: mlua::Value, lua: &mlua::Lua) -> mlua::Result<Self> {
137        let package_req = lua.from_value(value)?;
138        Ok(Self {
139            package_req,
140            pin: PinnedState::default(),
141            opt: OptState::default(),
142            source: None,
143        })
144    }
145}
146
147impl<'de> Deserialize<'de> for LuaDependencySpec {
148    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
149    where
150        D: Deserializer<'de>,
151    {
152        let package_req = PackageReq::deserialize(deserializer)?;
153        Ok(Self {
154            package_req,
155            pin: PinnedState::default(),
156            opt: OptState::default(),
157            source: None,
158        })
159    }
160}
161
162impl mlua::UserData for LuaDependencySpec {
163    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
164        methods.add_method("name", |_, this, ()| Ok(this.name().to_string()));
165        methods.add_method("version_req", |_, this, ()| {
166            Ok(this.version_req().to_string())
167        });
168        methods.add_method("matches", |_, this, package: PackageSpec| {
169            Ok(this.matches(&package))
170        });
171        methods.add_method("package_req", |_, this, ()| Ok(this.package_req().clone()));
172    }
173}
174
175pub enum DependencyType<T> {
176    Regular(Vec<T>),
177    Build(Vec<T>),
178    Test(Vec<T>),
179    External(HashMap<String, ExternalDependencySpec>),
180}
181
182impl<T> IntoLua for DependencyType<T>
183where
184    T: IntoLua,
185{
186    fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
187        let table = lua.create_table()?;
188
189        match self {
190            DependencyType::Regular(deps) => {
191                table.set("regular", deps)?;
192            }
193            DependencyType::Build(deps) => {
194                table.set("build", deps)?;
195            }
196            DependencyType::Test(deps) => {
197                table.set("test", deps)?;
198            }
199            DependencyType::External(deps) => {
200                table.set("external", deps)?;
201            }
202        }
203
204        Ok(mlua::Value::Table(table))
205    }
206}
207
208impl<T> FromLua for DependencyType<T>
209where
210    T: FromLua,
211{
212    fn from_lua(value: mlua::Value, _lua: &mlua::Lua) -> mlua::Result<Self> {
213        let tbl = value
214            .as_table()
215            .ok_or(mlua::Error::FromLuaConversionError {
216                from: "Value",
217                to: "table".to_string(),
218                message: Some("Expected a table".to_string()),
219            })?;
220
221        let deps = {
222            if let Some(regular) = tbl.get("regular")? {
223                DependencyType::Regular(regular)
224            } else if let Some(build) = tbl.get("build")? {
225                DependencyType::Build(build)
226            } else if let Some(test) = tbl.get("test")? {
227                DependencyType::Test(test)
228            } else if let Some(external) = tbl.get("external")? {
229                DependencyType::External(external)
230            } else {
231                return Err(mlua::Error::FromLuaConversionError {
232                    from: "table",
233                    to: "DependencyType".to_string(),
234                    message: Some(
235                        "expected a table with `regular`, `build`, `test` or `external`"
236                            .to_string(),
237                    ),
238                });
239            }
240        };
241
242        Ok(deps)
243    }
244}
245
246pub enum LuaDependencyType<T> {
247    Regular(Vec<T>),
248    Build(Vec<T>),
249    Test(Vec<T>),
250}
251
252impl<T> IntoLua for LuaDependencyType<T>
253where
254    T: IntoLua,
255{
256    fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
257        let table = lua.create_table()?;
258
259        match self {
260            LuaDependencyType::Regular(deps) => {
261                table.set("regular", deps)?;
262            }
263            LuaDependencyType::Build(deps) => {
264                table.set("build", deps)?;
265            }
266            LuaDependencyType::Test(deps) => {
267                table.set("test", deps)?;
268            }
269        }
270
271        Ok(mlua::Value::Table(table))
272    }
273}
274
275impl<T> FromLua for LuaDependencyType<T>
276where
277    T: FromLua,
278{
279    fn from_lua(value: mlua::Value, _lua: &mlua::Lua) -> mlua::Result<Self> {
280        let tbl = value
281            .as_table()
282            .ok_or(mlua::Error::FromLuaConversionError {
283                from: "Value",
284                to: "table".to_string(),
285                message: Some("Expected a table".to_string()),
286            })?;
287
288        let deps = {
289            if let Some(regular) = tbl.get("regular")? {
290                LuaDependencyType::Regular(regular)
291            } else if let Some(build) = tbl.get("build")? {
292                LuaDependencyType::Build(build)
293            } else if let Some(test) = tbl.get("test")? {
294                LuaDependencyType::Test(test)
295            } else {
296                return Err(mlua::Error::FromLuaConversionError {
297                    from: "table",
298                    to: "LuaDependencyType".to_string(),
299                    message: Some("expected a table with `regular`, `build` or `test`".to_string()),
300                });
301            }
302        };
303
304        Ok(deps)
305    }
306}
307
308#[cfg(test)]
309mod test {
310
311    use super::*;
312
313    #[tokio::test]
314    async fn test_override_lua_dependency_spec() {
315        let neorg_a: LuaDependencySpec = "neorg 1.0.0".parse().unwrap();
316        let neorg_b: LuaDependencySpec = "neorg 2.0.0".parse().unwrap();
317        let foo: LuaDependencySpec = "foo 1.0.0".parse().unwrap();
318        let bar: LuaDependencySpec = "bar 1.0.0".parse().unwrap();
319        let base_vec = vec![neorg_a, foo.clone()];
320        let override_vec = vec![neorg_b.clone(), bar.clone()];
321        let result = base_vec.apply_overrides(&override_vec).unwrap();
322        assert_eq!(result.clone().len(), 3);
323        assert_eq!(
324            result
325                .into_iter()
326                .filter(|dep| *dep == neorg_b || *dep == foo || *dep == bar)
327                .count(),
328            3
329        );
330    }
331}