1use std::{collections::HashMap, convert::Infallible, fmt::Display, str::FromStr};
2
3use mlua::{FromLua, IntoLua, LuaSerdeExt};
4use serde::{Deserialize, Deserializer};
5use thiserror::Error;
6
7use crate::{
8 lockfile::{OptState, PinnedState},
9 lua_rockspec::{
10 ExternalDependencySpec, PartialOverride, PerPlatform, PlatformOverridable, RockSourceSpec,
11 },
12 package::{PackageName, PackageReq, PackageReqParseError, PackageSpec, PackageVersionReq},
13};
14
15#[derive(Error, Debug)]
16pub enum LuaDependencySpecParseError {
17 #[error(transparent)]
18 PackageReq(#[from] PackageReqParseError),
19}
20
21#[derive(Debug, Clone, PartialEq)]
22pub struct LuaDependencySpec {
23 pub(crate) package_req: PackageReq,
24 pub(crate) pin: PinnedState,
25 pub(crate) opt: OptState,
26 pub(crate) source: Option<RockSourceSpec>,
27}
28
29impl LuaDependencySpec {
30 pub fn package_req(&self) -> &PackageReq {
31 &self.package_req
32 }
33 pub fn pin(&self) -> &PinnedState {
34 &self.pin
35 }
36 pub fn opt(&self) -> &OptState {
37 &self.opt
38 }
39 pub fn source(&self) -> &Option<RockSourceSpec> {
40 &self.source
41 }
42 pub fn into_package_req(self) -> PackageReq {
43 self.package_req
44 }
45 pub fn name(&self) -> &PackageName {
46 self.package_req.name()
47 }
48 pub fn version_req(&self) -> &PackageVersionReq {
49 self.package_req.version_req()
50 }
51 pub fn matches(&self, package: &PackageSpec) -> bool {
52 self.package_req.matches(package)
53 }
54}
55
56impl From<PackageName> for LuaDependencySpec {
57 fn from(name: PackageName) -> Self {
58 Self {
59 package_req: PackageReq::from(name),
60 pin: PinnedState::default(),
61 opt: OptState::default(),
62 source: None,
63 }
64 }
65}
66
67impl From<PackageReq> for LuaDependencySpec {
68 fn from(package_req: PackageReq) -> Self {
69 Self {
70 package_req,
71 pin: PinnedState::default(),
72 opt: OptState::default(),
73 source: None,
74 }
75 }
76}
77
78impl FromStr for LuaDependencySpec {
79 type Err = LuaDependencySpecParseError;
80
81 fn from_str(str: &str) -> Result<Self, LuaDependencySpecParseError> {
82 let package_req = PackageReq::from_str(str)?;
83 Ok(Self {
84 package_req,
85 pin: PinnedState::default(),
86 opt: OptState::default(),
87 source: None,
88 })
89 }
90}
91
92impl Display for LuaDependencySpec {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 if self.version_req().is_any() {
95 self.name().fmt(f)
96 } else {
97 f.write_str(format!("{} {}", self.name(), self.version_req()).as_str())
98 }
99 }
100}
101
102impl PartialOverride for Vec<LuaDependencySpec> {
106 type Err = Infallible;
107
108 fn apply_overrides(&self, override_vec: &Self) -> Result<Self, Self::Err> {
109 let mut result_map: HashMap<String, LuaDependencySpec> = self
110 .iter()
111 .map(|dep| (dep.name().clone().to_string(), dep.clone()))
112 .collect();
113 for override_dep in override_vec {
114 result_map.insert(
115 override_dep.name().clone().to_string(),
116 override_dep.clone(),
117 );
118 }
119 Ok(result_map.into_values().collect())
120 }
121}
122
123impl PlatformOverridable for Vec<LuaDependencySpec> {
124 type Err = Infallible;
125
126 fn on_nil<T>() -> Result<super::PerPlatform<T>, <Self as PlatformOverridable>::Err>
127 where
128 T: PlatformOverridable,
129 T: Default,
130 {
131 Ok(PerPlatform::default())
132 }
133}
134
135impl FromLua for LuaDependencySpec {
136 fn from_lua(value: mlua::Value, lua: &mlua::Lua) -> mlua::Result<Self> {
137 let package_req = lua.from_value(value)?;
138 Ok(Self {
139 package_req,
140 pin: PinnedState::default(),
141 opt: OptState::default(),
142 source: None,
143 })
144 }
145}
146
147impl<'de> Deserialize<'de> for LuaDependencySpec {
148 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
149 where
150 D: Deserializer<'de>,
151 {
152 let package_req = PackageReq::deserialize(deserializer)?;
153 Ok(Self {
154 package_req,
155 pin: PinnedState::default(),
156 opt: OptState::default(),
157 source: None,
158 })
159 }
160}
161
162impl mlua::UserData for LuaDependencySpec {
163 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
164 methods.add_method("name", |_, this, ()| Ok(this.name().to_string()));
165 methods.add_method("version_req", |_, this, ()| {
166 Ok(this.version_req().to_string())
167 });
168 methods.add_method("matches", |_, this, package: PackageSpec| {
169 Ok(this.matches(&package))
170 });
171 methods.add_method("package_req", |_, this, ()| Ok(this.package_req().clone()));
172 }
173}
174
175pub enum DependencyType<T> {
176 Regular(Vec<T>),
177 Build(Vec<T>),
178 Test(Vec<T>),
179 External(HashMap<String, ExternalDependencySpec>),
180}
181
182impl<T> IntoLua for DependencyType<T>
183where
184 T: IntoLua,
185{
186 fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
187 let table = lua.create_table()?;
188
189 match self {
190 DependencyType::Regular(deps) => {
191 table.set("regular", deps)?;
192 }
193 DependencyType::Build(deps) => {
194 table.set("build", deps)?;
195 }
196 DependencyType::Test(deps) => {
197 table.set("test", deps)?;
198 }
199 DependencyType::External(deps) => {
200 table.set("external", deps)?;
201 }
202 }
203
204 Ok(mlua::Value::Table(table))
205 }
206}
207
208impl<T> FromLua for DependencyType<T>
209where
210 T: FromLua,
211{
212 fn from_lua(value: mlua::Value, _lua: &mlua::Lua) -> mlua::Result<Self> {
213 let tbl = value
214 .as_table()
215 .ok_or(mlua::Error::FromLuaConversionError {
216 from: "Value",
217 to: "table".to_string(),
218 message: Some("Expected a table".to_string()),
219 })?;
220
221 let deps = {
222 if let Some(regular) = tbl.get("regular")? {
223 DependencyType::Regular(regular)
224 } else if let Some(build) = tbl.get("build")? {
225 DependencyType::Build(build)
226 } else if let Some(test) = tbl.get("test")? {
227 DependencyType::Test(test)
228 } else if let Some(external) = tbl.get("external")? {
229 DependencyType::External(external)
230 } else {
231 return Err(mlua::Error::FromLuaConversionError {
232 from: "table",
233 to: "DependencyType".to_string(),
234 message: Some(
235 "expected a table with `regular`, `build`, `test` or `external`"
236 .to_string(),
237 ),
238 });
239 }
240 };
241
242 Ok(deps)
243 }
244}
245
246pub enum LuaDependencyType<T> {
247 Regular(Vec<T>),
248 Build(Vec<T>),
249 Test(Vec<T>),
250}
251
252impl<T> IntoLua for LuaDependencyType<T>
253where
254 T: IntoLua,
255{
256 fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
257 let table = lua.create_table()?;
258
259 match self {
260 LuaDependencyType::Regular(deps) => {
261 table.set("regular", deps)?;
262 }
263 LuaDependencyType::Build(deps) => {
264 table.set("build", deps)?;
265 }
266 LuaDependencyType::Test(deps) => {
267 table.set("test", deps)?;
268 }
269 }
270
271 Ok(mlua::Value::Table(table))
272 }
273}
274
275impl<T> FromLua for LuaDependencyType<T>
276where
277 T: FromLua,
278{
279 fn from_lua(value: mlua::Value, _lua: &mlua::Lua) -> mlua::Result<Self> {
280 let tbl = value
281 .as_table()
282 .ok_or(mlua::Error::FromLuaConversionError {
283 from: "Value",
284 to: "table".to_string(),
285 message: Some("Expected a table".to_string()),
286 })?;
287
288 let deps = {
289 if let Some(regular) = tbl.get("regular")? {
290 LuaDependencyType::Regular(regular)
291 } else if let Some(build) = tbl.get("build")? {
292 LuaDependencyType::Build(build)
293 } else if let Some(test) = tbl.get("test")? {
294 LuaDependencyType::Test(test)
295 } else {
296 return Err(mlua::Error::FromLuaConversionError {
297 from: "table",
298 to: "LuaDependencyType".to_string(),
299 message: Some("expected a table with `regular`, `build` or `test`".to_string()),
300 });
301 }
302 };
303
304 Ok(deps)
305 }
306}
307
308#[cfg(test)]
309mod test {
310
311 use super::*;
312
313 #[tokio::test]
314 async fn test_override_lua_dependency_spec() {
315 let neorg_a: LuaDependencySpec = "neorg 1.0.0".parse().unwrap();
316 let neorg_b: LuaDependencySpec = "neorg 2.0.0".parse().unwrap();
317 let foo: LuaDependencySpec = "foo 1.0.0".parse().unwrap();
318 let bar: LuaDependencySpec = "bar 1.0.0".parse().unwrap();
319 let base_vec = vec![neorg_a, foo.clone()];
320 let override_vec = vec![neorg_b.clone(), bar.clone()];
321 let result = base_vec.apply_overrides(&override_vec).unwrap();
322 assert_eq!(result.clone().len(), 3);
323 assert_eq!(
324 result
325 .into_iter()
326 .filter(|dep| *dep == neorg_b || *dep == foo || *dep == bar)
327 .count(),
328 3
329 );
330 }
331}