1use itertools::Itertools;
2use mlua::{ExternalResult, FromLua, IntoLua, LuaSerdeExt};
3use serde::{de, Deserialize, Deserializer, Serialize};
4use std::{cmp::Ordering, fmt::Display, str::FromStr};
5use thiserror::Error;
6
7mod outdated;
8mod version;
9
10pub use outdated::*;
11pub use version::{
12 PackageVersion, PackageVersionParseError, PackageVersionReq, PackageVersionReqError,
13 VersionReqToVersionError,
14};
15
16use crate::{
17 lockfile::RemotePackageSourceUrl,
18 lua_rockspec::{DisplayAsLuaKV, DisplayLuaKV, DisplayLuaValue},
19 package::version::HasModRev,
20 remote_package_source::RemotePackageSource,
21 rockspec::lua_dependency::LuaDependencySpec,
22 variables::{GetVariableError, HasVariables},
23};
24
25#[derive(Debug, Error)]
26pub enum PackageSpecFromPackageReqError {
27 #[error("invalid version for rock {rock}: {err}")]
28 Version {
29 rock: PackageName,
30 err: VersionReqToVersionError,
31 },
32}
33
34#[derive(Clone, Debug)]
35#[cfg_attr(feature = "clap", derive(clap::Args))]
36pub struct PackageSpec {
37 name: PackageName,
38 version: PackageVersion,
39}
40
41impl PackageSpec {
42 pub fn new(name: PackageName, version: PackageVersion) -> Self {
43 Self { name, version }
44 }
45 pub fn parse(name: String, version: String) -> Result<Self, PackageVersionParseError> {
46 Ok(Self::new(
47 PackageName::new(name),
48 PackageVersion::parse(&version)?,
49 ))
50 }
51 pub fn name(&self) -> &PackageName {
52 &self.name
53 }
54 pub fn version(&self) -> &PackageVersion {
55 &self.version
56 }
57 pub fn into_package_req(self) -> PackageReq {
58 PackageReq {
59 name: self.name,
60 version_req: self.version.into_version_req(),
61 }
62 }
63}
64
65impl TryFrom<PackageReq> for PackageSpec {
66 type Error = PackageSpecFromPackageReqError;
67
68 fn try_from(value: PackageReq) -> Result<Self, Self::Error> {
69 let name = value.name;
70 let version = value.version_req.try_into().map_err(|err| {
71 PackageSpecFromPackageReqError::Version {
72 rock: name.clone(),
73 err,
74 }
75 })?;
76 Ok(Self { name, version })
77 }
78}
79
80impl FromLua for PackageSpec {
81 fn from_lua(
82 value: mlua::prelude::LuaValue,
83 lua: &mlua::prelude::Lua,
84 ) -> mlua::prelude::LuaResult<Self> {
85 let (name, version) = lua.from_value(value)?;
86
87 Self::parse(name, version).into_lua_err()
88 }
89}
90
91impl mlua::UserData for PackageSpec {
92 fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
93 fields.add_field_method_get("name", |_, this| Ok(this.name.to_string()));
94 fields.add_field_method_get("version", |_, this| Ok(this.version.to_string()));
95 }
96
97 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
98 methods.add_method("to_package_req", |_, this, ()| {
99 Ok(this.clone().into_package_req())
100 })
101 }
102}
103
104impl HasVariables for PackageSpec {
105 fn get_variable(&self, input: &str) -> Result<Option<String>, GetVariableError> {
106 Ok(match input {
107 "PACKAGE" => Some(self.name.to_string()),
108 "VERSION" => Some(self.version.to_modrev_string()),
109 _ => None,
110 })
111 }
112}
113
114#[derive(Clone, Debug)]
115pub(crate) struct RemotePackage {
116 pub package: PackageSpec,
117 pub source: RemotePackageSource,
118 pub source_url: Option<RemotePackageSourceUrl>,
120}
121
122impl RemotePackage {
123 pub fn new(
124 package: PackageSpec,
125 source: RemotePackageSource,
126 source_url: Option<RemotePackageSourceUrl>,
127 ) -> Self {
128 Self {
129 package,
130 source,
131 source_url,
132 }
133 }
134}
135
136#[derive(PartialEq, Eq, Hash, Clone, Debug)]
137pub(crate) enum RemotePackageType {
138 Rockspec,
139 Src,
140 Binary,
141}
142
143impl Ord for RemotePackageType {
144 fn cmp(&self, other: &Self) -> Ordering {
145 match (self, other) {
147 (Self::Binary, Self::Binary)
148 | (Self::Rockspec, Self::Rockspec)
149 | (Self::Src, Self::Src) => Ordering::Equal,
150
151 (Self::Binary, _) => Ordering::Greater,
152 (_, Self::Binary) => Ordering::Less,
153 (Self::Rockspec, Self::Src) => Ordering::Greater,
154 (Self::Src, Self::Rockspec) => Ordering::Less,
155 }
156 }
157}
158
159impl PartialOrd for RemotePackageType {
160 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
161 Some(self.cmp(other))
162 }
163}
164
165#[derive(Clone)]
166pub struct RemotePackageTypeFilterSpec {
167 pub rockspec: bool,
169 pub src: bool,
171 pub binary: bool,
173}
174
175impl Default for RemotePackageTypeFilterSpec {
176 fn default() -> Self {
177 Self {
178 rockspec: true,
179 src: true,
180 binary: true,
181 }
182 }
183}
184
185#[derive(Error, Debug)]
186pub enum ParseRemotePackageError {
187 #[error("unable to parse package {0}. expected format: `name@version`")]
188 InvalidInput(String),
189 #[error("unable to parse package {package_str}: {error}")]
190 InvalidPackageVersion {
191 #[source]
192 error: PackageVersionParseError,
193 package_str: String,
194 },
195}
196
197impl FromStr for PackageSpec {
198 type Err = ParseRemotePackageError;
199
200 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
201 let (name, version) = s
202 .split_once('@')
203 .ok_or_else(|| ParseRemotePackageError::InvalidInput(s.to_string()))?;
204
205 Self::parse(name.to_string(), version.to_string()).map_err(|error| {
206 ParseRemotePackageError::InvalidPackageVersion {
207 error,
208 package_str: s.to_string(),
209 }
210 })
211 }
212}
213
214#[derive(Debug, Clone, PartialEq)]
216#[cfg_attr(feature = "clap", derive(clap::Args))]
217pub struct PackageReq {
218 pub(crate) name: PackageName,
220 pub(crate) version_req: PackageVersionReq,
222}
223
224impl PackageReq {
225 pub fn new(name: String, version: Option<String>) -> Result<Self, PackageVersionReqError> {
226 let version_req = match version {
227 Some(version_req_str) => PackageVersionReq::parse(version_req_str.as_str())?,
228 None => PackageVersionReq::any(),
229 };
230 Ok(Self {
231 name: PackageName::new(name),
232 version_req,
233 })
234 }
235 pub fn parse(pkg_constraints: &str) -> Result<Self, PackageReqParseError> {
236 Self::from_str(pkg_constraints)
237 }
238 pub fn name(&self) -> &PackageName {
239 &self.name
240 }
241 pub fn version_req(&self) -> &PackageVersionReq {
242 &self.version_req
243 }
244 pub fn matches(&self, package: &PackageSpec) -> bool {
247 self.name == package.name && self.version_req.matches(&package.version)
248 }
249}
250
251impl Display for PackageReq {
252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 if self.version_req.is_any() {
254 self.name.fmt(f)
255 } else {
256 f.write_str(format!("{}{}", self.name, self.version_req).as_str())
257 }
258 }
259}
260
261impl From<PackageSpec> for PackageReq {
262 fn from(value: PackageSpec) -> Self {
263 value.into_package_req()
264 }
265}
266
267impl From<PackageName> for PackageReq {
268 fn from(name: PackageName) -> Self {
269 Self {
270 name,
271 version_req: PackageVersionReq::any(),
272 }
273 }
274}
275
276impl FromLua for PackageReq {
277 fn from_lua(value: mlua::Value, lua: &mlua::Lua) -> mlua::Result<Self> {
278 let str: String = lua.from_value(value)?;
279 Self::parse(&str).into_lua_err()
280 }
281}
282
283impl mlua::UserData for PackageReq {
284 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
285 methods.add_method("name", |_, this, ()| Ok(this.name.to_string()));
286 methods.add_method("version_req", |_, this, ()| {
287 Ok(this.version_req.to_string())
288 });
289 methods.add_method("matches", |_, this, package: PackageSpec| {
290 Ok(this.matches(&package))
291 });
292 }
293}
294
295pub(crate) struct Dependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
297pub(crate) struct BuildDependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
298pub(crate) struct TestDependencies<'a>(pub(crate) &'a Vec<LuaDependencySpec>);
299
300impl DisplayAsLuaKV for Dependencies<'_> {
301 fn display_lua(&self) -> DisplayLuaKV {
302 DisplayLuaKV {
303 key: "dependencies".to_string(),
304 value: DisplayLuaValue::List(
305 self.0
306 .iter()
307 .map(|req| DisplayLuaValue::String(req.to_string()))
308 .collect(),
309 ),
310 }
311 }
312}
313
314impl DisplayAsLuaKV for BuildDependencies<'_> {
315 fn display_lua(&self) -> DisplayLuaKV {
316 DisplayLuaKV {
317 key: "build_dependencies".to_string(),
318 value: DisplayLuaValue::List(
319 self.0
320 .iter()
321 .map(|req| DisplayLuaValue::String(req.to_string()))
322 .collect(),
323 ),
324 }
325 }
326}
327
328impl DisplayAsLuaKV for TestDependencies<'_> {
329 fn display_lua(&self) -> DisplayLuaKV {
330 DisplayLuaKV {
331 key: "test_dependencies".to_string(),
332 value: DisplayLuaValue::List(
333 self.0
334 .iter()
335 .map(|req| DisplayLuaValue::String(req.to_string()))
336 .collect(),
337 ),
338 }
339 }
340}
341
342#[derive(Error, Debug)]
343pub enum PackageReqParseError {
344 #[error("could not parse dependency name from {0}")]
345 InvalidDependencyName(String),
346 #[error("could not parse version requirement in '{str}': {error}")]
347 InvalidPackageVersionReq {
348 #[source]
349 error: PackageVersionReqError,
350 str: String,
351 },
352}
353
354impl FromStr for PackageReq {
355 type Err = PackageReqParseError;
356
357 fn from_str(str: &str) -> Result<Self, PackageReqParseError> {
358 let rock_name_str = str
359 .chars()
360 .peeking_take_while(|t| t.is_alphanumeric() || matches!(t, '-' | '_' | '.'))
361 .collect::<String>();
362
363 if rock_name_str.is_empty() {
364 return Err(PackageReqParseError::InvalidDependencyName(str.to_string()));
365 }
366
367 let constraints = str.trim_start_matches(&rock_name_str).trim();
368 let version_req = match constraints {
369 "" => PackageVersionReq::any(),
370 constraints => PackageVersionReq::parse(constraints.trim_start()).map_err(|error| {
371 PackageReqParseError::InvalidPackageVersionReq {
372 error,
373 str: str.to_string(),
374 }
375 })?,
376 };
377 Ok(Self {
378 name: PackageName::new(rock_name_str),
379 version_req,
380 })
381 }
382}
383
384impl<'de> Deserialize<'de> for PackageReq {
385 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
386 where
387 D: Deserializer<'de>,
388 {
389 let s = String::deserialize(deserializer)?;
390 Self::from_str(&s).map_err(de::Error::custom)
391 }
392}
393
394#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash)]
396pub struct PackageName(String);
397
398impl IntoLua for PackageName {
399 fn into_lua(self, lua: &mlua::Lua) -> mlua::Result<mlua::Value> {
400 self.0.into_lua(lua)
401 }
402}
403
404impl PackageName {
405 pub fn new(name: String) -> Self {
406 Self(name.to_lowercase())
407 }
408}
409
410impl<'de> Deserialize<'de> for PackageName {
411 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
412 where
413 D: serde::Deserializer<'de>,
414 {
415 Ok(PackageName::new(String::deserialize(deserializer)?))
416 }
417}
418
419impl FromLua for PackageName {
420 fn from_lua(
421 value: mlua::prelude::LuaValue,
422 lua: &mlua::prelude::Lua,
423 ) -> mlua::prelude::LuaResult<Self> {
424 Ok(Self::new(String::from_lua(value, lua)?))
425 }
426}
427
428impl Serialize for PackageName {
429 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
430 where
431 S: serde::Serializer,
432 {
433 self.0.serialize(serializer)
434 }
435}
436
437impl From<&str> for PackageName {
438 fn from(value: &str) -> Self {
439 Self::new(value.into())
440 }
441}
442
443impl Display for PackageName {
444 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
445 f.write_str(self.0.as_str())
446 }
447}
448
449#[derive(Debug)]
450pub struct PackageNameList(Vec<PackageName>);
451
452impl PackageNameList {
453 pub(crate) fn new(package_names: Vec<PackageName>) -> Self {
454 Self(package_names)
455 }
456}
457
458impl Display for PackageNameList {
459 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
460 f.write_str(self.0.iter().map(|pkg| pkg.to_string()).join(", ").as_str())
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467
468 #[tokio::test]
469 async fn parse_name() {
470 let mut package_name: PackageName = "neorg".into();
471 assert_eq!(package_name.to_string(), "neorg");
472 package_name = "LuaFileSystem".into();
473 assert_eq!(package_name.to_string(), "luafilesystem");
474 }
475
476 #[tokio::test]
477 async fn parse_lua_package() {
478 let neorg = PackageSpec::parse("neorg".into(), "1.0.0".into()).unwrap();
479 let expected_version = PackageVersion::parse("1.0.0").unwrap();
480 assert_eq!(neorg.name().to_string(), "neorg");
481 assert!(matches!(
482 neorg.version().cmp(&expected_version),
483 std::cmp::Ordering::Equal
484 ));
485 let neorg = PackageSpec::parse("neorg".into(), "1.0".into()).unwrap();
486 assert!(matches!(
487 neorg.version().cmp(&expected_version),
488 std::cmp::Ordering::Equal
489 ));
490 let neorg = PackageSpec::parse("neorg".into(), "1".into()).unwrap();
491 assert!(matches!(
492 neorg.version().cmp(&expected_version),
493 std::cmp::Ordering::Equal
494 ));
495 }
496
497 #[tokio::test]
498 async fn parse_lua_package_req() {
499 let mut package_req = PackageReq::new("foo".into(), Some("1.0.0".into())).unwrap();
500 assert!(package_req.matches(&PackageSpec::parse("foo".into(), "1.0.0".into()).unwrap()));
501 assert!(!package_req.matches(&PackageSpec::parse("bar".into(), "1.0.0".into()).unwrap()));
502 assert!(!package_req.matches(&PackageSpec::parse("foo".into(), "2.0.0".into()).unwrap()));
503 package_req = PackageReq::new("foo".into(), Some(">= 1.0.0".into())).unwrap();
504 assert!(package_req.matches(&PackageSpec::parse("foo".into(), "2.0.0".into()).unwrap()));
505 let package_req: PackageReq = "lua >= 5.1".parse().unwrap();
506 assert_eq!(package_req.name.to_string(), "lua");
507 let package_req: PackageReq = "lua>=5.1".parse().unwrap();
508 assert_eq!(package_req.name.to_string(), "lua");
509 let package_req: PackageReq = "toml-edit >= 0.1.0".parse().unwrap();
510 assert_eq!(package_req.name.to_string(), "toml-edit");
511 let package_req: PackageReq = "plugin.nvim >= 0.1.0".parse().unwrap();
512 assert_eq!(package_req.name.to_string(), "plugin.nvim");
513 let package_req: PackageReq = "lfs".parse().unwrap();
514 assert_eq!(package_req.name.to_string(), "lfs");
515 let package_req: PackageReq = "neorg 1.0.0".parse().unwrap();
516 assert_eq!(package_req.name.to_string(), "neorg");
517 let neorg = PackageSpec::parse("neorg".into(), "1.0.0".into()).unwrap();
518 assert!(package_req.matches(&neorg));
519 let neorg = PackageSpec::parse("neorg".into(), "2.0.0".into()).unwrap();
520 assert!(!package_req.matches(&neorg));
521 let package_req: PackageReq = "neorg 2.0.0".parse().unwrap();
522 assert!(package_req.matches(&neorg));
523 let package_req: PackageReq = "neorg = 2.0.0".parse().unwrap();
524 assert!(package_req.matches(&neorg));
525 let package_req: PackageReq = "neorg == 2.0.0".parse().unwrap();
526 assert!(package_req.matches(&neorg));
527 let package_req: PackageReq = "neorg = 2.0.0".parse().unwrap();
528 assert!(package_req.matches(&neorg));
529 let package_req: PackageReq = "neorg >= 1.0, < 2.0".parse().unwrap();
530 let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
531 assert!(package_req.matches(&neorg));
532 let package_req: PackageReq = "neorg > 1.0, < 2.0".parse().unwrap();
533 let neorg = PackageSpec::parse("neorg".into(), "1.11.0".into()).unwrap();
534 assert!(package_req.matches(&neorg));
535 let neorg = PackageSpec::parse("neorg".into(), "3.0.0".into()).unwrap();
536 assert!(!package_req.matches(&neorg));
537 let neorg = PackageSpec::parse("neorg".into(), "0.5".into()).unwrap();
538 assert!(!package_req.matches(&neorg));
539 let package_req: PackageReq = "neorg ~> 1".parse().unwrap();
540 assert!(!package_req.matches(&neorg));
541 let neorg = PackageSpec::parse("neorg".into(), "3".into()).unwrap();
542 assert!(!package_req.matches(&neorg));
543 let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
544 assert!(package_req.matches(&neorg));
545 let package_req: PackageReq = "neorg ~> 1.4".parse().unwrap();
546 let neorg = PackageSpec::parse("neorg".into(), "1.3".into()).unwrap();
547 assert!(!package_req.matches(&neorg));
548 let neorg = PackageSpec::parse("neorg".into(), "1.5".into()).unwrap();
549 assert!(!package_req.matches(&neorg));
550 let neorg = PackageSpec::parse("neorg".into(), "1.4.10".into()).unwrap();
551 assert!(package_req.matches(&neorg));
552 let neorg = PackageSpec::parse("neorg".into(), "1.4".into()).unwrap();
553 assert!(package_req.matches(&neorg));
554 let package_req: PackageReq = "neorg ~> 1.0.5".parse().unwrap();
555 let neorg = PackageSpec::parse("neorg".into(), "1.0.4".into()).unwrap();
556 assert!(!package_req.matches(&neorg));
557 let neorg = PackageSpec::parse("neorg".into(), "1.0.5".into()).unwrap();
558 assert!(package_req.matches(&neorg));
559 let neorg = PackageSpec::parse("neorg".into(), "1.0.6".into()).unwrap();
560 assert!(!package_req.matches(&neorg));
561 let package_req: PackageReq = "lua-utils.nvim ~> 1.1-1".parse().unwrap();
563 let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.1.4".into()).unwrap();
564 assert!(package_req.matches(&lua_utils));
565 let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.1.5".into()).unwrap();
566 assert!(package_req.matches(&lua_utils));
567 let lua_utils = PackageSpec::parse("lua-utils.nvim".into(), "1.2-1".into()).unwrap();
568 assert!(!package_req.matches(&lua_utils));
569 }
570
571 #[tokio::test]
572 pub async fn remote_package_type_priorities() {
573 let rock_types = vec![
574 RemotePackageType::Binary,
575 RemotePackageType::Src,
576 RemotePackageType::Rockspec,
577 ];
578 assert_eq!(
579 rock_types.into_iter().sorted().collect_vec(),
580 vec![
581 RemotePackageType::Src,
582 RemotePackageType::Rockspec,
583 RemotePackageType::Binary,
584 ]
585 );
586 }
587}