1use reqwest::Url;
2use serde::{de, Deserialize, Deserializer};
3use std::{convert::Infallible, fs, io, ops::Deref, path::PathBuf, str::FromStr};
4use thiserror::Error;
5
6use crate::{
7 git::{
8 url::{RemoteGitUrl, RemoteGitUrlParseError},
9 GitSource,
10 },
11 lua_rockspec::per_platform_from_intermediate,
12};
13
14use super::{
15 DisplayAsLuaKV, DisplayLuaKV, DisplayLuaValue, PartialOverride, PerPlatform,
16 PlatformOverridable,
17};
18
19#[derive(Default, Deserialize, Clone, Debug, PartialEq)]
20pub struct LocalRockSource {
21 pub archive_name: Option<PathBuf>,
22 pub unpack_dir: Option<PathBuf>,
23}
24
25#[derive(Deserialize, Clone, Debug, PartialEq)]
26pub struct RemoteRockSource {
27 pub(crate) local: LocalRockSource,
28 pub source_spec: RockSourceSpec,
29}
30
31impl From<RockSourceSpec> for RemoteRockSource {
32 fn from(source_spec: RockSourceSpec) -> Self {
33 Self {
34 local: LocalRockSource::default(),
35 source_spec,
36 }
37 }
38}
39
40impl Deref for RemoteRockSource {
41 type Target = LocalRockSource;
42
43 fn deref(&self) -> &Self::Target {
44 &self.local
45 }
46}
47
48#[derive(Error, Debug)]
49pub enum RockSourceError {
50 #[error("invalid rockspec source field combination")]
51 InvalidCombination,
52 #[error(transparent)]
53 SourceUrl(#[from] SourceUrlError),
54 #[error("source URL missing")]
55 SourceUrlMissing,
56}
57
58impl From<RockSourceInternal> for LocalRockSource {
59 fn from(internal: RockSourceInternal) -> Self {
60 LocalRockSource {
61 archive_name: internal.file,
62 unpack_dir: internal.dir,
63 }
64 }
65}
66
67impl TryFrom<RockSourceInternal> for RemoteRockSource {
68 type Error = RockSourceError;
69
70 fn try_from(internal: RockSourceInternal) -> Result<Self, Self::Error> {
71 let local = LocalRockSource::from(internal.clone());
72
73 let url = SourceUrl::from_str(&internal.url.ok_or(RockSourceError::SourceUrlMissing)?)?;
76
77 let source_spec = match (url, internal.tag, internal.branch) {
78 (source, None, None) => Ok(RockSourceSpec::default_from_source_url(source)),
79 (SourceUrl::Git(url), Some(tag), None) => Ok(RockSourceSpec::Git(GitSource {
80 url,
81 checkout_ref: Some(tag),
82 })),
83 (SourceUrl::Git(url), None, Some(branch)) => Ok(RockSourceSpec::Git(GitSource {
84 url,
85 checkout_ref: Some(branch),
86 })),
87 _ => Err(RockSourceError::InvalidCombination),
88 }?;
89
90 Ok(RemoteRockSource { source_spec, local })
91 }
92}
93
94impl<'de> Deserialize<'de> for PerPlatform<RemoteRockSource> {
95 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
96 where
97 D: Deserializer<'de>,
98 {
99 per_platform_from_intermediate::<_, RockSourceInternal, _>(deserializer)
100 }
101}
102
103#[derive(Debug, PartialEq, Clone)]
104pub enum RockSourceSpec {
105 Git(GitSource),
106 File(PathBuf),
107 Url(Url),
108}
109
110impl RockSourceSpec {
111 fn default_from_source_url(url: SourceUrl) -> Self {
112 match url {
113 SourceUrl::File(path) => Self::File(path),
114 SourceUrl::Url(url) => Self::Url(url),
115 SourceUrl::Git(url) => Self::Git(GitSource {
116 url,
117 checkout_ref: None,
118 }),
119 }
120 }
121}
122
123impl<'de> Deserialize<'de> for RockSourceSpec {
124 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
125 where
126 D: Deserializer<'de>,
127 {
128 let url = String::deserialize(deserializer)?;
129 Ok(RockSourceSpec::default_from_source_url(
130 url.parse().map_err(de::Error::custom)?,
131 ))
132 }
133}
134
135impl DisplayAsLuaKV for RockSourceSpec {
136 fn display_lua(&self) -> DisplayLuaKV {
137 match self {
138 RockSourceSpec::Git(git_source) => git_source.display_lua(),
139 RockSourceSpec::File(path) => {
140 let mut source_tbl = Vec::new();
141 source_tbl.push(DisplayLuaKV {
142 key: "url".to_string(),
143 value: DisplayLuaValue::String(format!("file:://{}", path.display())),
144 });
145 DisplayLuaKV {
146 key: "source".to_string(),
147 value: DisplayLuaValue::Table(source_tbl),
148 }
149 }
150 RockSourceSpec::Url(url) => {
151 let mut source_tbl = Vec::new();
152 source_tbl.push(DisplayLuaKV {
153 key: "url".to_string(),
154 value: DisplayLuaValue::String(format!("{url}")),
155 });
156 DisplayLuaKV {
157 key: "source".to_string(),
158 value: DisplayLuaValue::Table(source_tbl),
159 }
160 }
161 }
162 }
163}
164
165#[derive(Debug, PartialEq, Deserialize, Clone, Default, lux_macros::DisplayAsLuaKV)]
168#[display_lua(key = "source")]
169pub(crate) struct RockSourceInternal {
170 #[serde(default)]
171 pub(crate) url: Option<String>,
172 pub(crate) file: Option<PathBuf>,
173 pub(crate) dir: Option<PathBuf>,
174 pub(crate) tag: Option<String>,
175 pub(crate) branch: Option<String>,
176}
177
178impl PartialOverride for RockSourceInternal {
179 type Err = Infallible;
180
181 fn apply_overrides(&self, override_spec: &Self) -> Result<Self, Self::Err> {
182 Ok(Self {
183 url: override_opt(override_spec.url.as_ref(), self.url.as_ref()),
184 file: override_opt(override_spec.file.as_ref(), self.file.as_ref()),
185 dir: override_opt(override_spec.dir.as_ref(), self.dir.as_ref()),
186 tag: match &override_spec.branch {
187 None => override_opt(override_spec.tag.as_ref(), self.tag.as_ref()),
188 _ => None,
189 },
190 branch: match &override_spec.tag {
191 None => override_opt(override_spec.branch.as_ref(), self.branch.as_ref()),
192 _ => None,
193 },
194 })
195 }
196}
197
198#[derive(Error, Debug)]
199#[error("missing source")]
200pub struct RockSourceMissingSource;
201
202impl PlatformOverridable for RockSourceInternal {
203 type Err = RockSourceMissingSource;
204
205 fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
206 where
207 T: PlatformOverridable,
208 {
209 Err(RockSourceMissingSource)
210 }
211}
212
213fn override_opt<T: Clone>(override_opt: Option<&T>, base: Option<&T>) -> Option<T> {
214 override_opt.or(base).cloned()
215}
216
217#[derive(Debug, PartialEq, Clone)]
219pub(crate) enum SourceUrl {
220 File(PathBuf),
222 Url(Url),
224 Git(RemoteGitUrl),
226}
227
228#[derive(Error, Debug)]
229#[error("failed to parse source url: {0}")]
230pub enum SourceUrlError {
231 Io(#[from] io::Error),
232 Git(#[from] RemoteGitUrlParseError),
233 Url(#[source] <Url as FromStr>::Err),
234 #[error("lux does not support rockspecs with CVS sources.")]
235 CVS,
236 #[error("lux does not support rockspecs with mercurial sources.")]
237 Mercurial,
238 #[error("lux does not support rockspecs with SSCM sources.")]
239 SSCM,
240 #[error("lux does not support rockspecs with SVN sources.")]
241 SVN,
242 #[error("unsupported source URL prefix: '{0}+' in URL {1}")]
243 UnsupportedPrefix(String, String),
244 #[error("unsupported source URL: {0}")]
245 Unsupported(String),
246}
247
248impl FromStr for SourceUrl {
249 type Err = SourceUrlError;
250
251 fn from_str(str: &str) -> Result<Self, Self::Err> {
252 match str.split_once("+") {
253 Some(("git" | "gitrec", url)) => Ok(Self::Git(url.parse()?)),
254 Some((prefix, _)) => Err(SourceUrlError::UnsupportedPrefix(
255 prefix.to_string(),
256 str.to_string(),
257 )),
258 None => match str {
259 s if s.starts_with("file://") => {
260 let path_buf: PathBuf = s.trim_start_matches("file://").into();
261 let path = fs::canonicalize(&path_buf)?;
262 Ok(Self::File(path))
263 }
264 s if s.starts_with("git://") => {
265 Ok(Self::Git(s.replacen("git", "https", 1).parse()?))
266 }
267 s if s.ends_with(".git") => Ok(Self::Git(s.parse()?)),
268 s if starts_with_any(s, ["https://", "http://", "ftp://"].into()) => {
269 Ok(Self::Url(s.parse().map_err(SourceUrlError::Url)?))
270 }
271 s if s.starts_with("cvs://") => Err(SourceUrlError::CVS),
272 s if starts_with_any(
273 s,
274 ["hg://", "hg+http://", "hg+https://", "hg+ssh://"].into(),
275 ) =>
276 {
277 Err(SourceUrlError::Mercurial)
278 }
279 s if s.starts_with("sscm://") => Err(SourceUrlError::SSCM),
280 s if s.starts_with("svn://") => Err(SourceUrlError::SVN),
281 s => Err(SourceUrlError::Unsupported(s.to_string())),
282 },
283 }
284 }
285}
286
287impl<'de> Deserialize<'de> for SourceUrl {
288 fn deserialize<D>(deserializer: D) -> Result<SourceUrl, D::Error>
289 where
290 D: Deserializer<'de>,
291 {
292 SourceUrl::from_str(&String::deserialize(deserializer)?).map_err(de::Error::custom)
293 }
294}
295
296fn starts_with_any(str: &str, prefixes: Vec<&str>) -> bool {
297 prefixes.iter().any(|&prefix| str.starts_with(prefix))
298}
299
300#[cfg(test)]
301mod tests {
302
303 use assert_fs::TempDir;
304
305 use super::*;
306
307 fn eval_lua_global<T: serde::de::DeserializeOwned>(code: &str, key: &'static str) -> T {
308 use ottavino::{Closure, Executor, Fuel, Lua};
309 use ottavino_util::serde::from_value;
310 Lua::core()
311 .try_enter(|ctx| {
312 let closure = Closure::load(ctx, None, code.as_bytes())?;
313 let executor = Executor::start(ctx, closure.into(), ());
314 executor.step(ctx, &mut Fuel::with(i32::MAX))?;
315 from_value(ctx.globals().get_value(ctx, key)).map_err(ottavino::Error::from)
316 })
317 .unwrap()
318 }
319
320 #[test]
321 pub fn rock_source_internal_roundtrip() {
322 let source = RockSourceInternal {
323 url: Some("https://github.com/example/repo/archive/v1.0.tar.gz".into()),
324 file: Some("repo-1.0.tar.gz".into()),
325 dir: Some("repo-1.0".into()),
326 tag: Some("v1.0".into()),
327 branch: None,
328 };
329 let lua = source.display_lua().to_string();
330 let restored: RockSourceInternal = eval_lua_global(&lua, "source");
331 assert_eq!(source, restored);
332 }
333
334 #[test]
335 pub fn rock_source_internal_branch_roundtrip() {
336 let source = RockSourceInternal {
337 url: Some("git+https://github.com/example/repo.git".into()),
338 file: None,
339 dir: None,
340 tag: None,
341 branch: Some("main".into()),
342 };
343 let lua = source.display_lua().to_string();
344 let restored: RockSourceInternal = eval_lua_global(&lua, "source");
345 assert_eq!(source, restored);
346 }
347
348 #[tokio::test]
349 async fn parse_source_url() {
350 let dir = TempDir::new().unwrap();
351 let url: SourceUrl = format!("file://{}", dir.to_string_lossy()).parse().unwrap();
352 assert_eq!(url, SourceUrl::File(dir.path().to_path_buf()));
353 let url: SourceUrl = "ftp://example.com/foo/bar".parse().unwrap();
354 assert!(matches!(url, SourceUrl::Url { .. }));
355 let url: SourceUrl = "git://example.com/foo/bar".parse().unwrap();
356 assert!(matches!(url, SourceUrl::Git { .. }));
357 SourceUrl::from_str("git+file:///path/to/repo.git").unwrap_err();
359 let url: SourceUrl = "git+http://example.com/foo/bar".parse().unwrap();
360 assert!(matches!(url, SourceUrl::Git { .. }));
361 let url: SourceUrl = "git+https://example.com/foo/bar".parse().unwrap();
362 assert!(matches!(url, SourceUrl::Git { .. }));
363 let url: SourceUrl = "git+ssh://example.com/foo/bar".parse().unwrap();
364 assert!(matches!(url, SourceUrl::Git { .. }));
365 let url: SourceUrl = "gitrec+https://example.com/foo/bar".parse().unwrap();
366 assert!(matches!(url, SourceUrl::Git { .. }));
367 let url: SourceUrl = "https://example.com/foo/bar".parse().unwrap();
368 assert!(matches!(url, SourceUrl::Url { .. }));
369 let url: SourceUrl = "http://example.com/foo/bar".parse().unwrap();
370 assert!(matches!(url, SourceUrl::Url { .. }));
371 }
372}