1use reqwest::Url;
2use serde::{Deserialize, Serialize};
3use std::{convert::TryFrom, path::PathBuf};
4
5use crate::{ErrorKind, ResolvedInputSource};
6
7#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
11#[allow(variant_size_differences)]
12#[serde(try_from = "String")]
13pub enum Base {
14 Local(PathBuf),
16 Remote(Url),
18}
19
20impl Base {
21 #[must_use]
23 pub(crate) fn join(&self, link: &str) -> Option<Url> {
24 match self {
25 Self::Remote(url) => url.join(link).ok(),
26 Self::Local(path) => {
27 let full_path = path.join(link);
28 Url::from_file_path(full_path).ok()
29 }
30 }
31 }
32
33 pub(crate) fn from_source(source: &ResolvedInputSource) -> Option<Base> {
34 match &source {
35 ResolvedInputSource::RemoteUrl(url) => {
36 let mut base_url = url.clone();
38 base_url.set_path("");
39 base_url.set_query(None);
40 base_url.set_fragment(None);
41
42 Some(Base::Remote(*base_url))
44 }
45 _ => None,
47 }
48 }
49}
50
51impl TryFrom<&str> for Base {
52 type Error = ErrorKind;
53
54 fn try_from(value: &str) -> Result<Self, Self::Error> {
55 if let Ok(url) = Url::parse(value) {
56 if url.cannot_be_a_base() {
57 return Err(ErrorKind::InvalidBase(
58 value.to_string(),
59 "The given URL cannot be used as a base URL".to_string(),
60 ));
61 }
62 return Ok(Self::Remote(url));
63 }
64
65 let path = PathBuf::from(value);
69 if path.is_absolute() {
70 Ok(Self::Local(path))
71 } else {
72 Err(ErrorKind::InvalidBase(
73 value.to_string(),
74 "Base must either be a URL (with scheme) or an absolute local path".to_string(),
75 ))
76 }
77 }
78}
79
80impl TryFrom<String> for Base {
81 type Error = ErrorKind;
82
83 fn try_from(value: String) -> Result<Self, Self::Error> {
84 Self::try_from(value.as_str())
85 }
86}
87
88#[cfg(test)]
89mod test_base {
90 use crate::Result;
91
92 use super::*;
93
94 #[test]
95 fn test_valid_remote() -> Result<()> {
96 let base = Base::try_from("https://endler.dev")?;
97 assert_eq!(
98 base,
99 Base::Remote(Url::parse("https://endler.dev").unwrap())
100 );
101 Ok(())
102 }
103
104 #[test]
105 fn test_invalid_url() {
106 assert!(Base::try_from("data:text/plain,Hello?World#").is_err());
107 }
108
109 #[test]
110 fn test_valid_local_path_string_as_base() -> Result<()> {
111 let cases = vec!["/tmp/lychee", "/tmp/lychee/"];
112
113 for case in cases {
114 assert_eq!(Base::try_from(case)?, Base::Local(PathBuf::from(case)));
115 }
116 Ok(())
117 }
118
119 #[test]
120 fn test_invalid_local_path_string_as_base() {
121 let cases = vec!["a", "tmp/lychee/", "example.com", "../nonlocal"];
122
123 for case in cases {
124 assert!(Base::try_from(case).is_err());
125 }
126 }
127
128 #[test]
129 fn test_valid_local() -> Result<()> {
130 let dir = tempfile::tempdir().unwrap();
131 Base::try_from(dir.as_ref().to_str().unwrap())?;
132 Ok(())
133 }
134
135 #[test]
136 fn test_get_base_from_url() {
137 for (url, expected) in [
138 ("https://example.com", "https://example.com"),
139 ("https://example.com?query=something", "https://example.com"),
140 ("https://example.com/#anchor", "https://example.com"),
141 ("https://example.com/foo/bar", "https://example.com"),
142 (
143 "https://example.com:1234/foo/bar",
144 "https://example.com:1234",
145 ),
146 ] {
147 let url = Url::parse(url).unwrap();
148 let source = ResolvedInputSource::RemoteUrl(Box::new(url.clone()));
149 let base = Base::from_source(&source);
150 let expected = Base::Remote(Url::parse(expected).unwrap());
151 assert_eq!(base, Some(expected));
152 }
153 }
154}