Skip to main content

lux_lib/package/
mod.rs

1use itertools::Itertools;
2use mlua::{ExternalResult, FromLua, IntoLua, LuaSerdeExt};
3use serde::{de, Deserialize, Deserializer, Serialize};
4use std::{cmp::Ordering, fmt::Display, str::FromStr};
5use thiserror::Error;
6
7mod outdated;
8mod version;
9
10pub use outdated::*;
11pub use version::{
12    PackageVersion, PackageVersionParseError, PackageVersionReq, PackageVersionReqError, SpecRev,
13    VersionReqToVersionError,
14};
15
16pub(crate) use version::SpecRevIterator;
17
18use crate::{
19    lockfile::RemotePackageSourceUrl,
20    lua_rockspec::{DisplayAsLuaKV, DisplayLuaKV, DisplayLuaValue},
21    package::version::HasModRev,
22    remote_package_source::RemotePackageSource,
23    rockspec::lua_dependency::LuaDependencySpec,
24    variables::{GetVariableError, HasVariables},
25};
26
27#[derive(Debug, Error)]
28pub enum PackageSpecFromPackageReqError {
29    #[error("invalid version for rock {rock}: {err}")]
30    Version {
31        rock: PackageName,
32        err: VersionReqToVersionError,
33    },
34}
35
36#[derive(Clone, Debug, PartialEq, Eq)]
37#[cfg_attr(feature = "clap", derive(clap::Args))]
38pub struct PackageSpec {
39    name: PackageName,
40    version: PackageVersion,
41}
42
43impl PackageSpec {
44    pub fn new(name: PackageName, version: PackageVersion) -> Self {
45        Self { name, version }
46    }
47    pub fn parse(name: String, version: String) -> Result<Self, PackageVersionParseError> {
48        Ok(Self::new(
49            PackageName::new(name),
50            PackageVersion::parse(&version)?,
51        ))
52    }
53    pub fn name(&self) -> &PackageName {
54        &self.name
55    }
56    pub fn version(&self) -> &PackageVersion {
57        &self.version
58    }
59    pub fn into_package_req(self) -> PackageReq {
60        PackageReq {
61            name: self.name,
62            version_req: self.version.into_version_req(),
63        }
64    }
65}
66
67impl TryFrom<PackageReq> for PackageSpec {
68    type Error = PackageSpecFromPackageReqError;
69
70    fn try_from(value: PackageReq) -> Result<Self, Self::Error> {
71        let name = value.name;
72        let version = value.version_req.try_into().map_err(|err| {
73            PackageSpecFromPackageReqError::Version {
74                rock: name.clone(),
75                err,
76            }
77        })?;
78        Ok(Self { name, version })
79    }
80}
81
82impl FromLua for PackageSpec {
83    fn from_lua(
84        value: mlua::prelude::LuaValue,
85        lua: &mlua::prelude::Lua,
86    ) -> mlua::prelude::LuaResult<Self> {
87        let (name, version) = lua.from_value(value)?;
88
89        Self::parse(name, version).into_lua_err()
90    }
91}
92
93impl mlua::UserData for PackageSpec {
94    fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
95        fields.add_field_method_get("name", |_, this| Ok(this.name.to_string()));
96        fields.add_field_method_get("version", |_, this| Ok(this.version.to_string()));
97    }
98
99    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
100        methods.add_method("to_package_req", |_, this, ()| {
101            Ok(this.clone().into_package_req())
102        })
103    }
104}
105
106impl HasVariables for PackageSpec {
107    fn get_variable(&self, input: &str) -> Result<Option<String>, GetVariableError> {
108        Ok(match input {
109            "PACKAGE" => Some(self.name.to_string()),
110            "VERSION" => Some(self.version.to_modrev_string()),
111            _ => None,
112        })
113    }
114}
115
116#[derive(Clone, Debug)]
117pub(crate) struct RemotePackage {
118    pub package: PackageSpec,
119    pub source: RemotePackageSource,
120    /// `Some` if present in a lockfile
121    pub source_url: Option<RemotePackageSourceUrl>,
122}
123
124impl RemotePackage {
125    pub fn new(
126        package: PackageSpec,
127        source: RemotePackageSource,
128        source_url: Option<RemotePackageSourceUrl>,
129    ) -> Self {
130        Self {
131            package,
132            source,
133            source_url,
134        }
135    }
136}
137
138#[derive(PartialEq, Eq, Hash, Clone, Debug)]
139pub(crate) enum RemotePackageType {
140    Rockspec,
141    Src,
142    Binary,
143}
144
145impl Ord for RemotePackageType {
146    fn cmp(&self, other: &Self) -> Ordering {
147        // Priority: binary > rockspec > src
148        match (self, other) {
149            (Self::Binary, Self::Binary)
150            | (Self::Rockspec, Self::Rockspec)
151            | (Self::Src, Self::Src) => Ordering::Equal,
152
153            (Self::Binary, _) => Ordering::Greater,
154            (_, Self::Binary) => Ordering::Less,
155            (Self::Rockspec, Self::Src) => Ordering::Greater,
156            (Self::Src, Self::Rockspec) => Ordering::Less,
157        }
158    }
159}
160
161impl PartialOrd for RemotePackageType {
162    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
163        Some(self.cmp(other))
164    }
165}
166
167#[derive(Clone)]
168pub struct RemotePackageTypeFilterSpec {
169    /// Include Rockspec
170    pub rockspec: bool,
171    /// Include Src
172    pub src: bool,
173    /// Include Binary
174    pub binary: bool,
175}
176
177impl Default for RemotePackageTypeFilterSpec {
178    fn default() -> Self {
179        Self {
180            rockspec: true,
181            src: true,
182            binary: true,
183        }
184    }
185}
186
187#[derive(Error, Debug)]
188pub enum ParseRemotePackageError {
189    #[error("unable to parse package {0}. expected format: `name@version`")]
190    InvalidInput(String),
191    #[error("unable to parse package {package_str}:\n{error}")]
192    InvalidPackageVersion {
193        #[source]
194        error: PackageVersionParseError,
195        package_str: String,
196    },
197}
198
199impl FromStr for PackageSpec {
200    type Err = ParseRemotePackageError;
201
202    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
203        let (name, version) = s
204            .split_once('@')
205            .ok_or_else(|| ParseRemotePackageError::InvalidInput(s.to_string()))?;
206
207        Self::parse(name.to_string(), version.to_string()).map_err(|error| {
208            ParseRemotePackageError::InvalidPackageVersion {
209                error,
210                package_str: s.to_string(),
211            }
212        })
213    }
214}
215
216/// A lua package requirement with a name and an optional version requirement.
217#[derive(Debug, Clone, PartialEq, Eq, Hash)]
218#[cfg_attr(feature = "clap", derive(clap::Args))]
219pub struct PackageReq {
220    /// The name of the package.
221    pub(crate) name: PackageName,
222    /// The version requirement, for example "1.0.0" or ">=1.0.0".
223    pub(crate) version_req: PackageVersionReq,
224}
225
226impl PackageReq {
227    pub fn new(name: String, version: Option<String>) -> Result<Self, PackageVersionReqError> {
228        let version_req = match version {
229            Some(version_req_str) => PackageVersionReq::parse(version_req_str.as_str())?,
230            None => PackageVersionReq::any(),
231        };
232        Ok(Self {
233            name: PackageName::new(name),
234            version_req,
235        })
236    }
237
238    /// # Safety
239    ///
240    /// Call this only if you are absolutely sure that a `PackageReq` can be parsed from the arguments.
241    /// Calling this with invalid arguments is undefined behaviour.
242    pub unsafe fn new_unchecked(name: String, version: Option<String>) -> Self {
243        Self::new(name, version).unwrap_unchecked()
244    }
245
246    pub fn parse(pkg_constraints: &str) -> Result<Self, PackageReqParseError> {
247        Self::from_str(pkg_constraints)
248    }
249    pub fn name(&self) -> &PackageName {
250        &self.name
251    }
252    pub fn version_req(&self) -> &PackageVersionReq {
253        &self.version_req
254    }
255    /// Evaluate whether the given package satisfies the package requirement
256    /// given by `self`.
257    pub fn matches(&self, package: &PackageSpec) -> bool {
258        self.name == package.name && self.version_req.matches(&package.version)
259    }
260}
261
262impl Display for PackageReq {
263    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264        if self.version_req.is_any() {
265            self.name.fmt(f)
266        } else {
267            f.write_str(format!("{}{}", self.name, self.version_req).as_str())
268        }
269    }
270}
271
272impl From<PackageSpec> for PackageReq {
273    fn from(value: PackageSpec) -> Self {
274        value.into_package_req()
275    }
276}
277
278impl From<PackageName> for PackageReq {
279    fn from(name: PackageName) -> Self {
280        Self {
281            name,
282            version_req: PackageVersionReq::any(),
283        }
284    }
285}
286
287impl FromLua for PackageReq {
288    fn from_lua(value: mlua::Value, lua: &mlua::Lua) -> mlua::Result<Self> {
289        let str: String = lua.from_value(value)?;
290        Self::parse(&str).into_lua_err()
291    }
292}
293
294impl mlua::UserData for PackageReq {
295    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
296        methods.add_method("name", |_, this, ()| Ok(this.name.to_string()));
297        methods.add_method("version_req", |_, this, ()| {
298            Ok(this.version_req.to_string())
299        });
300        methods.add_method("matches", |_, this, package: PackageSpec| {
301            Ok(this.matches(&package))
302        });
303    }
304}
305
306/// Wrapper structs for proper serialization of various dependency types.
307pub(crate) struct Dependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
308pub(crate) struct BuildDependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
309
310impl DisplayAsLuaKV for Dependencies<'_> {
311    fn display_lua(&self) -> DisplayLuaKV {
312        DisplayLuaKV {
313            key: "dependencies".to_string(),
314            value: DisplayLuaValue::List(
315                self.0
316                    .iter()
317                    .map(|req| DisplayLuaValue::String(req.to_string()))
318                    .collect(),
319            ),
320        }
321    }
322}
323
324impl DisplayAsLuaKV for BuildDependencies<'_> {
325    fn display_lua(&self) -> DisplayLuaKV {
326        DisplayLuaKV {
327            key: "build_dependencies".to_string(),
328            value: DisplayLuaValue::List(
329                self.0
330                    .iter()
331                    .map(|req| DisplayLuaValue::String(req.to_string()))
332                    .collect(),
333            ),
334        }
335    }
336}
337
338#[derive(Error, Debug)]
339pub enum PackageReqParseError {
340    #[error("could not parse dependency name from {0}")]
341    InvalidDependencyName(String),
342    #[error("could not parse version requirement in '{str}': {error}")]
343    InvalidPackageVersionReq {
344        #[source]
345        error: PackageVersionReqError,
346        str: String,
347    },
348}
349
350impl FromStr for PackageReq {
351    type Err = PackageReqParseError;
352
353    fn from_str(str: &str) -> Result<Self, PackageReqParseError> {
354        let rock_name_str = str
355            .chars()
356            .peeking_take_while(|t| t.is_alphanumeric() || matches!(t, '-' | '_' | '.'))
357            .collect::<String>();
358
359        if rock_name_str.is_empty() {
360            return Err(PackageReqParseError::InvalidDependencyName(str.to_string()));
361        }
362
363        let constraints = str.trim_start_matches(&rock_name_str).trim();
364        let version_req = match constraints {
365            "" => PackageVersionReq::any(),
366            constraints => PackageVersionReq::parse(constraints.trim_start()).map_err(|error| {
367                PackageReqParseError::InvalidPackageVersionReq {
368                    error,
369                    str: str.to_string(),
370                }
371            })?,
372        };
373        Ok(Self {
374            name: PackageName::new(rock_name_str),
375            version_req,
376        })
377    }
378}
379
380impl<'de> Deserialize<'de> for PackageReq {
381    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
382    where
383        D: Deserializer<'de>,
384    {
385        let s = String::deserialize(deserializer)?;
386        Self::from_str(&s).map_err(de::Error::custom)
387    }
388}
389
390/// A luarocks package name, which is always lowercase
391#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash)]
392pub struct PackageName(String);
393
394impl IntoLua for PackageName {
395    fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
396        self.0.into_lua(lua)
397    }
398}
399
400impl PackageName {
401    pub fn new(name: String) -> Self {
402        Self(name.to_lowercase())
403    }
404}
405
406impl<'de> Deserialize<'de> for PackageName {
407    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
408    where
409        D: serde::Deserializer<'de>,
410    {
411        Ok(PackageName::new(String::deserialize(deserializer)?))
412    }
413}
414
415impl FromLua for PackageName {
416    fn from_lua(
417        value: mlua::prelude::LuaValue,
418        lua: &mlua::prelude::Lua,
419    ) -> mlua::prelude::LuaResult<Self> {
420        Ok(Self::new(String::from_lua(value, lua)?))
421    }
422}
423
424impl Serialize for PackageName {
425    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
426    where
427        S: serde::Serializer,
428    {
429        self.0.serialize(serializer)
430    }
431}
432
433impl From<&str> for PackageName {
434    fn from(value: &str) -> Self {
435        Self::new(value.into())
436    }
437}
438
439impl Display for PackageName {
440    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
441        f.write_str(self.0.as_str())
442    }
443}
444
445#[derive(Debug)]
446pub struct PackageNameList(Vec<PackageName>);
447
448impl PackageNameList {
449    pub(crate) fn new(package_names: Vec<PackageName>) -> Self {
450        Self(package_names)
451    }
452}
453
454impl Display for PackageNameList {
455    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
456        f.write_str(self.0.iter().map(|pkg| pkg.to_string()).join(", ").as_str())
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463
464    #[tokio::test]
465    async fn parse_name() {
466        let mut package_name: PackageName = "neorg".into();
467        assert_eq!(package_name.to_string(), "neorg");
468        package_name = "LuaFileSystem".into();
469        assert_eq!(package_name.to_string(), "luafilesystem");
470    }
471
472    #[tokio::test]
473    async fn parse_lua_package() {
474        let neorg = PackageSpec::parse("neorg".into(), "1.0.0".into()).unwrap();
475        let expected_version = PackageVersion::parse("1.0.0").unwrap();
476        assert_eq!(neorg.name().to_string(), "neorg");
477        assert!(matches!(
478            neorg.version().cmp(&expected_version),
479            std::cmp::Ordering::Equal
480        ));
481        let neorg = PackageSpec::parse("neorg".into(), "1.0".into()).unwrap();
482        assert!(matches!(
483            neorg.version().cmp(&expected_version),
484            std::cmp::Ordering::Equal
485        ));
486        let neorg = PackageSpec::parse("neorg".into(), "1".into()).unwrap();
487        assert!(matches!(
488            neorg.version().cmp(&expected_version),
489            std::cmp::Ordering::Equal
490        ));
491    }
492
493    #[tokio::test]
494    async fn parse_lua_package_req() {
495        let mut package_req = PackageReq::new("foo".into(), Some("1.0.0".into())).unwrap();
496        assert!(package_req.matches(&PackageSpec::parse("foo".into(), "1.0.0".into()).unwrap()));
497        assert!(!package_req.matches(&PackageSpec::parse("bar".into(), "1.0.0".into()).unwrap()));
498        assert!(!package_req.matches(&PackageSpec::parse("foo".into(), "2.0.0".into()).unwrap()));
499        package_req = PackageReq::new("foo".into(), Some(">= 1.0.0".into())).unwrap();
500        assert!(package_req.matches(&PackageSpec::parse("foo".into(), "2.0.0".into()).unwrap()));
501        let package_req: PackageReq = "lua >= 5.1".parse().unwrap();
502        assert_eq!(package_req.name.to_string(), "lua");
503        let package_req: PackageReq = "lua>=5.1".parse().unwrap();
504        assert_eq!(package_req.name.to_string(), "lua");
505        let package_req: PackageReq = "toml-edit >= 0.1.0".parse().unwrap();
506        assert_eq!(package_req.name.to_string(), "toml-edit");
507        let package_req: PackageReq = "plugin.nvim >= 0.1.0".parse().unwrap();
508        assert_eq!(package_req.name.to_string(), "plugin.nvim");
509        let package_req: PackageReq = "lfs".parse().unwrap();
510        assert_eq!(package_req.name.to_string(), "lfs");
511        let package_req: PackageReq = "neorg 1.0.0".parse().unwrap();
512        assert_eq!(package_req.name.to_string(), "neorg");
513        let neorg = PackageSpec::parse("neorg".into(), "1.0.0".into()).unwrap();
514        assert!(package_req.matches(&neorg));
515        let neorg = PackageSpec::parse("neorg".into(), "2.0.0".into()).unwrap();
516        assert!(!package_req.matches(&neorg));
517        let package_req: PackageReq = "neorg 2.0.0".parse().unwrap();
518        assert!(package_req.matches(&neorg));
519        let package_req: PackageReq = "neorg = 2.0.0".parse().unwrap();
520        assert!(package_req.matches(&neorg));
521        let package_req: PackageReq = "neorg == 2.0.0".parse().unwrap();
522        assert!(package_req.matches(&neorg));
523        let package_req: PackageReq = "neorg &equals; 2.0.0".parse().unwrap();
524        assert!(package_req.matches(&neorg));
525        let package_req: PackageReq = "neorg >= 1.0, &lt; 2.0".parse().unwrap();
526        let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
527        assert!(package_req.matches(&neorg));
528        let package_req: PackageReq = "neorg &gt; 1.0, &lt; 2.0".parse().unwrap();
529        let neorg = PackageSpec::parse("neorg".into(), "1.11.0".into()).unwrap();
530        assert!(package_req.matches(&neorg));
531        let neorg = PackageSpec::parse("neorg".into(), "3.0.0".into()).unwrap();
532        assert!(!package_req.matches(&neorg));
533        let neorg = PackageSpec::parse("neorg".into(), "0.5".into()).unwrap();
534        assert!(!package_req.matches(&neorg));
535        let package_req: PackageReq = "neorg ~> 1".parse().unwrap();
536        assert!(!package_req.matches(&neorg));
537        let neorg = PackageSpec::parse("neorg".into(), "3".into()).unwrap();
538        assert!(!package_req.matches(&neorg));
539        let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
540        assert!(package_req.matches(&neorg));
541        let package_req: PackageReq = "neorg ~> 1.4".parse().unwrap();
542        let neorg = PackageSpec::parse("neorg".into(), "1.3".into()).unwrap();
543        assert!(!package_req.matches(&neorg));
544        let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
545        assert!(!package_req.matches(&neorg));
546        let neorg = PackageSpec::parse("neorg".into(), "1.4.10".into()).unwrap();
547        assert!(package_req.matches(&neorg));
548        let neorg = PackageSpec::parse("neorg".into(), "1.4".into()).unwrap();
549        assert!(package_req.matches(&neorg));
550        let package_req: PackageReq = "neorg ~> 1.0.5".parse().unwrap();
551        let neorg = PackageSpec::parse("neorg".into(), "1.0.4".into()).unwrap();
552        assert!(!package_req.matches(&neorg));
553        let neorg = PackageSpec::parse("neorg".into(), "1.0.5".into()).unwrap();
554        assert!(package_req.matches(&neorg));
555        let neorg = PackageSpec::parse("neorg".into(), "1.0.6".into()).unwrap();
556        assert!(!package_req.matches(&neorg));
557        // Testing incomplete version constraints
558        let package_req: PackageReq = "lua-utils.nvim ~> 1.1-1".parse().unwrap();
559        let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.1.4".into()).unwrap();
560        assert!(package_req.matches(&lua_utils));
561        let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.1.5".into()).unwrap();
562        assert!(package_req.matches(&lua_utils));
563        let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.2-1".into()).unwrap();
564        assert!(!package_req.matches(&lua_utils));
565    }
566
567    #[tokio::test]
568    pub async fn remote_package_type_priorities() {
569        let rock_types = vec![
570            RemotePackageType::Binary,
571            RemotePackageType::Src,
572            RemotePackageType::Rockspec,
573        ];
574        assert_eq!(
575            rock_types.into_iter().sorted().collect_vec(),
576            vec![
577                RemotePackageType::Src,
578                RemotePackageType::Rockspec,
579                RemotePackageType::Binary,
580            ]
581        );
582    }
583
584    #[tokio::test]
585    async fn package_spec_req_roundtrips() {
586        let spec = PackageSpec::new("foo".into(), "1.0.0-1".parse().unwrap());
587        let req: PackageReq = spec.clone().into();
588        assert_eq!(spec, req.try_into().unwrap());
589
590        let spec = PackageSpec::new("foo".into(), "1.0.0".parse().unwrap());
591        let req: PackageReq = spec.clone().into();
592        assert_eq!(spec, req.try_into().unwrap());
593    }
594
595    #[tokio::test]
596    async fn package_req_spec_roundtrips() {
597        let req: PackageReq = "foo@1.0.0".parse().unwrap();
598        let spec: PackageSpec = req.clone().try_into().unwrap();
599        assert_eq!(req, spec.into());
600
601        let req: PackageReq = "foo@1.0.0-1".parse().unwrap();
602        let spec: PackageSpec = req.clone().try_into().unwrap();
603        assert_eq!(req, spec.into());
604    }
605
606    #[tokio::test]
607    async fn package_spec_parses_same_as_name_version() {
608        let req: PackageReq = "foo@1.0.0".parse().unwrap();
609        let spec: PackageSpec = req.clone().try_into().unwrap();
610        let name: PackageName = "foo".into();
611        let version: PackageVersion = "1.0.0".parse().unwrap();
612        assert_eq!(spec, PackageSpec::new(name, version));
613
614        let req: PackageReq = "foo@1.0.0-1".parse().unwrap();
615        let spec: PackageSpec = req.clone().try_into().unwrap();
616        let name: PackageName = "foo".into();
617        let version: PackageVersion = "1.0.0-1".parse().unwrap();
618        assert_eq!(spec, PackageSpec::new(name, version));
619    }
620}