Skip to main content

lux_lib/lua_rockspec/
rock_source.rs

1use mlua::{FromLua, IntoLua, Lua, UserData, Value};
2use path_slash::PathBufExt;
3use reqwest::Url;
4use serde::{de, Deserialize, Deserializer};
5use std::{convert::Infallible, fs, io, ops::Deref, path::PathBuf, str::FromStr};
6use thiserror::Error;
7
8use crate::git::{
9    url::{RemoteGitUrl, RemoteGitUrlParseError},
10    GitSource,
11};
12
13use super::{
14    DisplayAsLuaKV, DisplayLuaKV, DisplayLuaValue, FromPlatformOverridable, PartialOverride,
15    PerPlatform, PerPlatformWrapper, PlatformOverridable,
16};
17
18#[derive(Default, Deserialize, Clone, Debug, PartialEq)]
19pub struct LocalRockSource {
20    pub archive_name: Option<PathBuf>,
21    pub unpack_dir: Option<PathBuf>,
22}
23
24#[derive(Deserialize, Clone, Debug, PartialEq)]
25pub struct RemoteRockSource {
26    pub(crate) local: LocalRockSource,
27    pub source_spec: RockSourceSpec,
28}
29
30impl From<RockSourceSpec> for RemoteRockSource {
31    fn from(source_spec: RockSourceSpec) -> Self {
32        Self {
33            local: LocalRockSource::default(),
34            source_spec,
35        }
36    }
37}
38
39impl UserData for RemoteRockSource {
40    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
41        methods.add_method("source_spec", |_, this, _: ()| Ok(this.source_spec.clone()));
42        methods.add_method("archive_name", |_, this, _: ()| {
43            Ok(this.local.archive_name.clone())
44        });
45        methods.add_method("unpack_dir", |_, this, _: ()| {
46            Ok(this.local.unpack_dir.clone())
47        });
48    }
49}
50
51impl Deref for RemoteRockSource {
52    type Target = LocalRockSource;
53
54    fn deref(&self) -> &Self::Target {
55        &self.local
56    }
57}
58
59#[derive(Error, Debug)]
60pub enum RockSourceError {
61    #[error("invalid rockspec source field combination")]
62    InvalidCombination,
63    #[error(transparent)]
64    SourceUrl(#[from] SourceUrlError),
65    #[error("source URL missing")]
66    SourceUrlMissing,
67}
68
69impl FromPlatformOverridable<RockSourceInternal, Self> for LocalRockSource {
70    type Err = Infallible;
71
72    fn from_platform_overridable(internal: RockSourceInternal) -> Result<Self, Self::Err> {
73        Ok(LocalRockSource {
74            archive_name: internal.file,
75            unpack_dir: internal.dir,
76        })
77    }
78}
79
80impl FromPlatformOverridable<RockSourceInternal, Self> for RemoteRockSource {
81    type Err = RockSourceError;
82
83    fn from_platform_overridable(internal: RockSourceInternal) -> Result<Self, Self::Err> {
84        let local = LocalRockSource::from_platform_overridable(internal.clone()).unwrap();
85
86        // The rockspec.source table allows invalid combinations
87        // This ensures that invalid combinations are caught while parsing.
88        let url = SourceUrl::from_str(&internal.url.ok_or(RockSourceError::SourceUrlMissing)?)?;
89
90        let source_spec = match (url, internal.tag, internal.branch) {
91            (source, None, None) => Ok(RockSourceSpec::default_from_source_url(source)),
92            (SourceUrl::Git(url), Some(tag), None) => Ok(RockSourceSpec::Git(GitSource {
93                url,
94                checkout_ref: Some(tag),
95            })),
96            (SourceUrl::Git(url), None, Some(branch)) => Ok(RockSourceSpec::Git(GitSource {
97                url,
98                checkout_ref: Some(branch),
99            })),
100            _ => Err(RockSourceError::InvalidCombination),
101        }?;
102
103        Ok(RemoteRockSource { source_spec, local })
104    }
105}
106
107impl FromLua for PerPlatform<RemoteRockSource> {
108    fn from_lua(value: Value, lua: &Lua) -> mlua::Result<Self> {
109        let wrapper = PerPlatformWrapper::from_lua(value, lua)?;
110        Ok(wrapper.un_per_platform)
111    }
112}
113
114#[derive(Debug, PartialEq, Clone)]
115pub enum RockSourceSpec {
116    Git(GitSource),
117    File(PathBuf),
118    Url(Url),
119}
120
121impl IntoLua for RockSourceSpec {
122    fn into_lua(self, lua: &Lua) -> mlua::Result<Value> {
123        let table = lua.create_table()?;
124
125        match self {
126            RockSourceSpec::Git(git) => {
127                table.set("git", git.into_lua(lua)?)?;
128            }
129            RockSourceSpec::File(path) => {
130                table.set("file", path.to_string_lossy().to_string())?;
131            }
132            RockSourceSpec::Url(url) => {
133                table.set("url", url.to_string())?;
134            }
135        };
136
137        Ok(Value::Table(table))
138    }
139}
140
141impl RockSourceSpec {
142    fn default_from_source_url(url: SourceUrl) -> Self {
143        match url {
144            SourceUrl::File(path) => Self::File(path),
145            SourceUrl::Url(url) => Self::Url(url),
146            SourceUrl::Git(url) => Self::Git(GitSource {
147                url,
148                checkout_ref: None,
149            }),
150        }
151    }
152}
153
154impl<'de> Deserialize<'de> for RockSourceSpec {
155    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
156    where
157        D: Deserializer<'de>,
158    {
159        let url = String::deserialize(deserializer)?;
160        Ok(RockSourceSpec::default_from_source_url(
161            url.parse().map_err(de::Error::custom)?,
162        ))
163    }
164}
165
166impl DisplayAsLuaKV for RockSourceSpec {
167    fn display_lua(&self) -> DisplayLuaKV {
168        match self {
169            RockSourceSpec::Git(git_source) => git_source.display_lua(),
170            RockSourceSpec::File(path) => {
171                let mut source_tbl = Vec::new();
172                source_tbl.push(DisplayLuaKV {
173                    key: "url".to_string(),
174                    value: DisplayLuaValue::String(format!("file:://{}", path.display())),
175                });
176                DisplayLuaKV {
177                    key: "source".to_string(),
178                    value: DisplayLuaValue::Table(source_tbl),
179                }
180            }
181            RockSourceSpec::Url(url) => {
182                let mut source_tbl = Vec::new();
183                source_tbl.push(DisplayLuaKV {
184                    key: "url".to_string(),
185                    value: DisplayLuaValue::String(format!("{url}")),
186                });
187                DisplayLuaKV {
188                    key: "source".to_string(),
189                    value: DisplayLuaValue::Table(source_tbl),
190                }
191            }
192        }
193    }
194}
195
196/// Used as a helper for Deserialize,
197/// because the Rockspec schema allows invalid rockspecs (╯°□°)╯︵ ┻━┻
198#[derive(Debug, PartialEq, Deserialize, Clone, Default)]
199pub(crate) struct RockSourceInternal {
200    #[serde(default)]
201    pub(crate) url: Option<String>,
202    pub(crate) file: Option<PathBuf>,
203    pub(crate) dir: Option<PathBuf>,
204    pub(crate) tag: Option<String>,
205    pub(crate) branch: Option<String>,
206}
207
208impl PartialOverride for RockSourceInternal {
209    type Err = Infallible;
210
211    fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
212        Ok(Self {
213            url: override_opt(override_spec.url.as_ref(), self.url.as_ref()),
214            file: override_opt(override_spec.file.as_ref(), self.file.as_ref()),
215            dir: override_opt(override_spec.dir.as_ref(), self.dir.as_ref()),
216            tag: match &override_spec.branch {
217                None => override_opt(override_spec.tag.as_ref(), self.tag.as_ref()),
218                _ => None,
219            },
220            branch: match &override_spec.tag {
221                None => override_opt(override_spec.branch.as_ref(), self.branch.as_ref()),
222                _ => None,
223            },
224        })
225    }
226}
227
228impl DisplayAsLuaKV for RockSourceInternal {
229    fn display_lua(&self) -> DisplayLuaKV {
230        let mut result = Vec::new();
231
232        if let Some(url) = &self.url {
233            result.push(DisplayLuaKV {
234                key: "url".to_string(),
235                value: DisplayLuaValue::String(url.clone()),
236            });
237        }
238        if let Some(file) = &self.file {
239            result.push(DisplayLuaKV {
240                key: "file".to_string(),
241                value: DisplayLuaValue::String(file.to_slash_lossy().to_string()),
242            });
243        }
244        if let Some(dir) = &self.dir {
245            result.push(DisplayLuaKV {
246                key: "dir".to_string(),
247                value: DisplayLuaValue::String(dir.to_slash_lossy().to_string()),
248            });
249        }
250        if let Some(tag) = &self.tag {
251            result.push(DisplayLuaKV {
252                key: "tag".to_string(),
253                value: DisplayLuaValue::String(tag.clone()),
254            });
255        }
256        if let Some(branch) = &self.branch {
257            result.push(DisplayLuaKV {
258                key: "branch".to_string(),
259                value: DisplayLuaValue::String(branch.clone()),
260            });
261        }
262
263        DisplayLuaKV {
264            key: "source".to_string(),
265            value: DisplayLuaValue::Table(result),
266        }
267    }
268}
269
270#[derive(Error, Debug)]
271#[error("missing source")]
272pub struct RockSourceMissingSource;
273
274impl PlatformOverridable for RockSourceInternal {
275    type Err = RockSourceMissingSource;
276
277    fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
278    where
279        T: PlatformOverridable,
280    {
281        Err(RockSourceMissingSource)
282    }
283}
284
285fn override_opt<T: Clone>(override_opt: Option<&T>, base: Option<&T>) -> Option<T> {
286    override_opt.or(base).cloned()
287}
288
289/// Internal helper for parsing
290#[derive(Debug, PartialEq, Clone)]
291pub(crate) enum SourceUrl {
292    /// For URLs in the local filesystem
293    File(PathBuf),
294    /// Web URLs
295    Url(Url),
296    /// For the Git source control manager
297    Git(RemoteGitUrl),
298}
299
300#[derive(Error, Debug)]
301#[error("failed to parse source url: {0}")]
302pub enum SourceUrlError {
303    Io(#[from] io::Error),
304    Git(#[from] RemoteGitUrlParseError),
305    Url(#[source] <Url as FromStr>::Err),
306    #[error("lux does not support rockspecs with CVS sources.")]
307    CVS,
308    #[error("lux does not support rockspecs with mercurial sources.")]
309    Mercurial,
310    #[error("lux does not support rockspecs with SSCM sources.")]
311    SSCM,
312    #[error("lux does not support rockspecs with SVN sources.")]
313    SVN,
314    #[error("unsupported source URL prefix: '{0}+' in URL {1}")]
315    UnsupportedPrefix(String, String),
316    #[error("unsupported source URL: {0}")]
317    Unsupported(String),
318}
319
320impl FromStr for SourceUrl {
321    type Err = SourceUrlError;
322
323    fn from_str(str: &str) -> Result<Self, Self::Err> {
324        match str.split_once("+") {
325            Some(("git" | "gitrec", url)) => Ok(Self::Git(url.parse()?)),
326            Some((prefix, _)) => Err(SourceUrlError::UnsupportedPrefix(
327                prefix.to_string(),
328                str.to_string(),
329            )),
330            None => match str {
331                s if s.starts_with("file://") => {
332                    let path_buf: PathBuf = s.trim_start_matches("file://").into();
333                    let path = fs::canonicalize(&path_buf)?;
334                    Ok(Self::File(path))
335                }
336                s if s.starts_with("git://") => {
337                    Ok(Self::Git(s.replacen("git", "https", 1).parse()?))
338                }
339                s if s.ends_with(".git") => Ok(Self::Git(s.parse()?)),
340                s if starts_with_any(s, ["https://", "http://", "ftp://"].into()) => {
341                    Ok(Self::Url(s.parse().map_err(SourceUrlError::Url)?))
342                }
343                s if s.starts_with("cvs://") => Err(SourceUrlError::CVS),
344                s if starts_with_any(
345                    s,
346                    ["hg://", "hg+http://", "hg+https://", "hg+ssh://"].into(),
347                ) =>
348                {
349                    Err(SourceUrlError::Mercurial)
350                }
351                s if s.starts_with("sscm://") => Err(SourceUrlError::SSCM),
352                s if s.starts_with("svn://") => Err(SourceUrlError::SVN),
353                s => Err(SourceUrlError::Unsupported(s.to_string())),
354            },
355        }
356    }
357}
358
359impl<'de> Deserialize<'de> for SourceUrl {
360    fn deserialize<D>(deserializer: D) -> Result<SourceUrl, D::Error>
361    where
362        D: Deserializer<'de>,
363    {
364        SourceUrl::from_str(&String::deserialize(deserializer)?).map_err(de::Error::custom)
365    }
366}
367
368fn starts_with_any(str: &str, prefixes: Vec<&str>) -> bool {
369    prefixes.iter().any(|&prefix| str.starts_with(prefix))
370}
371
372#[cfg(test)]
373mod tests {
374
375    use assert_fs::TempDir;
376
377    use super::*;
378
379    #[tokio::test]
380    async fn parse_source_url() {
381        let dir = TempDir::new().unwrap();
382        let url: SourceUrl = format!("file://{}", dir.to_string_lossy()).parse().unwrap();
383        assert_eq!(url, SourceUrl::File(dir.path().to_path_buf()));
384        let url: SourceUrl = "ftp://example.com/foo/bar".parse().unwrap();
385        assert!(matches!(url, SourceUrl::Url { .. }));
386        let url: SourceUrl = "git://example.com/foo/bar".parse().unwrap();
387        assert!(matches!(url, SourceUrl::Git { .. }));
388        // We don't support file-like URLs, as they are not remote.
389        SourceUrl::from_str("git+file:///path/to/repo.git").unwrap_err();
390        let url: SourceUrl = "git+http://example.com/foo/bar".parse().unwrap();
391        assert!(matches!(url, SourceUrl::Git { .. }));
392        let url: SourceUrl = "git+https://example.com/foo/bar".parse().unwrap();
393        assert!(matches!(url, SourceUrl::Git { .. }));
394        let url: SourceUrl = "git+ssh://example.com/foo/bar".parse().unwrap();
395        assert!(matches!(url, SourceUrl::Git { .. }));
396        let url: SourceUrl = "gitrec+https://example.com/foo/bar".parse().unwrap();
397        assert!(matches!(url, SourceUrl::Git { .. }));
398        let url: SourceUrl = "https://example.com/foo/bar".parse().unwrap();
399        assert!(matches!(url, SourceUrl::Url { .. }));
400        let url: SourceUrl = "http://example.com/foo/bar".parse().unwrap();
401        assert!(matches!(url, SourceUrl::Url { .. }));
402    }
403}