Skip to main content

lux_lib/package/
mod.rs

1use itertools::Itertools;
2
3use serde::{de, Deserialize, Deserializer, Serialize};
4use std::{cmp::Ordering, fmt::Display, str::FromStr};
5use thiserror::Error;
6
7mod outdated;
8mod version;
9
10pub use outdated::*;
11pub use version::{
12    PackageVersion, PackageVersionParseError, PackageVersionReq, PackageVersionReqError, SpecRev,
13    VersionReqToVersionError,
14};
15
16pub(crate) use version::SpecRevIterator;
17
18use crate::{
19    lockfile::RemotePackageSourceUrl,
20    lua_rockspec::{DisplayAsLuaKV, DisplayLuaKV, DisplayLuaValue},
21    package::version::HasModRev,
22    remote_package_source::RemotePackageSource,
23    rockspec::lua_dependency::LuaDependencySpec,
24    variables::{GetVariableError, HasVariables},
25};
26
27#[derive(Debug, Error)]
28pub enum PackageSpecFromPackageReqError {
29    #[error("invalid version for rock {rock}: {err}")]
30    Version {
31        rock: PackageName,
32        err: VersionReqToVersionError,
33    },
34}
35
36#[derive(Clone, Debug, PartialEq, Eq)]
37#[cfg_attr(feature = "clap", derive(clap::Args))]
38pub struct PackageSpec {
39    name: PackageName,
40    version: PackageVersion,
41}
42
43impl PackageSpec {
44    pub fn new(name: PackageName, version: PackageVersion) -> Self {
45        Self { name, version }
46    }
47    pub fn parse(name: String, version: String) -> Result<Self, PackageVersionParseError> {
48        Ok(Self::new(
49            PackageName::new(name),
50            PackageVersion::parse(&version)?,
51        ))
52    }
53    pub fn name(&self) -> &PackageName {
54        &self.name
55    }
56    pub fn version(&self) -> &PackageVersion {
57        &self.version
58    }
59    pub fn into_package_req(self) -> PackageReq {
60        PackageReq {
61            name: self.name,
62            version_req: self.version.into_version_req(),
63        }
64    }
65}
66
67impl TryFrom<PackageReq> for PackageSpec {
68    type Error = PackageSpecFromPackageReqError;
69
70    fn try_from(value: PackageReq) -> Result<Self, Self::Error> {
71        let name = value.name;
72        let version = value.version_req.try_into().map_err(|err| {
73            PackageSpecFromPackageReqError::Version {
74                rock: name.clone(),
75                err,
76            }
77        })?;
78        Ok(Self { name, version })
79    }
80}
81
82impl<'de> Deserialize<'de> for PackageSpec {
83    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
84    where
85        D: Deserializer<'de>,
86    {
87        let (name, version) = <_>::deserialize(deserializer)?;
88        Self::parse(name, version).map_err(serde::de::Error::custom)
89    }
90}
91
92impl HasVariables for PackageSpec {
93    fn get_variable(&self, input: &str) -> Result<Option<String>, GetVariableError> {
94        Ok(match input {
95            "PACKAGE" => Some(self.name.to_string()),
96            "VERSION" => Some(self.version.to_modrev_string()),
97            _ => None,
98        })
99    }
100}
101
102#[derive(Clone, Debug)]
103pub(crate) struct RemotePackage {
104    pub package: PackageSpec,
105    pub source: RemotePackageSource,
106    /// `Some` if present in a lockfile
107    pub source_url: Option<RemotePackageSourceUrl>,
108}
109
110impl RemotePackage {
111    pub fn new(
112        package: PackageSpec,
113        source: RemotePackageSource,
114        source_url: Option<RemotePackageSourceUrl>,
115    ) -> Self {
116        Self {
117            package,
118            source,
119            source_url,
120        }
121    }
122}
123
124#[derive(PartialEq, Eq, Hash, Clone, Debug)]
125pub(crate) enum RemotePackageType {
126    Rockspec,
127    Src,
128    Binary,
129}
130
131impl Ord for RemotePackageType {
132    fn cmp(&self, other: &Self) -> Ordering {
133        // Priority: binary > rockspec > src
134        match (self, other) {
135            (Self::Binary, Self::Binary)
136            | (Self::Rockspec, Self::Rockspec)
137            | (Self::Src, Self::Src) => Ordering::Equal,
138
139            (Self::Binary, _) => Ordering::Greater,
140            (_, Self::Binary) => Ordering::Less,
141            (Self::Rockspec, Self::Src) => Ordering::Greater,
142            (Self::Src, Self::Rockspec) => Ordering::Less,
143        }
144    }
145}
146
147impl PartialOrd for RemotePackageType {
148    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
149        Some(self.cmp(other))
150    }
151}
152
153#[derive(Clone)]
154pub struct RemotePackageTypeFilterSpec {
155    /// Include Rockspec
156    pub rockspec: bool,
157    /// Include Src
158    pub src: bool,
159    /// Include Binary
160    pub binary: bool,
161}
162
163impl Default for RemotePackageTypeFilterSpec {
164    fn default() -> Self {
165        Self {
166            rockspec: true,
167            src: true,
168            binary: true,
169        }
170    }
171}
172
173#[derive(Error, Debug)]
174pub enum ParseRemotePackageError {
175    #[error("unable to parse package {0}. expected format: `name@version`")]
176    InvalidInput(String),
177    #[error("unable to parse package {package_str}:\n{error}")]
178    InvalidPackageVersion {
179        #[source]
180        error: PackageVersionParseError,
181        package_str: String,
182    },
183}
184
185impl FromStr for PackageSpec {
186    type Err = ParseRemotePackageError;
187
188    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
189        let (name, version) = s
190            .split_once('@')
191            .ok_or_else(|| ParseRemotePackageError::InvalidInput(s.to_string()))?;
192
193        Self::parse(name.to_string(), version.to_string()).map_err(|error| {
194            ParseRemotePackageError::InvalidPackageVersion {
195                error,
196                package_str: s.to_string(),
197            }
198        })
199    }
200}
201
202/// A lua package requirement with a name and an optional version requirement.
203#[derive(Debug, Clone, PartialEq, Eq, Hash)]
204#[cfg_attr(feature = "clap", derive(clap::Args))]
205pub struct PackageReq {
206    /// The name of the package.
207    pub(crate) name: PackageName,
208    /// The version requirement, for example "1.0.0" or ">=1.0.0".
209    pub(crate) version_req: PackageVersionReq,
210}
211
212impl PackageReq {
213    pub fn new(name: String, version: Option<String>) -> Result<Self, PackageVersionReqError> {
214        let version_req = match version {
215            Some(version_req_str) => PackageVersionReq::parse(version_req_str.as_str())?,
216            None => PackageVersionReq::any(),
217        };
218        Ok(Self {
219            name: PackageName::new(name),
220            version_req,
221        })
222    }
223
224    /// # Safety
225    ///
226    /// Call this only if you are absolutely sure that a `PackageReq` can be parsed from the arguments.
227    /// Calling this with invalid arguments is undefined behaviour.
228    pub unsafe fn new_unchecked(name: String, version: Option<String>) -> Self {
229        Self::new(name, version).unwrap_unchecked()
230    }
231
232    pub fn parse(pkg_constraints: &str) -> Result<Self, PackageReqParseError> {
233        Self::from_str(pkg_constraints)
234    }
235    pub fn name(&self) -> &PackageName {
236        &self.name
237    }
238    pub fn version_req(&self) -> &PackageVersionReq {
239        &self.version_req
240    }
241    /// Evaluate whether the given package satisfies the package requirement
242    /// given by `self`.
243    pub fn matches(&self, package: &PackageSpec) -> bool {
244        self.name == package.name && self.version_req.matches(&package.version)
245    }
246}
247
248impl Display for PackageReq {
249    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250        if self.version_req.is_any() {
251            self.name.fmt(f)
252        } else {
253            f.write_str(format!("{}{}", self.name, self.version_req).as_str())
254        }
255    }
256}
257
258impl From<PackageSpec> for PackageReq {
259    fn from(value: PackageSpec) -> Self {
260        value.into_package_req()
261    }
262}
263
264impl From<PackageName> for PackageReq {
265    fn from(name: PackageName) -> Self {
266        Self {
267            name,
268            version_req: PackageVersionReq::any(),
269        }
270    }
271}
272
273/// Wrapper structs for proper serialization of various dependency types.
274pub(crate) struct Dependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
275pub(crate) struct BuildDependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
276
277impl DisplayAsLuaKV for Dependencies<'_> {
278    fn display_lua(&self) -> DisplayLuaKV {
279        DisplayLuaKV {
280            key: "dependencies".to_string(),
281            value: DisplayLuaValue::List(
282                self.0
283                    .iter()
284                    .map(|req| DisplayLuaValue::String(req.to_string()))
285                    .collect(),
286            ),
287        }
288    }
289}
290
291impl DisplayAsLuaKV for BuildDependencies<'_> {
292    fn display_lua(&self) -> DisplayLuaKV {
293        DisplayLuaKV {
294            key: "build_dependencies".to_string(),
295            value: DisplayLuaValue::List(
296                self.0
297                    .iter()
298                    .map(|req| DisplayLuaValue::String(req.to_string()))
299                    .collect(),
300            ),
301        }
302    }
303}
304
305#[derive(Error, Debug)]
306pub enum PackageReqParseError {
307    #[error("could not parse dependency name from {0}")]
308    InvalidDependencyName(String),
309    #[error("could not parse version requirement in '{str}': {error}")]
310    InvalidPackageVersionReq {
311        #[source]
312        error: PackageVersionReqError,
313        str: String,
314    },
315}
316
317impl FromStr for PackageReq {
318    type Err = PackageReqParseError;
319
320    fn from_str(str: &str) -> Result<Self, PackageReqParseError> {
321        let rock_name_str = str
322            .chars()
323            .peeking_take_while(|t| t.is_alphanumeric() || matches!(t, '-' | '_' | '.'))
324            .collect::<String>();
325
326        if rock_name_str.is_empty() {
327            return Err(PackageReqParseError::InvalidDependencyName(str.to_string()));
328        }
329
330        let constraints = str.trim_start_matches(&rock_name_str).trim();
331        let version_req = match constraints {
332            "" => PackageVersionReq::any(),
333            constraints => PackageVersionReq::parse(constraints.trim_start()).map_err(|error| {
334                PackageReqParseError::InvalidPackageVersionReq {
335                    error,
336                    str: str.to_string(),
337                }
338            })?,
339        };
340        Ok(Self {
341            name: PackageName::new(rock_name_str),
342            version_req,
343        })
344    }
345}
346
347impl<'de> Deserialize<'de> for PackageReq {
348    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
349    where
350        D: Deserializer<'de>,
351    {
352        let s = String::deserialize(deserializer)?;
353        Self::from_str(&s).map_err(de::Error::custom)
354    }
355}
356
357/// A luarocks package name, which is always lowercase
358#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash)]
359pub struct PackageName(String);
360
361impl PackageName {
362    pub fn new(name: String) -> Self {
363        Self(name.to_lowercase())
364    }
365}
366
367impl<'de> Deserialize<'de> for PackageName {
368    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
369    where
370        D: serde::Deserializer<'de>,
371    {
372        Ok(PackageName::new(String::deserialize(deserializer)?))
373    }
374}
375
376impl Serialize for PackageName {
377    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
378    where
379        S: serde::Serializer,
380    {
381        self.0.serialize(serializer)
382    }
383}
384
385impl From<&str> for PackageName {
386    fn from(value: &str) -> Self {
387        Self::new(value.into())
388    }
389}
390
391impl Display for PackageName {
392    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
393        f.write_str(self.0.as_str())
394    }
395}
396
397#[derive(Debug)]
398pub struct PackageNameList(Vec<PackageName>);
399
400impl PackageNameList {
401    pub(crate) fn new(package_names: Vec<PackageName>) -> Self {
402        Self(package_names)
403    }
404}
405
406impl Display for PackageNameList {
407    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408        f.write_str(self.0.iter().map(|pkg| pkg.to_string()).join(", ").as_str())
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[tokio::test]
417    async fn parse_name() {
418        let mut package_name: PackageName = "neorg".into();
419        assert_eq!(package_name.to_string(), "neorg");
420        package_name = "LuaFileSystem".into();
421        assert_eq!(package_name.to_string(), "luafilesystem");
422    }
423
424    #[tokio::test]
425    async fn parse_lua_package() {
426        let neorg = PackageSpec::parse("neorg".into(), "1.0.0".into()).unwrap();
427        let expected_version = PackageVersion::parse("1.0.0").unwrap();
428        assert_eq!(neorg.name().to_string(), "neorg");
429        assert!(matches!(
430            neorg.version().cmp(&expected_version),
431            std::cmp::Ordering::Equal
432        ));
433        let neorg = PackageSpec::parse("neorg".into(), "1.0".into()).unwrap();
434        assert!(matches!(
435            neorg.version().cmp(&expected_version),
436            std::cmp::Ordering::Equal
437        ));
438        let neorg = PackageSpec::parse("neorg".into(), "1".into()).unwrap();
439        assert!(matches!(
440            neorg.version().cmp(&expected_version),
441            std::cmp::Ordering::Equal
442        ));
443    }
444
445    #[tokio::test]
446    async fn parse_lua_package_req() {
447        let mut package_req = PackageReq::new("foo".into(), Some("1.0.0".into())).unwrap();
448        assert!(package_req.matches(&PackageSpec::parse("foo".into(), "1.0.0".into()).unwrap()));
449        assert!(!package_req.matches(&PackageSpec::parse("bar".into(), "1.0.0".into()).unwrap()));
450        assert!(!package_req.matches(&PackageSpec::parse("foo".into(), "2.0.0".into()).unwrap()));
451        package_req = PackageReq::new("foo".into(), Some(">= 1.0.0".into())).unwrap();
452        assert!(package_req.matches(&PackageSpec::parse("foo".into(), "2.0.0".into()).unwrap()));
453        let package_req: PackageReq = "lua >= 5.1".parse().unwrap();
454        assert_eq!(package_req.name.to_string(), "lua");
455        let package_req: PackageReq = "lua>=5.1".parse().unwrap();
456        assert_eq!(package_req.name.to_string(), "lua");
457        let package_req: PackageReq = "toml-edit >= 0.1.0".parse().unwrap();
458        assert_eq!(package_req.name.to_string(), "toml-edit");
459        let package_req: PackageReq = "plugin.nvim >= 0.1.0".parse().unwrap();
460        assert_eq!(package_req.name.to_string(), "plugin.nvim");
461        let package_req: PackageReq = "lfs".parse().unwrap();
462        assert_eq!(package_req.name.to_string(), "lfs");
463        let package_req: PackageReq = "neorg 1.0.0".parse().unwrap();
464        assert_eq!(package_req.name.to_string(), "neorg");
465        let neorg = PackageSpec::parse("neorg".into(), "1.0.0".into()).unwrap();
466        assert!(package_req.matches(&neorg));
467        let neorg = PackageSpec::parse("neorg".into(), "2.0.0".into()).unwrap();
468        assert!(!package_req.matches(&neorg));
469        let package_req: PackageReq = "neorg 2.0.0".parse().unwrap();
470        assert!(package_req.matches(&neorg));
471        let package_req: PackageReq = "neorg = 2.0.0".parse().unwrap();
472        assert!(package_req.matches(&neorg));
473        let package_req: PackageReq = "neorg == 2.0.0".parse().unwrap();
474        assert!(package_req.matches(&neorg));
475        let package_req: PackageReq = "neorg &equals; 2.0.0".parse().unwrap();
476        assert!(package_req.matches(&neorg));
477        let package_req: PackageReq = "neorg >= 1.0, &lt; 2.0".parse().unwrap();
478        let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
479        assert!(package_req.matches(&neorg));
480        let package_req: PackageReq = "neorg &gt; 1.0, &lt; 2.0".parse().unwrap();
481        let neorg = PackageSpec::parse("neorg".into(), "1.11.0".into()).unwrap();
482        assert!(package_req.matches(&neorg));
483        let neorg = PackageSpec::parse("neorg".into(), "3.0.0".into()).unwrap();
484        assert!(!package_req.matches(&neorg));
485        let neorg = PackageSpec::parse("neorg".into(), "0.5".into()).unwrap();
486        assert!(!package_req.matches(&neorg));
487        let package_req: PackageReq = "neorg ~> 1".parse().unwrap();
488        assert!(!package_req.matches(&neorg));
489        let neorg = PackageSpec::parse("neorg".into(), "3".into()).unwrap();
490        assert!(!package_req.matches(&neorg));
491        let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
492        assert!(package_req.matches(&neorg));
493        let package_req: PackageReq = "neorg ~> 1.4".parse().unwrap();
494        let neorg = PackageSpec::parse("neorg".into(), "1.3".into()).unwrap();
495        assert!(!package_req.matches(&neorg));
496        let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
497        assert!(!package_req.matches(&neorg));
498        let neorg = PackageSpec::parse("neorg".into(), "1.4.10".into()).unwrap();
499        assert!(package_req.matches(&neorg));
500        let neorg = PackageSpec::parse("neorg".into(), "1.4".into()).unwrap();
501        assert!(package_req.matches(&neorg));
502        let package_req: PackageReq = "neorg ~> 1.0.5".parse().unwrap();
503        let neorg = PackageSpec::parse("neorg".into(), "1.0.4".into()).unwrap();
504        assert!(!package_req.matches(&neorg));
505        let neorg = PackageSpec::parse("neorg".into(), "1.0.5".into()).unwrap();
506        assert!(package_req.matches(&neorg));
507        let neorg = PackageSpec::parse("neorg".into(), "1.0.6".into()).unwrap();
508        assert!(!package_req.matches(&neorg));
509        // Testing incomplete version constraints
510        let package_req: PackageReq = "lua-utils.nvim ~> 1.1-1".parse().unwrap();
511        let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.1.4".into()).unwrap();
512        assert!(package_req.matches(&lua_utils));
513        let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.1.5".into()).unwrap();
514        assert!(package_req.matches(&lua_utils));
515        let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.2-1".into()).unwrap();
516        assert!(!package_req.matches(&lua_utils));
517    }
518
519    #[tokio::test]
520    pub async fn remote_package_type_priorities() {
521        let rock_types = vec![
522            RemotePackageType::Binary,
523            RemotePackageType::Src,
524            RemotePackageType::Rockspec,
525        ];
526        assert_eq!(
527            rock_types.into_iter().sorted().collect_vec(),
528            vec![
529                RemotePackageType::Src,
530                RemotePackageType::Rockspec,
531                RemotePackageType::Binary,
532            ]
533        );
534    }
535
536    #[tokio::test]
537    async fn package_spec_req_roundtrips() {
538        let spec = PackageSpec::new("foo".into(), "1.0.0-1".parse().unwrap());
539        let req: PackageReq = spec.clone().into();
540        assert_eq!(spec, req.try_into().unwrap());
541
542        let spec = PackageSpec::new("foo".into(), "1.0.0".parse().unwrap());
543        let req: PackageReq = spec.clone().into();
544        assert_eq!(spec, req.try_into().unwrap());
545    }
546
547    #[tokio::test]
548    async fn package_req_spec_roundtrips() {
549        let req: PackageReq = "foo@1.0.0".parse().unwrap();
550        let spec: PackageSpec = req.clone().try_into().unwrap();
551        assert_eq!(req, spec.into());
552
553        let req: PackageReq = "foo@1.0.0-1".parse().unwrap();
554        let spec: PackageSpec = req.clone().try_into().unwrap();
555        assert_eq!(req, spec.into());
556    }
557
558    #[tokio::test]
559    async fn package_spec_parses_same_as_name_version() {
560        let req: PackageReq = "foo@1.0.0".parse().unwrap();
561        let spec: PackageSpec = req.clone().try_into().unwrap();
562        let name: PackageName = "foo".into();
563        let version: PackageVersion = "1.0.0".parse().unwrap();
564        assert_eq!(spec, PackageSpec::new(name, version));
565
566        let req: PackageReq = "foo@1.0.0-1".parse().unwrap();
567        let spec: PackageSpec = req.clone().try_into().unwrap();
568        let name: PackageName = "foo".into();
569        let version: PackageVersion = "1.0.0-1".parse().unwrap();
570        assert_eq!(spec, PackageSpec::new(name, version));
571    }
572}