1use itertools::Itertools;
2use mlua::{ExternalResult, FromLua, IntoLua, LuaSerdeExt};
3use serde::{de, Deserialize, Deserializer, Serialize};
4use std::{cmp::Ordering, fmt::Display, str::FromStr};
5use thiserror::Error;
6
7mod outdated;
8mod version;
9
10pub use outdated::*;
11pub use version::{
12 PackageVersion, PackageVersionParseError, PackageVersionReq, PackageVersionReqError, SpecRev,
13 VersionReqToVersionError,
14};
15
16pub(crate) use version::SpecRevIterator;
17
18use crate::{
19 lockfile::RemotePackageSourceUrl,
20 lua_rockspec::{DisplayAsLuaKV, DisplayLuaKV, DisplayLuaValue},
21 package::version::HasModRev,
22 remote_package_source::RemotePackageSource,
23 rockspec::lua_dependency::LuaDependencySpec,
24 variables::{GetVariableError, HasVariables},
25};
26
27#[derive(Debug, Error)]
28pub enum PackageSpecFromPackageReqError {
29 #[error("invalid version for rock {rock}: {err}")]
30 Version {
31 rock: PackageName,
32 err: VersionReqToVersionError,
33 },
34}
35
36#[derive(Clone, Debug, PartialEq, Eq)]
37#[cfg_attr(feature = "clap", derive(clap::Args))]
38pub struct PackageSpec {
39 name: PackageName,
40 version: PackageVersion,
41}
42
43impl PackageSpec {
44 pub fn new(name: PackageName, version: PackageVersion) -> Self {
45 Self { name, version }
46 }
47 pub fn parse(name: String, version: String) -> Result<Self, PackageVersionParseError> {
48 Ok(Self::new(
49 PackageName::new(name),
50 PackageVersion::parse(&version)?,
51 ))
52 }
53 pub fn name(&self) -> &PackageName {
54 &self.name
55 }
56 pub fn version(&self) -> &PackageVersion {
57 &self.version
58 }
59 pub fn into_package_req(self) -> PackageReq {
60 PackageReq {
61 name: self.name,
62 version_req: self.version.into_version_req(),
63 }
64 }
65}
66
67impl TryFrom<PackageReq> for PackageSpec {
68 type Error = PackageSpecFromPackageReqError;
69
70 fn try_from(value: PackageReq) -> Result<Self, Self::Error> {
71 let name = value.name;
72 let version = value.version_req.try_into().map_err(|err| {
73 PackageSpecFromPackageReqError::Version {
74 rock: name.clone(),
75 err,
76 }
77 })?;
78 Ok(Self { name, version })
79 }
80}
81
82impl FromLua for PackageSpec {
83 fn from_lua(
84 value: mlua::prelude::LuaValue,
85 lua: &mlua::prelude::Lua,
86 ) -> mlua::prelude::LuaResult<Self> {
87 let (name, version) = lua.from_value(value)?;
88
89 Self::parse(name, version).into_lua_err()
90 }
91}
92
93impl mlua::UserData for PackageSpec {
94 fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
95 fields.add_field_method_get("name", |_, this| Ok(this.name.to_string()));
96 fields.add_field_method_get("version", |_, this| Ok(this.version.to_string()));
97 }
98
99 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
100 methods.add_method("to_package_req", |_, this, ()| {
101 Ok(this.clone().into_package_req())
102 })
103 }
104}
105
106impl HasVariables for PackageSpec {
107 fn get_variable(&self, input: &str) -> Result<Option<String>, GetVariableError> {
108 Ok(match input {
109 "PACKAGE" => Some(self.name.to_string()),
110 "VERSION" => Some(self.version.to_modrev_string()),
111 _ => None,
112 })
113 }
114}
115
116#[derive(Clone, Debug)]
117pub(crate) struct RemotePackage {
118 pub package: PackageSpec,
119 pub source: RemotePackageSource,
120 pub source_url: Option<RemotePackageSourceUrl>,
122}
123
124impl RemotePackage {
125 pub fn new(
126 package: PackageSpec,
127 source: RemotePackageSource,
128 source_url: Option<RemotePackageSourceUrl>,
129 ) -> Self {
130 Self {
131 package,
132 source,
133 source_url,
134 }
135 }
136}
137
138#[derive(PartialEq, Eq, Hash, Clone, Debug)]
139pub(crate) enum RemotePackageType {
140 Rockspec,
141 Src,
142 Binary,
143}
144
145impl Ord for RemotePackageType {
146 fn cmp(&self, other: &Self) -> Ordering {
147 match (self, other) {
149 (Self::Binary, Self::Binary)
150 | (Self::Rockspec, Self::Rockspec)
151 | (Self::Src, Self::Src) => Ordering::Equal,
152
153 (Self::Binary, _) => Ordering::Greater,
154 (_, Self::Binary) => Ordering::Less,
155 (Self::Rockspec, Self::Src) => Ordering::Greater,
156 (Self::Src, Self::Rockspec) => Ordering::Less,
157 }
158 }
159}
160
161impl PartialOrd for RemotePackageType {
162 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
163 Some(self.cmp(other))
164 }
165}
166
167#[derive(Clone)]
168pub struct RemotePackageTypeFilterSpec {
169 pub rockspec: bool,
171 pub src: bool,
173 pub binary: bool,
175}
176
177impl Default for RemotePackageTypeFilterSpec {
178 fn default() -> Self {
179 Self {
180 rockspec: true,
181 src: true,
182 binary: true,
183 }
184 }
185}
186
187#[derive(Error, Debug)]
188pub enum ParseRemotePackageError {
189 #[error("unable to parse package {0}. expected format: `name@version`")]
190 InvalidInput(String),
191 #[error("unable to parse package {package_str}:\n{error}")]
192 InvalidPackageVersion {
193 #[source]
194 error: PackageVersionParseError,
195 package_str: String,
196 },
197}
198
199impl FromStr for PackageSpec {
200 type Err = ParseRemotePackageError;
201
202 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
203 let (name, version) = s
204 .split_once('@')
205 .ok_or_else(|| ParseRemotePackageError::InvalidInput(s.to_string()))?;
206
207 Self::parse(name.to_string(), version.to_string()).map_err(|error| {
208 ParseRemotePackageError::InvalidPackageVersion {
209 error,
210 package_str: s.to_string(),
211 }
212 })
213 }
214}
215
216#[derive(Debug, Clone, PartialEq, Eq, Hash)]
218#[cfg_attr(feature = "clap", derive(clap::Args))]
219pub struct PackageReq {
220 pub(crate) name: PackageName,
222 pub(crate) version_req: PackageVersionReq,
224}
225
226impl PackageReq {
227 pub fn new(name: String, version: Option<String>) -> Result<Self, PackageVersionReqError> {
228 let version_req = match version {
229 Some(version_req_str) => PackageVersionReq::parse(version_req_str.as_str())?,
230 None => PackageVersionReq::any(),
231 };
232 Ok(Self {
233 name: PackageName::new(name),
234 version_req,
235 })
236 }
237
238 pub unsafe fn new_unchecked(name: String, version: Option<String>) -> Self {
243 Self::new(name, version).unwrap_unchecked()
244 }
245
246 pub fn parse(pkg_constraints: &str) -> Result<Self, PackageReqParseError> {
247 Self::from_str(pkg_constraints)
248 }
249 pub fn name(&self) -> &PackageName {
250 &self.name
251 }
252 pub fn version_req(&self) -> &PackageVersionReq {
253 &self.version_req
254 }
255 pub fn matches(&self, package: &PackageSpec) -> bool {
258 self.name == package.name && self.version_req.matches(&package.version)
259 }
260}
261
262impl Display for PackageReq {
263 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264 if self.version_req.is_any() {
265 self.name.fmt(f)
266 } else {
267 f.write_str(format!("{}{}", self.name, self.version_req).as_str())
268 }
269 }
270}
271
272impl From<PackageSpec> for PackageReq {
273 fn from(value: PackageSpec) -> Self {
274 value.into_package_req()
275 }
276}
277
278impl From<PackageName> for PackageReq {
279 fn from(name: PackageName) -> Self {
280 Self {
281 name,
282 version_req: PackageVersionReq::any(),
283 }
284 }
285}
286
287impl FromLua for PackageReq {
288 fn from_lua(value: mlua::Value, lua: &mlua::Lua) -> mlua::Result<Self> {
289 let str: String = lua.from_value(value)?;
290 Self::parse(&str).into_lua_err()
291 }
292}
293
294impl mlua::UserData for PackageReq {
295 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
296 methods.add_method("name", |_, this, ()| Ok(this.name.to_string()));
297 methods.add_method("version_req", |_, this, ()| {
298 Ok(this.version_req.to_string())
299 });
300 methods.add_method("matches", |_, this, package: PackageSpec| {
301 Ok(this.matches(&package))
302 });
303 }
304}
305
306pub(crate) struct Dependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
308pub(crate) struct BuildDependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
309
310impl DisplayAsLuaKV for Dependencies<'_> {
311 fn display_lua(&self) -> DisplayLuaKV {
312 DisplayLuaKV {
313 key: "dependencies".to_string(),
314 value: DisplayLuaValue::List(
315 self.0
316 .iter()
317 .map(|req| DisplayLuaValue::String(req.to_string()))
318 .collect(),
319 ),
320 }
321 }
322}
323
324impl DisplayAsLuaKV for BuildDependencies<'_> {
325 fn display_lua(&self) -> DisplayLuaKV {
326 DisplayLuaKV {
327 key: "build_dependencies".to_string(),
328 value: DisplayLuaValue::List(
329 self.0
330 .iter()
331 .map(|req| DisplayLuaValue::String(req.to_string()))
332 .collect(),
333 ),
334 }
335 }
336}
337
338#[derive(Error, Debug)]
339pub enum PackageReqParseError {
340 #[error("could not parse dependency name from {0}")]
341 InvalidDependencyName(String),
342 #[error("could not parse version requirement in '{str}': {error}")]
343 InvalidPackageVersionReq {
344 #[source]
345 error: PackageVersionReqError,
346 str: String,
347 },
348}
349
350impl FromStr for PackageReq {
351 type Err = PackageReqParseError;
352
353 fn from_str(str: &str) -> Result<Self, PackageReqParseError> {
354 let rock_name_str = str
355 .chars()
356 .peeking_take_while(|t| t.is_alphanumeric() || matches!(t, '-' | '_' | '.'))
357 .collect::<String>();
358
359 if rock_name_str.is_empty() {
360 return Err(PackageReqParseError::InvalidDependencyName(str.to_string()));
361 }
362
363 let constraints = str.trim_start_matches(&rock_name_str).trim();
364 let version_req = match constraints {
365 "" => PackageVersionReq::any(),
366 constraints => PackageVersionReq::parse(constraints.trim_start()).map_err(|error| {
367 PackageReqParseError::InvalidPackageVersionReq {
368 error,
369 str: str.to_string(),
370 }
371 })?,
372 };
373 Ok(Self {
374 name: PackageName::new(rock_name_str),
375 version_req,
376 })
377 }
378}
379
380impl<'de> Deserialize<'de> for PackageReq {
381 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
382 where
383 D: Deserializer<'de>,
384 {
385 let s = String::deserialize(deserializer)?;
386 Self::from_str(&s).map_err(de::Error::custom)
387 }
388}
389
390#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash)]
392pub struct PackageName(String);
393
394impl IntoLua for PackageName {
395 fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
396 self.0.into_lua(lua)
397 }
398}
399
400impl PackageName {
401 pub fn new(name: String) -> Self {
402 Self(name.to_lowercase())
403 }
404}
405
406impl<'de> Deserialize<'de> for PackageName {
407 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
408 where
409 D: serde::Deserializer<'de>,
410 {
411 Ok(PackageName::new(String::deserialize(deserializer)?))
412 }
413}
414
415impl FromLua for PackageName {
416 fn from_lua(
417 value: mlua::prelude::LuaValue,
418 lua: &mlua::prelude::Lua,
419 ) -> mlua::prelude::LuaResult<Self> {
420 Ok(Self::new(String::from_lua(value, lua)?))
421 }
422}
423
424impl Serialize for PackageName {
425 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
426 where
427 S: serde::Serializer,
428 {
429 self.0.serialize(serializer)
430 }
431}
432
433impl From<&str> for PackageName {
434 fn from(value: &str) -> Self {
435 Self::new(value.into())
436 }
437}
438
439impl Display for PackageName {
440 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
441 f.write_str(self.0.as_str())
442 }
443}
444
445#[derive(Debug)]
446pub struct PackageNameList(Vec<PackageName>);
447
448impl PackageNameList {
449 pub(crate) fn new(package_names: Vec<PackageName>) -> Self {
450 Self(package_names)
451 }
452}
453
454impl Display for PackageNameList {
455 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
456 f.write_str(self.0.iter().map(|pkg| pkg.to_string()).join(", ").as_str())
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 #[tokio::test]
465 async fn parse_name() {
466 let mut package_name: PackageName = "neorg".into();
467 assert_eq!(package_name.to_string(), "neorg");
468 package_name = "LuaFileSystem".into();
469 assert_eq!(package_name.to_string(), "luafilesystem");
470 }
471
472 #[tokio::test]
473 async fn parse_lua_package() {
474 let neorg = PackageSpec::parse("neorg".into(), "1.0.0".into()).unwrap();
475 let expected_version = PackageVersion::parse("1.0.0").unwrap();
476 assert_eq!(neorg.name().to_string(), "neorg");
477 assert!(matches!(
478 neorg.version().cmp(&expected_version),
479 std::cmp::Ordering::Equal
480 ));
481 let neorg = PackageSpec::parse("neorg".into(), "1.0".into()).unwrap();
482 assert!(matches!(
483 neorg.version().cmp(&expected_version),
484 std::cmp::Ordering::Equal
485 ));
486 let neorg = PackageSpec::parse("neorg".into(), "1".into()).unwrap();
487 assert!(matches!(
488 neorg.version().cmp(&expected_version),
489 std::cmp::Ordering::Equal
490 ));
491 }
492
493 #[tokio::test]
494 async fn parse_lua_package_req() {
495 let mut package_req = PackageReq::new("foo".into(), Some("1.0.0".into())).unwrap();
496 assert!(package_req.matches(&PackageSpec::parse("foo".into(), "1.0.0".into()).unwrap()));
497 assert!(!package_req.matches(&PackageSpec::parse("bar".into(), "1.0.0".into()).unwrap()));
498 assert!(!package_req.matches(&PackageSpec::parse("foo".into(), "2.0.0".into()).unwrap()));
499 package_req = PackageReq::new("foo".into(), Some(">= 1.0.0".into())).unwrap();
500 assert!(package_req.matches(&PackageSpec::parse("foo".into(), "2.0.0".into()).unwrap()));
501 let package_req: PackageReq = "lua >= 5.1".parse().unwrap();
502 assert_eq!(package_req.name.to_string(), "lua");
503 let package_req: PackageReq = "lua>=5.1".parse().unwrap();
504 assert_eq!(package_req.name.to_string(), "lua");
505 let package_req: PackageReq = "toml-edit >= 0.1.0".parse().unwrap();
506 assert_eq!(package_req.name.to_string(), "toml-edit");
507 let package_req: PackageReq = "plugin.nvim >= 0.1.0".parse().unwrap();
508 assert_eq!(package_req.name.to_string(), "plugin.nvim");
509 let package_req: PackageReq = "lfs".parse().unwrap();
510 assert_eq!(package_req.name.to_string(), "lfs");
511 let package_req: PackageReq = "neorg 1.0.0".parse().unwrap();
512 assert_eq!(package_req.name.to_string(), "neorg");
513 let neorg = PackageSpec::parse("neorg".into(), "1.0.0".into()).unwrap();
514 assert!(package_req.matches(&neorg));
515 let neorg = PackageSpec::parse("neorg".into(), "2.0.0".into()).unwrap();
516 assert!(!package_req.matches(&neorg));
517 let package_req: PackageReq = "neorg 2.0.0".parse().unwrap();
518 assert!(package_req.matches(&neorg));
519 let package_req: PackageReq = "neorg = 2.0.0".parse().unwrap();
520 assert!(package_req.matches(&neorg));
521 let package_req: PackageReq = "neorg == 2.0.0".parse().unwrap();
522 assert!(package_req.matches(&neorg));
523 let package_req: PackageReq = "neorg = 2.0.0".parse().unwrap();
524 assert!(package_req.matches(&neorg));
525 let package_req: PackageReq = "neorg >= 1.0, < 2.0".parse().unwrap();
526 let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
527 assert!(package_req.matches(&neorg));
528 let package_req: PackageReq = "neorg > 1.0, < 2.0".parse().unwrap();
529 let neorg = PackageSpec::parse("neorg".into(), "1.11.0".into()).unwrap();
530 assert!(package_req.matches(&neorg));
531 let neorg = PackageSpec::parse("neorg".into(), "3.0.0".into()).unwrap();
532 assert!(!package_req.matches(&neorg));
533 let neorg = PackageSpec::parse("neorg".into(), "0.5".into()).unwrap();
534 assert!(!package_req.matches(&neorg));
535 let package_req: PackageReq = "neorg ~> 1".parse().unwrap();
536 assert!(!package_req.matches(&neorg));
537 let neorg = PackageSpec::parse("neorg".into(), "3".into()).unwrap();
538 assert!(!package_req.matches(&neorg));
539 let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
540 assert!(package_req.matches(&neorg));
541 let package_req: PackageReq = "neorg ~> 1.4".parse().unwrap();
542 let neorg = PackageSpec::parse("neorg".into(), "1.3".into()).unwrap();
543 assert!(!package_req.matches(&neorg));
544 let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
545 assert!(!package_req.matches(&neorg));
546 let neorg = PackageSpec::parse("neorg".into(), "1.4.10".into()).unwrap();
547 assert!(package_req.matches(&neorg));
548 let neorg = PackageSpec::parse("neorg".into(), "1.4".into()).unwrap();
549 assert!(package_req.matches(&neorg));
550 let package_req: PackageReq = "neorg ~> 1.0.5".parse().unwrap();
551 let neorg = PackageSpec::parse("neorg".into(), "1.0.4".into()).unwrap();
552 assert!(!package_req.matches(&neorg));
553 let neorg = PackageSpec::parse("neorg".into(), "1.0.5".into()).unwrap();
554 assert!(package_req.matches(&neorg));
555 let neorg = PackageSpec::parse("neorg".into(), "1.0.6".into()).unwrap();
556 assert!(!package_req.matches(&neorg));
557 let package_req: PackageReq = "lua-utils.nvim ~> 1.1-1".parse().unwrap();
559 let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.1.4".into()).unwrap();
560 assert!(package_req.matches(&lua_utils));
561 let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.1.5".into()).unwrap();
562 assert!(package_req.matches(&lua_utils));
563 let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.2-1".into()).unwrap();
564 assert!(!package_req.matches(&lua_utils));
565 }
566
567 #[tokio::test]
568 pub async fn remote_package_type_priorities() {
569 let rock_types = vec![
570 RemotePackageType::Binary,
571 RemotePackageType::Src,
572 RemotePackageType::Rockspec,
573 ];
574 assert_eq!(
575 rock_types.into_iter().sorted().collect_vec(),
576 vec![
577 RemotePackageType::Src,
578 RemotePackageType::Rockspec,
579 RemotePackageType::Binary,
580 ]
581 );
582 }
583
584 #[tokio::test]
585 async fn package_spec_req_roundtrips() {
586 let spec = PackageSpec::new("foo".into(), "1.0.0-1".parse().unwrap());
587 let req: PackageReq = spec.clone().into();
588 assert_eq!(spec, req.try_into().unwrap());
589
590 let spec = PackageSpec::new("foo".into(), "1.0.0".parse().unwrap());
591 let req: PackageReq = spec.clone().into();
592 assert_eq!(spec, req.try_into().unwrap());
593 }
594
595 #[tokio::test]
596 async fn package_req_spec_roundtrips() {
597 let req: PackageReq = "foo@1.0.0".parse().unwrap();
598 let spec: PackageSpec = req.clone().try_into().unwrap();
599 assert_eq!(req, spec.into());
600
601 let req: PackageReq = "foo@1.0.0-1".parse().unwrap();
602 let spec: PackageSpec = req.clone().try_into().unwrap();
603 assert_eq!(req, spec.into());
604 }
605
606 #[tokio::test]
607 async fn package_spec_parses_same_as_name_version() {
608 let req: PackageReq = "foo@1.0.0".parse().unwrap();
609 let spec: PackageSpec = req.clone().try_into().unwrap();
610 let name: PackageName = "foo".into();
611 let version: PackageVersion = "1.0.0".parse().unwrap();
612 assert_eq!(spec, PackageSpec::new(name, version));
613
614 let req: PackageReq = "foo@1.0.0-1".parse().unwrap();
615 let spec: PackageSpec = req.clone().try_into().unwrap();
616 let name: PackageName = "foo".into();
617 let version: PackageVersion = "1.0.0-1".parse().unwrap();
618 assert_eq!(spec, PackageSpec::new(name, version));
619 }
620}