lux_lib/package/
mod.rs

1use itertools::Itertools;
2use mlua::{ExternalResult, FromLua, IntoLua, LuaSerdeExt};
3use serde::{de, Deserialize, Deserializer, Serialize};
4use std::{cmp::Ordering, fmt::Display, str::FromStr};
5use thiserror::Error;
6
7mod outdated;
8mod version;
9
10pub use outdated::*;
11pub use version::{
12    PackageVersion, PackageVersionParseError, PackageVersionReq, PackageVersionReqError,
13    VersionReqToVersionError,
14};
15
16use crate::{
17    lockfile::RemotePackageSourceUrl,
18    lua_rockspec::{DisplayAsLuaKV, DisplayLuaKV, DisplayLuaValue},
19    remote_package_source::RemotePackageSource,
20    rockspec::lua_dependency::LuaDependencySpec,
21    variables::HasVariables,
22};
23
24#[derive(Debug, Error)]
25pub enum PackageSpecFromPackageReqError {
26    #[error("invalid version for rock {rock}: {err}")]
27    Version {
28        rock: PackageName,
29        err: VersionReqToVersionError,
30    },
31}
32
33#[derive(Clone, Debug)]
34#[cfg_attr(feature = "clap", derive(clap::Args))]
35pub struct PackageSpec {
36    name: PackageName,
37    version: PackageVersion,
38}
39
40impl PackageSpec {
41    pub fn new(name: PackageName, version: PackageVersion) -> Self {
42        Self { name, version }
43    }
44    pub fn parse(name: String, version: String) -> Result<Self, PackageVersionParseError> {
45        Ok(Self::new(
46            PackageName::new(name),
47            PackageVersion::parse(&version)?,
48        ))
49    }
50    pub fn name(&self) -> &PackageName {
51        &self.name
52    }
53    pub fn version(&self) -> &PackageVersion {
54        &self.version
55    }
56    pub fn into_package_req(self) -> PackageReq {
57        PackageReq {
58            name: self.name,
59            version_req: self.version.into_version_req(),
60        }
61    }
62}
63
64impl TryFrom<PackageReq> for PackageSpec {
65    type Error = PackageSpecFromPackageReqError;
66
67    fn try_from(value: PackageReq) -> Result<Self, Self::Error> {
68        let name = value.name;
69        let version = value.version_req.try_into().map_err(|err| {
70            PackageSpecFromPackageReqError::Version {
71                rock: name.clone(),
72                err,
73            }
74        })?;
75        Ok(Self { name, version })
76    }
77}
78
79impl FromLua for PackageSpec {
80    fn from_lua(
81        value: mlua::prelude::LuaValue,
82        lua: &mlua::prelude::Lua,
83    ) -> mlua::prelude::LuaResult<Self> {
84        let (name, version) = lua.from_value(value)?;
85
86        Self::parse(name, version).into_lua_err()
87    }
88}
89
90impl mlua::UserData for PackageSpec {
91    fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
92        fields.add_field_method_get("name", |_, this| Ok(this.name.to_string()));
93        fields.add_field_method_get("version", |_, this| Ok(this.version.to_string()));
94    }
95
96    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
97        methods.add_method("to_package_req", |_, this, ()| {
98            Ok(this.clone().into_package_req())
99        })
100    }
101}
102
103impl HasVariables for PackageSpec {
104    fn get_variable(&self, input: &str) -> Option<String> {
105        match input {
106            "PACKAGE" => Some(self.name.to_string()),
107            "VERSION" => Some(self.version.to_string()),
108            _ => None,
109        }
110    }
111}
112
113#[derive(Clone, Debug)]
114pub(crate) struct RemotePackage {
115    pub package: PackageSpec,
116    pub source: RemotePackageSource,
117    /// `Some` if present in a lockfile
118    pub source_url: Option<RemotePackageSourceUrl>,
119}
120
121impl RemotePackage {
122    pub fn new(
123        package: PackageSpec,
124        source: RemotePackageSource,
125        source_url: Option<RemotePackageSourceUrl>,
126    ) -> Self {
127        Self {
128            package,
129            source,
130            source_url,
131        }
132    }
133}
134
135#[derive(PartialEq, Eq, Hash, Clone, Debug)]
136pub(crate) enum RemotePackageType {
137    Rockspec,
138    Src,
139    Binary,
140}
141
142impl Ord for RemotePackageType {
143    fn cmp(&self, other: &Self) -> Ordering {
144        // Priority: binary > rockspec > src
145        match (self, other) {
146            (Self::Binary, Self::Binary)
147            | (Self::Rockspec, Self::Rockspec)
148            | (Self::Src, Self::Src) => Ordering::Equal,
149
150            (Self::Binary, _) => Ordering::Greater,
151            (_, Self::Binary) => Ordering::Less,
152            (Self::Rockspec, Self::Src) => Ordering::Greater,
153            (Self::Src, Self::Rockspec) => Ordering::Less,
154        }
155    }
156}
157
158impl PartialOrd for RemotePackageType {
159    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
160        Some(self.cmp(other))
161    }
162}
163
164#[derive(Clone)]
165pub struct RemotePackageTypeFilterSpec {
166    /// Include Rockspec
167    pub rockspec: bool,
168    /// Include Src
169    pub src: bool,
170    /// Include Binary
171    pub binary: bool,
172}
173
174impl Default for RemotePackageTypeFilterSpec {
175    fn default() -> Self {
176        Self {
177            rockspec: true,
178            src: true,
179            binary: true,
180        }
181    }
182}
183
184#[derive(Error, Debug)]
185pub enum ParseRemotePackageError {
186    #[error("unable to parse package {0}. expected format: `name@version`")]
187    InvalidInput(String),
188    #[error("unable to parse package {package_str}: {error}")]
189    InvalidPackageVersion {
190        #[source]
191        error: PackageVersionParseError,
192        package_str: String,
193    },
194}
195
196impl FromStr for PackageSpec {
197    type Err = ParseRemotePackageError;
198
199    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
200        let (name, version) = s
201            .split_once('@')
202            .ok_or_else(|| ParseRemotePackageError::InvalidInput(s.to_string()))?;
203
204        Self::parse(name.to_string(), version.to_string()).map_err(|error| {
205            ParseRemotePackageError::InvalidPackageVersion {
206                error,
207                package_str: s.to_string(),
208            }
209        })
210    }
211}
212
213/// A lua package requirement with a name and an optional version requirement.
214#[derive(Debug, Clone, PartialEq)]
215#[cfg_attr(feature = "clap", derive(clap::Args))]
216pub struct PackageReq {
217    /// The name of the package.
218    pub(crate) name: PackageName,
219    /// The version requirement, for example "1.0.0" or ">=1.0.0".
220    pub(crate) version_req: PackageVersionReq,
221}
222
223impl PackageReq {
224    pub fn new(name: String, version: Option<String>) -> Result<Self, PackageVersionReqError> {
225        let version_req = match version {
226            Some(version_req_str) => PackageVersionReq::parse(version_req_str.as_str())?,
227            None => PackageVersionReq::any(),
228        };
229        Ok(Self {
230            name: PackageName::new(name),
231            version_req,
232        })
233    }
234    pub fn parse(pkg_constraints: &str) -> Result<Self, PackageReqParseError> {
235        Self::from_str(pkg_constraints)
236    }
237    pub fn name(&self) -> &PackageName {
238        &self.name
239    }
240    pub fn version_req(&self) -> &PackageVersionReq {
241        &self.version_req
242    }
243    /// Evaluate whether the given package satisfies the package requirement
244    /// given by `self`.
245    pub fn matches(&self, package: &PackageSpec) -> bool {
246        self.name == package.name && self.version_req.matches(&package.version)
247    }
248}
249
250impl Display for PackageReq {
251    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252        if self.version_req.is_any() {
253            self.name.fmt(f)
254        } else {
255            f.write_str(format!("{}{}", self.name, self.version_req).as_str())
256        }
257    }
258}
259
260impl From<PackageSpec> for PackageReq {
261    fn from(value: PackageSpec) -> Self {
262        value.into_package_req()
263    }
264}
265
266impl From<PackageName> for PackageReq {
267    fn from(name: PackageName) -> Self {
268        Self {
269            name,
270            version_req: PackageVersionReq::any(),
271        }
272    }
273}
274
275impl FromLua for PackageReq {
276    fn from_lua(value: mlua::Value, lua: &mlua::Lua) -> mlua::Result<Self> {
277        let str: String = lua.from_value(value)?;
278        Self::parse(&str).into_lua_err()
279    }
280}
281
282impl mlua::UserData for PackageReq {
283    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
284        methods.add_method("name", |_, this, ()| Ok(this.name.to_string()));
285        methods.add_method("version_req", |_, this, ()| {
286            Ok(this.version_req.to_string())
287        });
288        methods.add_method("matches", |_, this, package: PackageSpec| {
289            Ok(this.matches(&package))
290        });
291    }
292}
293
294/// Wrapper structs for proper serialization of various dependency types.
295pub(crate) struct Dependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
296pub(crate) struct BuildDependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
297pub(crate) struct TestDependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
298
299impl DisplayAsLuaKV for Dependencies<'_> {
300    fn display_lua(&self) -> DisplayLuaKV {
301        DisplayLuaKV {
302            key: "dependencies".to_string(),
303            value: DisplayLuaValue::List(
304                self.0
305                    .iter()
306                    .map(|req| DisplayLuaValue::String(req.to_string()))
307                    .collect(),
308            ),
309        }
310    }
311}
312
313impl DisplayAsLuaKV for BuildDependencies<'_> {
314    fn display_lua(&self) -> DisplayLuaKV {
315        DisplayLuaKV {
316            key: "build_dependencies".to_string(),
317            value: DisplayLuaValue::List(
318                self.0
319                    .iter()
320                    .map(|req| DisplayLuaValue::String(req.to_string()))
321                    .collect(),
322            ),
323        }
324    }
325}
326
327impl DisplayAsLuaKV for TestDependencies<'_> {
328    fn display_lua(&self) -> DisplayLuaKV {
329        DisplayLuaKV {
330            key: "test_dependencies".to_string(),
331            value: DisplayLuaValue::List(
332                self.0
333                    .iter()
334                    .map(|req| DisplayLuaValue::String(req.to_string()))
335                    .collect(),
336            ),
337        }
338    }
339}
340
341#[derive(Error, Debug)]
342pub enum PackageReqParseError {
343    #[error("could not parse dependency name from {0}")]
344    InvalidDependencyName(String),
345    #[error("could not parse version requirement in '{str}': {error}")]
346    InvalidPackageVersionReq {
347        #[source]
348        error: PackageVersionReqError,
349        str: String,
350    },
351}
352
353impl FromStr for PackageReq {
354    type Err = PackageReqParseError;
355
356    fn from_str(str: &str) -> Result<Self, PackageReqParseError> {
357        let rock_name_str = str
358            .chars()
359            .peeking_take_while(|t| t.is_alphanumeric() || matches!(t, '-' | '_' | '.'))
360            .collect::<String>();
361
362        if rock_name_str.is_empty() {
363            return Err(PackageReqParseError::InvalidDependencyName(str.to_string()));
364        }
365
366        let constraints = str.trim_start_matches(&rock_name_str).trim();
367        let version_req = match constraints {
368            "" => PackageVersionReq::any(),
369            constraints => PackageVersionReq::parse(constraints.trim_start()).map_err(|error| {
370                PackageReqParseError::InvalidPackageVersionReq {
371                    error,
372                    str: str.to_string(),
373                }
374            })?,
375        };
376        Ok(Self {
377            name: PackageName::new(rock_name_str),
378            version_req,
379        })
380    }
381}
382
383impl<'de> Deserialize<'de> for PackageReq {
384    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
385    where
386        D: Deserializer<'de>,
387    {
388        let s = String::deserialize(deserializer)?;
389        Self::from_str(&s).map_err(de::Error::custom)
390    }
391}
392
393/// A luarocks package name, which is always lowercase
394#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash)]
395pub struct PackageName(String);
396
397impl IntoLua for PackageName {
398    fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
399        self.0.into_lua(lua)
400    }
401}
402
403impl PackageName {
404    pub fn new(name: String) -> Self {
405        Self(name.to_lowercase())
406    }
407}
408
409impl<'de> Deserialize<'de> for PackageName {
410    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
411    where
412        D: serde::Deserializer<'de>,
413    {
414        Ok(PackageName::new(String::deserialize(deserializer)?))
415    }
416}
417
418impl FromLua for PackageName {
419    fn from_lua(
420        value: mlua::prelude::LuaValue,
421        lua: &mlua::prelude::Lua,
422    ) -> mlua::prelude::LuaResult<Self> {
423        Ok(Self::new(String::from_lua(value, lua)?))
424    }
425}
426
427impl Serialize for PackageName {
428    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
429    where
430        S: serde::Serializer,
431    {
432        self.0.serialize(serializer)
433    }
434}
435
436impl From<&str> for PackageName {
437    fn from(value: &str) -> Self {
438        Self::new(value.into())
439    }
440}
441
442impl Display for PackageName {
443    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
444        f.write_str(self.0.as_str())
445    }
446}
447
448#[derive(Debug)]
449pub struct PackageNameList(Vec<PackageName>);
450
451impl PackageNameList {
452    pub(crate) fn new(package_names: Vec<PackageName>) -> Self {
453        Self(package_names)
454    }
455}
456
457impl Display for PackageNameList {
458    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
459        f.write_str(self.0.iter().map(|pkg| pkg.to_string()).join(", ").as_str())
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    #[tokio::test]
468    async fn parse_name() {
469        let mut package_name: PackageName = "neorg".into();
470        assert_eq!(package_name.to_string(), "neorg");
471        package_name = "LuaFileSystem".into();
472        assert_eq!(package_name.to_string(), "luafilesystem");
473    }
474
475    #[tokio::test]
476    async fn parse_lua_package() {
477        let neorg = PackageSpec::parse("neorg".into(), "1.0.0".into()).unwrap();
478        let expected_version = PackageVersion::parse("1.0.0").unwrap();
479        assert_eq!(neorg.name().to_string(), "neorg");
480        assert!(matches!(
481            neorg.version().cmp(&expected_version),
482            std::cmp::Ordering::Equal
483        ));
484        let neorg = PackageSpec::parse("neorg".into(), "1.0".into()).unwrap();
485        assert!(matches!(
486            neorg.version().cmp(&expected_version),
487            std::cmp::Ordering::Equal
488        ));
489        let neorg = PackageSpec::parse("neorg".into(), "1".into()).unwrap();
490        assert!(matches!(
491            neorg.version().cmp(&expected_version),
492            std::cmp::Ordering::Equal
493        ));
494    }
495
496    #[tokio::test]
497    async fn parse_lua_package_req() {
498        let mut package_req = PackageReq::new("foo".into(), Some("1.0.0".into())).unwrap();
499        assert!(package_req.matches(&PackageSpec::parse("foo".into(), "1.0.0".into()).unwrap()));
500        assert!(!package_req.matches(&PackageSpec::parse("bar".into(), "1.0.0".into()).unwrap()));
501        assert!(!package_req.matches(&PackageSpec::parse("foo".into(), "2.0.0".into()).unwrap()));
502        package_req = PackageReq::new("foo".into(), Some(">= 1.0.0".into())).unwrap();
503        assert!(package_req.matches(&PackageSpec::parse("foo".into(), "2.0.0".into()).unwrap()));
504        let package_req: PackageReq = "lua >= 5.1".parse().unwrap();
505        assert_eq!(package_req.name.to_string(), "lua");
506        let package_req: PackageReq = "lua>=5.1".parse().unwrap();
507        assert_eq!(package_req.name.to_string(), "lua");
508        let package_req: PackageReq = "toml-edit >= 0.1.0".parse().unwrap();
509        assert_eq!(package_req.name.to_string(), "toml-edit");
510        let package_req: PackageReq = "plugin.nvim >= 0.1.0".parse().unwrap();
511        assert_eq!(package_req.name.to_string(), "plugin.nvim");
512        let package_req: PackageReq = "lfs".parse().unwrap();
513        assert_eq!(package_req.name.to_string(), "lfs");
514        let package_req: PackageReq = "neorg 1.0.0".parse().unwrap();
515        assert_eq!(package_req.name.to_string(), "neorg");
516        let neorg = PackageSpec::parse("neorg".into(), "1.0.0".into()).unwrap();
517        assert!(package_req.matches(&neorg));
518        let neorg = PackageSpec::parse("neorg".into(), "2.0.0".into()).unwrap();
519        assert!(!package_req.matches(&neorg));
520        let package_req: PackageReq = "neorg 2.0.0".parse().unwrap();
521        assert!(package_req.matches(&neorg));
522        let package_req: PackageReq = "neorg = 2.0.0".parse().unwrap();
523        assert!(package_req.matches(&neorg));
524        let package_req: PackageReq = "neorg == 2.0.0".parse().unwrap();
525        assert!(package_req.matches(&neorg));
526        let package_req: PackageReq = "neorg &equals; 2.0.0".parse().unwrap();
527        assert!(package_req.matches(&neorg));
528        let package_req: PackageReq = "neorg >= 1.0, &lt; 2.0".parse().unwrap();
529        let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
530        assert!(package_req.matches(&neorg));
531        let package_req: PackageReq = "neorg &gt; 1.0, &lt; 2.0".parse().unwrap();
532        let neorg = PackageSpec::parse("neorg".into(), "1.11.0".into()).unwrap();
533        assert!(package_req.matches(&neorg));
534        let neorg = PackageSpec::parse("neorg".into(), "3.0.0".into()).unwrap();
535        assert!(!package_req.matches(&neorg));
536        let neorg = PackageSpec::parse("neorg".into(), "0.5".into()).unwrap();
537        assert!(!package_req.matches(&neorg));
538        let package_req: PackageReq = "neorg ~> 1".parse().unwrap();
539        assert!(!package_req.matches(&neorg));
540        let neorg = PackageSpec::parse("neorg".into(), "3".into()).unwrap();
541        assert!(!package_req.matches(&neorg));
542        let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
543        assert!(package_req.matches(&neorg));
544        let package_req: PackageReq = "neorg ~> 1.4".parse().unwrap();
545        let neorg = PackageSpec::parse("neorg".into(), "1.3".into()).unwrap();
546        assert!(!package_req.matches(&neorg));
547        let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
548        assert!(!package_req.matches(&neorg));
549        let neorg = PackageSpec::parse("neorg".into(), "1.4.10".into()).unwrap();
550        assert!(package_req.matches(&neorg));
551        let neorg = PackageSpec::parse("neorg".into(), "1.4".into()).unwrap();
552        assert!(package_req.matches(&neorg));
553        let package_req: PackageReq = "neorg ~> 1.0.5".parse().unwrap();
554        let neorg = PackageSpec::parse("neorg".into(), "1.0.4".into()).unwrap();
555        assert!(!package_req.matches(&neorg));
556        let neorg = PackageSpec::parse("neorg".into(), "1.0.5".into()).unwrap();
557        assert!(package_req.matches(&neorg));
558        let neorg = PackageSpec::parse("neorg".into(), "1.0.6".into()).unwrap();
559        assert!(!package_req.matches(&neorg));
560        // Testing incomplete version constraints
561        let package_req: PackageReq = "lua-utils.nvim ~> 1.1-1".parse().unwrap();
562        let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.1.4".into()).unwrap();
563        assert!(package_req.matches(&lua_utils));
564        let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.1.5".into()).unwrap();
565        assert!(package_req.matches(&lua_utils));
566        let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.2-1".into()).unwrap();
567        assert!(!package_req.matches(&lua_utils));
568    }
569
570    #[tokio::test]
571    pub async fn remote_package_type_priorities() {
572        let rock_types = vec![
573            RemotePackageType::Binary,
574            RemotePackageType::Src,
575            RemotePackageType::Rockspec,
576        ];
577        assert_eq!(
578            rock_types.into_iter().sorted().collect_vec(),
579            vec![
580                RemotePackageType::Src,
581                RemotePackageType::Rockspec,
582                RemotePackageType::Binary,
583            ]
584        );
585    }
586}