use crate::ErrorKind;
use glob::Pattern;
use reqwest::Url;
use serde::{Deserialize, Deserializer, Serialize};
use std::borrow::Cow;
use std::fmt::Display;
use std::path::PathBuf;
use std::result::Result;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize)]
#[non_exhaustive]
pub enum InputSource {
RemoteUrl(Box<Url>),
FsGlob {
#[serde(deserialize_with = "InputSource::deserialize_pattern")]
pattern: Pattern,
ignore_case: bool,
},
FsPath(PathBuf),
Stdin,
String(Cow<'static, str>),
}
impl InputSource {
const STDIN: &str = "-";
pub fn new(input: &str, glob_ignore_case: bool) -> Result<Self, ErrorKind> {
if input == Self::STDIN {
return Ok(InputSource::Stdin);
}
if let Ok(url) = Url::parse(input) {
return match url.scheme() {
"http" | "https" => Ok(InputSource::RemoteUrl(Box::new(url))),
_ => Err(ErrorKind::InvalidFile(PathBuf::from(input))),
};
}
let is_glob = glob::Pattern::escape(input) != input;
if is_glob {
return Ok(InputSource::FsGlob {
pattern: Pattern::new(input)?,
ignore_case: glob_ignore_case,
});
}
let path = PathBuf::from(input);
#[cfg(windows)]
if path.exists() {
Ok(InputSource::FsPath(path))
} else {
Err(ErrorKind::InvalidFile(path))
}
#[cfg(unix)]
if path.exists() {
Ok(InputSource::FsPath(path))
} else if input.starts_with('~') || input.starts_with('.') {
Err(ErrorKind::InvalidFile(path))
} else {
let url = Url::parse(&format!("http://{input}"))
.map_err(|e| ErrorKind::ParseUrl(e, "Input is not a valid URL".to_string()))?;
Ok(InputSource::RemoteUrl(Box::new(url)))
}
}
fn deserialize_pattern<'de, D>(deserializer: D) -> Result<Pattern, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Error;
let s = String::deserialize(deserializer)?;
Pattern::new(&s).map_err(D::Error::custom)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ResolvedInputSource {
RemoteUrl(Box<Url>),
FsPath(PathBuf),
Stdin,
String(Cow<'static, str>),
}
impl From<ResolvedInputSource> for InputSource {
fn from(resolved: ResolvedInputSource) -> Self {
match resolved {
ResolvedInputSource::RemoteUrl(url) => InputSource::RemoteUrl(url),
ResolvedInputSource::FsPath(path) => InputSource::FsPath(path),
ResolvedInputSource::Stdin => InputSource::Stdin,
ResolvedInputSource::String(s) => InputSource::String(s),
}
}
}
impl Display for ResolvedInputSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
Self::RemoteUrl(url) => url.as_str(),
Self::FsPath(path) => path.to_str().unwrap_or_default(),
Self::Stdin => "stdin",
Self::String(s) => s.as_ref(),
})
}
}
impl Serialize for InputSource {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.collect_str(self)
}
}
impl Display for InputSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
Self::RemoteUrl(url) => url.as_str(),
Self::FsGlob { pattern, .. } => pattern.as_str(),
Self::FsPath(path) => path.to_str().unwrap_or_default(),
Self::Stdin => "stdin",
Self::String(s) => s.as_ref(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pattern_serialization_is_original_pattern() {
let pat = "asd[f]*";
assert_eq!(
serde_json::to_string(&InputSource::FsGlob {
pattern: Pattern::new(pat).unwrap(),
ignore_case: false,
})
.unwrap(),
serde_json::to_string(pat).unwrap(),
);
}
}