use reqwest::Url;
use serde::{Deserialize, Serialize};
use std::{convert::TryFrom, path::PathBuf};
use crate::{ErrorKind, ResolvedInputSource};
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
#[allow(variant_size_differences)]
#[serde(try_from = "String")]
pub enum Base {
Local(PathBuf),
Remote(Url),
}
impl Base {
#[must_use]
pub(crate) fn join(&self, link: &str) -> Option<Url> {
match self {
Self::Remote(url) => url.join(link).ok(),
Self::Local(path) => {
let full_path = path.join(link);
Url::from_file_path(full_path).ok()
}
}
}
pub(crate) fn from_source(source: &ResolvedInputSource) -> Option<Base> {
match &source {
ResolvedInputSource::RemoteUrl(url) => {
let mut base_url = url.clone();
base_url.set_path("");
base_url.set_query(None);
base_url.set_fragment(None);
Some(Base::Remote(*base_url))
}
_ => None,
}
}
}
impl TryFrom<&str> for Base {
type Error = ErrorKind;
fn try_from(value: &str) -> Result<Self, Self::Error> {
if let Ok(url) = Url::parse(value) {
if url.cannot_be_a_base() {
return Err(ErrorKind::InvalidBase(
value.to_string(),
"The given URL cannot be used as a base URL".to_string(),
));
}
return Ok(Self::Remote(url));
}
let path = PathBuf::from(value);
if path.is_absolute() {
Ok(Self::Local(path))
} else {
Err(ErrorKind::InvalidBase(
value.to_string(),
"Base must either be a URL (with scheme) or an absolute local path".to_string(),
))
}
}
}
impl TryFrom<String> for Base {
type Error = ErrorKind;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::try_from(value.as_str())
}
}
#[cfg(test)]
mod test_base {
use crate::Result;
use super::*;
#[test]
fn test_valid_remote() -> Result<()> {
let base = Base::try_from("https://endler.dev")?;
assert_eq!(
base,
Base::Remote(Url::parse("https://endler.dev").unwrap())
);
Ok(())
}
#[test]
fn test_invalid_url() {
assert!(Base::try_from("data:text/plain,Hello?World#").is_err());
}
#[test]
fn test_valid_local_path_string_as_base() -> Result<()> {
let cases = vec!["/tmp/lychee", "/tmp/lychee/"];
for case in cases {
assert_eq!(Base::try_from(case)?, Base::Local(PathBuf::from(case)));
}
Ok(())
}
#[test]
fn test_invalid_local_path_string_as_base() {
let cases = vec!["a", "tmp/lychee/", "example.com", "../nonlocal"];
for case in cases {
assert!(Base::try_from(case).is_err());
}
}
#[test]
fn test_valid_local() -> Result<()> {
let dir = tempfile::tempdir().unwrap();
Base::try_from(dir.as_ref().to_str().unwrap())?;
Ok(())
}
#[test]
fn test_get_base_from_url() {
for (url, expected) in [
("https://example.com", "https://example.com"),
("https://example.com?query=something", "https://example.com"),
("https://example.com/#anchor", "https://example.com"),
("https://example.com/foo/bar", "https://example.com"),
(
"https://example.com:1234/foo/bar",
"https://example.com:1234",
),
] {
let url = Url::parse(url).unwrap();
let source = ResolvedInputSource::RemoteUrl(Box::new(url.clone()));
let base = Base::from_source(&source);
let expected = Base::Remote(Url::parse(expected).unwrap());
assert_eq!(base, Some(expected));
}
}
}