lux_lib/package/
mod.rs

1use itertools::Itertools;
2use mlua::{ExternalResult, FromLua, IntoLua, LuaSerdeExt};
3use serde::{de, Deserialize, Deserializer, Serialize};
4use std::{cmp::Ordering, fmt::Display, str::FromStr};
5use thiserror::Error;
6
7mod outdated;
8mod version;
9
10pub use outdated::*;
11pub use version::{
12    PackageVersion, PackageVersionParseError, PackageVersionReq, PackageVersionReqError,
13    VersionReqToVersionError,
14};
15
16use crate::{
17    lockfile::RemotePackageSourceUrl,
18    lua_rockspec::{DisplayAsLuaKV, DisplayLuaKV, DisplayLuaValue},
19    package::version::HasModRev,
20    remote_package_source::RemotePackageSource,
21    rockspec::lua_dependency::LuaDependencySpec,
22    variables::{GetVariableError, HasVariables},
23};
24
25#[derive(Debug, Error)]
26pub enum PackageSpecFromPackageReqError {
27    #[error("invalid version for rock {rock}: {err}")]
28    Version {
29        rock: PackageName,
30        err: VersionReqToVersionError,
31    },
32}
33
34#[derive(Clone, Debug)]
35#[cfg_attr(feature = "clap", derive(clap::Args))]
36pub struct PackageSpec {
37    name: PackageName,
38    version: PackageVersion,
39}
40
41impl PackageSpec {
42    pub fn new(name: PackageName, version: PackageVersion) -> Self {
43        Self { name, version }
44    }
45    pub fn parse(name: String, version: String) -> Result<Self, PackageVersionParseError> {
46        Ok(Self::new(
47            PackageName::new(name),
48            PackageVersion::parse(&version)?,
49        ))
50    }
51    pub fn name(&self) -> &PackageName {
52        &self.name
53    }
54    pub fn version(&self) -> &PackageVersion {
55        &self.version
56    }
57    pub fn into_package_req(self) -> PackageReq {
58        PackageReq {
59            name: self.name,
60            version_req: self.version.into_version_req(),
61        }
62    }
63}
64
65impl TryFrom<PackageReq> for PackageSpec {
66    type Error = PackageSpecFromPackageReqError;
67
68    fn try_from(value: PackageReq) -> Result<Self, Self::Error> {
69        let name = value.name;
70        let version = value.version_req.try_into().map_err(|err| {
71            PackageSpecFromPackageReqError::Version {
72                rock: name.clone(),
73                err,
74            }
75        })?;
76        Ok(Self { name, version })
77    }
78}
79
80impl FromLua for PackageSpec {
81    fn from_lua(
82        value: mlua::prelude::LuaValue,
83        lua: &mlua::prelude::Lua,
84    ) -> mlua::prelude::LuaResult<Self> {
85        let (name, version) = lua.from_value(value)?;
86
87        Self::parse(name, version).into_lua_err()
88    }
89}
90
91impl mlua::UserData for PackageSpec {
92    fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
93        fields.add_field_method_get("name", |_, this| Ok(this.name.to_string()));
94        fields.add_field_method_get("version", |_, this| Ok(this.version.to_string()));
95    }
96
97    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
98        methods.add_method("to_package_req", |_, this, ()| {
99            Ok(this.clone().into_package_req())
100        })
101    }
102}
103
104impl HasVariables for PackageSpec {
105    fn get_variable(&self, input: &str) -> Result<Option<String>, GetVariableError> {
106        Ok(match input {
107            "PACKAGE" => Some(self.name.to_string()),
108            "VERSION" => Some(self.version.to_modrev_string()),
109            _ => None,
110        })
111    }
112}
113
114#[derive(Clone, Debug)]
115pub(crate) struct RemotePackage {
116    pub package: PackageSpec,
117    pub source: RemotePackageSource,
118    /// `Some` if present in a lockfile
119    pub source_url: Option<RemotePackageSourceUrl>,
120}
121
122impl RemotePackage {
123    pub fn new(
124        package: PackageSpec,
125        source: RemotePackageSource,
126        source_url: Option<RemotePackageSourceUrl>,
127    ) -> Self {
128        Self {
129            package,
130            source,
131            source_url,
132        }
133    }
134}
135
136#[derive(PartialEq, Eq, Hash, Clone, Debug)]
137pub(crate) enum RemotePackageType {
138    Rockspec,
139    Src,
140    Binary,
141}
142
143impl Ord for RemotePackageType {
144    fn cmp(&self, other: &Self) -> Ordering {
145        // Priority: binary > rockspec > src
146        match (self, other) {
147            (Self::Binary, Self::Binary)
148            | (Self::Rockspec, Self::Rockspec)
149            | (Self::Src, Self::Src) => Ordering::Equal,
150
151            (Self::Binary, _) => Ordering::Greater,
152            (_, Self::Binary) => Ordering::Less,
153            (Self::Rockspec, Self::Src) => Ordering::Greater,
154            (Self::Src, Self::Rockspec) => Ordering::Less,
155        }
156    }
157}
158
159impl PartialOrd for RemotePackageType {
160    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
161        Some(self.cmp(other))
162    }
163}
164
165#[derive(Clone)]
166pub struct RemotePackageTypeFilterSpec {
167    /// Include Rockspec
168    pub rockspec: bool,
169    /// Include Src
170    pub src: bool,
171    /// Include Binary
172    pub binary: bool,
173}
174
175impl Default for RemotePackageTypeFilterSpec {
176    fn default() -> Self {
177        Self {
178            rockspec: true,
179            src: true,
180            binary: true,
181        }
182    }
183}
184
185#[derive(Error, Debug)]
186pub enum ParseRemotePackageError {
187    #[error("unable to parse package {0}. expected format: `name@version`")]
188    InvalidInput(String),
189    #[error("unable to parse package {package_str}: {error}")]
190    InvalidPackageVersion {
191        #[source]
192        error: PackageVersionParseError,
193        package_str: String,
194    },
195}
196
197impl FromStr for PackageSpec {
198    type Err = ParseRemotePackageError;
199
200    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
201        let (name, version) = s
202            .split_once('@')
203            .ok_or_else(|| ParseRemotePackageError::InvalidInput(s.to_string()))?;
204
205        Self::parse(name.to_string(), version.to_string()).map_err(|error| {
206            ParseRemotePackageError::InvalidPackageVersion {
207                error,
208                package_str: s.to_string(),
209            }
210        })
211    }
212}
213
214/// A lua package requirement with a name and an optional version requirement.
215#[derive(Debug, Clone, PartialEq)]
216#[cfg_attr(feature = "clap", derive(clap::Args))]
217pub struct PackageReq {
218    /// The name of the package.
219    pub(crate) name: PackageName,
220    /// The version requirement, for example "1.0.0" or ">=1.0.0".
221    pub(crate) version_req: PackageVersionReq,
222}
223
224impl PackageReq {
225    pub fn new(name: String, version: Option<String>) -> Result<Self, PackageVersionReqError> {
226        let version_req = match version {
227            Some(version_req_str) => PackageVersionReq::parse(version_req_str.as_str())?,
228            None => PackageVersionReq::any(),
229        };
230        Ok(Self {
231            name: PackageName::new(name),
232            version_req,
233        })
234    }
235    pub fn parse(pkg_constraints: &str) -> Result<Self, PackageReqParseError> {
236        Self::from_str(pkg_constraints)
237    }
238    pub fn name(&self) -> &PackageName {
239        &self.name
240    }
241    pub fn version_req(&self) -> &PackageVersionReq {
242        &self.version_req
243    }
244    /// Evaluate whether the given package satisfies the package requirement
245    /// given by `self`.
246    pub fn matches(&self, package: &PackageSpec) -> bool {
247        self.name == package.name && self.version_req.matches(&package.version)
248    }
249}
250
251impl Display for PackageReq {
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        if self.version_req.is_any() {
254            self.name.fmt(f)
255        } else {
256            f.write_str(format!("{}{}", self.name, self.version_req).as_str())
257        }
258    }
259}
260
261impl From<PackageSpec> for PackageReq {
262    fn from(value: PackageSpec) -> Self {
263        value.into_package_req()
264    }
265}
266
267impl From<PackageName> for PackageReq {
268    fn from(name: PackageName) -> Self {
269        Self {
270            name,
271            version_req: PackageVersionReq::any(),
272        }
273    }
274}
275
276impl FromLua for PackageReq {
277    fn from_lua(value: mlua::Value, lua: &mlua::Lua) -> mlua::Result<Self> {
278        let str: String = lua.from_value(value)?;
279        Self::parse(&str).into_lua_err()
280    }
281}
282
283impl mlua::UserData for PackageReq {
284    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
285        methods.add_method("name", |_, this, ()| Ok(this.name.to_string()));
286        methods.add_method("version_req", |_, this, ()| {
287            Ok(this.version_req.to_string())
288        });
289        methods.add_method("matches", |_, this, package: PackageSpec| {
290            Ok(this.matches(&package))
291        });
292    }
293}
294
295/// Wrapper structs for proper serialization of various dependency types.
296pub(crate) struct Dependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
297pub(crate) struct BuildDependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
298pub(crate) struct TestDependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
299
300impl DisplayAsLuaKV for Dependencies<'_> {
301    fn display_lua(&self) -> DisplayLuaKV {
302        DisplayLuaKV {
303            key: "dependencies".to_string(),
304            value: DisplayLuaValue::List(
305                self.0
306                    .iter()
307                    .map(|req| DisplayLuaValue::String(req.to_string()))
308                    .collect(),
309            ),
310        }
311    }
312}
313
314impl DisplayAsLuaKV for BuildDependencies<'_> {
315    fn display_lua(&self) -> DisplayLuaKV {
316        DisplayLuaKV {
317            key: "build_dependencies".to_string(),
318            value: DisplayLuaValue::List(
319                self.0
320                    .iter()
321                    .map(|req| DisplayLuaValue::String(req.to_string()))
322                    .collect(),
323            ),
324        }
325    }
326}
327
328impl DisplayAsLuaKV for TestDependencies<'_> {
329    fn display_lua(&self) -> DisplayLuaKV {
330        DisplayLuaKV {
331            key: "test_dependencies".to_string(),
332            value: DisplayLuaValue::List(
333                self.0
334                    .iter()
335                    .map(|req| DisplayLuaValue::String(req.to_string()))
336                    .collect(),
337            ),
338        }
339    }
340}
341
342#[derive(Error, Debug)]
343pub enum PackageReqParseError {
344    #[error("could not parse dependency name from {0}")]
345    InvalidDependencyName(String),
346    #[error("could not parse version requirement in '{str}': {error}")]
347    InvalidPackageVersionReq {
348        #[source]
349        error: PackageVersionReqError,
350        str: String,
351    },
352}
353
354impl FromStr for PackageReq {
355    type Err = PackageReqParseError;
356
357    fn from_str(str: &str) -> Result<Self, PackageReqParseError> {
358        let rock_name_str = str
359            .chars()
360            .peeking_take_while(|t| t.is_alphanumeric() || matches!(t, '-' | '_' | '.'))
361            .collect::<String>();
362
363        if rock_name_str.is_empty() {
364            return Err(PackageReqParseError::InvalidDependencyName(str.to_string()));
365        }
366
367        let constraints = str.trim_start_matches(&rock_name_str).trim();
368        let version_req = match constraints {
369            "" => PackageVersionReq::any(),
370            constraints => PackageVersionReq::parse(constraints.trim_start()).map_err(|error| {
371                PackageReqParseError::InvalidPackageVersionReq {
372                    error,
373                    str: str.to_string(),
374                }
375            })?,
376        };
377        Ok(Self {
378            name: PackageName::new(rock_name_str),
379            version_req,
380        })
381    }
382}
383
384impl<'de> Deserialize<'de> for PackageReq {
385    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
386    where
387        D: Deserializer<'de>,
388    {
389        let s = String::deserialize(deserializer)?;
390        Self::from_str(&s).map_err(de::Error::custom)
391    }
392}
393
394/// A luarocks package name, which is always lowercase
395#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash)]
396pub struct PackageName(String);
397
398impl IntoLua for PackageName {
399    fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
400        self.0.into_lua(lua)
401    }
402}
403
404impl PackageName {
405    pub fn new(name: String) -> Self {
406        Self(name.to_lowercase())
407    }
408}
409
410impl<'de> Deserialize<'de> for PackageName {
411    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
412    where
413        D: serde::Deserializer<'de>,
414    {
415        Ok(PackageName::new(String::deserialize(deserializer)?))
416    }
417}
418
419impl FromLua for PackageName {
420    fn from_lua(
421        value: mlua::prelude::LuaValue,
422        lua: &mlua::prelude::Lua,
423    ) -> mlua::prelude::LuaResult<Self> {
424        Ok(Self::new(String::from_lua(value, lua)?))
425    }
426}
427
428impl Serialize for PackageName {
429    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
430    where
431        S: serde::Serializer,
432    {
433        self.0.serialize(serializer)
434    }
435}
436
437impl From<&str> for PackageName {
438    fn from(value: &str) -> Self {
439        Self::new(value.into())
440    }
441}
442
443impl Display for PackageName {
444    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
445        f.write_str(self.0.as_str())
446    }
447}
448
449#[derive(Debug)]
450pub struct PackageNameList(Vec<PackageName>);
451
452impl PackageNameList {
453    pub(crate) fn new(package_names: Vec<PackageName>) -> Self {
454        Self(package_names)
455    }
456}
457
458impl Display for PackageNameList {
459    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
460        f.write_str(self.0.iter().map(|pkg| pkg.to_string()).join(", ").as_str())
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467
468    #[tokio::test]
469    async fn parse_name() {
470        let mut package_name: PackageName = "neorg".into();
471        assert_eq!(package_name.to_string(), "neorg");
472        package_name = "LuaFileSystem".into();
473        assert_eq!(package_name.to_string(), "luafilesystem");
474    }
475
476    #[tokio::test]
477    async fn parse_lua_package() {
478        let neorg = PackageSpec::parse("neorg".into(), "1.0.0".into()).unwrap();
479        let expected_version = PackageVersion::parse("1.0.0").unwrap();
480        assert_eq!(neorg.name().to_string(), "neorg");
481        assert!(matches!(
482            neorg.version().cmp(&expected_version),
483            std::cmp::Ordering::Equal
484        ));
485        let neorg = PackageSpec::parse("neorg".into(), "1.0".into()).unwrap();
486        assert!(matches!(
487            neorg.version().cmp(&expected_version),
488            std::cmp::Ordering::Equal
489        ));
490        let neorg = PackageSpec::parse("neorg".into(), "1".into()).unwrap();
491        assert!(matches!(
492            neorg.version().cmp(&expected_version),
493            std::cmp::Ordering::Equal
494        ));
495    }
496
497    #[tokio::test]
498    async fn parse_lua_package_req() {
499        let mut package_req = PackageReq::new("foo".into(), Some("1.0.0".into())).unwrap();
500        assert!(package_req.matches(&PackageSpec::parse("foo".into(), "1.0.0".into()).unwrap()));
501        assert!(!package_req.matches(&PackageSpec::parse("bar".into(), "1.0.0".into()).unwrap()));
502        assert!(!package_req.matches(&PackageSpec::parse("foo".into(), "2.0.0".into()).unwrap()));
503        package_req = PackageReq::new("foo".into(), Some(">= 1.0.0".into())).unwrap();
504        assert!(package_req.matches(&PackageSpec::parse("foo".into(), "2.0.0".into()).unwrap()));
505        let package_req: PackageReq = "lua >= 5.1".parse().unwrap();
506        assert_eq!(package_req.name.to_string(), "lua");
507        let package_req: PackageReq = "lua>=5.1".parse().unwrap();
508        assert_eq!(package_req.name.to_string(), "lua");
509        let package_req: PackageReq = "toml-edit >= 0.1.0".parse().unwrap();
510        assert_eq!(package_req.name.to_string(), "toml-edit");
511        let package_req: PackageReq = "plugin.nvim >= 0.1.0".parse().unwrap();
512        assert_eq!(package_req.name.to_string(), "plugin.nvim");
513        let package_req: PackageReq = "lfs".parse().unwrap();
514        assert_eq!(package_req.name.to_string(), "lfs");
515        let package_req: PackageReq = "neorg 1.0.0".parse().unwrap();
516        assert_eq!(package_req.name.to_string(), "neorg");
517        let neorg = PackageSpec::parse("neorg".into(), "1.0.0".into()).unwrap();
518        assert!(package_req.matches(&neorg));
519        let neorg = PackageSpec::parse("neorg".into(), "2.0.0".into()).unwrap();
520        assert!(!package_req.matches(&neorg));
521        let package_req: PackageReq = "neorg 2.0.0".parse().unwrap();
522        assert!(package_req.matches(&neorg));
523        let package_req: PackageReq = "neorg = 2.0.0".parse().unwrap();
524        assert!(package_req.matches(&neorg));
525        let package_req: PackageReq = "neorg == 2.0.0".parse().unwrap();
526        assert!(package_req.matches(&neorg));
527        let package_req: PackageReq = "neorg &equals; 2.0.0".parse().unwrap();
528        assert!(package_req.matches(&neorg));
529        let package_req: PackageReq = "neorg >= 1.0, &lt; 2.0".parse().unwrap();
530        let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
531        assert!(package_req.matches(&neorg));
532        let package_req: PackageReq = "neorg &gt; 1.0, &lt; 2.0".parse().unwrap();
533        let neorg = PackageSpec::parse("neorg".into(), "1.11.0".into()).unwrap();
534        assert!(package_req.matches(&neorg));
535        let neorg = PackageSpec::parse("neorg".into(), "3.0.0".into()).unwrap();
536        assert!(!package_req.matches(&neorg));
537        let neorg = PackageSpec::parse("neorg".into(), "0.5".into()).unwrap();
538        assert!(!package_req.matches(&neorg));
539        let package_req: PackageReq = "neorg ~> 1".parse().unwrap();
540        assert!(!package_req.matches(&neorg));
541        let neorg = PackageSpec::parse("neorg".into(), "3".into()).unwrap();
542        assert!(!package_req.matches(&neorg));
543        let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
544        assert!(package_req.matches(&neorg));
545        let package_req: PackageReq = "neorg ~> 1.4".parse().unwrap();
546        let neorg = PackageSpec::parse("neorg".into(), "1.3".into()).unwrap();
547        assert!(!package_req.matches(&neorg));
548        let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
549        assert!(!package_req.matches(&neorg));
550        let neorg = PackageSpec::parse("neorg".into(), "1.4.10".into()).unwrap();
551        assert!(package_req.matches(&neorg));
552        let neorg = PackageSpec::parse("neorg".into(), "1.4".into()).unwrap();
553        assert!(package_req.matches(&neorg));
554        let package_req: PackageReq = "neorg ~> 1.0.5".parse().unwrap();
555        let neorg = PackageSpec::parse("neorg".into(), "1.0.4".into()).unwrap();
556        assert!(!package_req.matches(&neorg));
557        let neorg = PackageSpec::parse("neorg".into(), "1.0.5".into()).unwrap();
558        assert!(package_req.matches(&neorg));
559        let neorg = PackageSpec::parse("neorg".into(), "1.0.6".into()).unwrap();
560        assert!(!package_req.matches(&neorg));
561        // Testing incomplete version constraints
562        let package_req: PackageReq = "lua-utils.nvim ~> 1.1-1".parse().unwrap();
563        let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.1.4".into()).unwrap();
564        assert!(package_req.matches(&lua_utils));
565        let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.1.5".into()).unwrap();
566        assert!(package_req.matches(&lua_utils));
567        let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.2-1".into()).unwrap();
568        assert!(!package_req.matches(&lua_utils));
569    }
570
571    #[tokio::test]
572    pub async fn remote_package_type_priorities() {
573        let rock_types = vec![
574            RemotePackageType::Binary,
575            RemotePackageType::Src,
576            RemotePackageType::Rockspec,
577        ];
578        assert_eq!(
579            rock_types.into_iter().sorted().collect_vec(),
580            vec![
581                RemotePackageType::Src,
582                RemotePackageType::Rockspec,
583                RemotePackageType::Binary,
584            ]
585        );
586    }
587}