1use itertools::Itertools;
2
3use serde::{de, Deserialize, Deserializer, Serialize};
4use std::{cmp::Ordering, fmt::Display, str::FromStr};
5use thiserror::Error;
6
7mod outdated;
8mod version;
9
10pub use outdated::*;
11pub use version::{
12 PackageVersion, PackageVersionParseError, PackageVersionReq, PackageVersionReqError, SpecRev,
13 VersionReqToVersionError,
14};
15
16pub(crate) use version::SpecRevIterator;
17
18use crate::{
19 lockfile::RemotePackageSourceUrl,
20 lua_rockspec::{DisplayAsLuaKV, DisplayLuaKV, DisplayLuaValue},
21 package::version::HasModRev,
22 remote_package_source::RemotePackageSource,
23 rockspec::lua_dependency::LuaDependencySpec,
24 variables::{GetVariableError, HasVariables},
25};
26
27#[derive(Debug, Error)]
28pub enum PackageSpecFromPackageReqError {
29 #[error("invalid version for rock {rock}: {err}")]
30 Version {
31 rock: PackageName,
32 err: VersionReqToVersionError,
33 },
34}
35
36#[derive(Clone, Debug, PartialEq, Eq)]
37#[cfg_attr(feature = "clap", derive(clap::Args))]
38pub struct PackageSpec {
39 name: PackageName,
40 version: PackageVersion,
41}
42
43impl PackageSpec {
44 pub fn new(name: PackageName, version: PackageVersion) -> Self {
45 Self { name, version }
46 }
47 pub fn parse(name: String, version: String) -> Result<Self, PackageVersionParseError> {
48 Ok(Self::new(
49 PackageName::new(name),
50 PackageVersion::parse(&version)?,
51 ))
52 }
53 pub fn name(&self) -> &PackageName {
54 &self.name
55 }
56 pub fn version(&self) -> &PackageVersion {
57 &self.version
58 }
59 pub fn into_package_req(self) -> PackageReq {
60 PackageReq {
61 name: self.name,
62 version_req: self.version.into_version_req(),
63 }
64 }
65}
66
67impl TryFrom<PackageReq> for PackageSpec {
68 type Error = PackageSpecFromPackageReqError;
69
70 fn try_from(value: PackageReq) -> Result<Self, Self::Error> {
71 let name = value.name;
72 let version = value.version_req.try_into().map_err(|err| {
73 PackageSpecFromPackageReqError::Version {
74 rock: name.clone(),
75 err,
76 }
77 })?;
78 Ok(Self { name, version })
79 }
80}
81
82impl<'de> Deserialize<'de> for PackageSpec {
83 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
84 where
85 D: Deserializer<'de>,
86 {
87 let (name, version) = <_>::deserialize(deserializer)?;
88 Self::parse(name, version).map_err(serde::de::Error::custom)
89 }
90}
91
92impl HasVariables for PackageSpec {
93 fn get_variable(&self, input: &str) -> Result<Option<String>, GetVariableError> {
94 Ok(match input {
95 "PACKAGE" => Some(self.name.to_string()),
96 "VERSION" => Some(self.version.to_modrev_string()),
97 _ => None,
98 })
99 }
100}
101
102#[derive(Clone, Debug)]
103pub(crate) struct RemotePackage {
104 pub package: PackageSpec,
105 pub source: RemotePackageSource,
106 pub source_url: Option<RemotePackageSourceUrl>,
108}
109
110impl RemotePackage {
111 pub fn new(
112 package: PackageSpec,
113 source: RemotePackageSource,
114 source_url: Option<RemotePackageSourceUrl>,
115 ) -> Self {
116 Self {
117 package,
118 source,
119 source_url,
120 }
121 }
122}
123
124#[derive(PartialEq, Eq, Hash, Clone, Debug)]
125pub(crate) enum RemotePackageType {
126 Rockspec,
127 Src,
128 Binary,
129}
130
131impl Ord for RemotePackageType {
132 fn cmp(&self, other: &Self) -> Ordering {
133 match (self, other) {
135 (Self::Binary, Self::Binary)
136 | (Self::Rockspec, Self::Rockspec)
137 | (Self::Src, Self::Src) => Ordering::Equal,
138
139 (Self::Binary, _) => Ordering::Greater,
140 (_, Self::Binary) => Ordering::Less,
141 (Self::Rockspec, Self::Src) => Ordering::Greater,
142 (Self::Src, Self::Rockspec) => Ordering::Less,
143 }
144 }
145}
146
147impl PartialOrd for RemotePackageType {
148 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
149 Some(self.cmp(other))
150 }
151}
152
153#[derive(Clone)]
154pub struct RemotePackageTypeFilterSpec {
155 pub rockspec: bool,
157 pub src: bool,
159 pub binary: bool,
161}
162
163impl Default for RemotePackageTypeFilterSpec {
164 fn default() -> Self {
165 Self {
166 rockspec: true,
167 src: true,
168 binary: true,
169 }
170 }
171}
172
173#[derive(Error, Debug)]
174pub enum ParseRemotePackageError {
175 #[error("unable to parse package {0}. expected format: `name@version`")]
176 InvalidInput(String),
177 #[error("unable to parse package {package_str}:\n{error}")]
178 InvalidPackageVersion {
179 #[source]
180 error: PackageVersionParseError,
181 package_str: String,
182 },
183}
184
185impl FromStr for PackageSpec {
186 type Err = ParseRemotePackageError;
187
188 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
189 let (name, version) = s
190 .split_once('@')
191 .ok_or_else(|| ParseRemotePackageError::InvalidInput(s.to_string()))?;
192
193 Self::parse(name.to_string(), version.to_string()).map_err(|error| {
194 ParseRemotePackageError::InvalidPackageVersion {
195 error,
196 package_str: s.to_string(),
197 }
198 })
199 }
200}
201
202#[derive(Debug, Clone, PartialEq, Eq, Hash)]
204#[cfg_attr(feature = "clap", derive(clap::Args))]
205pub struct PackageReq {
206 pub(crate) name: PackageName,
208 pub(crate) version_req: PackageVersionReq,
210}
211
212impl PackageReq {
213 pub fn new(name: String, version: Option<String>) -> Result<Self, PackageVersionReqError> {
214 let version_req = match version {
215 Some(version_req_str) => PackageVersionReq::parse(version_req_str.as_str())?,
216 None => PackageVersionReq::any(),
217 };
218 Ok(Self {
219 name: PackageName::new(name),
220 version_req,
221 })
222 }
223
224 pub unsafe fn new_unchecked(name: String, version: Option<String>) -> Self {
229 Self::new(name, version).unwrap_unchecked()
230 }
231
232 pub fn parse(pkg_constraints: &str) -> Result<Self, PackageReqParseError> {
233 Self::from_str(pkg_constraints)
234 }
235 pub fn name(&self) -> &PackageName {
236 &self.name
237 }
238 pub fn version_req(&self) -> &PackageVersionReq {
239 &self.version_req
240 }
241 pub fn matches(&self, package: &PackageSpec) -> bool {
244 self.name == package.name && self.version_req.matches(&package.version)
245 }
246}
247
248impl Display for PackageReq {
249 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250 if self.version_req.is_any() {
251 self.name.fmt(f)
252 } else {
253 f.write_str(format!("{}{}", self.name, self.version_req).as_str())
254 }
255 }
256}
257
258impl From<PackageSpec> for PackageReq {
259 fn from(value: PackageSpec) -> Self {
260 value.into_package_req()
261 }
262}
263
264impl From<PackageName> for PackageReq {
265 fn from(name: PackageName) -> Self {
266 Self {
267 name,
268 version_req: PackageVersionReq::any(),
269 }
270 }
271}
272
273pub(crate) struct Dependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
275pub(crate) struct BuildDependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
276
277impl DisplayAsLuaKV for Dependencies<'_> {
278 fn display_lua(&self) -> DisplayLuaKV {
279 DisplayLuaKV {
280 key: "dependencies".to_string(),
281 value: DisplayLuaValue::List(
282 self.0
283 .iter()
284 .map(|req| DisplayLuaValue::String(req.to_string()))
285 .collect(),
286 ),
287 }
288 }
289}
290
291impl DisplayAsLuaKV for BuildDependencies<'_> {
292 fn display_lua(&self) -> DisplayLuaKV {
293 DisplayLuaKV {
294 key: "build_dependencies".to_string(),
295 value: DisplayLuaValue::List(
296 self.0
297 .iter()
298 .map(|req| DisplayLuaValue::String(req.to_string()))
299 .collect(),
300 ),
301 }
302 }
303}
304
305#[derive(Error, Debug)]
306pub enum PackageReqParseError {
307 #[error("could not parse dependency name from {0}")]
308 InvalidDependencyName(String),
309 #[error("could not parse version requirement in '{str}': {error}")]
310 InvalidPackageVersionReq {
311 #[source]
312 error: PackageVersionReqError,
313 str: String,
314 },
315}
316
317impl FromStr for PackageReq {
318 type Err = PackageReqParseError;
319
320 fn from_str(str: &str) -> Result<Self, PackageReqParseError> {
321 let rock_name_str = str
322 .chars()
323 .peeking_take_while(|t| t.is_alphanumeric() || matches!(t, '-' | '_' | '.'))
324 .collect::<String>();
325
326 if rock_name_str.is_empty() {
327 return Err(PackageReqParseError::InvalidDependencyName(str.to_string()));
328 }
329
330 let constraints = str.trim_start_matches(&rock_name_str).trim();
331 let version_req = match constraints {
332 "" => PackageVersionReq::any(),
333 constraints => PackageVersionReq::parse(constraints.trim_start()).map_err(|error| {
334 PackageReqParseError::InvalidPackageVersionReq {
335 error,
336 str: str.to_string(),
337 }
338 })?,
339 };
340 Ok(Self {
341 name: PackageName::new(rock_name_str),
342 version_req,
343 })
344 }
345}
346
347impl<'de> Deserialize<'de> for PackageReq {
348 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
349 where
350 D: Deserializer<'de>,
351 {
352 let s = String::deserialize(deserializer)?;
353 Self::from_str(&s).map_err(de::Error::custom)
354 }
355}
356
357#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash)]
359pub struct PackageName(String);
360
361impl PackageName {
362 pub fn new(name: String) -> Self {
363 Self(name.to_lowercase())
364 }
365}
366
367impl<'de> Deserialize<'de> for PackageName {
368 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
369 where
370 D: serde::Deserializer<'de>,
371 {
372 Ok(PackageName::new(String::deserialize(deserializer)?))
373 }
374}
375
376impl Serialize for PackageName {
377 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
378 where
379 S: serde::Serializer,
380 {
381 self.0.serialize(serializer)
382 }
383}
384
385impl From<&str> for PackageName {
386 fn from(value: &str) -> Self {
387 Self::new(value.into())
388 }
389}
390
391impl Display for PackageName {
392 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
393 f.write_str(self.0.as_str())
394 }
395}
396
397#[derive(Debug)]
398pub struct PackageNameList(Vec<PackageName>);
399
400impl PackageNameList {
401 pub(crate) fn new(package_names: Vec<PackageName>) -> Self {
402 Self(package_names)
403 }
404}
405
406impl Display for PackageNameList {
407 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408 f.write_str(self.0.iter().map(|pkg| pkg.to_string()).join(", ").as_str())
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[tokio::test]
417 async fn parse_name() {
418 let mut package_name: PackageName = "neorg".into();
419 assert_eq!(package_name.to_string(), "neorg");
420 package_name = "LuaFileSystem".into();
421 assert_eq!(package_name.to_string(), "luafilesystem");
422 }
423
424 #[tokio::test]
425 async fn parse_lua_package() {
426 let neorg = PackageSpec::parse("neorg".into(), "1.0.0".into()).unwrap();
427 let expected_version = PackageVersion::parse("1.0.0").unwrap();
428 assert_eq!(neorg.name().to_string(), "neorg");
429 assert!(matches!(
430 neorg.version().cmp(&expected_version),
431 std::cmp::Ordering::Equal
432 ));
433 let neorg = PackageSpec::parse("neorg".into(), "1.0".into()).unwrap();
434 assert!(matches!(
435 neorg.version().cmp(&expected_version),
436 std::cmp::Ordering::Equal
437 ));
438 let neorg = PackageSpec::parse("neorg".into(), "1".into()).unwrap();
439 assert!(matches!(
440 neorg.version().cmp(&expected_version),
441 std::cmp::Ordering::Equal
442 ));
443 }
444
445 #[tokio::test]
446 async fn parse_lua_package_req() {
447 let mut package_req = PackageReq::new("foo".into(), Some("1.0.0".into())).unwrap();
448 assert!(package_req.matches(&PackageSpec::parse("foo".into(), "1.0.0".into()).unwrap()));
449 assert!(!package_req.matches(&PackageSpec::parse("bar".into(), "1.0.0".into()).unwrap()));
450 assert!(!package_req.matches(&PackageSpec::parse("foo".into(), "2.0.0".into()).unwrap()));
451 package_req = PackageReq::new("foo".into(), Some(">= 1.0.0".into())).unwrap();
452 assert!(package_req.matches(&PackageSpec::parse("foo".into(), "2.0.0".into()).unwrap()));
453 let package_req: PackageReq = "lua >= 5.1".parse().unwrap();
454 assert_eq!(package_req.name.to_string(), "lua");
455 let package_req: PackageReq = "lua>=5.1".parse().unwrap();
456 assert_eq!(package_req.name.to_string(), "lua");
457 let package_req: PackageReq = "toml-edit >= 0.1.0".parse().unwrap();
458 assert_eq!(package_req.name.to_string(), "toml-edit");
459 let package_req: PackageReq = "plugin.nvim >= 0.1.0".parse().unwrap();
460 assert_eq!(package_req.name.to_string(), "plugin.nvim");
461 let package_req: PackageReq = "lfs".parse().unwrap();
462 assert_eq!(package_req.name.to_string(), "lfs");
463 let package_req: PackageReq = "neorg 1.0.0".parse().unwrap();
464 assert_eq!(package_req.name.to_string(), "neorg");
465 let neorg = PackageSpec::parse("neorg".into(), "1.0.0".into()).unwrap();
466 assert!(package_req.matches(&neorg));
467 let neorg = PackageSpec::parse("neorg".into(), "2.0.0".into()).unwrap();
468 assert!(!package_req.matches(&neorg));
469 let package_req: PackageReq = "neorg 2.0.0".parse().unwrap();
470 assert!(package_req.matches(&neorg));
471 let package_req: PackageReq = "neorg = 2.0.0".parse().unwrap();
472 assert!(package_req.matches(&neorg));
473 let package_req: PackageReq = "neorg == 2.0.0".parse().unwrap();
474 assert!(package_req.matches(&neorg));
475 let package_req: PackageReq = "neorg = 2.0.0".parse().unwrap();
476 assert!(package_req.matches(&neorg));
477 let package_req: PackageReq = "neorg >= 1.0, < 2.0".parse().unwrap();
478 let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
479 assert!(package_req.matches(&neorg));
480 let package_req: PackageReq = "neorg > 1.0, < 2.0".parse().unwrap();
481 let neorg = PackageSpec::parse("neorg".into(), "1.11.0".into()).unwrap();
482 assert!(package_req.matches(&neorg));
483 let neorg = PackageSpec::parse("neorg".into(), "3.0.0".into()).unwrap();
484 assert!(!package_req.matches(&neorg));
485 let neorg = PackageSpec::parse("neorg".into(), "0.5".into()).unwrap();
486 assert!(!package_req.matches(&neorg));
487 let package_req: PackageReq = "neorg ~> 1".parse().unwrap();
488 assert!(!package_req.matches(&neorg));
489 let neorg = PackageSpec::parse("neorg".into(), "3".into()).unwrap();
490 assert!(!package_req.matches(&neorg));
491 let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
492 assert!(package_req.matches(&neorg));
493 let package_req: PackageReq = "neorg ~> 1.4".parse().unwrap();
494 let neorg = PackageSpec::parse("neorg".into(), "1.3".into()).unwrap();
495 assert!(!package_req.matches(&neorg));
496 let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
497 assert!(!package_req.matches(&neorg));
498 let neorg = PackageSpec::parse("neorg".into(), "1.4.10".into()).unwrap();
499 assert!(package_req.matches(&neorg));
500 let neorg = PackageSpec::parse("neorg".into(), "1.4".into()).unwrap();
501 assert!(package_req.matches(&neorg));
502 let package_req: PackageReq = "neorg ~> 1.0.5".parse().unwrap();
503 let neorg = PackageSpec::parse("neorg".into(), "1.0.4".into()).unwrap();
504 assert!(!package_req.matches(&neorg));
505 let neorg = PackageSpec::parse("neorg".into(), "1.0.5".into()).unwrap();
506 assert!(package_req.matches(&neorg));
507 let neorg = PackageSpec::parse("neorg".into(), "1.0.6".into()).unwrap();
508 assert!(!package_req.matches(&neorg));
509 let package_req: PackageReq = "lua-utils.nvim ~> 1.1-1".parse().unwrap();
511 let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.1.4".into()).unwrap();
512 assert!(package_req.matches(&lua_utils));
513 let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.1.5".into()).unwrap();
514 assert!(package_req.matches(&lua_utils));
515 let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.2-1".into()).unwrap();
516 assert!(!package_req.matches(&lua_utils));
517 }
518
519 #[tokio::test]
520 pub async fn remote_package_type_priorities() {
521 let rock_types = vec![
522 RemotePackageType::Binary,
523 RemotePackageType::Src,
524 RemotePackageType::Rockspec,
525 ];
526 assert_eq!(
527 rock_types.into_iter().sorted().collect_vec(),
528 vec![
529 RemotePackageType::Src,
530 RemotePackageType::Rockspec,
531 RemotePackageType::Binary,
532 ]
533 );
534 }
535
536 #[tokio::test]
537 async fn package_spec_req_roundtrips() {
538 let spec = PackageSpec::new("foo".into(), "1.0.0-1".parse().unwrap());
539 let req: PackageReq = spec.clone().into();
540 assert_eq!(spec, req.try_into().unwrap());
541
542 let spec = PackageSpec::new("foo".into(), "1.0.0".parse().unwrap());
543 let req: PackageReq = spec.clone().into();
544 assert_eq!(spec, req.try_into().unwrap());
545 }
546
547 #[tokio::test]
548 async fn package_req_spec_roundtrips() {
549 let req: PackageReq = "foo@1.0.0".parse().unwrap();
550 let spec: PackageSpec = req.clone().try_into().unwrap();
551 assert_eq!(req, spec.into());
552
553 let req: PackageReq = "foo@1.0.0-1".parse().unwrap();
554 let spec: PackageSpec = req.clone().try_into().unwrap();
555 assert_eq!(req, spec.into());
556 }
557
558 #[tokio::test]
559 async fn package_spec_parses_same_as_name_version() {
560 let req: PackageReq = "foo@1.0.0".parse().unwrap();
561 let spec: PackageSpec = req.clone().try_into().unwrap();
562 let name: PackageName = "foo".into();
563 let version: PackageVersion = "1.0.0".parse().unwrap();
564 assert_eq!(spec, PackageSpec::new(name, version));
565
566 let req: PackageReq = "foo@1.0.0-1".parse().unwrap();
567 let spec: PackageSpec = req.clone().try_into().unwrap();
568 let name: PackageName = "foo".into();
569 let version: PackageVersion = "1.0.0-1".parse().unwrap();
570 assert_eq!(spec, PackageSpec::new(name, version));
571 }
572}