1use itertools::Itertools;
2use mlua::{ExternalResult, FromLua, IntoLua, LuaSerdeExt};
3use serde::{de, Deserialize, Deserializer, Serialize};
4use std::{cmp::Ordering, fmt::Display, str::FromStr};
5use thiserror::Error;
6
7mod outdated;
8mod version;
9
10pub use outdated::*;
11pub use version::{
12 PackageVersion, PackageVersionParseError, PackageVersionReq, PackageVersionReqError,
13 VersionReqToVersionError,
14};
15
16use crate::{
17 lockfile::RemotePackageSourceUrl,
18 lua_rockspec::{DisplayAsLuaKV, DisplayLuaKV, DisplayLuaValue},
19 remote_package_source::RemotePackageSource,
20 rockspec::lua_dependency::LuaDependencySpec,
21 variables::HasVariables,
22};
23
24#[derive(Debug, Error)]
25pub enum PackageSpecFromPackageReqError {
26 #[error("invalid version for rock {rock}: {err}")]
27 Version {
28 rock: PackageName,
29 err: VersionReqToVersionError,
30 },
31}
32
33#[derive(Clone, Debug)]
34#[cfg_attr(feature = "clap", derive(clap::Args))]
35pub struct PackageSpec {
36 name: PackageName,
37 version: PackageVersion,
38}
39
40impl PackageSpec {
41 pub fn new(name: PackageName, version: PackageVersion) -> Self {
42 Self { name, version }
43 }
44 pub fn parse(name: String, version: String) -> Result<Self, PackageVersionParseError> {
45 Ok(Self::new(
46 PackageName::new(name),
47 PackageVersion::parse(&version)?,
48 ))
49 }
50 pub fn name(&self) -> &PackageName {
51 &self.name
52 }
53 pub fn version(&self) -> &PackageVersion {
54 &self.version
55 }
56 pub fn into_package_req(self) -> PackageReq {
57 PackageReq {
58 name: self.name,
59 version_req: self.version.into_version_req(),
60 }
61 }
62}
63
64impl TryFrom<PackageReq> for PackageSpec {
65 type Error = PackageSpecFromPackageReqError;
66
67 fn try_from(value: PackageReq) -> Result<Self, Self::Error> {
68 let name = value.name;
69 let version = value.version_req.try_into().map_err(|err| {
70 PackageSpecFromPackageReqError::Version {
71 rock: name.clone(),
72 err,
73 }
74 })?;
75 Ok(Self { name, version })
76 }
77}
78
79impl FromLua for PackageSpec {
80 fn from_lua(
81 value: mlua::prelude::LuaValue,
82 lua: &mlua::prelude::Lua,
83 ) -> mlua::prelude::LuaResult<Self> {
84 let (name, version) = lua.from_value(value)?;
85
86 Self::parse(name, version).into_lua_err()
87 }
88}
89
90impl mlua::UserData for PackageSpec {
91 fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
92 fields.add_field_method_get("name", |_, this| Ok(this.name.to_string()));
93 fields.add_field_method_get("version", |_, this| Ok(this.version.to_string()));
94 }
95
96 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
97 methods.add_method("to_package_req", |_, this, ()| {
98 Ok(this.clone().into_package_req())
99 })
100 }
101}
102
103impl HasVariables for PackageSpec {
104 fn get_variable(&self, input: &str) -> Option<String> {
105 match input {
106 "PACKAGE" => Some(self.name.to_string()),
107 "VERSION" => Some(self.version.to_string()),
108 _ => None,
109 }
110 }
111}
112
113#[derive(Clone, Debug)]
114pub(crate) struct RemotePackage {
115 pub package: PackageSpec,
116 pub source: RemotePackageSource,
117 pub source_url: Option<RemotePackageSourceUrl>,
119}
120
121impl RemotePackage {
122 pub fn new(
123 package: PackageSpec,
124 source: RemotePackageSource,
125 source_url: Option<RemotePackageSourceUrl>,
126 ) -> Self {
127 Self {
128 package,
129 source,
130 source_url,
131 }
132 }
133}
134
135#[derive(PartialEq, Eq, Hash, Clone, Debug)]
136pub(crate) enum RemotePackageType {
137 Rockspec,
138 Src,
139 Binary,
140}
141
142impl Ord for RemotePackageType {
143 fn cmp(&self, other: &Self) -> Ordering {
144 match (self, other) {
146 (Self::Binary, Self::Binary)
147 | (Self::Rockspec, Self::Rockspec)
148 | (Self::Src, Self::Src) => Ordering::Equal,
149
150 (Self::Binary, _) => Ordering::Greater,
151 (_, Self::Binary) => Ordering::Less,
152 (Self::Rockspec, Self::Src) => Ordering::Greater,
153 (Self::Src, Self::Rockspec) => Ordering::Less,
154 }
155 }
156}
157
158impl PartialOrd for RemotePackageType {
159 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
160 Some(self.cmp(other))
161 }
162}
163
164#[derive(Clone)]
165pub struct RemotePackageTypeFilterSpec {
166 pub rockspec: bool,
168 pub src: bool,
170 pub binary: bool,
172}
173
174impl Default for RemotePackageTypeFilterSpec {
175 fn default() -> Self {
176 Self {
177 rockspec: true,
178 src: true,
179 binary: true,
180 }
181 }
182}
183
184#[derive(Error, Debug)]
185pub enum ParseRemotePackageError {
186 #[error("unable to parse package {0}. expected format: `name@version`")]
187 InvalidInput(String),
188 #[error("unable to parse package {package_str}: {error}")]
189 InvalidPackageVersion {
190 #[source]
191 error: PackageVersionParseError,
192 package_str: String,
193 },
194}
195
196impl FromStr for PackageSpec {
197 type Err = ParseRemotePackageError;
198
199 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
200 let (name, version) = s
201 .split_once('@')
202 .ok_or_else(|| ParseRemotePackageError::InvalidInput(s.to_string()))?;
203
204 Self::parse(name.to_string(), version.to_string()).map_err(|error| {
205 ParseRemotePackageError::InvalidPackageVersion {
206 error,
207 package_str: s.to_string(),
208 }
209 })
210 }
211}
212
213#[derive(Debug, Clone, PartialEq)]
215#[cfg_attr(feature = "clap", derive(clap::Args))]
216pub struct PackageReq {
217 pub(crate) name: PackageName,
219 pub(crate) version_req: PackageVersionReq,
221}
222
223impl PackageReq {
224 pub fn new(name: String, version: Option<String>) -> Result<Self, PackageVersionReqError> {
225 let version_req = match version {
226 Some(version_req_str) => PackageVersionReq::parse(version_req_str.as_str())?,
227 None => PackageVersionReq::any(),
228 };
229 Ok(Self {
230 name: PackageName::new(name),
231 version_req,
232 })
233 }
234 pub fn parse(pkg_constraints: &str) -> Result<Self, PackageReqParseError> {
235 Self::from_str(pkg_constraints)
236 }
237 pub fn name(&self) -> &PackageName {
238 &self.name
239 }
240 pub fn version_req(&self) -> &PackageVersionReq {
241 &self.version_req
242 }
243 pub fn matches(&self, package: &PackageSpec) -> bool {
246 self.name == package.name && self.version_req.matches(&package.version)
247 }
248}
249
250impl Display for PackageReq {
251 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252 if self.version_req.is_any() {
253 self.name.fmt(f)
254 } else {
255 f.write_str(format!("{}{}", self.name, self.version_req).as_str())
256 }
257 }
258}
259
260impl From<PackageSpec> for PackageReq {
261 fn from(value: PackageSpec) -> Self {
262 value.into_package_req()
263 }
264}
265
266impl From<PackageName> for PackageReq {
267 fn from(name: PackageName) -> Self {
268 Self {
269 name,
270 version_req: PackageVersionReq::any(),
271 }
272 }
273}
274
275impl FromLua for PackageReq {
276 fn from_lua(value: mlua::Value, lua: &mlua::Lua) -> mlua::Result<Self> {
277 let str: String = lua.from_value(value)?;
278 Self::parse(&str).into_lua_err()
279 }
280}
281
282impl mlua::UserData for PackageReq {
283 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
284 methods.add_method("name", |_, this, ()| Ok(this.name.to_string()));
285 methods.add_method("version_req", |_, this, ()| {
286 Ok(this.version_req.to_string())
287 });
288 methods.add_method("matches", |_, this, package: PackageSpec| {
289 Ok(this.matches(&package))
290 });
291 }
292}
293
294pub(crate) struct Dependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
296pub(crate) struct BuildDependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
297pub(crate) struct TestDependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
298
299impl DisplayAsLuaKV for Dependencies<'_> {
300 fn display_lua(&self) -> DisplayLuaKV {
301 DisplayLuaKV {
302 key: "dependencies".to_string(),
303 value: DisplayLuaValue::List(
304 self.0
305 .iter()
306 .map(|req| DisplayLuaValue::String(req.to_string()))
307 .collect(),
308 ),
309 }
310 }
311}
312
313impl DisplayAsLuaKV for BuildDependencies<'_> {
314 fn display_lua(&self) -> DisplayLuaKV {
315 DisplayLuaKV {
316 key: "build_dependencies".to_string(),
317 value: DisplayLuaValue::List(
318 self.0
319 .iter()
320 .map(|req| DisplayLuaValue::String(req.to_string()))
321 .collect(),
322 ),
323 }
324 }
325}
326
327impl DisplayAsLuaKV for TestDependencies<'_> {
328 fn display_lua(&self) -> DisplayLuaKV {
329 DisplayLuaKV {
330 key: "test_dependencies".to_string(),
331 value: DisplayLuaValue::List(
332 self.0
333 .iter()
334 .map(|req| DisplayLuaValue::String(req.to_string()))
335 .collect(),
336 ),
337 }
338 }
339}
340
341#[derive(Error, Debug)]
342pub enum PackageReqParseError {
343 #[error("could not parse dependency name from {0}")]
344 InvalidDependencyName(String),
345 #[error("could not parse version requirement in '{str}': {error}")]
346 InvalidPackageVersionReq {
347 #[source]
348 error: PackageVersionReqError,
349 str: String,
350 },
351}
352
353impl FromStr for PackageReq {
354 type Err = PackageReqParseError;
355
356 fn from_str(str: &str) -> Result<Self, PackageReqParseError> {
357 let rock_name_str = str
358 .chars()
359 .peeking_take_while(|t| t.is_alphanumeric() || matches!(t, '-' | '_' | '.'))
360 .collect::<String>();
361
362 if rock_name_str.is_empty() {
363 return Err(PackageReqParseError::InvalidDependencyName(str.to_string()));
364 }
365
366 let constraints = str.trim_start_matches(&rock_name_str).trim();
367 let version_req = match constraints {
368 "" => PackageVersionReq::any(),
369 constraints => PackageVersionReq::parse(constraints.trim_start()).map_err(|error| {
370 PackageReqParseError::InvalidPackageVersionReq {
371 error,
372 str: str.to_string(),
373 }
374 })?,
375 };
376 Ok(Self {
377 name: PackageName::new(rock_name_str),
378 version_req,
379 })
380 }
381}
382
383impl<'de> Deserialize<'de> for PackageReq {
384 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
385 where
386 D: Deserializer<'de>,
387 {
388 let s = String::deserialize(deserializer)?;
389 Self::from_str(&s).map_err(de::Error::custom)
390 }
391}
392
393#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash)]
395pub struct PackageName(String);
396
397impl IntoLua for PackageName {
398 fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
399 self.0.into_lua(lua)
400 }
401}
402
403impl PackageName {
404 pub fn new(name: String) -> Self {
405 Self(name.to_lowercase())
406 }
407}
408
409impl<'de> Deserialize<'de> for PackageName {
410 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
411 where
412 D: serde::Deserializer<'de>,
413 {
414 Ok(PackageName::new(String::deserialize(deserializer)?))
415 }
416}
417
418impl FromLua for PackageName {
419 fn from_lua(
420 value: mlua::prelude::LuaValue,
421 lua: &mlua::prelude::Lua,
422 ) -> mlua::prelude::LuaResult<Self> {
423 Ok(Self::new(String::from_lua(value, lua)?))
424 }
425}
426
427impl Serialize for PackageName {
428 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
429 where
430 S: serde::Serializer,
431 {
432 self.0.serialize(serializer)
433 }
434}
435
436impl From<&str> for PackageName {
437 fn from(value: &str) -> Self {
438 Self::new(value.into())
439 }
440}
441
442impl Display for PackageName {
443 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
444 f.write_str(self.0.as_str())
445 }
446}
447
448#[derive(Debug)]
449pub struct PackageNameList(Vec<PackageName>);
450
451impl PackageNameList {
452 pub(crate) fn new(package_names: Vec<PackageName>) -> Self {
453 Self(package_names)
454 }
455}
456
457impl Display for PackageNameList {
458 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
459 f.write_str(self.0.iter().map(|pkg| pkg.to_string()).join(", ").as_str())
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466
467 #[tokio::test]
468 async fn parse_name() {
469 let mut package_name: PackageName = "neorg".into();
470 assert_eq!(package_name.to_string(), "neorg");
471 package_name = "LuaFileSystem".into();
472 assert_eq!(package_name.to_string(), "luafilesystem");
473 }
474
475 #[tokio::test]
476 async fn parse_lua_package() {
477 let neorg = PackageSpec::parse("neorg".into(), "1.0.0".into()).unwrap();
478 let expected_version = PackageVersion::parse("1.0.0").unwrap();
479 assert_eq!(neorg.name().to_string(), "neorg");
480 assert!(matches!(
481 neorg.version().cmp(&expected_version),
482 std::cmp::Ordering::Equal
483 ));
484 let neorg = PackageSpec::parse("neorg".into(), "1.0".into()).unwrap();
485 assert!(matches!(
486 neorg.version().cmp(&expected_version),
487 std::cmp::Ordering::Equal
488 ));
489 let neorg = PackageSpec::parse("neorg".into(), "1".into()).unwrap();
490 assert!(matches!(
491 neorg.version().cmp(&expected_version),
492 std::cmp::Ordering::Equal
493 ));
494 }
495
496 #[tokio::test]
497 async fn parse_lua_package_req() {
498 let mut package_req = PackageReq::new("foo".into(), Some("1.0.0".into())).unwrap();
499 assert!(package_req.matches(&PackageSpec::parse("foo".into(), "1.0.0".into()).unwrap()));
500 assert!(!package_req.matches(&PackageSpec::parse("bar".into(), "1.0.0".into()).unwrap()));
501 assert!(!package_req.matches(&PackageSpec::parse("foo".into(), "2.0.0".into()).unwrap()));
502 package_req = PackageReq::new("foo".into(), Some(">= 1.0.0".into())).unwrap();
503 assert!(package_req.matches(&PackageSpec::parse("foo".into(), "2.0.0".into()).unwrap()));
504 let package_req: PackageReq = "lua >= 5.1".parse().unwrap();
505 assert_eq!(package_req.name.to_string(), "lua");
506 let package_req: PackageReq = "lua>=5.1".parse().unwrap();
507 assert_eq!(package_req.name.to_string(), "lua");
508 let package_req: PackageReq = "toml-edit >= 0.1.0".parse().unwrap();
509 assert_eq!(package_req.name.to_string(), "toml-edit");
510 let package_req: PackageReq = "plugin.nvim >= 0.1.0".parse().unwrap();
511 assert_eq!(package_req.name.to_string(), "plugin.nvim");
512 let package_req: PackageReq = "lfs".parse().unwrap();
513 assert_eq!(package_req.name.to_string(), "lfs");
514 let package_req: PackageReq = "neorg 1.0.0".parse().unwrap();
515 assert_eq!(package_req.name.to_string(), "neorg");
516 let neorg = PackageSpec::parse("neorg".into(), "1.0.0".into()).unwrap();
517 assert!(package_req.matches(&neorg));
518 let neorg = PackageSpec::parse("neorg".into(), "2.0.0".into()).unwrap();
519 assert!(!package_req.matches(&neorg));
520 let package_req: PackageReq = "neorg 2.0.0".parse().unwrap();
521 assert!(package_req.matches(&neorg));
522 let package_req: PackageReq = "neorg = 2.0.0".parse().unwrap();
523 assert!(package_req.matches(&neorg));
524 let package_req: PackageReq = "neorg == 2.0.0".parse().unwrap();
525 assert!(package_req.matches(&neorg));
526 let package_req: PackageReq = "neorg = 2.0.0".parse().unwrap();
527 assert!(package_req.matches(&neorg));
528 let package_req: PackageReq = "neorg >= 1.0, < 2.0".parse().unwrap();
529 let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
530 assert!(package_req.matches(&neorg));
531 let package_req: PackageReq = "neorg > 1.0, < 2.0".parse().unwrap();
532 let neorg = PackageSpec::parse("neorg".into(), "1.11.0".into()).unwrap();
533 assert!(package_req.matches(&neorg));
534 let neorg = PackageSpec::parse("neorg".into(), "3.0.0".into()).unwrap();
535 assert!(!package_req.matches(&neorg));
536 let neorg = PackageSpec::parse("neorg".into(), "0.5".into()).unwrap();
537 assert!(!package_req.matches(&neorg));
538 let package_req: PackageReq = "neorg ~> 1".parse().unwrap();
539 assert!(!package_req.matches(&neorg));
540 let neorg = PackageSpec::parse("neorg".into(), "3".into()).unwrap();
541 assert!(!package_req.matches(&neorg));
542 let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
543 assert!(package_req.matches(&neorg));
544 let package_req: PackageReq = "neorg ~> 1.4".parse().unwrap();
545 let neorg = PackageSpec::parse("neorg".into(), "1.3".into()).unwrap();
546 assert!(!package_req.matches(&neorg));
547 let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
548 assert!(!package_req.matches(&neorg));
549 let neorg = PackageSpec::parse("neorg".into(), "1.4.10".into()).unwrap();
550 assert!(package_req.matches(&neorg));
551 let neorg = PackageSpec::parse("neorg".into(), "1.4".into()).unwrap();
552 assert!(package_req.matches(&neorg));
553 let package_req: PackageReq = "neorg ~> 1.0.5".parse().unwrap();
554 let neorg = PackageSpec::parse("neorg".into(), "1.0.4".into()).unwrap();
555 assert!(!package_req.matches(&neorg));
556 let neorg = PackageSpec::parse("neorg".into(), "1.0.5".into()).unwrap();
557 assert!(package_req.matches(&neorg));
558 let neorg = PackageSpec::parse("neorg".into(), "1.0.6".into()).unwrap();
559 assert!(!package_req.matches(&neorg));
560 let package_req: PackageReq = "lua-utils.nvim ~> 1.1-1".parse().unwrap();
562 let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.1.4".into()).unwrap();
563 assert!(package_req.matches(&lua_utils));
564 let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.1.5".into()).unwrap();
565 assert!(package_req.matches(&lua_utils));
566 let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.2-1".into()).unwrap();
567 assert!(!package_req.matches(&lua_utils));
568 }
569
570 #[tokio::test]
571 pub async fn remote_package_type_priorities() {
572 let rock_types = vec![
573 RemotePackageType::Binary,
574 RemotePackageType::Src,
575 RemotePackageType::Rockspec,
576 ];
577 assert_eq!(
578 rock_types.into_iter().sorted().collect_vec(),
579 vec![
580 RemotePackageType::Src,
581 RemotePackageType::Rockspec,
582 RemotePackageType::Binary,
583 ]
584 );
585 }
586}